GNU Linux-libre 6.8.9-gnu
[releases.git] / net / xfrm / xfrm_policy.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * xfrm_policy.c
4  *
5  * Changes:
6  *      Mitsuru KANDA @USAGI
7  *      Kazunori MIYAZAWA @USAGI
8  *      Kunihiro Ishiguro <kunihiro@ipinfusion.com>
9  *              IPv6 support
10  *      Kazunori MIYAZAWA @USAGI
11  *      YOSHIFUJI Hideaki
12  *              Split up af-specific portion
13  *      Derek Atkins <derek@ihtfp.com>          Add the post_input processor
14  *
15  */
16
17 #include <linux/err.h>
18 #include <linux/slab.h>
19 #include <linux/kmod.h>
20 #include <linux/list.h>
21 #include <linux/spinlock.h>
22 #include <linux/workqueue.h>
23 #include <linux/notifier.h>
24 #include <linux/netdevice.h>
25 #include <linux/netfilter.h>
26 #include <linux/module.h>
27 #include <linux/cache.h>
28 #include <linux/cpu.h>
29 #include <linux/audit.h>
30 #include <linux/rhashtable.h>
31 #include <linux/if_tunnel.h>
32 #include <net/dst.h>
33 #include <net/flow.h>
34 #include <net/inet_ecn.h>
35 #include <net/xfrm.h>
36 #include <net/ip.h>
37 #include <net/gre.h>
38 #if IS_ENABLED(CONFIG_IPV6_MIP6)
39 #include <net/mip6.h>
40 #endif
41 #ifdef CONFIG_XFRM_STATISTICS
42 #include <net/snmp.h>
43 #endif
44 #ifdef CONFIG_XFRM_ESPINTCP
45 #include <net/espintcp.h>
46 #endif
47
48 #include "xfrm_hash.h"
49
50 #define XFRM_QUEUE_TMO_MIN ((unsigned)(HZ/10))
51 #define XFRM_QUEUE_TMO_MAX ((unsigned)(60*HZ))
52 #define XFRM_MAX_QUEUE_LEN      100
53
54 struct xfrm_flo {
55         struct dst_entry *dst_orig;
56         u8 flags;
57 };
58
59 /* prefixes smaller than this are stored in lists, not trees. */
60 #define INEXACT_PREFIXLEN_IPV4  16
61 #define INEXACT_PREFIXLEN_IPV6  48
62
63 struct xfrm_pol_inexact_node {
64         struct rb_node node;
65         union {
66                 xfrm_address_t addr;
67                 struct rcu_head rcu;
68         };
69         u8 prefixlen;
70
71         struct rb_root root;
72
73         /* the policies matching this node, can be empty list */
74         struct hlist_head hhead;
75 };
76
77 /* xfrm inexact policy search tree:
78  * xfrm_pol_inexact_bin = hash(dir,type,family,if_id);
79  *  |
80  * +---- root_d: sorted by daddr:prefix
81  * |                 |
82  * |        xfrm_pol_inexact_node
83  * |                 |
84  * |                 +- root: sorted by saddr/prefix
85  * |                 |              |
86  * |                 |         xfrm_pol_inexact_node
87  * |                 |              |
88  * |                 |              + root: unused
89  * |                 |              |
90  * |                 |              + hhead: saddr:daddr policies
91  * |                 |
92  * |                 +- coarse policies and all any:daddr policies
93  * |
94  * +---- root_s: sorted by saddr:prefix
95  * |                 |
96  * |        xfrm_pol_inexact_node
97  * |                 |
98  * |                 + root: unused
99  * |                 |
100  * |                 + hhead: saddr:any policies
101  * |
102  * +---- coarse policies and all any:any policies
103  *
104  * Lookups return four candidate lists:
105  * 1. any:any list from top-level xfrm_pol_inexact_bin
106  * 2. any:daddr list from daddr tree
107  * 3. saddr:daddr list from 2nd level daddr tree
108  * 4. saddr:any list from saddr tree
109  *
110  * This result set then needs to be searched for the policy with
111  * the lowest priority.  If two results have same prio, youngest one wins.
112  */
113
114 struct xfrm_pol_inexact_key {
115         possible_net_t net;
116         u32 if_id;
117         u16 family;
118         u8 dir, type;
119 };
120
121 struct xfrm_pol_inexact_bin {
122         struct xfrm_pol_inexact_key k;
123         struct rhash_head head;
124         /* list containing '*:*' policies */
125         struct hlist_head hhead;
126
127         seqcount_spinlock_t count;
128         /* tree sorted by daddr/prefix */
129         struct rb_root root_d;
130
131         /* tree sorted by saddr/prefix */
132         struct rb_root root_s;
133
134         /* slow path below */
135         struct list_head inexact_bins;
136         struct rcu_head rcu;
137 };
138
139 enum xfrm_pol_inexact_candidate_type {
140         XFRM_POL_CAND_BOTH,
141         XFRM_POL_CAND_SADDR,
142         XFRM_POL_CAND_DADDR,
143         XFRM_POL_CAND_ANY,
144
145         XFRM_POL_CAND_MAX,
146 };
147
148 struct xfrm_pol_inexact_candidates {
149         struct hlist_head *res[XFRM_POL_CAND_MAX];
150 };
151
152 struct xfrm_flow_keys {
153         struct flow_dissector_key_basic basic;
154         struct flow_dissector_key_control control;
155         union {
156                 struct flow_dissector_key_ipv4_addrs ipv4;
157                 struct flow_dissector_key_ipv6_addrs ipv6;
158         } addrs;
159         struct flow_dissector_key_ip ip;
160         struct flow_dissector_key_icmp icmp;
161         struct flow_dissector_key_ports ports;
162         struct flow_dissector_key_keyid gre;
163 };
164
165 static struct flow_dissector xfrm_session_dissector __ro_after_init;
166
167 static DEFINE_SPINLOCK(xfrm_if_cb_lock);
168 static struct xfrm_if_cb const __rcu *xfrm_if_cb __read_mostly;
169
170 static DEFINE_SPINLOCK(xfrm_policy_afinfo_lock);
171 static struct xfrm_policy_afinfo const __rcu *xfrm_policy_afinfo[AF_INET6 + 1]
172                                                 __read_mostly;
173
174 static struct kmem_cache *xfrm_dst_cache __ro_after_init;
175
176 static struct rhashtable xfrm_policy_inexact_table;
177 static const struct rhashtable_params xfrm_pol_inexact_params;
178
179 static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr);
180 static int stale_bundle(struct dst_entry *dst);
181 static int xfrm_bundle_ok(struct xfrm_dst *xdst);
182 static void xfrm_policy_queue_process(struct timer_list *t);
183
184 static void __xfrm_policy_link(struct xfrm_policy *pol, int dir);
185 static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
186                                                 int dir);
187
188 static struct xfrm_pol_inexact_bin *
189 xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family, u8 dir,
190                            u32 if_id);
191
192 static struct xfrm_pol_inexact_bin *
193 xfrm_policy_inexact_lookup_rcu(struct net *net,
194                                u8 type, u16 family, u8 dir, u32 if_id);
195 static struct xfrm_policy *
196 xfrm_policy_insert_list(struct hlist_head *chain, struct xfrm_policy *policy,
197                         bool excl);
198 static void xfrm_policy_insert_inexact_list(struct hlist_head *chain,
199                                             struct xfrm_policy *policy);
200
201 static bool
202 xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand,
203                                     struct xfrm_pol_inexact_bin *b,
204                                     const xfrm_address_t *saddr,
205                                     const xfrm_address_t *daddr);
206
207 static inline bool xfrm_pol_hold_rcu(struct xfrm_policy *policy)
208 {
209         return refcount_inc_not_zero(&policy->refcnt);
210 }
211
212 static inline bool
213 __xfrm4_selector_match(const struct xfrm_selector *sel, const struct flowi *fl)
214 {
215         const struct flowi4 *fl4 = &fl->u.ip4;
216
217         return  addr4_match(fl4->daddr, sel->daddr.a4, sel->prefixlen_d) &&
218                 addr4_match(fl4->saddr, sel->saddr.a4, sel->prefixlen_s) &&
219                 !((xfrm_flowi_dport(fl, &fl4->uli) ^ sel->dport) & sel->dport_mask) &&
220                 !((xfrm_flowi_sport(fl, &fl4->uli) ^ sel->sport) & sel->sport_mask) &&
221                 (fl4->flowi4_proto == sel->proto || !sel->proto) &&
222                 (fl4->flowi4_oif == sel->ifindex || !sel->ifindex);
223 }
224
225 static inline bool
226 __xfrm6_selector_match(const struct xfrm_selector *sel, const struct flowi *fl)
227 {
228         const struct flowi6 *fl6 = &fl->u.ip6;
229
230         return  addr_match(&fl6->daddr, &sel->daddr, sel->prefixlen_d) &&
231                 addr_match(&fl6->saddr, &sel->saddr, sel->prefixlen_s) &&
232                 !((xfrm_flowi_dport(fl, &fl6->uli) ^ sel->dport) & sel->dport_mask) &&
233                 !((xfrm_flowi_sport(fl, &fl6->uli) ^ sel->sport) & sel->sport_mask) &&
234                 (fl6->flowi6_proto == sel->proto || !sel->proto) &&
235                 (fl6->flowi6_oif == sel->ifindex || !sel->ifindex);
236 }
237
238 bool xfrm_selector_match(const struct xfrm_selector *sel, const struct flowi *fl,
239                          unsigned short family)
240 {
241         switch (family) {
242         case AF_INET:
243                 return __xfrm4_selector_match(sel, fl);
244         case AF_INET6:
245                 return __xfrm6_selector_match(sel, fl);
246         }
247         return false;
248 }
249
250 static const struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family)
251 {
252         const struct xfrm_policy_afinfo *afinfo;
253
254         if (unlikely(family >= ARRAY_SIZE(xfrm_policy_afinfo)))
255                 return NULL;
256         rcu_read_lock();
257         afinfo = rcu_dereference(xfrm_policy_afinfo[family]);
258         if (unlikely(!afinfo))
259                 rcu_read_unlock();
260         return afinfo;
261 }
262
263 /* Called with rcu_read_lock(). */
264 static const struct xfrm_if_cb *xfrm_if_get_cb(void)
265 {
266         return rcu_dereference(xfrm_if_cb);
267 }
268
269 struct dst_entry *__xfrm_dst_lookup(struct net *net, int tos, int oif,
270                                     const xfrm_address_t *saddr,
271                                     const xfrm_address_t *daddr,
272                                     int family, u32 mark)
273 {
274         const struct xfrm_policy_afinfo *afinfo;
275         struct dst_entry *dst;
276
277         afinfo = xfrm_policy_get_afinfo(family);
278         if (unlikely(afinfo == NULL))
279                 return ERR_PTR(-EAFNOSUPPORT);
280
281         dst = afinfo->dst_lookup(net, tos, oif, saddr, daddr, mark);
282
283         rcu_read_unlock();
284
285         return dst;
286 }
287 EXPORT_SYMBOL(__xfrm_dst_lookup);
288
289 static inline struct dst_entry *xfrm_dst_lookup(struct xfrm_state *x,
290                                                 int tos, int oif,
291                                                 xfrm_address_t *prev_saddr,
292                                                 xfrm_address_t *prev_daddr,
293                                                 int family, u32 mark)
294 {
295         struct net *net = xs_net(x);
296         xfrm_address_t *saddr = &x->props.saddr;
297         xfrm_address_t *daddr = &x->id.daddr;
298         struct dst_entry *dst;
299
300         if (x->type->flags & XFRM_TYPE_LOCAL_COADDR) {
301                 saddr = x->coaddr;
302                 daddr = prev_daddr;
303         }
304         if (x->type->flags & XFRM_TYPE_REMOTE_COADDR) {
305                 saddr = prev_saddr;
306                 daddr = x->coaddr;
307         }
308
309         dst = __xfrm_dst_lookup(net, tos, oif, saddr, daddr, family, mark);
310
311         if (!IS_ERR(dst)) {
312                 if (prev_saddr != saddr)
313                         memcpy(prev_saddr, saddr,  sizeof(*prev_saddr));
314                 if (prev_daddr != daddr)
315                         memcpy(prev_daddr, daddr,  sizeof(*prev_daddr));
316         }
317
318         return dst;
319 }
320
321 static inline unsigned long make_jiffies(long secs)
322 {
323         if (secs >= (MAX_SCHEDULE_TIMEOUT-1)/HZ)
324                 return MAX_SCHEDULE_TIMEOUT-1;
325         else
326                 return secs*HZ;
327 }
328
329 static void xfrm_policy_timer(struct timer_list *t)
330 {
331         struct xfrm_policy *xp = from_timer(xp, t, timer);
332         time64_t now = ktime_get_real_seconds();
333         time64_t next = TIME64_MAX;
334         int warn = 0;
335         int dir;
336
337         read_lock(&xp->lock);
338
339         if (unlikely(xp->walk.dead))
340                 goto out;
341
342         dir = xfrm_policy_id2dir(xp->index);
343
344         if (xp->lft.hard_add_expires_seconds) {
345                 time64_t tmo = xp->lft.hard_add_expires_seconds +
346                         xp->curlft.add_time - now;
347                 if (tmo <= 0)
348                         goto expired;
349                 if (tmo < next)
350                         next = tmo;
351         }
352         if (xp->lft.hard_use_expires_seconds) {
353                 time64_t tmo = xp->lft.hard_use_expires_seconds +
354                         (READ_ONCE(xp->curlft.use_time) ? : xp->curlft.add_time) - now;
355                 if (tmo <= 0)
356                         goto expired;
357                 if (tmo < next)
358                         next = tmo;
359         }
360         if (xp->lft.soft_add_expires_seconds) {
361                 time64_t tmo = xp->lft.soft_add_expires_seconds +
362                         xp->curlft.add_time - now;
363                 if (tmo <= 0) {
364                         warn = 1;
365                         tmo = XFRM_KM_TIMEOUT;
366                 }
367                 if (tmo < next)
368                         next = tmo;
369         }
370         if (xp->lft.soft_use_expires_seconds) {
371                 time64_t tmo = xp->lft.soft_use_expires_seconds +
372                         (READ_ONCE(xp->curlft.use_time) ? : xp->curlft.add_time) - now;
373                 if (tmo <= 0) {
374                         warn = 1;
375                         tmo = XFRM_KM_TIMEOUT;
376                 }
377                 if (tmo < next)
378                         next = tmo;
379         }
380
381         if (warn)
382                 km_policy_expired(xp, dir, 0, 0);
383         if (next != TIME64_MAX &&
384             !mod_timer(&xp->timer, jiffies + make_jiffies(next)))
385                 xfrm_pol_hold(xp);
386
387 out:
388         read_unlock(&xp->lock);
389         xfrm_pol_put(xp);
390         return;
391
392 expired:
393         read_unlock(&xp->lock);
394         if (!xfrm_policy_delete(xp, dir))
395                 km_policy_expired(xp, dir, 1, 0);
396         xfrm_pol_put(xp);
397 }
398
399 /* Allocate xfrm_policy. Not used here, it is supposed to be used by pfkeyv2
400  * SPD calls.
401  */
402
403 struct xfrm_policy *xfrm_policy_alloc(struct net *net, gfp_t gfp)
404 {
405         struct xfrm_policy *policy;
406
407         policy = kzalloc(sizeof(struct xfrm_policy), gfp);
408
409         if (policy) {
410                 write_pnet(&policy->xp_net, net);
411                 INIT_LIST_HEAD(&policy->walk.all);
412                 INIT_HLIST_NODE(&policy->bydst_inexact_list);
413                 INIT_HLIST_NODE(&policy->bydst);
414                 INIT_HLIST_NODE(&policy->byidx);
415                 rwlock_init(&policy->lock);
416                 refcount_set(&policy->refcnt, 1);
417                 skb_queue_head_init(&policy->polq.hold_queue);
418                 timer_setup(&policy->timer, xfrm_policy_timer, 0);
419                 timer_setup(&policy->polq.hold_timer,
420                             xfrm_policy_queue_process, 0);
421         }
422         return policy;
423 }
424 EXPORT_SYMBOL(xfrm_policy_alloc);
425
426 static void xfrm_policy_destroy_rcu(struct rcu_head *head)
427 {
428         struct xfrm_policy *policy = container_of(head, struct xfrm_policy, rcu);
429
430         security_xfrm_policy_free(policy->security);
431         kfree(policy);
432 }
433
434 /* Destroy xfrm_policy: descendant resources must be released to this moment. */
435
436 void xfrm_policy_destroy(struct xfrm_policy *policy)
437 {
438         BUG_ON(!policy->walk.dead);
439
440         if (del_timer(&policy->timer) || del_timer(&policy->polq.hold_timer))
441                 BUG();
442
443         xfrm_dev_policy_free(policy);
444         call_rcu(&policy->rcu, xfrm_policy_destroy_rcu);
445 }
446 EXPORT_SYMBOL(xfrm_policy_destroy);
447
448 /* Rule must be locked. Release descendant resources, announce
449  * entry dead. The rule must be unlinked from lists to the moment.
450  */
451
452 static void xfrm_policy_kill(struct xfrm_policy *policy)
453 {
454         write_lock_bh(&policy->lock);
455         policy->walk.dead = 1;
456         write_unlock_bh(&policy->lock);
457
458         atomic_inc(&policy->genid);
459
460         if (del_timer(&policy->polq.hold_timer))
461                 xfrm_pol_put(policy);
462         skb_queue_purge(&policy->polq.hold_queue);
463
464         if (del_timer(&policy->timer))
465                 xfrm_pol_put(policy);
466
467         xfrm_pol_put(policy);
468 }
469
470 static unsigned int xfrm_policy_hashmax __read_mostly = 1 * 1024 * 1024;
471
472 static inline unsigned int idx_hash(struct net *net, u32 index)
473 {
474         return __idx_hash(index, net->xfrm.policy_idx_hmask);
475 }
476
477 /* calculate policy hash thresholds */
478 static void __get_hash_thresh(struct net *net,
479                               unsigned short family, int dir,
480                               u8 *dbits, u8 *sbits)
481 {
482         switch (family) {
483         case AF_INET:
484                 *dbits = net->xfrm.policy_bydst[dir].dbits4;
485                 *sbits = net->xfrm.policy_bydst[dir].sbits4;
486                 break;
487
488         case AF_INET6:
489                 *dbits = net->xfrm.policy_bydst[dir].dbits6;
490                 *sbits = net->xfrm.policy_bydst[dir].sbits6;
491                 break;
492
493         default:
494                 *dbits = 0;
495                 *sbits = 0;
496         }
497 }
498
499 static struct hlist_head *policy_hash_bysel(struct net *net,
500                                             const struct xfrm_selector *sel,
501                                             unsigned short family, int dir)
502 {
503         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
504         unsigned int hash;
505         u8 dbits;
506         u8 sbits;
507
508         __get_hash_thresh(net, family, dir, &dbits, &sbits);
509         hash = __sel_hash(sel, family, hmask, dbits, sbits);
510
511         if (hash == hmask + 1)
512                 return NULL;
513
514         return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
515                      lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
516 }
517
518 static struct hlist_head *policy_hash_direct(struct net *net,
519                                              const xfrm_address_t *daddr,
520                                              const xfrm_address_t *saddr,
521                                              unsigned short family, int dir)
522 {
523         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
524         unsigned int hash;
525         u8 dbits;
526         u8 sbits;
527
528         __get_hash_thresh(net, family, dir, &dbits, &sbits);
529         hash = __addr_hash(daddr, saddr, family, hmask, dbits, sbits);
530
531         return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
532                      lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
533 }
534
535 static void xfrm_dst_hash_transfer(struct net *net,
536                                    struct hlist_head *list,
537                                    struct hlist_head *ndsttable,
538                                    unsigned int nhashmask,
539                                    int dir)
540 {
541         struct hlist_node *tmp, *entry0 = NULL;
542         struct xfrm_policy *pol;
543         unsigned int h0 = 0;
544         u8 dbits;
545         u8 sbits;
546
547 redo:
548         hlist_for_each_entry_safe(pol, tmp, list, bydst) {
549                 unsigned int h;
550
551                 __get_hash_thresh(net, pol->family, dir, &dbits, &sbits);
552                 h = __addr_hash(&pol->selector.daddr, &pol->selector.saddr,
553                                 pol->family, nhashmask, dbits, sbits);
554                 if (!entry0 || pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) {
555                         hlist_del_rcu(&pol->bydst);
556                         hlist_add_head_rcu(&pol->bydst, ndsttable + h);
557                         h0 = h;
558                 } else {
559                         if (h != h0)
560                                 continue;
561                         hlist_del_rcu(&pol->bydst);
562                         hlist_add_behind_rcu(&pol->bydst, entry0);
563                 }
564                 entry0 = &pol->bydst;
565         }
566         if (!hlist_empty(list)) {
567                 entry0 = NULL;
568                 goto redo;
569         }
570 }
571
572 static void xfrm_idx_hash_transfer(struct hlist_head *list,
573                                    struct hlist_head *nidxtable,
574                                    unsigned int nhashmask)
575 {
576         struct hlist_node *tmp;
577         struct xfrm_policy *pol;
578
579         hlist_for_each_entry_safe(pol, tmp, list, byidx) {
580                 unsigned int h;
581
582                 h = __idx_hash(pol->index, nhashmask);
583                 hlist_add_head(&pol->byidx, nidxtable+h);
584         }
585 }
586
587 static unsigned long xfrm_new_hash_mask(unsigned int old_hmask)
588 {
589         return ((old_hmask + 1) << 1) - 1;
590 }
591
592 static void xfrm_bydst_resize(struct net *net, int dir)
593 {
594         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
595         unsigned int nhashmask = xfrm_new_hash_mask(hmask);
596         unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head);
597         struct hlist_head *ndst = xfrm_hash_alloc(nsize);
598         struct hlist_head *odst;
599         int i;
600
601         if (!ndst)
602                 return;
603
604         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
605         write_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation);
606
607         odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table,
608                                 lockdep_is_held(&net->xfrm.xfrm_policy_lock));
609
610         for (i = hmask; i >= 0; i--)
611                 xfrm_dst_hash_transfer(net, odst + i, ndst, nhashmask, dir);
612
613         rcu_assign_pointer(net->xfrm.policy_bydst[dir].table, ndst);
614         net->xfrm.policy_bydst[dir].hmask = nhashmask;
615
616         write_seqcount_end(&net->xfrm.xfrm_policy_hash_generation);
617         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
618
619         synchronize_rcu();
620
621         xfrm_hash_free(odst, (hmask + 1) * sizeof(struct hlist_head));
622 }
623
624 static void xfrm_byidx_resize(struct net *net)
625 {
626         unsigned int hmask = net->xfrm.policy_idx_hmask;
627         unsigned int nhashmask = xfrm_new_hash_mask(hmask);
628         unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head);
629         struct hlist_head *oidx = net->xfrm.policy_byidx;
630         struct hlist_head *nidx = xfrm_hash_alloc(nsize);
631         int i;
632
633         if (!nidx)
634                 return;
635
636         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
637
638         for (i = hmask; i >= 0; i--)
639                 xfrm_idx_hash_transfer(oidx + i, nidx, nhashmask);
640
641         net->xfrm.policy_byidx = nidx;
642         net->xfrm.policy_idx_hmask = nhashmask;
643
644         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
645
646         xfrm_hash_free(oidx, (hmask + 1) * sizeof(struct hlist_head));
647 }
648
649 static inline int xfrm_bydst_should_resize(struct net *net, int dir, int *total)
650 {
651         unsigned int cnt = net->xfrm.policy_count[dir];
652         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
653
654         if (total)
655                 *total += cnt;
656
657         if ((hmask + 1) < xfrm_policy_hashmax &&
658             cnt > hmask)
659                 return 1;
660
661         return 0;
662 }
663
664 static inline int xfrm_byidx_should_resize(struct net *net, int total)
665 {
666         unsigned int hmask = net->xfrm.policy_idx_hmask;
667
668         if ((hmask + 1) < xfrm_policy_hashmax &&
669             total > hmask)
670                 return 1;
671
672         return 0;
673 }
674
675 void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si)
676 {
677         si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN];
678         si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT];
679         si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD];
680         si->inscnt = net->xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX];
681         si->outscnt = net->xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX];
682         si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX];
683         si->spdhcnt = net->xfrm.policy_idx_hmask;
684         si->spdhmcnt = xfrm_policy_hashmax;
685 }
686 EXPORT_SYMBOL(xfrm_spd_getinfo);
687
688 static DEFINE_MUTEX(hash_resize_mutex);
689 static void xfrm_hash_resize(struct work_struct *work)
690 {
691         struct net *net = container_of(work, struct net, xfrm.policy_hash_work);
692         int dir, total;
693
694         mutex_lock(&hash_resize_mutex);
695
696         total = 0;
697         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
698                 if (xfrm_bydst_should_resize(net, dir, &total))
699                         xfrm_bydst_resize(net, dir);
700         }
701         if (xfrm_byidx_should_resize(net, total))
702                 xfrm_byidx_resize(net);
703
704         mutex_unlock(&hash_resize_mutex);
705 }
706
707 /* Make sure *pol can be inserted into fastbin.
708  * Useful to check that later insert requests will be successful
709  * (provided xfrm_policy_lock is held throughout).
710  */
711 static struct xfrm_pol_inexact_bin *
712 xfrm_policy_inexact_alloc_bin(const struct xfrm_policy *pol, u8 dir)
713 {
714         struct xfrm_pol_inexact_bin *bin, *prev;
715         struct xfrm_pol_inexact_key k = {
716                 .family = pol->family,
717                 .type = pol->type,
718                 .dir = dir,
719                 .if_id = pol->if_id,
720         };
721         struct net *net = xp_net(pol);
722
723         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
724
725         write_pnet(&k.net, net);
726         bin = rhashtable_lookup_fast(&xfrm_policy_inexact_table, &k,
727                                      xfrm_pol_inexact_params);
728         if (bin)
729                 return bin;
730
731         bin = kzalloc(sizeof(*bin), GFP_ATOMIC);
732         if (!bin)
733                 return NULL;
734
735         bin->k = k;
736         INIT_HLIST_HEAD(&bin->hhead);
737         bin->root_d = RB_ROOT;
738         bin->root_s = RB_ROOT;
739         seqcount_spinlock_init(&bin->count, &net->xfrm.xfrm_policy_lock);
740
741         prev = rhashtable_lookup_get_insert_key(&xfrm_policy_inexact_table,
742                                                 &bin->k, &bin->head,
743                                                 xfrm_pol_inexact_params);
744         if (!prev) {
745                 list_add(&bin->inexact_bins, &net->xfrm.inexact_bins);
746                 return bin;
747         }
748
749         kfree(bin);
750
751         return IS_ERR(prev) ? NULL : prev;
752 }
753
754 static bool xfrm_pol_inexact_addr_use_any_list(const xfrm_address_t *addr,
755                                                int family, u8 prefixlen)
756 {
757         if (xfrm_addr_any(addr, family))
758                 return true;
759
760         if (family == AF_INET6 && prefixlen < INEXACT_PREFIXLEN_IPV6)
761                 return true;
762
763         if (family == AF_INET && prefixlen < INEXACT_PREFIXLEN_IPV4)
764                 return true;
765
766         return false;
767 }
768
769 static bool
770 xfrm_policy_inexact_insert_use_any_list(const struct xfrm_policy *policy)
771 {
772         const xfrm_address_t *addr;
773         bool saddr_any, daddr_any;
774         u8 prefixlen;
775
776         addr = &policy->selector.saddr;
777         prefixlen = policy->selector.prefixlen_s;
778
779         saddr_any = xfrm_pol_inexact_addr_use_any_list(addr,
780                                                        policy->family,
781                                                        prefixlen);
782         addr = &policy->selector.daddr;
783         prefixlen = policy->selector.prefixlen_d;
784         daddr_any = xfrm_pol_inexact_addr_use_any_list(addr,
785                                                        policy->family,
786                                                        prefixlen);
787         return saddr_any && daddr_any;
788 }
789
790 static void xfrm_pol_inexact_node_init(struct xfrm_pol_inexact_node *node,
791                                        const xfrm_address_t *addr, u8 prefixlen)
792 {
793         node->addr = *addr;
794         node->prefixlen = prefixlen;
795 }
796
797 static struct xfrm_pol_inexact_node *
798 xfrm_pol_inexact_node_alloc(const xfrm_address_t *addr, u8 prefixlen)
799 {
800         struct xfrm_pol_inexact_node *node;
801
802         node = kzalloc(sizeof(*node), GFP_ATOMIC);
803         if (node)
804                 xfrm_pol_inexact_node_init(node, addr, prefixlen);
805
806         return node;
807 }
808
809 static int xfrm_policy_addr_delta(const xfrm_address_t *a,
810                                   const xfrm_address_t *b,
811                                   u8 prefixlen, u16 family)
812 {
813         u32 ma, mb, mask;
814         unsigned int pdw, pbi;
815         int delta = 0;
816
817         switch (family) {
818         case AF_INET:
819                 if (prefixlen == 0)
820                         return 0;
821                 mask = ~0U << (32 - prefixlen);
822                 ma = ntohl(a->a4) & mask;
823                 mb = ntohl(b->a4) & mask;
824                 if (ma < mb)
825                         delta = -1;
826                 else if (ma > mb)
827                         delta = 1;
828                 break;
829         case AF_INET6:
830                 pdw = prefixlen >> 5;
831                 pbi = prefixlen & 0x1f;
832
833                 if (pdw) {
834                         delta = memcmp(a->a6, b->a6, pdw << 2);
835                         if (delta)
836                                 return delta;
837                 }
838                 if (pbi) {
839                         mask = ~0U << (32 - pbi);
840                         ma = ntohl(a->a6[pdw]) & mask;
841                         mb = ntohl(b->a6[pdw]) & mask;
842                         if (ma < mb)
843                                 delta = -1;
844                         else if (ma > mb)
845                                 delta = 1;
846                 }
847                 break;
848         default:
849                 break;
850         }
851
852         return delta;
853 }
854
855 static void xfrm_policy_inexact_list_reinsert(struct net *net,
856                                               struct xfrm_pol_inexact_node *n,
857                                               u16 family)
858 {
859         unsigned int matched_s, matched_d;
860         struct xfrm_policy *policy, *p;
861
862         matched_s = 0;
863         matched_d = 0;
864
865         list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) {
866                 struct hlist_node *newpos = NULL;
867                 bool matches_s, matches_d;
868
869                 if (policy->walk.dead || !policy->bydst_reinsert)
870                         continue;
871
872                 WARN_ON_ONCE(policy->family != family);
873
874                 policy->bydst_reinsert = false;
875                 hlist_for_each_entry(p, &n->hhead, bydst) {
876                         if (policy->priority > p->priority)
877                                 newpos = &p->bydst;
878                         else if (policy->priority == p->priority &&
879                                  policy->pos > p->pos)
880                                 newpos = &p->bydst;
881                         else
882                                 break;
883                 }
884
885                 if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
886                         hlist_add_behind_rcu(&policy->bydst, newpos);
887                 else
888                         hlist_add_head_rcu(&policy->bydst, &n->hhead);
889
890                 /* paranoia checks follow.
891                  * Check that the reinserted policy matches at least
892                  * saddr or daddr for current node prefix.
893                  *
894                  * Matching both is fine, matching saddr in one policy
895                  * (but not daddr) and then matching only daddr in another
896                  * is a bug.
897                  */
898                 matches_s = xfrm_policy_addr_delta(&policy->selector.saddr,
899                                                    &n->addr,
900                                                    n->prefixlen,
901                                                    family) == 0;
902                 matches_d = xfrm_policy_addr_delta(&policy->selector.daddr,
903                                                    &n->addr,
904                                                    n->prefixlen,
905                                                    family) == 0;
906                 if (matches_s && matches_d)
907                         continue;
908
909                 WARN_ON_ONCE(!matches_s && !matches_d);
910                 if (matches_s)
911                         matched_s++;
912                 if (matches_d)
913                         matched_d++;
914                 WARN_ON_ONCE(matched_s && matched_d);
915         }
916 }
917
918 static void xfrm_policy_inexact_node_reinsert(struct net *net,
919                                               struct xfrm_pol_inexact_node *n,
920                                               struct rb_root *new,
921                                               u16 family)
922 {
923         struct xfrm_pol_inexact_node *node;
924         struct rb_node **p, *parent;
925
926         /* we should not have another subtree here */
927         WARN_ON_ONCE(!RB_EMPTY_ROOT(&n->root));
928 restart:
929         parent = NULL;
930         p = &new->rb_node;
931         while (*p) {
932                 u8 prefixlen;
933                 int delta;
934
935                 parent = *p;
936                 node = rb_entry(*p, struct xfrm_pol_inexact_node, node);
937
938                 prefixlen = min(node->prefixlen, n->prefixlen);
939
940                 delta = xfrm_policy_addr_delta(&n->addr, &node->addr,
941                                                prefixlen, family);
942                 if (delta < 0) {
943                         p = &parent->rb_left;
944                 } else if (delta > 0) {
945                         p = &parent->rb_right;
946                 } else {
947                         bool same_prefixlen = node->prefixlen == n->prefixlen;
948                         struct xfrm_policy *tmp;
949
950                         hlist_for_each_entry(tmp, &n->hhead, bydst) {
951                                 tmp->bydst_reinsert = true;
952                                 hlist_del_rcu(&tmp->bydst);
953                         }
954
955                         node->prefixlen = prefixlen;
956
957                         xfrm_policy_inexact_list_reinsert(net, node, family);
958
959                         if (same_prefixlen) {
960                                 kfree_rcu(n, rcu);
961                                 return;
962                         }
963
964                         rb_erase(*p, new);
965                         kfree_rcu(n, rcu);
966                         n = node;
967                         goto restart;
968                 }
969         }
970
971         rb_link_node_rcu(&n->node, parent, p);
972         rb_insert_color(&n->node, new);
973 }
974
975 /* merge nodes v and n */
976 static void xfrm_policy_inexact_node_merge(struct net *net,
977                                            struct xfrm_pol_inexact_node *v,
978                                            struct xfrm_pol_inexact_node *n,
979                                            u16 family)
980 {
981         struct xfrm_pol_inexact_node *node;
982         struct xfrm_policy *tmp;
983         struct rb_node *rnode;
984
985         /* To-be-merged node v has a subtree.
986          *
987          * Dismantle it and insert its nodes to n->root.
988          */
989         while ((rnode = rb_first(&v->root)) != NULL) {
990                 node = rb_entry(rnode, struct xfrm_pol_inexact_node, node);
991                 rb_erase(&node->node, &v->root);
992                 xfrm_policy_inexact_node_reinsert(net, node, &n->root,
993                                                   family);
994         }
995
996         hlist_for_each_entry(tmp, &v->hhead, bydst) {
997                 tmp->bydst_reinsert = true;
998                 hlist_del_rcu(&tmp->bydst);
999         }
1000
1001         xfrm_policy_inexact_list_reinsert(net, n, family);
1002 }
1003
1004 static struct xfrm_pol_inexact_node *
1005 xfrm_policy_inexact_insert_node(struct net *net,
1006                                 struct rb_root *root,
1007                                 xfrm_address_t *addr,
1008                                 u16 family, u8 prefixlen, u8 dir)
1009 {
1010         struct xfrm_pol_inexact_node *cached = NULL;
1011         struct rb_node **p, *parent = NULL;
1012         struct xfrm_pol_inexact_node *node;
1013
1014         p = &root->rb_node;
1015         while (*p) {
1016                 int delta;
1017
1018                 parent = *p;
1019                 node = rb_entry(*p, struct xfrm_pol_inexact_node, node);
1020
1021                 delta = xfrm_policy_addr_delta(addr, &node->addr,
1022                                                node->prefixlen,
1023                                                family);
1024                 if (delta == 0 && prefixlen >= node->prefixlen) {
1025                         WARN_ON_ONCE(cached); /* ipsec policies got lost */
1026                         return node;
1027                 }
1028
1029                 if (delta < 0)
1030                         p = &parent->rb_left;
1031                 else
1032                         p = &parent->rb_right;
1033
1034                 if (prefixlen < node->prefixlen) {
1035                         delta = xfrm_policy_addr_delta(addr, &node->addr,
1036                                                        prefixlen,
1037                                                        family);
1038                         if (delta)
1039                                 continue;
1040
1041                         /* This node is a subnet of the new prefix. It needs
1042                          * to be removed and re-inserted with the smaller
1043                          * prefix and all nodes that are now also covered
1044                          * by the reduced prefixlen.
1045                          */
1046                         rb_erase(&node->node, root);
1047
1048                         if (!cached) {
1049                                 xfrm_pol_inexact_node_init(node, addr,
1050                                                            prefixlen);
1051                                 cached = node;
1052                         } else {
1053                                 /* This node also falls within the new
1054                                  * prefixlen. Merge the to-be-reinserted
1055                                  * node and this one.
1056                                  */
1057                                 xfrm_policy_inexact_node_merge(net, node,
1058                                                                cached, family);
1059                                 kfree_rcu(node, rcu);
1060                         }
1061
1062                         /* restart */
1063                         p = &root->rb_node;
1064                         parent = NULL;
1065                 }
1066         }
1067
1068         node = cached;
1069         if (!node) {
1070                 node = xfrm_pol_inexact_node_alloc(addr, prefixlen);
1071                 if (!node)
1072                         return NULL;
1073         }
1074
1075         rb_link_node_rcu(&node->node, parent, p);
1076         rb_insert_color(&node->node, root);
1077
1078         return node;
1079 }
1080
1081 static void xfrm_policy_inexact_gc_tree(struct rb_root *r, bool rm)
1082 {
1083         struct xfrm_pol_inexact_node *node;
1084         struct rb_node *rn = rb_first(r);
1085
1086         while (rn) {
1087                 node = rb_entry(rn, struct xfrm_pol_inexact_node, node);
1088
1089                 xfrm_policy_inexact_gc_tree(&node->root, rm);
1090                 rn = rb_next(rn);
1091
1092                 if (!hlist_empty(&node->hhead) || !RB_EMPTY_ROOT(&node->root)) {
1093                         WARN_ON_ONCE(rm);
1094                         continue;
1095                 }
1096
1097                 rb_erase(&node->node, r);
1098                 kfree_rcu(node, rcu);
1099         }
1100 }
1101
1102 static void __xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b, bool net_exit)
1103 {
1104         write_seqcount_begin(&b->count);
1105         xfrm_policy_inexact_gc_tree(&b->root_d, net_exit);
1106         xfrm_policy_inexact_gc_tree(&b->root_s, net_exit);
1107         write_seqcount_end(&b->count);
1108
1109         if (!RB_EMPTY_ROOT(&b->root_d) || !RB_EMPTY_ROOT(&b->root_s) ||
1110             !hlist_empty(&b->hhead)) {
1111                 WARN_ON_ONCE(net_exit);
1112                 return;
1113         }
1114
1115         if (rhashtable_remove_fast(&xfrm_policy_inexact_table, &b->head,
1116                                    xfrm_pol_inexact_params) == 0) {
1117                 list_del(&b->inexact_bins);
1118                 kfree_rcu(b, rcu);
1119         }
1120 }
1121
1122 static void xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b)
1123 {
1124         struct net *net = read_pnet(&b->k.net);
1125
1126         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1127         __xfrm_policy_inexact_prune_bin(b, false);
1128         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1129 }
1130
1131 static void __xfrm_policy_inexact_flush(struct net *net)
1132 {
1133         struct xfrm_pol_inexact_bin *bin, *t;
1134
1135         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1136
1137         list_for_each_entry_safe(bin, t, &net->xfrm.inexact_bins, inexact_bins)
1138                 __xfrm_policy_inexact_prune_bin(bin, false);
1139 }
1140
1141 static struct hlist_head *
1142 xfrm_policy_inexact_alloc_chain(struct xfrm_pol_inexact_bin *bin,
1143                                 struct xfrm_policy *policy, u8 dir)
1144 {
1145         struct xfrm_pol_inexact_node *n;
1146         struct net *net;
1147
1148         net = xp_net(policy);
1149         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1150
1151         if (xfrm_policy_inexact_insert_use_any_list(policy))
1152                 return &bin->hhead;
1153
1154         if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.daddr,
1155                                                policy->family,
1156                                                policy->selector.prefixlen_d)) {
1157                 write_seqcount_begin(&bin->count);
1158                 n = xfrm_policy_inexact_insert_node(net,
1159                                                     &bin->root_s,
1160                                                     &policy->selector.saddr,
1161                                                     policy->family,
1162                                                     policy->selector.prefixlen_s,
1163                                                     dir);
1164                 write_seqcount_end(&bin->count);
1165                 if (!n)
1166                         return NULL;
1167
1168                 return &n->hhead;
1169         }
1170
1171         /* daddr is fixed */
1172         write_seqcount_begin(&bin->count);
1173         n = xfrm_policy_inexact_insert_node(net,
1174                                             &bin->root_d,
1175                                             &policy->selector.daddr,
1176                                             policy->family,
1177                                             policy->selector.prefixlen_d, dir);
1178         write_seqcount_end(&bin->count);
1179         if (!n)
1180                 return NULL;
1181
1182         /* saddr is wildcard */
1183         if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.saddr,
1184                                                policy->family,
1185                                                policy->selector.prefixlen_s))
1186                 return &n->hhead;
1187
1188         write_seqcount_begin(&bin->count);
1189         n = xfrm_policy_inexact_insert_node(net,
1190                                             &n->root,
1191                                             &policy->selector.saddr,
1192                                             policy->family,
1193                                             policy->selector.prefixlen_s, dir);
1194         write_seqcount_end(&bin->count);
1195         if (!n)
1196                 return NULL;
1197
1198         return &n->hhead;
1199 }
1200
1201 static struct xfrm_policy *
1202 xfrm_policy_inexact_insert(struct xfrm_policy *policy, u8 dir, int excl)
1203 {
1204         struct xfrm_pol_inexact_bin *bin;
1205         struct xfrm_policy *delpol;
1206         struct hlist_head *chain;
1207         struct net *net;
1208
1209         bin = xfrm_policy_inexact_alloc_bin(policy, dir);
1210         if (!bin)
1211                 return ERR_PTR(-ENOMEM);
1212
1213         net = xp_net(policy);
1214         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1215
1216         chain = xfrm_policy_inexact_alloc_chain(bin, policy, dir);
1217         if (!chain) {
1218                 __xfrm_policy_inexact_prune_bin(bin, false);
1219                 return ERR_PTR(-ENOMEM);
1220         }
1221
1222         delpol = xfrm_policy_insert_list(chain, policy, excl);
1223         if (delpol && excl) {
1224                 __xfrm_policy_inexact_prune_bin(bin, false);
1225                 return ERR_PTR(-EEXIST);
1226         }
1227
1228         chain = &net->xfrm.policy_inexact[dir];
1229         xfrm_policy_insert_inexact_list(chain, policy);
1230
1231         if (delpol)
1232                 __xfrm_policy_inexact_prune_bin(bin, false);
1233
1234         return delpol;
1235 }
1236
1237 static void xfrm_hash_rebuild(struct work_struct *work)
1238 {
1239         struct net *net = container_of(work, struct net,
1240                                        xfrm.policy_hthresh.work);
1241         unsigned int hmask;
1242         struct xfrm_policy *pol;
1243         struct xfrm_policy *policy;
1244         struct hlist_head *chain;
1245         struct hlist_head *odst;
1246         struct hlist_node *newpos;
1247         int i;
1248         int dir;
1249         unsigned seq;
1250         u8 lbits4, rbits4, lbits6, rbits6;
1251
1252         mutex_lock(&hash_resize_mutex);
1253
1254         /* read selector prefixlen thresholds */
1255         do {
1256                 seq = read_seqbegin(&net->xfrm.policy_hthresh.lock);
1257
1258                 lbits4 = net->xfrm.policy_hthresh.lbits4;
1259                 rbits4 = net->xfrm.policy_hthresh.rbits4;
1260                 lbits6 = net->xfrm.policy_hthresh.lbits6;
1261                 rbits6 = net->xfrm.policy_hthresh.rbits6;
1262         } while (read_seqretry(&net->xfrm.policy_hthresh.lock, seq));
1263
1264         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1265         write_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation);
1266
1267         /* make sure that we can insert the indirect policies again before
1268          * we start with destructive action.
1269          */
1270         list_for_each_entry(policy, &net->xfrm.policy_all, walk.all) {
1271                 struct xfrm_pol_inexact_bin *bin;
1272                 u8 dbits, sbits;
1273
1274                 if (policy->walk.dead)
1275                         continue;
1276
1277                 dir = xfrm_policy_id2dir(policy->index);
1278                 if (dir >= XFRM_POLICY_MAX)
1279                         continue;
1280
1281                 if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) {
1282                         if (policy->family == AF_INET) {
1283                                 dbits = rbits4;
1284                                 sbits = lbits4;
1285                         } else {
1286                                 dbits = rbits6;
1287                                 sbits = lbits6;
1288                         }
1289                 } else {
1290                         if (policy->family == AF_INET) {
1291                                 dbits = lbits4;
1292                                 sbits = rbits4;
1293                         } else {
1294                                 dbits = lbits6;
1295                                 sbits = rbits6;
1296                         }
1297                 }
1298
1299                 if (policy->selector.prefixlen_d < dbits ||
1300                     policy->selector.prefixlen_s < sbits)
1301                         continue;
1302
1303                 bin = xfrm_policy_inexact_alloc_bin(policy, dir);
1304                 if (!bin)
1305                         goto out_unlock;
1306
1307                 if (!xfrm_policy_inexact_alloc_chain(bin, policy, dir))
1308                         goto out_unlock;
1309         }
1310
1311         /* reset the bydst and inexact table in all directions */
1312         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
1313                 struct hlist_node *n;
1314
1315                 hlist_for_each_entry_safe(policy, n,
1316                                           &net->xfrm.policy_inexact[dir],
1317                                           bydst_inexact_list) {
1318                         hlist_del_rcu(&policy->bydst);
1319                         hlist_del_init(&policy->bydst_inexact_list);
1320                 }
1321
1322                 hmask = net->xfrm.policy_bydst[dir].hmask;
1323                 odst = net->xfrm.policy_bydst[dir].table;
1324                 for (i = hmask; i >= 0; i--) {
1325                         hlist_for_each_entry_safe(policy, n, odst + i, bydst)
1326                                 hlist_del_rcu(&policy->bydst);
1327                 }
1328                 if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) {
1329                         /* dir out => dst = remote, src = local */
1330                         net->xfrm.policy_bydst[dir].dbits4 = rbits4;
1331                         net->xfrm.policy_bydst[dir].sbits4 = lbits4;
1332                         net->xfrm.policy_bydst[dir].dbits6 = rbits6;
1333                         net->xfrm.policy_bydst[dir].sbits6 = lbits6;
1334                 } else {
1335                         /* dir in/fwd => dst = local, src = remote */
1336                         net->xfrm.policy_bydst[dir].dbits4 = lbits4;
1337                         net->xfrm.policy_bydst[dir].sbits4 = rbits4;
1338                         net->xfrm.policy_bydst[dir].dbits6 = lbits6;
1339                         net->xfrm.policy_bydst[dir].sbits6 = rbits6;
1340                 }
1341         }
1342
1343         /* re-insert all policies by order of creation */
1344         list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) {
1345                 if (policy->walk.dead)
1346                         continue;
1347                 dir = xfrm_policy_id2dir(policy->index);
1348                 if (dir >= XFRM_POLICY_MAX) {
1349                         /* skip socket policies */
1350                         continue;
1351                 }
1352                 newpos = NULL;
1353                 chain = policy_hash_bysel(net, &policy->selector,
1354                                           policy->family, dir);
1355
1356                 if (!chain) {
1357                         void *p = xfrm_policy_inexact_insert(policy, dir, 0);
1358
1359                         WARN_ONCE(IS_ERR(p), "reinsert: %ld\n", PTR_ERR(p));
1360                         continue;
1361                 }
1362
1363                 hlist_for_each_entry(pol, chain, bydst) {
1364                         if (policy->priority >= pol->priority)
1365                                 newpos = &pol->bydst;
1366                         else
1367                                 break;
1368                 }
1369                 if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
1370                         hlist_add_behind_rcu(&policy->bydst, newpos);
1371                 else
1372                         hlist_add_head_rcu(&policy->bydst, chain);
1373         }
1374
1375 out_unlock:
1376         __xfrm_policy_inexact_flush(net);
1377         write_seqcount_end(&net->xfrm.xfrm_policy_hash_generation);
1378         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1379
1380         mutex_unlock(&hash_resize_mutex);
1381 }
1382
1383 void xfrm_policy_hash_rebuild(struct net *net)
1384 {
1385         schedule_work(&net->xfrm.policy_hthresh.work);
1386 }
1387 EXPORT_SYMBOL(xfrm_policy_hash_rebuild);
1388
1389 /* Generate new index... KAME seems to generate them ordered by cost
1390  * of an absolute inpredictability of ordering of rules. This will not pass. */
1391 static u32 xfrm_gen_index(struct net *net, int dir, u32 index)
1392 {
1393         for (;;) {
1394                 struct hlist_head *list;
1395                 struct xfrm_policy *p;
1396                 u32 idx;
1397                 int found;
1398
1399                 if (!index) {
1400                         idx = (net->xfrm.idx_generator | dir);
1401                         net->xfrm.idx_generator += 8;
1402                 } else {
1403                         idx = index;
1404                         index = 0;
1405                 }
1406
1407                 if (idx == 0)
1408                         idx = 8;
1409                 list = net->xfrm.policy_byidx + idx_hash(net, idx);
1410                 found = 0;
1411                 hlist_for_each_entry(p, list, byidx) {
1412                         if (p->index == idx) {
1413                                 found = 1;
1414                                 break;
1415                         }
1416                 }
1417                 if (!found)
1418                         return idx;
1419         }
1420 }
1421
1422 static inline int selector_cmp(struct xfrm_selector *s1, struct xfrm_selector *s2)
1423 {
1424         u32 *p1 = (u32 *) s1;
1425         u32 *p2 = (u32 *) s2;
1426         int len = sizeof(struct xfrm_selector) / sizeof(u32);
1427         int i;
1428
1429         for (i = 0; i < len; i++) {
1430                 if (p1[i] != p2[i])
1431                         return 1;
1432         }
1433
1434         return 0;
1435 }
1436
1437 static void xfrm_policy_requeue(struct xfrm_policy *old,
1438                                 struct xfrm_policy *new)
1439 {
1440         struct xfrm_policy_queue *pq = &old->polq;
1441         struct sk_buff_head list;
1442
1443         if (skb_queue_empty(&pq->hold_queue))
1444                 return;
1445
1446         __skb_queue_head_init(&list);
1447
1448         spin_lock_bh(&pq->hold_queue.lock);
1449         skb_queue_splice_init(&pq->hold_queue, &list);
1450         if (del_timer(&pq->hold_timer))
1451                 xfrm_pol_put(old);
1452         spin_unlock_bh(&pq->hold_queue.lock);
1453
1454         pq = &new->polq;
1455
1456         spin_lock_bh(&pq->hold_queue.lock);
1457         skb_queue_splice(&list, &pq->hold_queue);
1458         pq->timeout = XFRM_QUEUE_TMO_MIN;
1459         if (!mod_timer(&pq->hold_timer, jiffies))
1460                 xfrm_pol_hold(new);
1461         spin_unlock_bh(&pq->hold_queue.lock);
1462 }
1463
1464 static inline bool xfrm_policy_mark_match(const struct xfrm_mark *mark,
1465                                           struct xfrm_policy *pol)
1466 {
1467         return mark->v == pol->mark.v && mark->m == pol->mark.m;
1468 }
1469
1470 static u32 xfrm_pol_bin_key(const void *data, u32 len, u32 seed)
1471 {
1472         const struct xfrm_pol_inexact_key *k = data;
1473         u32 a = k->type << 24 | k->dir << 16 | k->family;
1474
1475         return jhash_3words(a, k->if_id, net_hash_mix(read_pnet(&k->net)),
1476                             seed);
1477 }
1478
1479 static u32 xfrm_pol_bin_obj(const void *data, u32 len, u32 seed)
1480 {
1481         const struct xfrm_pol_inexact_bin *b = data;
1482
1483         return xfrm_pol_bin_key(&b->k, 0, seed);
1484 }
1485
1486 static int xfrm_pol_bin_cmp(struct rhashtable_compare_arg *arg,
1487                             const void *ptr)
1488 {
1489         const struct xfrm_pol_inexact_key *key = arg->key;
1490         const struct xfrm_pol_inexact_bin *b = ptr;
1491         int ret;
1492
1493         if (!net_eq(read_pnet(&b->k.net), read_pnet(&key->net)))
1494                 return -1;
1495
1496         ret = b->k.dir ^ key->dir;
1497         if (ret)
1498                 return ret;
1499
1500         ret = b->k.type ^ key->type;
1501         if (ret)
1502                 return ret;
1503
1504         ret = b->k.family ^ key->family;
1505         if (ret)
1506                 return ret;
1507
1508         return b->k.if_id ^ key->if_id;
1509 }
1510
1511 static const struct rhashtable_params xfrm_pol_inexact_params = {
1512         .head_offset            = offsetof(struct xfrm_pol_inexact_bin, head),
1513         .hashfn                 = xfrm_pol_bin_key,
1514         .obj_hashfn             = xfrm_pol_bin_obj,
1515         .obj_cmpfn              = xfrm_pol_bin_cmp,
1516         .automatic_shrinking    = true,
1517 };
1518
1519 static void xfrm_policy_insert_inexact_list(struct hlist_head *chain,
1520                                             struct xfrm_policy *policy)
1521 {
1522         struct xfrm_policy *pol, *delpol = NULL;
1523         struct hlist_node *newpos = NULL;
1524         int i = 0;
1525
1526         hlist_for_each_entry(pol, chain, bydst_inexact_list) {
1527                 if (pol->type == policy->type &&
1528                     pol->if_id == policy->if_id &&
1529                     !selector_cmp(&pol->selector, &policy->selector) &&
1530                     xfrm_policy_mark_match(&policy->mark, pol) &&
1531                     xfrm_sec_ctx_match(pol->security, policy->security) &&
1532                     !WARN_ON(delpol)) {
1533                         delpol = pol;
1534                         if (policy->priority > pol->priority)
1535                                 continue;
1536                 } else if (policy->priority >= pol->priority) {
1537                         newpos = &pol->bydst_inexact_list;
1538                         continue;
1539                 }
1540                 if (delpol)
1541                         break;
1542         }
1543
1544         if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
1545                 hlist_add_behind_rcu(&policy->bydst_inexact_list, newpos);
1546         else
1547                 hlist_add_head_rcu(&policy->bydst_inexact_list, chain);
1548
1549         hlist_for_each_entry(pol, chain, bydst_inexact_list) {
1550                 pol->pos = i;
1551                 i++;
1552         }
1553 }
1554
1555 static struct xfrm_policy *xfrm_policy_insert_list(struct hlist_head *chain,
1556                                                    struct xfrm_policy *policy,
1557                                                    bool excl)
1558 {
1559         struct xfrm_policy *pol, *newpos = NULL, *delpol = NULL;
1560
1561         hlist_for_each_entry(pol, chain, bydst) {
1562                 if (pol->type == policy->type &&
1563                     pol->if_id == policy->if_id &&
1564                     !selector_cmp(&pol->selector, &policy->selector) &&
1565                     xfrm_policy_mark_match(&policy->mark, pol) &&
1566                     xfrm_sec_ctx_match(pol->security, policy->security) &&
1567                     !WARN_ON(delpol)) {
1568                         if (excl)
1569                                 return ERR_PTR(-EEXIST);
1570                         delpol = pol;
1571                         if (policy->priority > pol->priority)
1572                                 continue;
1573                 } else if (policy->priority >= pol->priority) {
1574                         newpos = pol;
1575                         continue;
1576                 }
1577                 if (delpol)
1578                         break;
1579         }
1580
1581         if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
1582                 hlist_add_behind_rcu(&policy->bydst, &newpos->bydst);
1583         else
1584                 /* Packet offload policies enter to the head
1585                  * to speed-up lookups.
1586                  */
1587                 hlist_add_head_rcu(&policy->bydst, chain);
1588
1589         return delpol;
1590 }
1591
1592 int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
1593 {
1594         struct net *net = xp_net(policy);
1595         struct xfrm_policy *delpol;
1596         struct hlist_head *chain;
1597
1598         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1599         chain = policy_hash_bysel(net, &policy->selector, policy->family, dir);
1600         if (chain)
1601                 delpol = xfrm_policy_insert_list(chain, policy, excl);
1602         else
1603                 delpol = xfrm_policy_inexact_insert(policy, dir, excl);
1604
1605         if (IS_ERR(delpol)) {
1606                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1607                 return PTR_ERR(delpol);
1608         }
1609
1610         __xfrm_policy_link(policy, dir);
1611
1612         /* After previous checking, family can either be AF_INET or AF_INET6 */
1613         if (policy->family == AF_INET)
1614                 rt_genid_bump_ipv4(net);
1615         else
1616                 rt_genid_bump_ipv6(net);
1617
1618         if (delpol) {
1619                 xfrm_policy_requeue(delpol, policy);
1620                 __xfrm_policy_unlink(delpol, dir);
1621         }
1622         policy->index = delpol ? delpol->index : xfrm_gen_index(net, dir, policy->index);
1623         hlist_add_head(&policy->byidx, net->xfrm.policy_byidx+idx_hash(net, policy->index));
1624         policy->curlft.add_time = ktime_get_real_seconds();
1625         policy->curlft.use_time = 0;
1626         if (!mod_timer(&policy->timer, jiffies + HZ))
1627                 xfrm_pol_hold(policy);
1628         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1629
1630         if (delpol)
1631                 xfrm_policy_kill(delpol);
1632         else if (xfrm_bydst_should_resize(net, dir, NULL))
1633                 schedule_work(&net->xfrm.policy_hash_work);
1634
1635         return 0;
1636 }
1637 EXPORT_SYMBOL(xfrm_policy_insert);
1638
1639 static struct xfrm_policy *
1640 __xfrm_policy_bysel_ctx(struct hlist_head *chain, const struct xfrm_mark *mark,
1641                         u32 if_id, u8 type, int dir, struct xfrm_selector *sel,
1642                         struct xfrm_sec_ctx *ctx)
1643 {
1644         struct xfrm_policy *pol;
1645
1646         if (!chain)
1647                 return NULL;
1648
1649         hlist_for_each_entry(pol, chain, bydst) {
1650                 if (pol->type == type &&
1651                     pol->if_id == if_id &&
1652                     xfrm_policy_mark_match(mark, pol) &&
1653                     !selector_cmp(sel, &pol->selector) &&
1654                     xfrm_sec_ctx_match(ctx, pol->security))
1655                         return pol;
1656         }
1657
1658         return NULL;
1659 }
1660
1661 struct xfrm_policy *
1662 xfrm_policy_bysel_ctx(struct net *net, const struct xfrm_mark *mark, u32 if_id,
1663                       u8 type, int dir, struct xfrm_selector *sel,
1664                       struct xfrm_sec_ctx *ctx, int delete, int *err)
1665 {
1666         struct xfrm_pol_inexact_bin *bin = NULL;
1667         struct xfrm_policy *pol, *ret = NULL;
1668         struct hlist_head *chain;
1669
1670         *err = 0;
1671         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1672         chain = policy_hash_bysel(net, sel, sel->family, dir);
1673         if (!chain) {
1674                 struct xfrm_pol_inexact_candidates cand;
1675                 int i;
1676
1677                 bin = xfrm_policy_inexact_lookup(net, type,
1678                                                  sel->family, dir, if_id);
1679                 if (!bin) {
1680                         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1681                         return NULL;
1682                 }
1683
1684                 if (!xfrm_policy_find_inexact_candidates(&cand, bin,
1685                                                          &sel->saddr,
1686                                                          &sel->daddr)) {
1687                         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1688                         return NULL;
1689                 }
1690
1691                 pol = NULL;
1692                 for (i = 0; i < ARRAY_SIZE(cand.res); i++) {
1693                         struct xfrm_policy *tmp;
1694
1695                         tmp = __xfrm_policy_bysel_ctx(cand.res[i], mark,
1696                                                       if_id, type, dir,
1697                                                       sel, ctx);
1698                         if (!tmp)
1699                                 continue;
1700
1701                         if (!pol || tmp->pos < pol->pos)
1702                                 pol = tmp;
1703                 }
1704         } else {
1705                 pol = __xfrm_policy_bysel_ctx(chain, mark, if_id, type, dir,
1706                                               sel, ctx);
1707         }
1708
1709         if (pol) {
1710                 xfrm_pol_hold(pol);
1711                 if (delete) {
1712                         *err = security_xfrm_policy_delete(pol->security);
1713                         if (*err) {
1714                                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1715                                 return pol;
1716                         }
1717                         __xfrm_policy_unlink(pol, dir);
1718                 }
1719                 ret = pol;
1720         }
1721         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1722
1723         if (ret && delete)
1724                 xfrm_policy_kill(ret);
1725         if (bin && delete)
1726                 xfrm_policy_inexact_prune_bin(bin);
1727         return ret;
1728 }
1729 EXPORT_SYMBOL(xfrm_policy_bysel_ctx);
1730
1731 struct xfrm_policy *
1732 xfrm_policy_byid(struct net *net, const struct xfrm_mark *mark, u32 if_id,
1733                  u8 type, int dir, u32 id, int delete, int *err)
1734 {
1735         struct xfrm_policy *pol, *ret;
1736         struct hlist_head *chain;
1737
1738         *err = -ENOENT;
1739         if (xfrm_policy_id2dir(id) != dir)
1740                 return NULL;
1741
1742         *err = 0;
1743         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1744         chain = net->xfrm.policy_byidx + idx_hash(net, id);
1745         ret = NULL;
1746         hlist_for_each_entry(pol, chain, byidx) {
1747                 if (pol->type == type && pol->index == id &&
1748                     pol->if_id == if_id && xfrm_policy_mark_match(mark, pol)) {
1749                         xfrm_pol_hold(pol);
1750                         if (delete) {
1751                                 *err = security_xfrm_policy_delete(
1752                                                                 pol->security);
1753                                 if (*err) {
1754                                         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1755                                         return pol;
1756                                 }
1757                                 __xfrm_policy_unlink(pol, dir);
1758                         }
1759                         ret = pol;
1760                         break;
1761                 }
1762         }
1763         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1764
1765         if (ret && delete)
1766                 xfrm_policy_kill(ret);
1767         return ret;
1768 }
1769 EXPORT_SYMBOL(xfrm_policy_byid);
1770
1771 #ifdef CONFIG_SECURITY_NETWORK_XFRM
1772 static inline int
1773 xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid)
1774 {
1775         struct xfrm_policy *pol;
1776         int err = 0;
1777
1778         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1779                 if (pol->walk.dead ||
1780                     xfrm_policy_id2dir(pol->index) >= XFRM_POLICY_MAX ||
1781                     pol->type != type)
1782                         continue;
1783
1784                 err = security_xfrm_policy_delete(pol->security);
1785                 if (err) {
1786                         xfrm_audit_policy_delete(pol, 0, task_valid);
1787                         return err;
1788                 }
1789         }
1790         return err;
1791 }
1792
1793 static inline int xfrm_dev_policy_flush_secctx_check(struct net *net,
1794                                                      struct net_device *dev,
1795                                                      bool task_valid)
1796 {
1797         struct xfrm_policy *pol;
1798         int err = 0;
1799
1800         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1801                 if (pol->walk.dead ||
1802                     xfrm_policy_id2dir(pol->index) >= XFRM_POLICY_MAX ||
1803                     pol->xdo.dev != dev)
1804                         continue;
1805
1806                 err = security_xfrm_policy_delete(pol->security);
1807                 if (err) {
1808                         xfrm_audit_policy_delete(pol, 0, task_valid);
1809                         return err;
1810                 }
1811         }
1812         return err;
1813 }
1814 #else
1815 static inline int
1816 xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid)
1817 {
1818         return 0;
1819 }
1820
1821 static inline int xfrm_dev_policy_flush_secctx_check(struct net *net,
1822                                                      struct net_device *dev,
1823                                                      bool task_valid)
1824 {
1825         return 0;
1826 }
1827 #endif
1828
1829 int xfrm_policy_flush(struct net *net, u8 type, bool task_valid)
1830 {
1831         int dir, err = 0, cnt = 0;
1832         struct xfrm_policy *pol;
1833
1834         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1835
1836         err = xfrm_policy_flush_secctx_check(net, type, task_valid);
1837         if (err)
1838                 goto out;
1839
1840 again:
1841         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1842                 if (pol->walk.dead)
1843                         continue;
1844
1845                 dir = xfrm_policy_id2dir(pol->index);
1846                 if (dir >= XFRM_POLICY_MAX ||
1847                     pol->type != type)
1848                         continue;
1849
1850                 __xfrm_policy_unlink(pol, dir);
1851                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1852                 xfrm_dev_policy_delete(pol);
1853                 cnt++;
1854                 xfrm_audit_policy_delete(pol, 1, task_valid);
1855                 xfrm_policy_kill(pol);
1856                 spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1857                 goto again;
1858         }
1859         if (cnt)
1860                 __xfrm_policy_inexact_flush(net);
1861         else
1862                 err = -ESRCH;
1863 out:
1864         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1865         return err;
1866 }
1867 EXPORT_SYMBOL(xfrm_policy_flush);
1868
1869 int xfrm_dev_policy_flush(struct net *net, struct net_device *dev,
1870                           bool task_valid)
1871 {
1872         int dir, err = 0, cnt = 0;
1873         struct xfrm_policy *pol;
1874
1875         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1876
1877         err = xfrm_dev_policy_flush_secctx_check(net, dev, task_valid);
1878         if (err)
1879                 goto out;
1880
1881 again:
1882         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1883                 if (pol->walk.dead)
1884                         continue;
1885
1886                 dir = xfrm_policy_id2dir(pol->index);
1887                 if (dir >= XFRM_POLICY_MAX ||
1888                     pol->xdo.dev != dev)
1889                         continue;
1890
1891                 __xfrm_policy_unlink(pol, dir);
1892                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1893                 xfrm_dev_policy_delete(pol);
1894                 cnt++;
1895                 xfrm_audit_policy_delete(pol, 1, task_valid);
1896                 xfrm_policy_kill(pol);
1897                 spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1898                 goto again;
1899         }
1900         if (cnt)
1901                 __xfrm_policy_inexact_flush(net);
1902         else
1903                 err = -ESRCH;
1904 out:
1905         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1906         return err;
1907 }
1908 EXPORT_SYMBOL(xfrm_dev_policy_flush);
1909
1910 int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk,
1911                      int (*func)(struct xfrm_policy *, int, int, void*),
1912                      void *data)
1913 {
1914         struct xfrm_policy *pol;
1915         struct xfrm_policy_walk_entry *x;
1916         int error = 0;
1917
1918         if (walk->type >= XFRM_POLICY_TYPE_MAX &&
1919             walk->type != XFRM_POLICY_TYPE_ANY)
1920                 return -EINVAL;
1921
1922         if (list_empty(&walk->walk.all) && walk->seq != 0)
1923                 return 0;
1924
1925         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1926         if (list_empty(&walk->walk.all))
1927                 x = list_first_entry(&net->xfrm.policy_all, struct xfrm_policy_walk_entry, all);
1928         else
1929                 x = list_first_entry(&walk->walk.all,
1930                                      struct xfrm_policy_walk_entry, all);
1931
1932         list_for_each_entry_from(x, &net->xfrm.policy_all, all) {
1933                 if (x->dead)
1934                         continue;
1935                 pol = container_of(x, struct xfrm_policy, walk);
1936                 if (walk->type != XFRM_POLICY_TYPE_ANY &&
1937                     walk->type != pol->type)
1938                         continue;
1939                 error = func(pol, xfrm_policy_id2dir(pol->index),
1940                              walk->seq, data);
1941                 if (error) {
1942                         list_move_tail(&walk->walk.all, &x->all);
1943                         goto out;
1944                 }
1945                 walk->seq++;
1946         }
1947         if (walk->seq == 0) {
1948                 error = -ENOENT;
1949                 goto out;
1950         }
1951         list_del_init(&walk->walk.all);
1952 out:
1953         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1954         return error;
1955 }
1956 EXPORT_SYMBOL(xfrm_policy_walk);
1957
1958 void xfrm_policy_walk_init(struct xfrm_policy_walk *walk, u8 type)
1959 {
1960         INIT_LIST_HEAD(&walk->walk.all);
1961         walk->walk.dead = 1;
1962         walk->type = type;
1963         walk->seq = 0;
1964 }
1965 EXPORT_SYMBOL(xfrm_policy_walk_init);
1966
1967 void xfrm_policy_walk_done(struct xfrm_policy_walk *walk, struct net *net)
1968 {
1969         if (list_empty(&walk->walk.all))
1970                 return;
1971
1972         spin_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */
1973         list_del(&walk->walk.all);
1974         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1975 }
1976 EXPORT_SYMBOL(xfrm_policy_walk_done);
1977
1978 /*
1979  * Find policy to apply to this flow.
1980  *
1981  * Returns 0 if policy found, else an -errno.
1982  */
1983 static int xfrm_policy_match(const struct xfrm_policy *pol,
1984                              const struct flowi *fl,
1985                              u8 type, u16 family, u32 if_id)
1986 {
1987         const struct xfrm_selector *sel = &pol->selector;
1988         int ret = -ESRCH;
1989         bool match;
1990
1991         if (pol->family != family ||
1992             pol->if_id != if_id ||
1993             (fl->flowi_mark & pol->mark.m) != pol->mark.v ||
1994             pol->type != type)
1995                 return ret;
1996
1997         match = xfrm_selector_match(sel, fl, family);
1998         if (match)
1999                 ret = security_xfrm_policy_lookup(pol->security, fl->flowi_secid);
2000         return ret;
2001 }
2002
2003 static struct xfrm_pol_inexact_node *
2004 xfrm_policy_lookup_inexact_addr(const struct rb_root *r,
2005                                 seqcount_spinlock_t *count,
2006                                 const xfrm_address_t *addr, u16 family)
2007 {
2008         const struct rb_node *parent;
2009         int seq;
2010
2011 again:
2012         seq = read_seqcount_begin(count);
2013
2014         parent = rcu_dereference_raw(r->rb_node);
2015         while (parent) {
2016                 struct xfrm_pol_inexact_node *node;
2017                 int delta;
2018
2019                 node = rb_entry(parent, struct xfrm_pol_inexact_node, node);
2020
2021                 delta = xfrm_policy_addr_delta(addr, &node->addr,
2022                                                node->prefixlen, family);
2023                 if (delta < 0) {
2024                         parent = rcu_dereference_raw(parent->rb_left);
2025                         continue;
2026                 } else if (delta > 0) {
2027                         parent = rcu_dereference_raw(parent->rb_right);
2028                         continue;
2029                 }
2030
2031                 return node;
2032         }
2033
2034         if (read_seqcount_retry(count, seq))
2035                 goto again;
2036
2037         return NULL;
2038 }
2039
2040 static bool
2041 xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand,
2042                                     struct xfrm_pol_inexact_bin *b,
2043                                     const xfrm_address_t *saddr,
2044                                     const xfrm_address_t *daddr)
2045 {
2046         struct xfrm_pol_inexact_node *n;
2047         u16 family;
2048
2049         if (!b)
2050                 return false;
2051
2052         family = b->k.family;
2053         memset(cand, 0, sizeof(*cand));
2054         cand->res[XFRM_POL_CAND_ANY] = &b->hhead;
2055
2056         n = xfrm_policy_lookup_inexact_addr(&b->root_d, &b->count, daddr,
2057                                             family);
2058         if (n) {
2059                 cand->res[XFRM_POL_CAND_DADDR] = &n->hhead;
2060                 n = xfrm_policy_lookup_inexact_addr(&n->root, &b->count, saddr,
2061                                                     family);
2062                 if (n)
2063                         cand->res[XFRM_POL_CAND_BOTH] = &n->hhead;
2064         }
2065
2066         n = xfrm_policy_lookup_inexact_addr(&b->root_s, &b->count, saddr,
2067                                             family);
2068         if (n)
2069                 cand->res[XFRM_POL_CAND_SADDR] = &n->hhead;
2070
2071         return true;
2072 }
2073
2074 static struct xfrm_pol_inexact_bin *
2075 xfrm_policy_inexact_lookup_rcu(struct net *net, u8 type, u16 family,
2076                                u8 dir, u32 if_id)
2077 {
2078         struct xfrm_pol_inexact_key k = {
2079                 .family = family,
2080                 .type = type,
2081                 .dir = dir,
2082                 .if_id = if_id,
2083         };
2084
2085         write_pnet(&k.net, net);
2086
2087         return rhashtable_lookup(&xfrm_policy_inexact_table, &k,
2088                                  xfrm_pol_inexact_params);
2089 }
2090
2091 static struct xfrm_pol_inexact_bin *
2092 xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family,
2093                            u8 dir, u32 if_id)
2094 {
2095         struct xfrm_pol_inexact_bin *bin;
2096
2097         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
2098
2099         rcu_read_lock();
2100         bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id);
2101         rcu_read_unlock();
2102
2103         return bin;
2104 }
2105
2106 static struct xfrm_policy *
2107 __xfrm_policy_eval_candidates(struct hlist_head *chain,
2108                               struct xfrm_policy *prefer,
2109                               const struct flowi *fl,
2110                               u8 type, u16 family, u32 if_id)
2111 {
2112         u32 priority = prefer ? prefer->priority : ~0u;
2113         struct xfrm_policy *pol;
2114
2115         if (!chain)
2116                 return NULL;
2117
2118         hlist_for_each_entry_rcu(pol, chain, bydst) {
2119                 int err;
2120
2121                 if (pol->priority > priority)
2122                         break;
2123
2124                 err = xfrm_policy_match(pol, fl, type, family, if_id);
2125                 if (err) {
2126                         if (err != -ESRCH)
2127                                 return ERR_PTR(err);
2128
2129                         continue;
2130                 }
2131
2132                 if (prefer) {
2133                         /* matches.  Is it older than *prefer? */
2134                         if (pol->priority == priority &&
2135                             prefer->pos < pol->pos)
2136                                 return prefer;
2137                 }
2138
2139                 return pol;
2140         }
2141
2142         return NULL;
2143 }
2144
2145 static struct xfrm_policy *
2146 xfrm_policy_eval_candidates(struct xfrm_pol_inexact_candidates *cand,
2147                             struct xfrm_policy *prefer,
2148                             const struct flowi *fl,
2149                             u8 type, u16 family, u32 if_id)
2150 {
2151         struct xfrm_policy *tmp;
2152         int i;
2153
2154         for (i = 0; i < ARRAY_SIZE(cand->res); i++) {
2155                 tmp = __xfrm_policy_eval_candidates(cand->res[i],
2156                                                     prefer,
2157                                                     fl, type, family, if_id);
2158                 if (!tmp)
2159                         continue;
2160
2161                 if (IS_ERR(tmp))
2162                         return tmp;
2163                 prefer = tmp;
2164         }
2165
2166         return prefer;
2167 }
2168
2169 static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type,
2170                                                      const struct flowi *fl,
2171                                                      u16 family, u8 dir,
2172                                                      u32 if_id)
2173 {
2174         struct xfrm_pol_inexact_candidates cand;
2175         const xfrm_address_t *daddr, *saddr;
2176         struct xfrm_pol_inexact_bin *bin;
2177         struct xfrm_policy *pol, *ret;
2178         struct hlist_head *chain;
2179         unsigned int sequence;
2180         int err;
2181
2182         daddr = xfrm_flowi_daddr(fl, family);
2183         saddr = xfrm_flowi_saddr(fl, family);
2184         if (unlikely(!daddr || !saddr))
2185                 return NULL;
2186
2187         rcu_read_lock();
2188  retry:
2189         do {
2190                 sequence = read_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation);
2191                 chain = policy_hash_direct(net, daddr, saddr, family, dir);
2192         } while (read_seqcount_retry(&net->xfrm.xfrm_policy_hash_generation, sequence));
2193
2194         ret = NULL;
2195         hlist_for_each_entry_rcu(pol, chain, bydst) {
2196                 err = xfrm_policy_match(pol, fl, type, family, if_id);
2197                 if (err) {
2198                         if (err == -ESRCH)
2199                                 continue;
2200                         else {
2201                                 ret = ERR_PTR(err);
2202                                 goto fail;
2203                         }
2204                 } else {
2205                         ret = pol;
2206                         break;
2207                 }
2208         }
2209         if (ret && ret->xdo.type == XFRM_DEV_OFFLOAD_PACKET)
2210                 goto skip_inexact;
2211
2212         bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id);
2213         if (!bin || !xfrm_policy_find_inexact_candidates(&cand, bin, saddr,
2214                                                          daddr))
2215                 goto skip_inexact;
2216
2217         pol = xfrm_policy_eval_candidates(&cand, ret, fl, type,
2218                                           family, if_id);
2219         if (pol) {
2220                 ret = pol;
2221                 if (IS_ERR(pol))
2222                         goto fail;
2223         }
2224
2225 skip_inexact:
2226         if (read_seqcount_retry(&net->xfrm.xfrm_policy_hash_generation, sequence))
2227                 goto retry;
2228
2229         if (ret && !xfrm_pol_hold_rcu(ret))
2230                 goto retry;
2231 fail:
2232         rcu_read_unlock();
2233
2234         return ret;
2235 }
2236
2237 static struct xfrm_policy *xfrm_policy_lookup(struct net *net,
2238                                               const struct flowi *fl,
2239                                               u16 family, u8 dir, u32 if_id)
2240 {
2241 #ifdef CONFIG_XFRM_SUB_POLICY
2242         struct xfrm_policy *pol;
2243
2244         pol = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_SUB, fl, family,
2245                                         dir, if_id);
2246         if (pol != NULL)
2247                 return pol;
2248 #endif
2249         return xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN, fl, family,
2250                                          dir, if_id);
2251 }
2252
2253 static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir,
2254                                                  const struct flowi *fl,
2255                                                  u16 family, u32 if_id)
2256 {
2257         struct xfrm_policy *pol;
2258
2259         rcu_read_lock();
2260  again:
2261         pol = rcu_dereference(sk->sk_policy[dir]);
2262         if (pol != NULL) {
2263                 bool match;
2264                 int err = 0;
2265
2266                 if (pol->family != family) {
2267                         pol = NULL;
2268                         goto out;
2269                 }
2270
2271                 match = xfrm_selector_match(&pol->selector, fl, family);
2272                 if (match) {
2273                         if ((READ_ONCE(sk->sk_mark) & pol->mark.m) != pol->mark.v ||
2274                             pol->if_id != if_id) {
2275                                 pol = NULL;
2276                                 goto out;
2277                         }
2278                         err = security_xfrm_policy_lookup(pol->security,
2279                                                       fl->flowi_secid);
2280                         if (!err) {
2281                                 if (!xfrm_pol_hold_rcu(pol))
2282                                         goto again;
2283                         } else if (err == -ESRCH) {
2284                                 pol = NULL;
2285                         } else {
2286                                 pol = ERR_PTR(err);
2287                         }
2288                 } else
2289                         pol = NULL;
2290         }
2291 out:
2292         rcu_read_unlock();
2293         return pol;
2294 }
2295
2296 static void __xfrm_policy_link(struct xfrm_policy *pol, int dir)
2297 {
2298         struct net *net = xp_net(pol);
2299
2300         list_add(&pol->walk.all, &net->xfrm.policy_all);
2301         net->xfrm.policy_count[dir]++;
2302         xfrm_pol_hold(pol);
2303 }
2304
2305 static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
2306                                                 int dir)
2307 {
2308         struct net *net = xp_net(pol);
2309
2310         if (list_empty(&pol->walk.all))
2311                 return NULL;
2312
2313         /* Socket policies are not hashed. */
2314         if (!hlist_unhashed(&pol->bydst)) {
2315                 hlist_del_rcu(&pol->bydst);
2316                 hlist_del_init(&pol->bydst_inexact_list);
2317                 hlist_del(&pol->byidx);
2318         }
2319
2320         list_del_init(&pol->walk.all);
2321         net->xfrm.policy_count[dir]--;
2322
2323         return pol;
2324 }
2325
2326 static void xfrm_sk_policy_link(struct xfrm_policy *pol, int dir)
2327 {
2328         __xfrm_policy_link(pol, XFRM_POLICY_MAX + dir);
2329 }
2330
2331 static void xfrm_sk_policy_unlink(struct xfrm_policy *pol, int dir)
2332 {
2333         __xfrm_policy_unlink(pol, XFRM_POLICY_MAX + dir);
2334 }
2335
2336 int xfrm_policy_delete(struct xfrm_policy *pol, int dir)
2337 {
2338         struct net *net = xp_net(pol);
2339
2340         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2341         pol = __xfrm_policy_unlink(pol, dir);
2342         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2343         if (pol) {
2344                 xfrm_dev_policy_delete(pol);
2345                 xfrm_policy_kill(pol);
2346                 return 0;
2347         }
2348         return -ENOENT;
2349 }
2350 EXPORT_SYMBOL(xfrm_policy_delete);
2351
2352 int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol)
2353 {
2354         struct net *net = sock_net(sk);
2355         struct xfrm_policy *old_pol;
2356
2357 #ifdef CONFIG_XFRM_SUB_POLICY
2358         if (pol && pol->type != XFRM_POLICY_TYPE_MAIN)
2359                 return -EINVAL;
2360 #endif
2361
2362         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2363         old_pol = rcu_dereference_protected(sk->sk_policy[dir],
2364                                 lockdep_is_held(&net->xfrm.xfrm_policy_lock));
2365         if (pol) {
2366                 pol->curlft.add_time = ktime_get_real_seconds();
2367                 pol->index = xfrm_gen_index(net, XFRM_POLICY_MAX+dir, 0);
2368                 xfrm_sk_policy_link(pol, dir);
2369         }
2370         rcu_assign_pointer(sk->sk_policy[dir], pol);
2371         if (old_pol) {
2372                 if (pol)
2373                         xfrm_policy_requeue(old_pol, pol);
2374
2375                 /* Unlinking succeeds always. This is the only function
2376                  * allowed to delete or replace socket policy.
2377                  */
2378                 xfrm_sk_policy_unlink(old_pol, dir);
2379         }
2380         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2381
2382         if (old_pol) {
2383                 xfrm_policy_kill(old_pol);
2384         }
2385         return 0;
2386 }
2387
2388 static struct xfrm_policy *clone_policy(const struct xfrm_policy *old, int dir)
2389 {
2390         struct xfrm_policy *newp = xfrm_policy_alloc(xp_net(old), GFP_ATOMIC);
2391         struct net *net = xp_net(old);
2392
2393         if (newp) {
2394                 newp->selector = old->selector;
2395                 if (security_xfrm_policy_clone(old->security,
2396                                                &newp->security)) {
2397                         kfree(newp);
2398                         return NULL;  /* ENOMEM */
2399                 }
2400                 newp->lft = old->lft;
2401                 newp->curlft = old->curlft;
2402                 newp->mark = old->mark;
2403                 newp->if_id = old->if_id;
2404                 newp->action = old->action;
2405                 newp->flags = old->flags;
2406                 newp->xfrm_nr = old->xfrm_nr;
2407                 newp->index = old->index;
2408                 newp->type = old->type;
2409                 newp->family = old->family;
2410                 memcpy(newp->xfrm_vec, old->xfrm_vec,
2411                        newp->xfrm_nr*sizeof(struct xfrm_tmpl));
2412                 spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2413                 xfrm_sk_policy_link(newp, dir);
2414                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2415                 xfrm_pol_put(newp);
2416         }
2417         return newp;
2418 }
2419
2420 int __xfrm_sk_clone_policy(struct sock *sk, const struct sock *osk)
2421 {
2422         const struct xfrm_policy *p;
2423         struct xfrm_policy *np;
2424         int i, ret = 0;
2425
2426         rcu_read_lock();
2427         for (i = 0; i < 2; i++) {
2428                 p = rcu_dereference(osk->sk_policy[i]);
2429                 if (p) {
2430                         np = clone_policy(p, i);
2431                         if (unlikely(!np)) {
2432                                 ret = -ENOMEM;
2433                                 break;
2434                         }
2435                         rcu_assign_pointer(sk->sk_policy[i], np);
2436                 }
2437         }
2438         rcu_read_unlock();
2439         return ret;
2440 }
2441
2442 static int
2443 xfrm_get_saddr(struct net *net, int oif, xfrm_address_t *local,
2444                xfrm_address_t *remote, unsigned short family, u32 mark)
2445 {
2446         int err;
2447         const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
2448
2449         if (unlikely(afinfo == NULL))
2450                 return -EINVAL;
2451         err = afinfo->get_saddr(net, oif, local, remote, mark);
2452         rcu_read_unlock();
2453         return err;
2454 }
2455
2456 /* Resolve list of templates for the flow, given policy. */
2457
2458 static int
2459 xfrm_tmpl_resolve_one(struct xfrm_policy *policy, const struct flowi *fl,
2460                       struct xfrm_state **xfrm, unsigned short family)
2461 {
2462         struct net *net = xp_net(policy);
2463         int nx;
2464         int i, error;
2465         xfrm_address_t *daddr = xfrm_flowi_daddr(fl, family);
2466         xfrm_address_t *saddr = xfrm_flowi_saddr(fl, family);
2467         xfrm_address_t tmp;
2468
2469         for (nx = 0, i = 0; i < policy->xfrm_nr; i++) {
2470                 struct xfrm_state *x;
2471                 xfrm_address_t *remote = daddr;
2472                 xfrm_address_t *local  = saddr;
2473                 struct xfrm_tmpl *tmpl = &policy->xfrm_vec[i];
2474
2475                 if (tmpl->mode == XFRM_MODE_TUNNEL ||
2476                     tmpl->mode == XFRM_MODE_BEET) {
2477                         remote = &tmpl->id.daddr;
2478                         local = &tmpl->saddr;
2479                         if (xfrm_addr_any(local, tmpl->encap_family)) {
2480                                 error = xfrm_get_saddr(net, fl->flowi_oif,
2481                                                        &tmp, remote,
2482                                                        tmpl->encap_family, 0);
2483                                 if (error)
2484                                         goto fail;
2485                                 local = &tmp;
2486                         }
2487                 }
2488
2489                 x = xfrm_state_find(remote, local, fl, tmpl, policy, &error,
2490                                     family, policy->if_id);
2491
2492                 if (x && x->km.state == XFRM_STATE_VALID) {
2493                         xfrm[nx++] = x;
2494                         daddr = remote;
2495                         saddr = local;
2496                         continue;
2497                 }
2498                 if (x) {
2499                         error = (x->km.state == XFRM_STATE_ERROR ?
2500                                  -EINVAL : -EAGAIN);
2501                         xfrm_state_put(x);
2502                 } else if (error == -ESRCH) {
2503                         error = -EAGAIN;
2504                 }
2505
2506                 if (!tmpl->optional)
2507                         goto fail;
2508         }
2509         return nx;
2510
2511 fail:
2512         for (nx--; nx >= 0; nx--)
2513                 xfrm_state_put(xfrm[nx]);
2514         return error;
2515 }
2516
2517 static int
2518 xfrm_tmpl_resolve(struct xfrm_policy **pols, int npols, const struct flowi *fl,
2519                   struct xfrm_state **xfrm, unsigned short family)
2520 {
2521         struct xfrm_state *tp[XFRM_MAX_DEPTH];
2522         struct xfrm_state **tpp = (npols > 1) ? tp : xfrm;
2523         int cnx = 0;
2524         int error;
2525         int ret;
2526         int i;
2527
2528         for (i = 0; i < npols; i++) {
2529                 if (cnx + pols[i]->xfrm_nr >= XFRM_MAX_DEPTH) {
2530                         error = -ENOBUFS;
2531                         goto fail;
2532                 }
2533
2534                 ret = xfrm_tmpl_resolve_one(pols[i], fl, &tpp[cnx], family);
2535                 if (ret < 0) {
2536                         error = ret;
2537                         goto fail;
2538                 } else
2539                         cnx += ret;
2540         }
2541
2542         /* found states are sorted for outbound processing */
2543         if (npols > 1)
2544                 xfrm_state_sort(xfrm, tpp, cnx, family);
2545
2546         return cnx;
2547
2548  fail:
2549         for (cnx--; cnx >= 0; cnx--)
2550                 xfrm_state_put(tpp[cnx]);
2551         return error;
2552
2553 }
2554
2555 static int xfrm_get_tos(const struct flowi *fl, int family)
2556 {
2557         if (family == AF_INET)
2558                 return IPTOS_RT_MASK & fl->u.ip4.flowi4_tos;
2559
2560         return 0;
2561 }
2562
2563 static inline struct xfrm_dst *xfrm_alloc_dst(struct net *net, int family)
2564 {
2565         const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
2566         struct dst_ops *dst_ops;
2567         struct xfrm_dst *xdst;
2568
2569         if (!afinfo)
2570                 return ERR_PTR(-EINVAL);
2571
2572         switch (family) {
2573         case AF_INET:
2574                 dst_ops = &net->xfrm.xfrm4_dst_ops;
2575                 break;
2576 #if IS_ENABLED(CONFIG_IPV6)
2577         case AF_INET6:
2578                 dst_ops = &net->xfrm.xfrm6_dst_ops;
2579                 break;
2580 #endif
2581         default:
2582                 BUG();
2583         }
2584         xdst = dst_alloc(dst_ops, NULL, DST_OBSOLETE_NONE, 0);
2585
2586         if (likely(xdst)) {
2587                 memset_after(xdst, 0, u.dst);
2588         } else
2589                 xdst = ERR_PTR(-ENOBUFS);
2590
2591         rcu_read_unlock();
2592
2593         return xdst;
2594 }
2595
2596 static void xfrm_init_path(struct xfrm_dst *path, struct dst_entry *dst,
2597                            int nfheader_len)
2598 {
2599         if (dst->ops->family == AF_INET6) {
2600                 struct rt6_info *rt = (struct rt6_info *)dst;
2601                 path->path_cookie = rt6_get_cookie(rt);
2602                 path->u.rt6.rt6i_nfheader_len = nfheader_len;
2603         }
2604 }
2605
2606 static inline int xfrm_fill_dst(struct xfrm_dst *xdst, struct net_device *dev,
2607                                 const struct flowi *fl)
2608 {
2609         const struct xfrm_policy_afinfo *afinfo =
2610                 xfrm_policy_get_afinfo(xdst->u.dst.ops->family);
2611         int err;
2612
2613         if (!afinfo)
2614                 return -EINVAL;
2615
2616         err = afinfo->fill_dst(xdst, dev, fl);
2617
2618         rcu_read_unlock();
2619
2620         return err;
2621 }
2622
2623
2624 /* Allocate chain of dst_entry's, attach known xfrm's, calculate
2625  * all the metrics... Shortly, bundle a bundle.
2626  */
2627
2628 static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy,
2629                                             struct xfrm_state **xfrm,
2630                                             struct xfrm_dst **bundle,
2631                                             int nx,
2632                                             const struct flowi *fl,
2633                                             struct dst_entry *dst)
2634 {
2635         const struct xfrm_state_afinfo *afinfo;
2636         const struct xfrm_mode *inner_mode;
2637         struct net *net = xp_net(policy);
2638         unsigned long now = jiffies;
2639         struct net_device *dev;
2640         struct xfrm_dst *xdst_prev = NULL;
2641         struct xfrm_dst *xdst0 = NULL;
2642         int i = 0;
2643         int err;
2644         int header_len = 0;
2645         int nfheader_len = 0;
2646         int trailer_len = 0;
2647         int tos;
2648         int family = policy->selector.family;
2649         xfrm_address_t saddr, daddr;
2650
2651         xfrm_flowi_addr_get(fl, &saddr, &daddr, family);
2652
2653         tos = xfrm_get_tos(fl, family);
2654
2655         dst_hold(dst);
2656
2657         for (; i < nx; i++) {
2658                 struct xfrm_dst *xdst = xfrm_alloc_dst(net, family);
2659                 struct dst_entry *dst1 = &xdst->u.dst;
2660
2661                 err = PTR_ERR(xdst);
2662                 if (IS_ERR(xdst)) {
2663                         dst_release(dst);
2664                         goto put_states;
2665                 }
2666
2667                 bundle[i] = xdst;
2668                 if (!xdst_prev)
2669                         xdst0 = xdst;
2670                 else
2671                         /* Ref count is taken during xfrm_alloc_dst()
2672                          * No need to do dst_clone() on dst1
2673                          */
2674                         xfrm_dst_set_child(xdst_prev, &xdst->u.dst);
2675
2676                 if (xfrm[i]->sel.family == AF_UNSPEC) {
2677                         inner_mode = xfrm_ip2inner_mode(xfrm[i],
2678                                                         xfrm_af2proto(family));
2679                         if (!inner_mode) {
2680                                 err = -EAFNOSUPPORT;
2681                                 dst_release(dst);
2682                                 goto put_states;
2683                         }
2684                 } else
2685                         inner_mode = &xfrm[i]->inner_mode;
2686
2687                 xdst->route = dst;
2688                 dst_copy_metrics(dst1, dst);
2689
2690                 if (xfrm[i]->props.mode != XFRM_MODE_TRANSPORT) {
2691                         __u32 mark = 0;
2692                         int oif;
2693
2694                         if (xfrm[i]->props.smark.v || xfrm[i]->props.smark.m)
2695                                 mark = xfrm_smark_get(fl->flowi_mark, xfrm[i]);
2696
2697                         if (xfrm[i]->xso.type != XFRM_DEV_OFFLOAD_PACKET)
2698                                 family = xfrm[i]->props.family;
2699
2700                         oif = fl->flowi_oif ? : fl->flowi_l3mdev;
2701                         dst = xfrm_dst_lookup(xfrm[i], tos, oif,
2702                                               &saddr, &daddr, family, mark);
2703                         err = PTR_ERR(dst);
2704                         if (IS_ERR(dst))
2705                                 goto put_states;
2706                 } else
2707                         dst_hold(dst);
2708
2709                 dst1->xfrm = xfrm[i];
2710                 xdst->xfrm_genid = xfrm[i]->genid;
2711
2712                 dst1->obsolete = DST_OBSOLETE_FORCE_CHK;
2713                 dst1->lastuse = now;
2714
2715                 dst1->input = dst_discard;
2716
2717                 rcu_read_lock();
2718                 afinfo = xfrm_state_afinfo_get_rcu(inner_mode->family);
2719                 if (likely(afinfo))
2720                         dst1->output = afinfo->output;
2721                 else
2722                         dst1->output = dst_discard_out;
2723                 rcu_read_unlock();
2724
2725                 xdst_prev = xdst;
2726
2727                 header_len += xfrm[i]->props.header_len;
2728                 if (xfrm[i]->type->flags & XFRM_TYPE_NON_FRAGMENT)
2729                         nfheader_len += xfrm[i]->props.header_len;
2730                 trailer_len += xfrm[i]->props.trailer_len;
2731         }
2732
2733         xfrm_dst_set_child(xdst_prev, dst);
2734         xdst0->path = dst;
2735
2736         err = -ENODEV;
2737         dev = dst->dev;
2738         if (!dev)
2739                 goto free_dst;
2740
2741         xfrm_init_path(xdst0, dst, nfheader_len);
2742         xfrm_init_pmtu(bundle, nx);
2743
2744         for (xdst_prev = xdst0; xdst_prev != (struct xfrm_dst *)dst;
2745              xdst_prev = (struct xfrm_dst *) xfrm_dst_child(&xdst_prev->u.dst)) {
2746                 err = xfrm_fill_dst(xdst_prev, dev, fl);
2747                 if (err)
2748                         goto free_dst;
2749
2750                 xdst_prev->u.dst.header_len = header_len;
2751                 xdst_prev->u.dst.trailer_len = trailer_len;
2752                 header_len -= xdst_prev->u.dst.xfrm->props.header_len;
2753                 trailer_len -= xdst_prev->u.dst.xfrm->props.trailer_len;
2754         }
2755
2756         return &xdst0->u.dst;
2757
2758 put_states:
2759         for (; i < nx; i++)
2760                 xfrm_state_put(xfrm[i]);
2761 free_dst:
2762         if (xdst0)
2763                 dst_release_immediate(&xdst0->u.dst);
2764
2765         return ERR_PTR(err);
2766 }
2767
2768 static int xfrm_expand_policies(const struct flowi *fl, u16 family,
2769                                 struct xfrm_policy **pols,
2770                                 int *num_pols, int *num_xfrms)
2771 {
2772         int i;
2773
2774         if (*num_pols == 0 || !pols[0]) {
2775                 *num_pols = 0;
2776                 *num_xfrms = 0;
2777                 return 0;
2778         }
2779         if (IS_ERR(pols[0])) {
2780                 *num_pols = 0;
2781                 return PTR_ERR(pols[0]);
2782         }
2783
2784         *num_xfrms = pols[0]->xfrm_nr;
2785
2786 #ifdef CONFIG_XFRM_SUB_POLICY
2787         if (pols[0]->action == XFRM_POLICY_ALLOW &&
2788             pols[0]->type != XFRM_POLICY_TYPE_MAIN) {
2789                 pols[1] = xfrm_policy_lookup_bytype(xp_net(pols[0]),
2790                                                     XFRM_POLICY_TYPE_MAIN,
2791                                                     fl, family,
2792                                                     XFRM_POLICY_OUT,
2793                                                     pols[0]->if_id);
2794                 if (pols[1]) {
2795                         if (IS_ERR(pols[1])) {
2796                                 xfrm_pols_put(pols, *num_pols);
2797                                 *num_pols = 0;
2798                                 return PTR_ERR(pols[1]);
2799                         }
2800                         (*num_pols)++;
2801                         (*num_xfrms) += pols[1]->xfrm_nr;
2802                 }
2803         }
2804 #endif
2805         for (i = 0; i < *num_pols; i++) {
2806                 if (pols[i]->action != XFRM_POLICY_ALLOW) {
2807                         *num_xfrms = -1;
2808                         break;
2809                 }
2810         }
2811
2812         return 0;
2813
2814 }
2815
2816 static struct xfrm_dst *
2817 xfrm_resolve_and_create_bundle(struct xfrm_policy **pols, int num_pols,
2818                                const struct flowi *fl, u16 family,
2819                                struct dst_entry *dst_orig)
2820 {
2821         struct net *net = xp_net(pols[0]);
2822         struct xfrm_state *xfrm[XFRM_MAX_DEPTH];
2823         struct xfrm_dst *bundle[XFRM_MAX_DEPTH];
2824         struct xfrm_dst *xdst;
2825         struct dst_entry *dst;
2826         int err;
2827
2828         /* Try to instantiate a bundle */
2829         err = xfrm_tmpl_resolve(pols, num_pols, fl, xfrm, family);
2830         if (err <= 0) {
2831                 if (err == 0)
2832                         return NULL;
2833
2834                 if (err != -EAGAIN)
2835                         XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR);
2836                 return ERR_PTR(err);
2837         }
2838
2839         dst = xfrm_bundle_create(pols[0], xfrm, bundle, err, fl, dst_orig);
2840         if (IS_ERR(dst)) {
2841                 XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTBUNDLEGENERROR);
2842                 return ERR_CAST(dst);
2843         }
2844
2845         xdst = (struct xfrm_dst *)dst;
2846         xdst->num_xfrms = err;
2847         xdst->num_pols = num_pols;
2848         memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols);
2849         xdst->policy_genid = atomic_read(&pols[0]->genid);
2850
2851         return xdst;
2852 }
2853
2854 static void xfrm_policy_queue_process(struct timer_list *t)
2855 {
2856         struct sk_buff *skb;
2857         struct sock *sk;
2858         struct dst_entry *dst;
2859         struct xfrm_policy *pol = from_timer(pol, t, polq.hold_timer);
2860         struct net *net = xp_net(pol);
2861         struct xfrm_policy_queue *pq = &pol->polq;
2862         struct flowi fl;
2863         struct sk_buff_head list;
2864         __u32 skb_mark;
2865
2866         spin_lock(&pq->hold_queue.lock);
2867         skb = skb_peek(&pq->hold_queue);
2868         if (!skb) {
2869                 spin_unlock(&pq->hold_queue.lock);
2870                 goto out;
2871         }
2872         dst = skb_dst(skb);
2873         sk = skb->sk;
2874
2875         /* Fixup the mark to support VTI. */
2876         skb_mark = skb->mark;
2877         skb->mark = pol->mark.v;
2878         xfrm_decode_session(net, skb, &fl, dst->ops->family);
2879         skb->mark = skb_mark;
2880         spin_unlock(&pq->hold_queue.lock);
2881
2882         dst_hold(xfrm_dst_path(dst));
2883         dst = xfrm_lookup(net, xfrm_dst_path(dst), &fl, sk, XFRM_LOOKUP_QUEUE);
2884         if (IS_ERR(dst))
2885                 goto purge_queue;
2886
2887         if (dst->flags & DST_XFRM_QUEUE) {
2888                 dst_release(dst);
2889
2890                 if (pq->timeout >= XFRM_QUEUE_TMO_MAX)
2891                         goto purge_queue;
2892
2893                 pq->timeout = pq->timeout << 1;
2894                 if (!mod_timer(&pq->hold_timer, jiffies + pq->timeout))
2895                         xfrm_pol_hold(pol);
2896                 goto out;
2897         }
2898
2899         dst_release(dst);
2900
2901         __skb_queue_head_init(&list);
2902
2903         spin_lock(&pq->hold_queue.lock);
2904         pq->timeout = 0;
2905         skb_queue_splice_init(&pq->hold_queue, &list);
2906         spin_unlock(&pq->hold_queue.lock);
2907
2908         while (!skb_queue_empty(&list)) {
2909                 skb = __skb_dequeue(&list);
2910
2911                 /* Fixup the mark to support VTI. */
2912                 skb_mark = skb->mark;
2913                 skb->mark = pol->mark.v;
2914                 xfrm_decode_session(net, skb, &fl, skb_dst(skb)->ops->family);
2915                 skb->mark = skb_mark;
2916
2917                 dst_hold(xfrm_dst_path(skb_dst(skb)));
2918                 dst = xfrm_lookup(net, xfrm_dst_path(skb_dst(skb)), &fl, skb->sk, 0);
2919                 if (IS_ERR(dst)) {
2920                         kfree_skb(skb);
2921                         continue;
2922                 }
2923
2924                 nf_reset_ct(skb);
2925                 skb_dst_drop(skb);
2926                 skb_dst_set(skb, dst);
2927
2928                 dst_output(net, skb->sk, skb);
2929         }
2930
2931 out:
2932         xfrm_pol_put(pol);
2933         return;
2934
2935 purge_queue:
2936         pq->timeout = 0;
2937         skb_queue_purge(&pq->hold_queue);
2938         xfrm_pol_put(pol);
2939 }
2940
2941 static int xdst_queue_output(struct net *net, struct sock *sk, struct sk_buff *skb)
2942 {
2943         unsigned long sched_next;
2944         struct dst_entry *dst = skb_dst(skb);
2945         struct xfrm_dst *xdst = (struct xfrm_dst *) dst;
2946         struct xfrm_policy *pol = xdst->pols[0];
2947         struct xfrm_policy_queue *pq = &pol->polq;
2948
2949         if (unlikely(skb_fclone_busy(sk, skb))) {
2950                 kfree_skb(skb);
2951                 return 0;
2952         }
2953
2954         if (pq->hold_queue.qlen > XFRM_MAX_QUEUE_LEN) {
2955                 kfree_skb(skb);
2956                 return -EAGAIN;
2957         }
2958
2959         skb_dst_force(skb);
2960
2961         spin_lock_bh(&pq->hold_queue.lock);
2962
2963         if (!pq->timeout)
2964                 pq->timeout = XFRM_QUEUE_TMO_MIN;
2965
2966         sched_next = jiffies + pq->timeout;
2967
2968         if (del_timer(&pq->hold_timer)) {
2969                 if (time_before(pq->hold_timer.expires, sched_next))
2970                         sched_next = pq->hold_timer.expires;
2971                 xfrm_pol_put(pol);
2972         }
2973
2974         __skb_queue_tail(&pq->hold_queue, skb);
2975         if (!mod_timer(&pq->hold_timer, sched_next))
2976                 xfrm_pol_hold(pol);
2977
2978         spin_unlock_bh(&pq->hold_queue.lock);
2979
2980         return 0;
2981 }
2982
2983 static struct xfrm_dst *xfrm_create_dummy_bundle(struct net *net,
2984                                                  struct xfrm_flo *xflo,
2985                                                  const struct flowi *fl,
2986                                                  int num_xfrms,
2987                                                  u16 family)
2988 {
2989         int err;
2990         struct net_device *dev;
2991         struct dst_entry *dst;
2992         struct dst_entry *dst1;
2993         struct xfrm_dst *xdst;
2994
2995         xdst = xfrm_alloc_dst(net, family);
2996         if (IS_ERR(xdst))
2997                 return xdst;
2998
2999         if (!(xflo->flags & XFRM_LOOKUP_QUEUE) ||
3000             net->xfrm.sysctl_larval_drop ||
3001             num_xfrms <= 0)
3002                 return xdst;
3003
3004         dst = xflo->dst_orig;
3005         dst1 = &xdst->u.dst;
3006         dst_hold(dst);
3007         xdst->route = dst;
3008
3009         dst_copy_metrics(dst1, dst);
3010
3011         dst1->obsolete = DST_OBSOLETE_FORCE_CHK;
3012         dst1->flags |= DST_XFRM_QUEUE;
3013         dst1->lastuse = jiffies;
3014
3015         dst1->input = dst_discard;
3016         dst1->output = xdst_queue_output;
3017
3018         dst_hold(dst);
3019         xfrm_dst_set_child(xdst, dst);
3020         xdst->path = dst;
3021
3022         xfrm_init_path((struct xfrm_dst *)dst1, dst, 0);
3023
3024         err = -ENODEV;
3025         dev = dst->dev;
3026         if (!dev)
3027                 goto free_dst;
3028
3029         err = xfrm_fill_dst(xdst, dev, fl);
3030         if (err)
3031                 goto free_dst;
3032
3033 out:
3034         return xdst;
3035
3036 free_dst:
3037         dst_release(dst1);
3038         xdst = ERR_PTR(err);
3039         goto out;
3040 }
3041
3042 static struct xfrm_dst *xfrm_bundle_lookup(struct net *net,
3043                                            const struct flowi *fl,
3044                                            u16 family, u8 dir,
3045                                            struct xfrm_flo *xflo, u32 if_id)
3046 {
3047         struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3048         int num_pols = 0, num_xfrms = 0, err;
3049         struct xfrm_dst *xdst;
3050
3051         /* Resolve policies to use if we couldn't get them from
3052          * previous cache entry */
3053         num_pols = 1;
3054         pols[0] = xfrm_policy_lookup(net, fl, family, dir, if_id);
3055         err = xfrm_expand_policies(fl, family, pols,
3056                                            &num_pols, &num_xfrms);
3057         if (err < 0)
3058                 goto inc_error;
3059         if (num_pols == 0)
3060                 return NULL;
3061         if (num_xfrms <= 0)
3062                 goto make_dummy_bundle;
3063
3064         xdst = xfrm_resolve_and_create_bundle(pols, num_pols, fl, family,
3065                                               xflo->dst_orig);
3066         if (IS_ERR(xdst)) {
3067                 err = PTR_ERR(xdst);
3068                 if (err == -EREMOTE) {
3069                         xfrm_pols_put(pols, num_pols);
3070                         return NULL;
3071                 }
3072
3073                 if (err != -EAGAIN)
3074                         goto error;
3075                 goto make_dummy_bundle;
3076         } else if (xdst == NULL) {
3077                 num_xfrms = 0;
3078                 goto make_dummy_bundle;
3079         }
3080
3081         return xdst;
3082
3083 make_dummy_bundle:
3084         /* We found policies, but there's no bundles to instantiate:
3085          * either because the policy blocks, has no transformations or
3086          * we could not build template (no xfrm_states).*/
3087         xdst = xfrm_create_dummy_bundle(net, xflo, fl, num_xfrms, family);
3088         if (IS_ERR(xdst)) {
3089                 xfrm_pols_put(pols, num_pols);
3090                 return ERR_CAST(xdst);
3091         }
3092         xdst->num_pols = num_pols;
3093         xdst->num_xfrms = num_xfrms;
3094         memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols);
3095
3096         return xdst;
3097
3098 inc_error:
3099         XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR);
3100 error:
3101         xfrm_pols_put(pols, num_pols);
3102         return ERR_PTR(err);
3103 }
3104
3105 static struct dst_entry *make_blackhole(struct net *net, u16 family,
3106                                         struct dst_entry *dst_orig)
3107 {
3108         const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
3109         struct dst_entry *ret;
3110
3111         if (!afinfo) {
3112                 dst_release(dst_orig);
3113                 return ERR_PTR(-EINVAL);
3114         } else {
3115                 ret = afinfo->blackhole_route(net, dst_orig);
3116         }
3117         rcu_read_unlock();
3118
3119         return ret;
3120 }
3121
3122 /* Finds/creates a bundle for given flow and if_id
3123  *
3124  * At the moment we eat a raw IP route. Mostly to speed up lookups
3125  * on interfaces with disabled IPsec.
3126  *
3127  * xfrm_lookup uses an if_id of 0 by default, and is provided for
3128  * compatibility
3129  */
3130 struct dst_entry *xfrm_lookup_with_ifid(struct net *net,
3131                                         struct dst_entry *dst_orig,
3132                                         const struct flowi *fl,
3133                                         const struct sock *sk,
3134                                         int flags, u32 if_id)
3135 {
3136         struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3137         struct xfrm_dst *xdst;
3138         struct dst_entry *dst, *route;
3139         u16 family = dst_orig->ops->family;
3140         u8 dir = XFRM_POLICY_OUT;
3141         int i, err, num_pols, num_xfrms = 0, drop_pols = 0;
3142
3143         dst = NULL;
3144         xdst = NULL;
3145         route = NULL;
3146
3147         sk = sk_const_to_full_sk(sk);
3148         if (sk && sk->sk_policy[XFRM_POLICY_OUT]) {
3149                 num_pols = 1;
3150                 pols[0] = xfrm_sk_policy_lookup(sk, XFRM_POLICY_OUT, fl, family,
3151                                                 if_id);
3152                 err = xfrm_expand_policies(fl, family, pols,
3153                                            &num_pols, &num_xfrms);
3154                 if (err < 0)
3155                         goto dropdst;
3156
3157                 if (num_pols) {
3158                         if (num_xfrms <= 0) {
3159                                 drop_pols = num_pols;
3160                                 goto no_transform;
3161                         }
3162
3163                         xdst = xfrm_resolve_and_create_bundle(
3164                                         pols, num_pols, fl,
3165                                         family, dst_orig);
3166
3167                         if (IS_ERR(xdst)) {
3168                                 xfrm_pols_put(pols, num_pols);
3169                                 err = PTR_ERR(xdst);
3170                                 if (err == -EREMOTE)
3171                                         goto nopol;
3172
3173                                 goto dropdst;
3174                         } else if (xdst == NULL) {
3175                                 num_xfrms = 0;
3176                                 drop_pols = num_pols;
3177                                 goto no_transform;
3178                         }
3179
3180                         route = xdst->route;
3181                 }
3182         }
3183
3184         if (xdst == NULL) {
3185                 struct xfrm_flo xflo;
3186
3187                 xflo.dst_orig = dst_orig;
3188                 xflo.flags = flags;
3189
3190                 /* To accelerate a bit...  */
3191                 if (!if_id && ((dst_orig->flags & DST_NOXFRM) ||
3192                                !net->xfrm.policy_count[XFRM_POLICY_OUT]))
3193                         goto nopol;
3194
3195                 xdst = xfrm_bundle_lookup(net, fl, family, dir, &xflo, if_id);
3196                 if (xdst == NULL)
3197                         goto nopol;
3198                 if (IS_ERR(xdst)) {
3199                         err = PTR_ERR(xdst);
3200                         goto dropdst;
3201                 }
3202
3203                 num_pols = xdst->num_pols;
3204                 num_xfrms = xdst->num_xfrms;
3205                 memcpy(pols, xdst->pols, sizeof(struct xfrm_policy *) * num_pols);
3206                 route = xdst->route;
3207         }
3208
3209         dst = &xdst->u.dst;
3210         if (route == NULL && num_xfrms > 0) {
3211                 /* The only case when xfrm_bundle_lookup() returns a
3212                  * bundle with null route, is when the template could
3213                  * not be resolved. It means policies are there, but
3214                  * bundle could not be created, since we don't yet
3215                  * have the xfrm_state's. We need to wait for KM to
3216                  * negotiate new SA's or bail out with error.*/
3217                 if (net->xfrm.sysctl_larval_drop) {
3218                         XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES);
3219                         err = -EREMOTE;
3220                         goto error;
3221                 }
3222
3223                 err = -EAGAIN;
3224
3225                 XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES);
3226                 goto error;
3227         }
3228
3229 no_transform:
3230         if (num_pols == 0)
3231                 goto nopol;
3232
3233         if ((flags & XFRM_LOOKUP_ICMP) &&
3234             !(pols[0]->flags & XFRM_POLICY_ICMP)) {
3235                 err = -ENOENT;
3236                 goto error;
3237         }
3238
3239         for (i = 0; i < num_pols; i++)
3240                 WRITE_ONCE(pols[i]->curlft.use_time, ktime_get_real_seconds());
3241
3242         if (num_xfrms < 0) {
3243                 /* Prohibit the flow */
3244                 XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLBLOCK);
3245                 err = -EPERM;
3246                 goto error;
3247         } else if (num_xfrms > 0) {
3248                 /* Flow transformed */
3249                 dst_release(dst_orig);
3250         } else {
3251                 /* Flow passes untransformed */
3252                 dst_release(dst);
3253                 dst = dst_orig;
3254         }
3255 ok:
3256         xfrm_pols_put(pols, drop_pols);
3257         if (dst && dst->xfrm &&
3258             dst->xfrm->props.mode == XFRM_MODE_TUNNEL)
3259                 dst->flags |= DST_XFRM_TUNNEL;
3260         return dst;
3261
3262 nopol:
3263         if ((!dst_orig->dev || !(dst_orig->dev->flags & IFF_LOOPBACK)) &&
3264             net->xfrm.policy_default[dir] == XFRM_USERPOLICY_BLOCK) {
3265                 err = -EPERM;
3266                 goto error;
3267         }
3268         if (!(flags & XFRM_LOOKUP_ICMP)) {
3269                 dst = dst_orig;
3270                 goto ok;
3271         }
3272         err = -ENOENT;
3273 error:
3274         dst_release(dst);
3275 dropdst:
3276         if (!(flags & XFRM_LOOKUP_KEEP_DST_REF))
3277                 dst_release(dst_orig);
3278         xfrm_pols_put(pols, drop_pols);
3279         return ERR_PTR(err);
3280 }
3281 EXPORT_SYMBOL(xfrm_lookup_with_ifid);
3282
3283 /* Main function: finds/creates a bundle for given flow.
3284  *
3285  * At the moment we eat a raw IP route. Mostly to speed up lookups
3286  * on interfaces with disabled IPsec.
3287  */
3288 struct dst_entry *xfrm_lookup(struct net *net, struct dst_entry *dst_orig,
3289                               const struct flowi *fl, const struct sock *sk,
3290                               int flags)
3291 {
3292         return xfrm_lookup_with_ifid(net, dst_orig, fl, sk, flags, 0);
3293 }
3294 EXPORT_SYMBOL(xfrm_lookup);
3295
3296 /* Callers of xfrm_lookup_route() must ensure a call to dst_output().
3297  * Otherwise we may send out blackholed packets.
3298  */
3299 struct dst_entry *xfrm_lookup_route(struct net *net, struct dst_entry *dst_orig,
3300                                     const struct flowi *fl,
3301                                     const struct sock *sk, int flags)
3302 {
3303         struct dst_entry *dst = xfrm_lookup(net, dst_orig, fl, sk,
3304                                             flags | XFRM_LOOKUP_QUEUE |
3305                                             XFRM_LOOKUP_KEEP_DST_REF);
3306
3307         if (PTR_ERR(dst) == -EREMOTE)
3308                 return make_blackhole(net, dst_orig->ops->family, dst_orig);
3309
3310         if (IS_ERR(dst))
3311                 dst_release(dst_orig);
3312
3313         return dst;
3314 }
3315 EXPORT_SYMBOL(xfrm_lookup_route);
3316
3317 static inline int
3318 xfrm_secpath_reject(int idx, struct sk_buff *skb, const struct flowi *fl)
3319 {
3320         struct sec_path *sp = skb_sec_path(skb);
3321         struct xfrm_state *x;
3322
3323         if (!sp || idx < 0 || idx >= sp->len)
3324                 return 0;
3325         x = sp->xvec[idx];
3326         if (!x->type->reject)
3327                 return 0;
3328         return x->type->reject(x, skb, fl);
3329 }
3330
3331 /* When skb is transformed back to its "native" form, we have to
3332  * check policy restrictions. At the moment we make this in maximally
3333  * stupid way. Shame on me. :-) Of course, connected sockets must
3334  * have policy cached at them.
3335  */
3336
3337 static inline int
3338 xfrm_state_ok(const struct xfrm_tmpl *tmpl, const struct xfrm_state *x,
3339               unsigned short family, u32 if_id)
3340 {
3341         if (xfrm_state_kern(x))
3342                 return tmpl->optional && !xfrm_state_addr_cmp(tmpl, x, tmpl->encap_family);
3343         return  x->id.proto == tmpl->id.proto &&
3344                 (x->id.spi == tmpl->id.spi || !tmpl->id.spi) &&
3345                 (x->props.reqid == tmpl->reqid || !tmpl->reqid) &&
3346                 x->props.mode == tmpl->mode &&
3347                 (tmpl->allalgs || (tmpl->aalgos & (1<<x->props.aalgo)) ||
3348                  !(xfrm_id_proto_match(tmpl->id.proto, IPSEC_PROTO_ANY))) &&
3349                 !(x->props.mode != XFRM_MODE_TRANSPORT &&
3350                   xfrm_state_addr_cmp(tmpl, x, family)) &&
3351                 (if_id == 0 || if_id == x->if_id);
3352 }
3353
3354 /*
3355  * 0 or more than 0 is returned when validation is succeeded (either bypass
3356  * because of optional transport mode, or next index of the matched secpath
3357  * state with the template.
3358  * -1 is returned when no matching template is found.
3359  * Otherwise "-2 - errored_index" is returned.
3360  */
3361 static inline int
3362 xfrm_policy_ok(const struct xfrm_tmpl *tmpl, const struct sec_path *sp, int start,
3363                unsigned short family, u32 if_id)
3364 {
3365         int idx = start;
3366
3367         if (tmpl->optional) {
3368                 if (tmpl->mode == XFRM_MODE_TRANSPORT)
3369                         return start;
3370         } else
3371                 start = -1;
3372         for (; idx < sp->len; idx++) {
3373                 if (xfrm_state_ok(tmpl, sp->xvec[idx], family, if_id))
3374                         return ++idx;
3375                 if (sp->xvec[idx]->props.mode != XFRM_MODE_TRANSPORT) {
3376                         if (idx < sp->verified_cnt) {
3377                                 /* Secpath entry previously verified, consider optional and
3378                                  * continue searching
3379                                  */
3380                                 continue;
3381                         }
3382
3383                         if (start == -1)
3384                                 start = -2-idx;
3385                         break;
3386                 }
3387         }
3388         return start;
3389 }
3390
3391 static void
3392 decode_session4(const struct xfrm_flow_keys *flkeys, struct flowi *fl, bool reverse)
3393 {
3394         struct flowi4 *fl4 = &fl->u.ip4;
3395
3396         memset(fl4, 0, sizeof(struct flowi4));
3397
3398         if (reverse) {
3399                 fl4->saddr = flkeys->addrs.ipv4.dst;
3400                 fl4->daddr = flkeys->addrs.ipv4.src;
3401                 fl4->fl4_sport = flkeys->ports.dst;
3402                 fl4->fl4_dport = flkeys->ports.src;
3403         } else {
3404                 fl4->saddr = flkeys->addrs.ipv4.src;
3405                 fl4->daddr = flkeys->addrs.ipv4.dst;
3406                 fl4->fl4_sport = flkeys->ports.src;
3407                 fl4->fl4_dport = flkeys->ports.dst;
3408         }
3409
3410         switch (flkeys->basic.ip_proto) {
3411         case IPPROTO_GRE:
3412                 fl4->fl4_gre_key = flkeys->gre.keyid;
3413                 break;
3414         case IPPROTO_ICMP:
3415                 fl4->fl4_icmp_type = flkeys->icmp.type;
3416                 fl4->fl4_icmp_code = flkeys->icmp.code;
3417                 break;
3418         }
3419
3420         fl4->flowi4_proto = flkeys->basic.ip_proto;
3421         fl4->flowi4_tos = flkeys->ip.tos & ~INET_ECN_MASK;
3422 }
3423
3424 #if IS_ENABLED(CONFIG_IPV6)
3425 static void
3426 decode_session6(const struct xfrm_flow_keys *flkeys, struct flowi *fl, bool reverse)
3427 {
3428         struct flowi6 *fl6 = &fl->u.ip6;
3429
3430         memset(fl6, 0, sizeof(struct flowi6));
3431
3432         if (reverse) {
3433                 fl6->saddr = flkeys->addrs.ipv6.dst;
3434                 fl6->daddr = flkeys->addrs.ipv6.src;
3435                 fl6->fl6_sport = flkeys->ports.dst;
3436                 fl6->fl6_dport = flkeys->ports.src;
3437         } else {
3438                 fl6->saddr = flkeys->addrs.ipv6.src;
3439                 fl6->daddr = flkeys->addrs.ipv6.dst;
3440                 fl6->fl6_sport = flkeys->ports.src;
3441                 fl6->fl6_dport = flkeys->ports.dst;
3442         }
3443
3444         switch (flkeys->basic.ip_proto) {
3445         case IPPROTO_GRE:
3446                 fl6->fl6_gre_key = flkeys->gre.keyid;
3447                 break;
3448         case IPPROTO_ICMPV6:
3449                 fl6->fl6_icmp_type = flkeys->icmp.type;
3450                 fl6->fl6_icmp_code = flkeys->icmp.code;
3451                 break;
3452         }
3453
3454         fl6->flowi6_proto = flkeys->basic.ip_proto;
3455 }
3456 #endif
3457
3458 int __xfrm_decode_session(struct net *net, struct sk_buff *skb, struct flowi *fl,
3459                           unsigned int family, int reverse)
3460 {
3461         struct xfrm_flow_keys flkeys;
3462
3463         memset(&flkeys, 0, sizeof(flkeys));
3464         __skb_flow_dissect(net, skb, &xfrm_session_dissector, &flkeys,
3465                            NULL, 0, 0, 0, FLOW_DISSECTOR_F_STOP_AT_ENCAP);
3466
3467         switch (family) {
3468         case AF_INET:
3469                 decode_session4(&flkeys, fl, reverse);
3470                 break;
3471 #if IS_ENABLED(CONFIG_IPV6)
3472         case AF_INET6:
3473                 decode_session6(&flkeys, fl, reverse);
3474                 break;
3475 #endif
3476         default:
3477                 return -EAFNOSUPPORT;
3478         }
3479
3480         fl->flowi_mark = skb->mark;
3481         if (reverse) {
3482                 fl->flowi_oif = skb->skb_iif;
3483         } else {
3484                 int oif = 0;
3485
3486                 if (skb_dst(skb) && skb_dst(skb)->dev)
3487                         oif = skb_dst(skb)->dev->ifindex;
3488
3489                 fl->flowi_oif = oif;
3490         }
3491
3492         return security_xfrm_decode_session(skb, &fl->flowi_secid);
3493 }
3494 EXPORT_SYMBOL(__xfrm_decode_session);
3495
3496 static inline int secpath_has_nontransport(const struct sec_path *sp, int k, int *idxp)
3497 {
3498         for (; k < sp->len; k++) {
3499                 if (sp->xvec[k]->props.mode != XFRM_MODE_TRANSPORT) {
3500                         *idxp = k;
3501                         return 1;
3502                 }
3503         }
3504
3505         return 0;
3506 }
3507
3508 int __xfrm_policy_check(struct sock *sk, int dir, struct sk_buff *skb,
3509                         unsigned short family)
3510 {
3511         struct net *net = dev_net(skb->dev);
3512         struct xfrm_policy *pol;
3513         struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3514         int npols = 0;
3515         int xfrm_nr;
3516         int pi;
3517         int reverse;
3518         struct flowi fl;
3519         int xerr_idx = -1;
3520         const struct xfrm_if_cb *ifcb;
3521         struct sec_path *sp;
3522         u32 if_id = 0;
3523
3524         rcu_read_lock();
3525         ifcb = xfrm_if_get_cb();
3526
3527         if (ifcb) {
3528                 struct xfrm_if_decode_session_result r;
3529
3530                 if (ifcb->decode_session(skb, family, &r)) {
3531                         if_id = r.if_id;
3532                         net = r.net;
3533                 }
3534         }
3535         rcu_read_unlock();
3536
3537         reverse = dir & ~XFRM_POLICY_MASK;
3538         dir &= XFRM_POLICY_MASK;
3539
3540         if (__xfrm_decode_session(net, skb, &fl, family, reverse) < 0) {
3541                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR);
3542                 return 0;
3543         }
3544
3545         nf_nat_decode_session(skb, &fl, family);
3546
3547         /* First, check used SA against their selectors. */
3548         sp = skb_sec_path(skb);
3549         if (sp) {
3550                 int i;
3551
3552                 for (i = sp->len - 1; i >= 0; i--) {
3553                         struct xfrm_state *x = sp->xvec[i];
3554                         if (!xfrm_selector_match(&x->sel, &fl, family)) {
3555                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMISMATCH);
3556                                 return 0;
3557                         }
3558                 }
3559         }
3560
3561         pol = NULL;
3562         sk = sk_to_full_sk(sk);
3563         if (sk && sk->sk_policy[dir]) {
3564                 pol = xfrm_sk_policy_lookup(sk, dir, &fl, family, if_id);
3565                 if (IS_ERR(pol)) {
3566                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3567                         return 0;
3568                 }
3569         }
3570
3571         if (!pol)
3572                 pol = xfrm_policy_lookup(net, &fl, family, dir, if_id);
3573
3574         if (IS_ERR(pol)) {
3575                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3576                 return 0;
3577         }
3578
3579         if (!pol) {
3580                 if (net->xfrm.policy_default[dir] == XFRM_USERPOLICY_BLOCK) {
3581                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOPOLS);
3582                         return 0;
3583                 }
3584
3585                 if (sp && secpath_has_nontransport(sp, 0, &xerr_idx)) {
3586                         xfrm_secpath_reject(xerr_idx, skb, &fl);
3587                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOPOLS);
3588                         return 0;
3589                 }
3590                 return 1;
3591         }
3592
3593         /* This lockless write can happen from different cpus. */
3594         WRITE_ONCE(pol->curlft.use_time, ktime_get_real_seconds());
3595
3596         pols[0] = pol;
3597         npols++;
3598 #ifdef CONFIG_XFRM_SUB_POLICY
3599         if (pols[0]->type != XFRM_POLICY_TYPE_MAIN) {
3600                 pols[1] = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN,
3601                                                     &fl, family,
3602                                                     XFRM_POLICY_IN, if_id);
3603                 if (pols[1]) {
3604                         if (IS_ERR(pols[1])) {
3605                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3606                                 xfrm_pol_put(pols[0]);
3607                                 return 0;
3608                         }
3609                         /* This write can happen from different cpus. */
3610                         WRITE_ONCE(pols[1]->curlft.use_time,
3611                                    ktime_get_real_seconds());
3612                         npols++;
3613                 }
3614         }
3615 #endif
3616
3617         if (pol->action == XFRM_POLICY_ALLOW) {
3618                 static struct sec_path dummy;
3619                 struct xfrm_tmpl *tp[XFRM_MAX_DEPTH];
3620                 struct xfrm_tmpl *stp[XFRM_MAX_DEPTH];
3621                 struct xfrm_tmpl **tpp = tp;
3622                 int ti = 0;
3623                 int i, k;
3624
3625                 sp = skb_sec_path(skb);
3626                 if (!sp)
3627                         sp = &dummy;
3628
3629                 for (pi = 0; pi < npols; pi++) {
3630                         if (pols[pi] != pol &&
3631                             pols[pi]->action != XFRM_POLICY_ALLOW) {
3632                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK);
3633                                 goto reject;
3634                         }
3635                         if (ti + pols[pi]->xfrm_nr >= XFRM_MAX_DEPTH) {
3636                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINBUFFERERROR);
3637                                 goto reject_error;
3638                         }
3639                         for (i = 0; i < pols[pi]->xfrm_nr; i++)
3640                                 tpp[ti++] = &pols[pi]->xfrm_vec[i];
3641                 }
3642                 xfrm_nr = ti;
3643
3644                 if (npols > 1) {
3645                         xfrm_tmpl_sort(stp, tpp, xfrm_nr, family);
3646                         tpp = stp;
3647                 }
3648
3649                 /* For each tunnel xfrm, find the first matching tmpl.
3650                  * For each tmpl before that, find corresponding xfrm.
3651                  * Order is _important_. Later we will implement
3652                  * some barriers, but at the moment barriers
3653                  * are implied between each two transformations.
3654                  * Upon success, marks secpath entries as having been
3655                  * verified to allow them to be skipped in future policy
3656                  * checks (e.g. nested tunnels).
3657                  */
3658                 for (i = xfrm_nr-1, k = 0; i >= 0; i--) {
3659                         k = xfrm_policy_ok(tpp[i], sp, k, family, if_id);
3660                         if (k < 0) {
3661                                 if (k < -1)
3662                                         /* "-2 - errored_index" returned */
3663                                         xerr_idx = -(2+k);
3664                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH);
3665                                 goto reject;
3666                         }
3667                 }
3668
3669                 if (secpath_has_nontransport(sp, k, &xerr_idx)) {
3670                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH);
3671                         goto reject;
3672                 }
3673
3674                 xfrm_pols_put(pols, npols);
3675                 sp->verified_cnt = k;
3676
3677                 return 1;
3678         }
3679         XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK);
3680
3681 reject:
3682         xfrm_secpath_reject(xerr_idx, skb, &fl);
3683 reject_error:
3684         xfrm_pols_put(pols, npols);
3685         return 0;
3686 }
3687 EXPORT_SYMBOL(__xfrm_policy_check);
3688
3689 int __xfrm_route_forward(struct sk_buff *skb, unsigned short family)
3690 {
3691         struct net *net = dev_net(skb->dev);
3692         struct flowi fl;
3693         struct dst_entry *dst;
3694         int res = 1;
3695
3696         if (xfrm_decode_session(net, skb, &fl, family) < 0) {
3697                 XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR);
3698                 return 0;
3699         }
3700
3701         skb_dst_force(skb);
3702         if (!skb_dst(skb)) {
3703                 XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR);
3704                 return 0;
3705         }
3706
3707         dst = xfrm_lookup(net, skb_dst(skb), &fl, NULL, XFRM_LOOKUP_QUEUE);
3708         if (IS_ERR(dst)) {
3709                 res = 0;
3710                 dst = NULL;
3711         }
3712         skb_dst_set(skb, dst);
3713         return res;
3714 }
3715 EXPORT_SYMBOL(__xfrm_route_forward);
3716
3717 /* Optimize later using cookies and generation ids. */
3718
3719 static struct dst_entry *xfrm_dst_check(struct dst_entry *dst, u32 cookie)
3720 {
3721         /* Code (such as __xfrm4_bundle_create()) sets dst->obsolete
3722          * to DST_OBSOLETE_FORCE_CHK to force all XFRM destinations to
3723          * get validated by dst_ops->check on every use.  We do this
3724          * because when a normal route referenced by an XFRM dst is
3725          * obsoleted we do not go looking around for all parent
3726          * referencing XFRM dsts so that we can invalidate them.  It
3727          * is just too much work.  Instead we make the checks here on
3728          * every use.  For example:
3729          *
3730          *      XFRM dst A --> IPv4 dst X
3731          *
3732          * X is the "xdst->route" of A (X is also the "dst->path" of A
3733          * in this example).  If X is marked obsolete, "A" will not
3734          * notice.  That's what we are validating here via the
3735          * stale_bundle() check.
3736          *
3737          * When a dst is removed from the fib tree, DST_OBSOLETE_DEAD will
3738          * be marked on it.
3739          * This will force stale_bundle() to fail on any xdst bundle with
3740          * this dst linked in it.
3741          */
3742         if (dst->obsolete < 0 && !stale_bundle(dst))
3743                 return dst;
3744
3745         return NULL;
3746 }
3747
3748 static int stale_bundle(struct dst_entry *dst)
3749 {
3750         return !xfrm_bundle_ok((struct xfrm_dst *)dst);
3751 }
3752
3753 void xfrm_dst_ifdown(struct dst_entry *dst, struct net_device *dev)
3754 {
3755         while ((dst = xfrm_dst_child(dst)) && dst->xfrm && dst->dev == dev) {
3756                 dst->dev = blackhole_netdev;
3757                 dev_hold(dst->dev);
3758                 dev_put(dev);
3759         }
3760 }
3761 EXPORT_SYMBOL(xfrm_dst_ifdown);
3762
3763 static void xfrm_link_failure(struct sk_buff *skb)
3764 {
3765         /* Impossible. Such dst must be popped before reaches point of failure. */
3766 }
3767
3768 static struct dst_entry *xfrm_negative_advice(struct dst_entry *dst)
3769 {
3770         if (dst) {
3771                 if (dst->obsolete) {
3772                         dst_release(dst);
3773                         dst = NULL;
3774                 }
3775         }
3776         return dst;
3777 }
3778
3779 static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr)
3780 {
3781         while (nr--) {
3782                 struct xfrm_dst *xdst = bundle[nr];
3783                 u32 pmtu, route_mtu_cached;
3784                 struct dst_entry *dst;
3785
3786                 dst = &xdst->u.dst;
3787                 pmtu = dst_mtu(xfrm_dst_child(dst));
3788                 xdst->child_mtu_cached = pmtu;
3789
3790                 pmtu = xfrm_state_mtu(dst->xfrm, pmtu);
3791
3792                 route_mtu_cached = dst_mtu(xdst->route);
3793                 xdst->route_mtu_cached = route_mtu_cached;
3794
3795                 if (pmtu > route_mtu_cached)
3796                         pmtu = route_mtu_cached;
3797
3798                 dst_metric_set(dst, RTAX_MTU, pmtu);
3799         }
3800 }
3801
3802 /* Check that the bundle accepts the flow and its components are
3803  * still valid.
3804  */
3805
3806 static int xfrm_bundle_ok(struct xfrm_dst *first)
3807 {
3808         struct xfrm_dst *bundle[XFRM_MAX_DEPTH];
3809         struct dst_entry *dst = &first->u.dst;
3810         struct xfrm_dst *xdst;
3811         int start_from, nr;
3812         u32 mtu;
3813
3814         if (!dst_check(xfrm_dst_path(dst), ((struct xfrm_dst *)dst)->path_cookie) ||
3815             (dst->dev && !netif_running(dst->dev)))
3816                 return 0;
3817
3818         if (dst->flags & DST_XFRM_QUEUE)
3819                 return 1;
3820
3821         start_from = nr = 0;
3822         do {
3823                 struct xfrm_dst *xdst = (struct xfrm_dst *)dst;
3824
3825                 if (dst->xfrm->km.state != XFRM_STATE_VALID)
3826                         return 0;
3827                 if (xdst->xfrm_genid != dst->xfrm->genid)
3828                         return 0;
3829                 if (xdst->num_pols > 0 &&
3830                     xdst->policy_genid != atomic_read(&xdst->pols[0]->genid))
3831                         return 0;
3832
3833                 bundle[nr++] = xdst;
3834
3835                 mtu = dst_mtu(xfrm_dst_child(dst));
3836                 if (xdst->child_mtu_cached != mtu) {
3837                         start_from = nr;
3838                         xdst->child_mtu_cached = mtu;
3839                 }
3840
3841                 if (!dst_check(xdst->route, xdst->route_cookie))
3842                         return 0;
3843                 mtu = dst_mtu(xdst->route);
3844                 if (xdst->route_mtu_cached != mtu) {
3845                         start_from = nr;
3846                         xdst->route_mtu_cached = mtu;
3847                 }
3848
3849                 dst = xfrm_dst_child(dst);
3850         } while (dst->xfrm);
3851
3852         if (likely(!start_from))
3853                 return 1;
3854
3855         xdst = bundle[start_from - 1];
3856         mtu = xdst->child_mtu_cached;
3857         while (start_from--) {
3858                 dst = &xdst->u.dst;
3859
3860                 mtu = xfrm_state_mtu(dst->xfrm, mtu);
3861                 if (mtu > xdst->route_mtu_cached)
3862                         mtu = xdst->route_mtu_cached;
3863                 dst_metric_set(dst, RTAX_MTU, mtu);
3864                 if (!start_from)
3865                         break;
3866
3867                 xdst = bundle[start_from - 1];
3868                 xdst->child_mtu_cached = mtu;
3869         }
3870
3871         return 1;
3872 }
3873
3874 static unsigned int xfrm_default_advmss(const struct dst_entry *dst)
3875 {
3876         return dst_metric_advmss(xfrm_dst_path(dst));
3877 }
3878
3879 static unsigned int xfrm_mtu(const struct dst_entry *dst)
3880 {
3881         unsigned int mtu = dst_metric_raw(dst, RTAX_MTU);
3882
3883         return mtu ? : dst_mtu(xfrm_dst_path(dst));
3884 }
3885
3886 static const void *xfrm_get_dst_nexthop(const struct dst_entry *dst,
3887                                         const void *daddr)
3888 {
3889         while (dst->xfrm) {
3890                 const struct xfrm_state *xfrm = dst->xfrm;
3891
3892                 dst = xfrm_dst_child(dst);
3893
3894                 if (xfrm->props.mode == XFRM_MODE_TRANSPORT)
3895                         continue;
3896                 if (xfrm->type->flags & XFRM_TYPE_REMOTE_COADDR)
3897                         daddr = xfrm->coaddr;
3898                 else if (!(xfrm->type->flags & XFRM_TYPE_LOCAL_COADDR))
3899                         daddr = &xfrm->id.daddr;
3900         }
3901         return daddr;
3902 }
3903
3904 static struct neighbour *xfrm_neigh_lookup(const struct dst_entry *dst,
3905                                            struct sk_buff *skb,
3906                                            const void *daddr)
3907 {
3908         const struct dst_entry *path = xfrm_dst_path(dst);
3909
3910         if (!skb)
3911                 daddr = xfrm_get_dst_nexthop(dst, daddr);
3912         return path->ops->neigh_lookup(path, skb, daddr);
3913 }
3914
3915 static void xfrm_confirm_neigh(const struct dst_entry *dst, const void *daddr)
3916 {
3917         const struct dst_entry *path = xfrm_dst_path(dst);
3918
3919         daddr = xfrm_get_dst_nexthop(dst, daddr);
3920         path->ops->confirm_neigh(path, daddr);
3921 }
3922
3923 int xfrm_policy_register_afinfo(const struct xfrm_policy_afinfo *afinfo, int family)
3924 {
3925         int err = 0;
3926
3927         if (WARN_ON(family >= ARRAY_SIZE(xfrm_policy_afinfo)))
3928                 return -EAFNOSUPPORT;
3929
3930         spin_lock(&xfrm_policy_afinfo_lock);
3931         if (unlikely(xfrm_policy_afinfo[family] != NULL))
3932                 err = -EEXIST;
3933         else {
3934                 struct dst_ops *dst_ops = afinfo->dst_ops;
3935                 if (likely(dst_ops->kmem_cachep == NULL))
3936                         dst_ops->kmem_cachep = xfrm_dst_cache;
3937                 if (likely(dst_ops->check == NULL))
3938                         dst_ops->check = xfrm_dst_check;
3939                 if (likely(dst_ops->default_advmss == NULL))
3940                         dst_ops->default_advmss = xfrm_default_advmss;
3941                 if (likely(dst_ops->mtu == NULL))
3942                         dst_ops->mtu = xfrm_mtu;
3943                 if (likely(dst_ops->negative_advice == NULL))
3944                         dst_ops->negative_advice = xfrm_negative_advice;
3945                 if (likely(dst_ops->link_failure == NULL))
3946                         dst_ops->link_failure = xfrm_link_failure;
3947                 if (likely(dst_ops->neigh_lookup == NULL))
3948                         dst_ops->neigh_lookup = xfrm_neigh_lookup;
3949                 if (likely(!dst_ops->confirm_neigh))
3950                         dst_ops->confirm_neigh = xfrm_confirm_neigh;
3951                 rcu_assign_pointer(xfrm_policy_afinfo[family], afinfo);
3952         }
3953         spin_unlock(&xfrm_policy_afinfo_lock);
3954
3955         return err;
3956 }
3957 EXPORT_SYMBOL(xfrm_policy_register_afinfo);
3958
3959 void xfrm_policy_unregister_afinfo(const struct xfrm_policy_afinfo *afinfo)
3960 {
3961         struct dst_ops *dst_ops = afinfo->dst_ops;
3962         int i;
3963
3964         for (i = 0; i < ARRAY_SIZE(xfrm_policy_afinfo); i++) {
3965                 if (xfrm_policy_afinfo[i] != afinfo)
3966                         continue;
3967                 RCU_INIT_POINTER(xfrm_policy_afinfo[i], NULL);
3968                 break;
3969         }
3970
3971         synchronize_rcu();
3972
3973         dst_ops->kmem_cachep = NULL;
3974         dst_ops->check = NULL;
3975         dst_ops->negative_advice = NULL;
3976         dst_ops->link_failure = NULL;
3977 }
3978 EXPORT_SYMBOL(xfrm_policy_unregister_afinfo);
3979
3980 void xfrm_if_register_cb(const struct xfrm_if_cb *ifcb)
3981 {
3982         spin_lock(&xfrm_if_cb_lock);
3983         rcu_assign_pointer(xfrm_if_cb, ifcb);
3984         spin_unlock(&xfrm_if_cb_lock);
3985 }
3986 EXPORT_SYMBOL(xfrm_if_register_cb);
3987
3988 void xfrm_if_unregister_cb(void)
3989 {
3990         RCU_INIT_POINTER(xfrm_if_cb, NULL);
3991         synchronize_rcu();
3992 }
3993 EXPORT_SYMBOL(xfrm_if_unregister_cb);
3994
3995 #ifdef CONFIG_XFRM_STATISTICS
3996 static int __net_init xfrm_statistics_init(struct net *net)
3997 {
3998         int rv;
3999         net->mib.xfrm_statistics = alloc_percpu(struct linux_xfrm_mib);
4000         if (!net->mib.xfrm_statistics)
4001                 return -ENOMEM;
4002         rv = xfrm_proc_init(net);
4003         if (rv < 0)
4004                 free_percpu(net->mib.xfrm_statistics);
4005         return rv;
4006 }
4007
4008 static void xfrm_statistics_fini(struct net *net)
4009 {
4010         xfrm_proc_fini(net);
4011         free_percpu(net->mib.xfrm_statistics);
4012 }
4013 #else
4014 static int __net_init xfrm_statistics_init(struct net *net)
4015 {
4016         return 0;
4017 }
4018
4019 static void xfrm_statistics_fini(struct net *net)
4020 {
4021 }
4022 #endif
4023
4024 static int __net_init xfrm_policy_init(struct net *net)
4025 {
4026         unsigned int hmask, sz;
4027         int dir, err;
4028
4029         if (net_eq(net, &init_net)) {
4030                 xfrm_dst_cache = kmem_cache_create("xfrm_dst_cache",
4031                                            sizeof(struct xfrm_dst),
4032                                            0, SLAB_HWCACHE_ALIGN|SLAB_PANIC,
4033                                            NULL);
4034                 err = rhashtable_init(&xfrm_policy_inexact_table,
4035                                       &xfrm_pol_inexact_params);
4036                 BUG_ON(err);
4037         }
4038
4039         hmask = 8 - 1;
4040         sz = (hmask+1) * sizeof(struct hlist_head);
4041
4042         net->xfrm.policy_byidx = xfrm_hash_alloc(sz);
4043         if (!net->xfrm.policy_byidx)
4044                 goto out_byidx;
4045         net->xfrm.policy_idx_hmask = hmask;
4046
4047         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
4048                 struct xfrm_policy_hash *htab;
4049
4050                 net->xfrm.policy_count[dir] = 0;
4051                 net->xfrm.policy_count[XFRM_POLICY_MAX + dir] = 0;
4052                 INIT_HLIST_HEAD(&net->xfrm.policy_inexact[dir]);
4053
4054                 htab = &net->xfrm.policy_bydst[dir];
4055                 htab->table = xfrm_hash_alloc(sz);
4056                 if (!htab->table)
4057                         goto out_bydst;
4058                 htab->hmask = hmask;
4059                 htab->dbits4 = 32;
4060                 htab->sbits4 = 32;
4061                 htab->dbits6 = 128;
4062                 htab->sbits6 = 128;
4063         }
4064         net->xfrm.policy_hthresh.lbits4 = 32;
4065         net->xfrm.policy_hthresh.rbits4 = 32;
4066         net->xfrm.policy_hthresh.lbits6 = 128;
4067         net->xfrm.policy_hthresh.rbits6 = 128;
4068
4069         seqlock_init(&net->xfrm.policy_hthresh.lock);
4070
4071         INIT_LIST_HEAD(&net->xfrm.policy_all);
4072         INIT_LIST_HEAD(&net->xfrm.inexact_bins);
4073         INIT_WORK(&net->xfrm.policy_hash_work, xfrm_hash_resize);
4074         INIT_WORK(&net->xfrm.policy_hthresh.work, xfrm_hash_rebuild);
4075         return 0;
4076
4077 out_bydst:
4078         for (dir--; dir >= 0; dir--) {
4079                 struct xfrm_policy_hash *htab;
4080
4081                 htab = &net->xfrm.policy_bydst[dir];
4082                 xfrm_hash_free(htab->table, sz);
4083         }
4084         xfrm_hash_free(net->xfrm.policy_byidx, sz);
4085 out_byidx:
4086         return -ENOMEM;
4087 }
4088
4089 static void xfrm_policy_fini(struct net *net)
4090 {
4091         struct xfrm_pol_inexact_bin *b, *t;
4092         unsigned int sz;
4093         int dir;
4094
4095         flush_work(&net->xfrm.policy_hash_work);
4096 #ifdef CONFIG_XFRM_SUB_POLICY
4097         xfrm_policy_flush(net, XFRM_POLICY_TYPE_SUB, false);
4098 #endif
4099         xfrm_policy_flush(net, XFRM_POLICY_TYPE_MAIN, false);
4100
4101         WARN_ON(!list_empty(&net->xfrm.policy_all));
4102
4103         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
4104                 struct xfrm_policy_hash *htab;
4105
4106                 WARN_ON(!hlist_empty(&net->xfrm.policy_inexact[dir]));
4107
4108                 htab = &net->xfrm.policy_bydst[dir];
4109                 sz = (htab->hmask + 1) * sizeof(struct hlist_head);
4110                 WARN_ON(!hlist_empty(htab->table));
4111                 xfrm_hash_free(htab->table, sz);
4112         }
4113
4114         sz = (net->xfrm.policy_idx_hmask + 1) * sizeof(struct hlist_head);
4115         WARN_ON(!hlist_empty(net->xfrm.policy_byidx));
4116         xfrm_hash_free(net->xfrm.policy_byidx, sz);
4117
4118         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
4119         list_for_each_entry_safe(b, t, &net->xfrm.inexact_bins, inexact_bins)
4120                 __xfrm_policy_inexact_prune_bin(b, true);
4121         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
4122 }
4123
4124 static int __net_init xfrm_net_init(struct net *net)
4125 {
4126         int rv;
4127
4128         /* Initialize the per-net locks here */
4129         spin_lock_init(&net->xfrm.xfrm_state_lock);
4130         spin_lock_init(&net->xfrm.xfrm_policy_lock);
4131         seqcount_spinlock_init(&net->xfrm.xfrm_policy_hash_generation, &net->xfrm.xfrm_policy_lock);
4132         mutex_init(&net->xfrm.xfrm_cfg_mutex);
4133         net->xfrm.policy_default[XFRM_POLICY_IN] = XFRM_USERPOLICY_ACCEPT;
4134         net->xfrm.policy_default[XFRM_POLICY_FWD] = XFRM_USERPOLICY_ACCEPT;
4135         net->xfrm.policy_default[XFRM_POLICY_OUT] = XFRM_USERPOLICY_ACCEPT;
4136
4137         rv = xfrm_statistics_init(net);
4138         if (rv < 0)
4139                 goto out_statistics;
4140         rv = xfrm_state_init(net);
4141         if (rv < 0)
4142                 goto out_state;
4143         rv = xfrm_policy_init(net);
4144         if (rv < 0)
4145                 goto out_policy;
4146         rv = xfrm_sysctl_init(net);
4147         if (rv < 0)
4148                 goto out_sysctl;
4149
4150         return 0;
4151
4152 out_sysctl:
4153         xfrm_policy_fini(net);
4154 out_policy:
4155         xfrm_state_fini(net);
4156 out_state:
4157         xfrm_statistics_fini(net);
4158 out_statistics:
4159         return rv;
4160 }
4161
4162 static void __net_exit xfrm_net_exit(struct net *net)
4163 {
4164         xfrm_sysctl_fini(net);
4165         xfrm_policy_fini(net);
4166         xfrm_state_fini(net);
4167         xfrm_statistics_fini(net);
4168 }
4169
4170 static struct pernet_operations __net_initdata xfrm_net_ops = {
4171         .init = xfrm_net_init,
4172         .exit = xfrm_net_exit,
4173 };
4174
4175 static const struct flow_dissector_key xfrm_flow_dissector_keys[] = {
4176         {
4177                 .key_id = FLOW_DISSECTOR_KEY_CONTROL,
4178                 .offset = offsetof(struct xfrm_flow_keys, control),
4179         },
4180         {
4181                 .key_id = FLOW_DISSECTOR_KEY_BASIC,
4182                 .offset = offsetof(struct xfrm_flow_keys, basic),
4183         },
4184         {
4185                 .key_id = FLOW_DISSECTOR_KEY_IPV4_ADDRS,
4186                 .offset = offsetof(struct xfrm_flow_keys, addrs.ipv4),
4187         },
4188         {
4189                 .key_id = FLOW_DISSECTOR_KEY_IPV6_ADDRS,
4190                 .offset = offsetof(struct xfrm_flow_keys, addrs.ipv6),
4191         },
4192         {
4193                 .key_id = FLOW_DISSECTOR_KEY_PORTS,
4194                 .offset = offsetof(struct xfrm_flow_keys, ports),
4195         },
4196         {
4197                 .key_id = FLOW_DISSECTOR_KEY_GRE_KEYID,
4198                 .offset = offsetof(struct xfrm_flow_keys, gre),
4199         },
4200         {
4201                 .key_id = FLOW_DISSECTOR_KEY_IP,
4202                 .offset = offsetof(struct xfrm_flow_keys, ip),
4203         },
4204         {
4205                 .key_id = FLOW_DISSECTOR_KEY_ICMP,
4206                 .offset = offsetof(struct xfrm_flow_keys, icmp),
4207         },
4208 };
4209
4210 void __init xfrm_init(void)
4211 {
4212         skb_flow_dissector_init(&xfrm_session_dissector,
4213                                 xfrm_flow_dissector_keys,
4214                                 ARRAY_SIZE(xfrm_flow_dissector_keys));
4215
4216         register_pernet_subsys(&xfrm_net_ops);
4217         xfrm_dev_init();
4218         xfrm_input_init();
4219
4220 #ifdef CONFIG_XFRM_ESPINTCP
4221         espintcp_init();
4222 #endif
4223
4224         register_xfrm_state_bpf();
4225 }
4226
4227 #ifdef CONFIG_AUDITSYSCALL
4228 static void xfrm_audit_common_policyinfo(struct xfrm_policy *xp,
4229                                          struct audit_buffer *audit_buf)
4230 {
4231         struct xfrm_sec_ctx *ctx = xp->security;
4232         struct xfrm_selector *sel = &xp->selector;
4233
4234         if (ctx)
4235                 audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s",
4236                                  ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str);
4237
4238         switch (sel->family) {
4239         case AF_INET:
4240                 audit_log_format(audit_buf, " src=%pI4", &sel->saddr.a4);
4241                 if (sel->prefixlen_s != 32)
4242                         audit_log_format(audit_buf, " src_prefixlen=%d",
4243                                          sel->prefixlen_s);
4244                 audit_log_format(audit_buf, " dst=%pI4", &sel->daddr.a4);
4245                 if (sel->prefixlen_d != 32)
4246                         audit_log_format(audit_buf, " dst_prefixlen=%d",
4247                                          sel->prefixlen_d);
4248                 break;
4249         case AF_INET6:
4250                 audit_log_format(audit_buf, " src=%pI6", sel->saddr.a6);
4251                 if (sel->prefixlen_s != 128)
4252                         audit_log_format(audit_buf, " src_prefixlen=%d",
4253                                          sel->prefixlen_s);
4254                 audit_log_format(audit_buf, " dst=%pI6", sel->daddr.a6);
4255                 if (sel->prefixlen_d != 128)
4256                         audit_log_format(audit_buf, " dst_prefixlen=%d",
4257                                          sel->prefixlen_d);
4258                 break;
4259         }
4260 }
4261
4262 void xfrm_audit_policy_add(struct xfrm_policy *xp, int result, bool task_valid)
4263 {
4264         struct audit_buffer *audit_buf;
4265
4266         audit_buf = xfrm_audit_start("SPD-add");
4267         if (audit_buf == NULL)
4268                 return;
4269         xfrm_audit_helper_usrinfo(task_valid, audit_buf);
4270         audit_log_format(audit_buf, " res=%u", result);
4271         xfrm_audit_common_policyinfo(xp, audit_buf);
4272         audit_log_end(audit_buf);
4273 }
4274 EXPORT_SYMBOL_GPL(xfrm_audit_policy_add);
4275
4276 void xfrm_audit_policy_delete(struct xfrm_policy *xp, int result,
4277                               bool task_valid)
4278 {
4279         struct audit_buffer *audit_buf;
4280
4281         audit_buf = xfrm_audit_start("SPD-delete");
4282         if (audit_buf == NULL)
4283                 return;
4284         xfrm_audit_helper_usrinfo(task_valid, audit_buf);
4285         audit_log_format(audit_buf, " res=%u", result);
4286         xfrm_audit_common_policyinfo(xp, audit_buf);
4287         audit_log_end(audit_buf);
4288 }
4289 EXPORT_SYMBOL_GPL(xfrm_audit_policy_delete);
4290 #endif
4291
4292 #ifdef CONFIG_XFRM_MIGRATE
4293 static bool xfrm_migrate_selector_match(const struct xfrm_selector *sel_cmp,
4294                                         const struct xfrm_selector *sel_tgt)
4295 {
4296         if (sel_cmp->proto == IPSEC_ULPROTO_ANY) {
4297                 if (sel_tgt->family == sel_cmp->family &&
4298                     xfrm_addr_equal(&sel_tgt->daddr, &sel_cmp->daddr,
4299                                     sel_cmp->family) &&
4300                     xfrm_addr_equal(&sel_tgt->saddr, &sel_cmp->saddr,
4301                                     sel_cmp->family) &&
4302                     sel_tgt->prefixlen_d == sel_cmp->prefixlen_d &&
4303                     sel_tgt->prefixlen_s == sel_cmp->prefixlen_s) {
4304                         return true;
4305                 }
4306         } else {
4307                 if (memcmp(sel_tgt, sel_cmp, sizeof(*sel_tgt)) == 0) {
4308                         return true;
4309                 }
4310         }
4311         return false;
4312 }
4313
4314 static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector *sel,
4315                                                     u8 dir, u8 type, struct net *net, u32 if_id)
4316 {
4317         struct xfrm_policy *pol, *ret = NULL;
4318         struct hlist_head *chain;
4319         u32 priority = ~0U;
4320
4321         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
4322         chain = policy_hash_direct(net, &sel->daddr, &sel->saddr, sel->family, dir);
4323         hlist_for_each_entry(pol, chain, bydst) {
4324                 if ((if_id == 0 || pol->if_id == if_id) &&
4325                     xfrm_migrate_selector_match(sel, &pol->selector) &&
4326                     pol->type == type) {
4327                         ret = pol;
4328                         priority = ret->priority;
4329                         break;
4330                 }
4331         }
4332         chain = &net->xfrm.policy_inexact[dir];
4333         hlist_for_each_entry(pol, chain, bydst_inexact_list) {
4334                 if ((pol->priority >= priority) && ret)
4335                         break;
4336
4337                 if ((if_id == 0 || pol->if_id == if_id) &&
4338                     xfrm_migrate_selector_match(sel, &pol->selector) &&
4339                     pol->type == type) {
4340                         ret = pol;
4341                         break;
4342                 }
4343         }
4344
4345         xfrm_pol_hold(ret);
4346
4347         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
4348
4349         return ret;
4350 }
4351
4352 static int migrate_tmpl_match(const struct xfrm_migrate *m, const struct xfrm_tmpl *t)
4353 {
4354         int match = 0;
4355
4356         if (t->mode == m->mode && t->id.proto == m->proto &&
4357             (m->reqid == 0 || t->reqid == m->reqid)) {
4358                 switch (t->mode) {
4359                 case XFRM_MODE_TUNNEL:
4360                 case XFRM_MODE_BEET:
4361                         if (xfrm_addr_equal(&t->id.daddr, &m->old_daddr,
4362                                             m->old_family) &&
4363                             xfrm_addr_equal(&t->saddr, &m->old_saddr,
4364                                             m->old_family)) {
4365                                 match = 1;
4366                         }
4367                         break;
4368                 case XFRM_MODE_TRANSPORT:
4369                         /* in case of transport mode, template does not store
4370                            any IP addresses, hence we just compare mode and
4371                            protocol */
4372                         match = 1;
4373                         break;
4374                 default:
4375                         break;
4376                 }
4377         }
4378         return match;
4379 }
4380
4381 /* update endpoint address(es) of template(s) */
4382 static int xfrm_policy_migrate(struct xfrm_policy *pol,
4383                                struct xfrm_migrate *m, int num_migrate,
4384                                struct netlink_ext_ack *extack)
4385 {
4386         struct xfrm_migrate *mp;
4387         int i, j, n = 0;
4388
4389         write_lock_bh(&pol->lock);
4390         if (unlikely(pol->walk.dead)) {
4391                 /* target policy has been deleted */
4392                 NL_SET_ERR_MSG(extack, "Target policy not found");
4393                 write_unlock_bh(&pol->lock);
4394                 return -ENOENT;
4395         }
4396
4397         for (i = 0; i < pol->xfrm_nr; i++) {
4398                 for (j = 0, mp = m; j < num_migrate; j++, mp++) {
4399                         if (!migrate_tmpl_match(mp, &pol->xfrm_vec[i]))
4400                                 continue;
4401                         n++;
4402                         if (pol->xfrm_vec[i].mode != XFRM_MODE_TUNNEL &&
4403                             pol->xfrm_vec[i].mode != XFRM_MODE_BEET)
4404                                 continue;
4405                         /* update endpoints */
4406                         memcpy(&pol->xfrm_vec[i].id.daddr, &mp->new_daddr,
4407                                sizeof(pol->xfrm_vec[i].id.daddr));
4408                         memcpy(&pol->xfrm_vec[i].saddr, &mp->new_saddr,
4409                                sizeof(pol->xfrm_vec[i].saddr));
4410                         pol->xfrm_vec[i].encap_family = mp->new_family;
4411                         /* flush bundles */
4412                         atomic_inc(&pol->genid);
4413                 }
4414         }
4415
4416         write_unlock_bh(&pol->lock);
4417
4418         if (!n)
4419                 return -ENODATA;
4420
4421         return 0;
4422 }
4423
4424 static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate,
4425                               struct netlink_ext_ack *extack)
4426 {
4427         int i, j;
4428
4429         if (num_migrate < 1 || num_migrate > XFRM_MAX_DEPTH) {
4430                 NL_SET_ERR_MSG(extack, "Invalid number of SAs to migrate, must be 0 < num <= XFRM_MAX_DEPTH (6)");
4431                 return -EINVAL;
4432         }
4433
4434         for (i = 0; i < num_migrate; i++) {
4435                 if (xfrm_addr_any(&m[i].new_daddr, m[i].new_family) ||
4436                     xfrm_addr_any(&m[i].new_saddr, m[i].new_family)) {
4437                         NL_SET_ERR_MSG(extack, "Addresses in the MIGRATE attribute's list cannot be null");
4438                         return -EINVAL;
4439                 }
4440
4441                 /* check if there is any duplicated entry */
4442                 for (j = i + 1; j < num_migrate; j++) {
4443                         if (!memcmp(&m[i].old_daddr, &m[j].old_daddr,
4444                                     sizeof(m[i].old_daddr)) &&
4445                             !memcmp(&m[i].old_saddr, &m[j].old_saddr,
4446                                     sizeof(m[i].old_saddr)) &&
4447                             m[i].proto == m[j].proto &&
4448                             m[i].mode == m[j].mode &&
4449                             m[i].reqid == m[j].reqid &&
4450                             m[i].old_family == m[j].old_family) {
4451                                 NL_SET_ERR_MSG(extack, "Entries in the MIGRATE attribute's list must be unique");
4452                                 return -EINVAL;
4453                         }
4454                 }
4455         }
4456
4457         return 0;
4458 }
4459
4460 int xfrm_migrate(const struct xfrm_selector *sel, u8 dir, u8 type,
4461                  struct xfrm_migrate *m, int num_migrate,
4462                  struct xfrm_kmaddress *k, struct net *net,
4463                  struct xfrm_encap_tmpl *encap, u32 if_id,
4464                  struct netlink_ext_ack *extack)
4465 {
4466         int i, err, nx_cur = 0, nx_new = 0;
4467         struct xfrm_policy *pol = NULL;
4468         struct xfrm_state *x, *xc;
4469         struct xfrm_state *x_cur[XFRM_MAX_DEPTH];
4470         struct xfrm_state *x_new[XFRM_MAX_DEPTH];
4471         struct xfrm_migrate *mp;
4472
4473         /* Stage 0 - sanity checks */
4474         err = xfrm_migrate_check(m, num_migrate, extack);
4475         if (err < 0)
4476                 goto out;
4477
4478         if (dir >= XFRM_POLICY_MAX) {
4479                 NL_SET_ERR_MSG(extack, "Invalid policy direction");
4480                 err = -EINVAL;
4481                 goto out;
4482         }
4483
4484         /* Stage 1 - find policy */
4485         pol = xfrm_migrate_policy_find(sel, dir, type, net, if_id);
4486         if (!pol) {
4487                 NL_SET_ERR_MSG(extack, "Target policy not found");
4488                 err = -ENOENT;
4489                 goto out;
4490         }
4491
4492         /* Stage 2 - find and update state(s) */
4493         for (i = 0, mp = m; i < num_migrate; i++, mp++) {
4494                 if ((x = xfrm_migrate_state_find(mp, net, if_id))) {
4495                         x_cur[nx_cur] = x;
4496                         nx_cur++;
4497                         xc = xfrm_state_migrate(x, mp, encap);
4498                         if (xc) {
4499                                 x_new[nx_new] = xc;
4500                                 nx_new++;
4501                         } else {
4502                                 err = -ENODATA;
4503                                 goto restore_state;
4504                         }
4505                 }
4506         }
4507
4508         /* Stage 3 - update policy */
4509         err = xfrm_policy_migrate(pol, m, num_migrate, extack);
4510         if (err < 0)
4511                 goto restore_state;
4512
4513         /* Stage 4 - delete old state(s) */
4514         if (nx_cur) {
4515                 xfrm_states_put(x_cur, nx_cur);
4516                 xfrm_states_delete(x_cur, nx_cur);
4517         }
4518
4519         /* Stage 5 - announce */
4520         km_migrate(sel, dir, type, m, num_migrate, k, encap);
4521
4522         xfrm_pol_put(pol);
4523
4524         return 0;
4525 out:
4526         return err;
4527
4528 restore_state:
4529         if (pol)
4530                 xfrm_pol_put(pol);
4531         if (nx_cur)
4532                 xfrm_states_put(x_cur, nx_cur);
4533         if (nx_new)
4534                 xfrm_states_delete(x_new, nx_new);
4535
4536         return err;
4537 }
4538 EXPORT_SYMBOL(xfrm_migrate);
4539 #endif