Linux 6.7-rc7
[linux-modified.git] / net / xfrm / xfrm_policy.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * xfrm_policy.c
4  *
5  * Changes:
6  *      Mitsuru KANDA @USAGI
7  *      Kazunori MIYAZAWA @USAGI
8  *      Kunihiro Ishiguro <kunihiro@ipinfusion.com>
9  *              IPv6 support
10  *      Kazunori MIYAZAWA @USAGI
11  *      YOSHIFUJI Hideaki
12  *              Split up af-specific portion
13  *      Derek Atkins <derek@ihtfp.com>          Add the post_input processor
14  *
15  */
16
17 #include <linux/err.h>
18 #include <linux/slab.h>
19 #include <linux/kmod.h>
20 #include <linux/list.h>
21 #include <linux/spinlock.h>
22 #include <linux/workqueue.h>
23 #include <linux/notifier.h>
24 #include <linux/netdevice.h>
25 #include <linux/netfilter.h>
26 #include <linux/module.h>
27 #include <linux/cache.h>
28 #include <linux/cpu.h>
29 #include <linux/audit.h>
30 #include <linux/rhashtable.h>
31 #include <linux/if_tunnel.h>
32 #include <net/dst.h>
33 #include <net/flow.h>
34 #include <net/inet_ecn.h>
35 #include <net/xfrm.h>
36 #include <net/ip.h>
37 #include <net/gre.h>
38 #if IS_ENABLED(CONFIG_IPV6_MIP6)
39 #include <net/mip6.h>
40 #endif
41 #ifdef CONFIG_XFRM_STATISTICS
42 #include <net/snmp.h>
43 #endif
44 #ifdef CONFIG_XFRM_ESPINTCP
45 #include <net/espintcp.h>
46 #endif
47
48 #include "xfrm_hash.h"
49
50 #define XFRM_QUEUE_TMO_MIN ((unsigned)(HZ/10))
51 #define XFRM_QUEUE_TMO_MAX ((unsigned)(60*HZ))
52 #define XFRM_MAX_QUEUE_LEN      100
53
54 struct xfrm_flo {
55         struct dst_entry *dst_orig;
56         u8 flags;
57 };
58
59 /* prefixes smaller than this are stored in lists, not trees. */
60 #define INEXACT_PREFIXLEN_IPV4  16
61 #define INEXACT_PREFIXLEN_IPV6  48
62
63 struct xfrm_pol_inexact_node {
64         struct rb_node node;
65         union {
66                 xfrm_address_t addr;
67                 struct rcu_head rcu;
68         };
69         u8 prefixlen;
70
71         struct rb_root root;
72
73         /* the policies matching this node, can be empty list */
74         struct hlist_head hhead;
75 };
76
77 /* xfrm inexact policy search tree:
78  * xfrm_pol_inexact_bin = hash(dir,type,family,if_id);
79  *  |
80  * +---- root_d: sorted by daddr:prefix
81  * |                 |
82  * |        xfrm_pol_inexact_node
83  * |                 |
84  * |                 +- root: sorted by saddr/prefix
85  * |                 |              |
86  * |                 |         xfrm_pol_inexact_node
87  * |                 |              |
88  * |                 |              + root: unused
89  * |                 |              |
90  * |                 |              + hhead: saddr:daddr policies
91  * |                 |
92  * |                 +- coarse policies and all any:daddr policies
93  * |
94  * +---- root_s: sorted by saddr:prefix
95  * |                 |
96  * |        xfrm_pol_inexact_node
97  * |                 |
98  * |                 + root: unused
99  * |                 |
100  * |                 + hhead: saddr:any policies
101  * |
102  * +---- coarse policies and all any:any policies
103  *
104  * Lookups return four candidate lists:
105  * 1. any:any list from top-level xfrm_pol_inexact_bin
106  * 2. any:daddr list from daddr tree
107  * 3. saddr:daddr list from 2nd level daddr tree
108  * 4. saddr:any list from saddr tree
109  *
110  * This result set then needs to be searched for the policy with
111  * the lowest priority.  If two results have same prio, youngest one wins.
112  */
113
114 struct xfrm_pol_inexact_key {
115         possible_net_t net;
116         u32 if_id;
117         u16 family;
118         u8 dir, type;
119 };
120
121 struct xfrm_pol_inexact_bin {
122         struct xfrm_pol_inexact_key k;
123         struct rhash_head head;
124         /* list containing '*:*' policies */
125         struct hlist_head hhead;
126
127         seqcount_spinlock_t count;
128         /* tree sorted by daddr/prefix */
129         struct rb_root root_d;
130
131         /* tree sorted by saddr/prefix */
132         struct rb_root root_s;
133
134         /* slow path below */
135         struct list_head inexact_bins;
136         struct rcu_head rcu;
137 };
138
139 enum xfrm_pol_inexact_candidate_type {
140         XFRM_POL_CAND_BOTH,
141         XFRM_POL_CAND_SADDR,
142         XFRM_POL_CAND_DADDR,
143         XFRM_POL_CAND_ANY,
144
145         XFRM_POL_CAND_MAX,
146 };
147
148 struct xfrm_pol_inexact_candidates {
149         struct hlist_head *res[XFRM_POL_CAND_MAX];
150 };
151
152 struct xfrm_flow_keys {
153         struct flow_dissector_key_basic basic;
154         struct flow_dissector_key_control control;
155         union {
156                 struct flow_dissector_key_ipv4_addrs ipv4;
157                 struct flow_dissector_key_ipv6_addrs ipv6;
158         } addrs;
159         struct flow_dissector_key_ip ip;
160         struct flow_dissector_key_icmp icmp;
161         struct flow_dissector_key_ports ports;
162         struct flow_dissector_key_keyid gre;
163 };
164
165 static struct flow_dissector xfrm_session_dissector __ro_after_init;
166
167 static DEFINE_SPINLOCK(xfrm_if_cb_lock);
168 static struct xfrm_if_cb const __rcu *xfrm_if_cb __read_mostly;
169
170 static DEFINE_SPINLOCK(xfrm_policy_afinfo_lock);
171 static struct xfrm_policy_afinfo const __rcu *xfrm_policy_afinfo[AF_INET6 + 1]
172                                                 __read_mostly;
173
174 static struct kmem_cache *xfrm_dst_cache __ro_after_init;
175
176 static struct rhashtable xfrm_policy_inexact_table;
177 static const struct rhashtable_params xfrm_pol_inexact_params;
178
179 static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr);
180 static int stale_bundle(struct dst_entry *dst);
181 static int xfrm_bundle_ok(struct xfrm_dst *xdst);
182 static void xfrm_policy_queue_process(struct timer_list *t);
183
184 static void __xfrm_policy_link(struct xfrm_policy *pol, int dir);
185 static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
186                                                 int dir);
187
188 static struct xfrm_pol_inexact_bin *
189 xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family, u8 dir,
190                            u32 if_id);
191
192 static struct xfrm_pol_inexact_bin *
193 xfrm_policy_inexact_lookup_rcu(struct net *net,
194                                u8 type, u16 family, u8 dir, u32 if_id);
195 static struct xfrm_policy *
196 xfrm_policy_insert_list(struct hlist_head *chain, struct xfrm_policy *policy,
197                         bool excl);
198 static void xfrm_policy_insert_inexact_list(struct hlist_head *chain,
199                                             struct xfrm_policy *policy);
200
201 static bool
202 xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand,
203                                     struct xfrm_pol_inexact_bin *b,
204                                     const xfrm_address_t *saddr,
205                                     const xfrm_address_t *daddr);
206
207 static inline bool xfrm_pol_hold_rcu(struct xfrm_policy *policy)
208 {
209         return refcount_inc_not_zero(&policy->refcnt);
210 }
211
212 static inline bool
213 __xfrm4_selector_match(const struct xfrm_selector *sel, const struct flowi *fl)
214 {
215         const struct flowi4 *fl4 = &fl->u.ip4;
216
217         return  addr4_match(fl4->daddr, sel->daddr.a4, sel->prefixlen_d) &&
218                 addr4_match(fl4->saddr, sel->saddr.a4, sel->prefixlen_s) &&
219                 !((xfrm_flowi_dport(fl, &fl4->uli) ^ sel->dport) & sel->dport_mask) &&
220                 !((xfrm_flowi_sport(fl, &fl4->uli) ^ sel->sport) & sel->sport_mask) &&
221                 (fl4->flowi4_proto == sel->proto || !sel->proto) &&
222                 (fl4->flowi4_oif == sel->ifindex || !sel->ifindex);
223 }
224
225 static inline bool
226 __xfrm6_selector_match(const struct xfrm_selector *sel, const struct flowi *fl)
227 {
228         const struct flowi6 *fl6 = &fl->u.ip6;
229
230         return  addr_match(&fl6->daddr, &sel->daddr, sel->prefixlen_d) &&
231                 addr_match(&fl6->saddr, &sel->saddr, sel->prefixlen_s) &&
232                 !((xfrm_flowi_dport(fl, &fl6->uli) ^ sel->dport) & sel->dport_mask) &&
233                 !((xfrm_flowi_sport(fl, &fl6->uli) ^ sel->sport) & sel->sport_mask) &&
234                 (fl6->flowi6_proto == sel->proto || !sel->proto) &&
235                 (fl6->flowi6_oif == sel->ifindex || !sel->ifindex);
236 }
237
238 bool xfrm_selector_match(const struct xfrm_selector *sel, const struct flowi *fl,
239                          unsigned short family)
240 {
241         switch (family) {
242         case AF_INET:
243                 return __xfrm4_selector_match(sel, fl);
244         case AF_INET6:
245                 return __xfrm6_selector_match(sel, fl);
246         }
247         return false;
248 }
249
250 static const struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family)
251 {
252         const struct xfrm_policy_afinfo *afinfo;
253
254         if (unlikely(family >= ARRAY_SIZE(xfrm_policy_afinfo)))
255                 return NULL;
256         rcu_read_lock();
257         afinfo = rcu_dereference(xfrm_policy_afinfo[family]);
258         if (unlikely(!afinfo))
259                 rcu_read_unlock();
260         return afinfo;
261 }
262
263 /* Called with rcu_read_lock(). */
264 static const struct xfrm_if_cb *xfrm_if_get_cb(void)
265 {
266         return rcu_dereference(xfrm_if_cb);
267 }
268
269 struct dst_entry *__xfrm_dst_lookup(struct net *net, int tos, int oif,
270                                     const xfrm_address_t *saddr,
271                                     const xfrm_address_t *daddr,
272                                     int family, u32 mark)
273 {
274         const struct xfrm_policy_afinfo *afinfo;
275         struct dst_entry *dst;
276
277         afinfo = xfrm_policy_get_afinfo(family);
278         if (unlikely(afinfo == NULL))
279                 return ERR_PTR(-EAFNOSUPPORT);
280
281         dst = afinfo->dst_lookup(net, tos, oif, saddr, daddr, mark);
282
283         rcu_read_unlock();
284
285         return dst;
286 }
287 EXPORT_SYMBOL(__xfrm_dst_lookup);
288
289 static inline struct dst_entry *xfrm_dst_lookup(struct xfrm_state *x,
290                                                 int tos, int oif,
291                                                 xfrm_address_t *prev_saddr,
292                                                 xfrm_address_t *prev_daddr,
293                                                 int family, u32 mark)
294 {
295         struct net *net = xs_net(x);
296         xfrm_address_t *saddr = &x->props.saddr;
297         xfrm_address_t *daddr = &x->id.daddr;
298         struct dst_entry *dst;
299
300         if (x->type->flags & XFRM_TYPE_LOCAL_COADDR) {
301                 saddr = x->coaddr;
302                 daddr = prev_daddr;
303         }
304         if (x->type->flags & XFRM_TYPE_REMOTE_COADDR) {
305                 saddr = prev_saddr;
306                 daddr = x->coaddr;
307         }
308
309         dst = __xfrm_dst_lookup(net, tos, oif, saddr, daddr, family, mark);
310
311         if (!IS_ERR(dst)) {
312                 if (prev_saddr != saddr)
313                         memcpy(prev_saddr, saddr,  sizeof(*prev_saddr));
314                 if (prev_daddr != daddr)
315                         memcpy(prev_daddr, daddr,  sizeof(*prev_daddr));
316         }
317
318         return dst;
319 }
320
321 static inline unsigned long make_jiffies(long secs)
322 {
323         if (secs >= (MAX_SCHEDULE_TIMEOUT-1)/HZ)
324                 return MAX_SCHEDULE_TIMEOUT-1;
325         else
326                 return secs*HZ;
327 }
328
329 static void xfrm_policy_timer(struct timer_list *t)
330 {
331         struct xfrm_policy *xp = from_timer(xp, t, timer);
332         time64_t now = ktime_get_real_seconds();
333         time64_t next = TIME64_MAX;
334         int warn = 0;
335         int dir;
336
337         read_lock(&xp->lock);
338
339         if (unlikely(xp->walk.dead))
340                 goto out;
341
342         dir = xfrm_policy_id2dir(xp->index);
343
344         if (xp->lft.hard_add_expires_seconds) {
345                 time64_t tmo = xp->lft.hard_add_expires_seconds +
346                         xp->curlft.add_time - now;
347                 if (tmo <= 0)
348                         goto expired;
349                 if (tmo < next)
350                         next = tmo;
351         }
352         if (xp->lft.hard_use_expires_seconds) {
353                 time64_t tmo = xp->lft.hard_use_expires_seconds +
354                         (READ_ONCE(xp->curlft.use_time) ? : xp->curlft.add_time) - now;
355                 if (tmo <= 0)
356                         goto expired;
357                 if (tmo < next)
358                         next = tmo;
359         }
360         if (xp->lft.soft_add_expires_seconds) {
361                 time64_t tmo = xp->lft.soft_add_expires_seconds +
362                         xp->curlft.add_time - now;
363                 if (tmo <= 0) {
364                         warn = 1;
365                         tmo = XFRM_KM_TIMEOUT;
366                 }
367                 if (tmo < next)
368                         next = tmo;
369         }
370         if (xp->lft.soft_use_expires_seconds) {
371                 time64_t tmo = xp->lft.soft_use_expires_seconds +
372                         (READ_ONCE(xp->curlft.use_time) ? : xp->curlft.add_time) - now;
373                 if (tmo <= 0) {
374                         warn = 1;
375                         tmo = XFRM_KM_TIMEOUT;
376                 }
377                 if (tmo < next)
378                         next = tmo;
379         }
380
381         if (warn)
382                 km_policy_expired(xp, dir, 0, 0);
383         if (next != TIME64_MAX &&
384             !mod_timer(&xp->timer, jiffies + make_jiffies(next)))
385                 xfrm_pol_hold(xp);
386
387 out:
388         read_unlock(&xp->lock);
389         xfrm_pol_put(xp);
390         return;
391
392 expired:
393         read_unlock(&xp->lock);
394         if (!xfrm_policy_delete(xp, dir))
395                 km_policy_expired(xp, dir, 1, 0);
396         xfrm_pol_put(xp);
397 }
398
399 /* Allocate xfrm_policy. Not used here, it is supposed to be used by pfkeyv2
400  * SPD calls.
401  */
402
403 struct xfrm_policy *xfrm_policy_alloc(struct net *net, gfp_t gfp)
404 {
405         struct xfrm_policy *policy;
406
407         policy = kzalloc(sizeof(struct xfrm_policy), gfp);
408
409         if (policy) {
410                 write_pnet(&policy->xp_net, net);
411                 INIT_LIST_HEAD(&policy->walk.all);
412                 INIT_HLIST_NODE(&policy->bydst_inexact_list);
413                 INIT_HLIST_NODE(&policy->bydst);
414                 INIT_HLIST_NODE(&policy->byidx);
415                 rwlock_init(&policy->lock);
416                 refcount_set(&policy->refcnt, 1);
417                 skb_queue_head_init(&policy->polq.hold_queue);
418                 timer_setup(&policy->timer, xfrm_policy_timer, 0);
419                 timer_setup(&policy->polq.hold_timer,
420                             xfrm_policy_queue_process, 0);
421         }
422         return policy;
423 }
424 EXPORT_SYMBOL(xfrm_policy_alloc);
425
426 static void xfrm_policy_destroy_rcu(struct rcu_head *head)
427 {
428         struct xfrm_policy *policy = container_of(head, struct xfrm_policy, rcu);
429
430         security_xfrm_policy_free(policy->security);
431         kfree(policy);
432 }
433
434 /* Destroy xfrm_policy: descendant resources must be released to this moment. */
435
436 void xfrm_policy_destroy(struct xfrm_policy *policy)
437 {
438         BUG_ON(!policy->walk.dead);
439
440         if (del_timer(&policy->timer) || del_timer(&policy->polq.hold_timer))
441                 BUG();
442
443         xfrm_dev_policy_free(policy);
444         call_rcu(&policy->rcu, xfrm_policy_destroy_rcu);
445 }
446 EXPORT_SYMBOL(xfrm_policy_destroy);
447
448 /* Rule must be locked. Release descendant resources, announce
449  * entry dead. The rule must be unlinked from lists to the moment.
450  */
451
452 static void xfrm_policy_kill(struct xfrm_policy *policy)
453 {
454         write_lock_bh(&policy->lock);
455         policy->walk.dead = 1;
456         write_unlock_bh(&policy->lock);
457
458         atomic_inc(&policy->genid);
459
460         if (del_timer(&policy->polq.hold_timer))
461                 xfrm_pol_put(policy);
462         skb_queue_purge(&policy->polq.hold_queue);
463
464         if (del_timer(&policy->timer))
465                 xfrm_pol_put(policy);
466
467         xfrm_pol_put(policy);
468 }
469
470 static unsigned int xfrm_policy_hashmax __read_mostly = 1 * 1024 * 1024;
471
472 static inline unsigned int idx_hash(struct net *net, u32 index)
473 {
474         return __idx_hash(index, net->xfrm.policy_idx_hmask);
475 }
476
477 /* calculate policy hash thresholds */
478 static void __get_hash_thresh(struct net *net,
479                               unsigned short family, int dir,
480                               u8 *dbits, u8 *sbits)
481 {
482         switch (family) {
483         case AF_INET:
484                 *dbits = net->xfrm.policy_bydst[dir].dbits4;
485                 *sbits = net->xfrm.policy_bydst[dir].sbits4;
486                 break;
487
488         case AF_INET6:
489                 *dbits = net->xfrm.policy_bydst[dir].dbits6;
490                 *sbits = net->xfrm.policy_bydst[dir].sbits6;
491                 break;
492
493         default:
494                 *dbits = 0;
495                 *sbits = 0;
496         }
497 }
498
499 static struct hlist_head *policy_hash_bysel(struct net *net,
500                                             const struct xfrm_selector *sel,
501                                             unsigned short family, int dir)
502 {
503         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
504         unsigned int hash;
505         u8 dbits;
506         u8 sbits;
507
508         __get_hash_thresh(net, family, dir, &dbits, &sbits);
509         hash = __sel_hash(sel, family, hmask, dbits, sbits);
510
511         if (hash == hmask + 1)
512                 return NULL;
513
514         return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
515                      lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
516 }
517
518 static struct hlist_head *policy_hash_direct(struct net *net,
519                                              const xfrm_address_t *daddr,
520                                              const xfrm_address_t *saddr,
521                                              unsigned short family, int dir)
522 {
523         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
524         unsigned int hash;
525         u8 dbits;
526         u8 sbits;
527
528         __get_hash_thresh(net, family, dir, &dbits, &sbits);
529         hash = __addr_hash(daddr, saddr, family, hmask, dbits, sbits);
530
531         return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
532                      lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
533 }
534
535 static void xfrm_dst_hash_transfer(struct net *net,
536                                    struct hlist_head *list,
537                                    struct hlist_head *ndsttable,
538                                    unsigned int nhashmask,
539                                    int dir)
540 {
541         struct hlist_node *tmp, *entry0 = NULL;
542         struct xfrm_policy *pol;
543         unsigned int h0 = 0;
544         u8 dbits;
545         u8 sbits;
546
547 redo:
548         hlist_for_each_entry_safe(pol, tmp, list, bydst) {
549                 unsigned int h;
550
551                 __get_hash_thresh(net, pol->family, dir, &dbits, &sbits);
552                 h = __addr_hash(&pol->selector.daddr, &pol->selector.saddr,
553                                 pol->family, nhashmask, dbits, sbits);
554                 if (!entry0 || pol->xdo.type == XFRM_DEV_OFFLOAD_PACKET) {
555                         hlist_del_rcu(&pol->bydst);
556                         hlist_add_head_rcu(&pol->bydst, ndsttable + h);
557                         h0 = h;
558                 } else {
559                         if (h != h0)
560                                 continue;
561                         hlist_del_rcu(&pol->bydst);
562                         hlist_add_behind_rcu(&pol->bydst, entry0);
563                 }
564                 entry0 = &pol->bydst;
565         }
566         if (!hlist_empty(list)) {
567                 entry0 = NULL;
568                 goto redo;
569         }
570 }
571
572 static void xfrm_idx_hash_transfer(struct hlist_head *list,
573                                    struct hlist_head *nidxtable,
574                                    unsigned int nhashmask)
575 {
576         struct hlist_node *tmp;
577         struct xfrm_policy *pol;
578
579         hlist_for_each_entry_safe(pol, tmp, list, byidx) {
580                 unsigned int h;
581
582                 h = __idx_hash(pol->index, nhashmask);
583                 hlist_add_head(&pol->byidx, nidxtable+h);
584         }
585 }
586
587 static unsigned long xfrm_new_hash_mask(unsigned int old_hmask)
588 {
589         return ((old_hmask + 1) << 1) - 1;
590 }
591
592 static void xfrm_bydst_resize(struct net *net, int dir)
593 {
594         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
595         unsigned int nhashmask = xfrm_new_hash_mask(hmask);
596         unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head);
597         struct hlist_head *ndst = xfrm_hash_alloc(nsize);
598         struct hlist_head *odst;
599         int i;
600
601         if (!ndst)
602                 return;
603
604         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
605         write_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation);
606
607         odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table,
608                                 lockdep_is_held(&net->xfrm.xfrm_policy_lock));
609
610         for (i = hmask; i >= 0; i--)
611                 xfrm_dst_hash_transfer(net, odst + i, ndst, nhashmask, dir);
612
613         rcu_assign_pointer(net->xfrm.policy_bydst[dir].table, ndst);
614         net->xfrm.policy_bydst[dir].hmask = nhashmask;
615
616         write_seqcount_end(&net->xfrm.xfrm_policy_hash_generation);
617         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
618
619         synchronize_rcu();
620
621         xfrm_hash_free(odst, (hmask + 1) * sizeof(struct hlist_head));
622 }
623
624 static void xfrm_byidx_resize(struct net *net)
625 {
626         unsigned int hmask = net->xfrm.policy_idx_hmask;
627         unsigned int nhashmask = xfrm_new_hash_mask(hmask);
628         unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head);
629         struct hlist_head *oidx = net->xfrm.policy_byidx;
630         struct hlist_head *nidx = xfrm_hash_alloc(nsize);
631         int i;
632
633         if (!nidx)
634                 return;
635
636         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
637
638         for (i = hmask; i >= 0; i--)
639                 xfrm_idx_hash_transfer(oidx + i, nidx, nhashmask);
640
641         net->xfrm.policy_byidx = nidx;
642         net->xfrm.policy_idx_hmask = nhashmask;
643
644         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
645
646         xfrm_hash_free(oidx, (hmask + 1) * sizeof(struct hlist_head));
647 }
648
649 static inline int xfrm_bydst_should_resize(struct net *net, int dir, int *total)
650 {
651         unsigned int cnt = net->xfrm.policy_count[dir];
652         unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
653
654         if (total)
655                 *total += cnt;
656
657         if ((hmask + 1) < xfrm_policy_hashmax &&
658             cnt > hmask)
659                 return 1;
660
661         return 0;
662 }
663
664 static inline int xfrm_byidx_should_resize(struct net *net, int total)
665 {
666         unsigned int hmask = net->xfrm.policy_idx_hmask;
667
668         if ((hmask + 1) < xfrm_policy_hashmax &&
669             total > hmask)
670                 return 1;
671
672         return 0;
673 }
674
675 void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si)
676 {
677         si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN];
678         si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT];
679         si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD];
680         si->inscnt = net->xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX];
681         si->outscnt = net->xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX];
682         si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX];
683         si->spdhcnt = net->xfrm.policy_idx_hmask;
684         si->spdhmcnt = xfrm_policy_hashmax;
685 }
686 EXPORT_SYMBOL(xfrm_spd_getinfo);
687
688 static DEFINE_MUTEX(hash_resize_mutex);
689 static void xfrm_hash_resize(struct work_struct *work)
690 {
691         struct net *net = container_of(work, struct net, xfrm.policy_hash_work);
692         int dir, total;
693
694         mutex_lock(&hash_resize_mutex);
695
696         total = 0;
697         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
698                 if (xfrm_bydst_should_resize(net, dir, &total))
699                         xfrm_bydst_resize(net, dir);
700         }
701         if (xfrm_byidx_should_resize(net, total))
702                 xfrm_byidx_resize(net);
703
704         mutex_unlock(&hash_resize_mutex);
705 }
706
707 /* Make sure *pol can be inserted into fastbin.
708  * Useful to check that later insert requests will be successful
709  * (provided xfrm_policy_lock is held throughout).
710  */
711 static struct xfrm_pol_inexact_bin *
712 xfrm_policy_inexact_alloc_bin(const struct xfrm_policy *pol, u8 dir)
713 {
714         struct xfrm_pol_inexact_bin *bin, *prev;
715         struct xfrm_pol_inexact_key k = {
716                 .family = pol->family,
717                 .type = pol->type,
718                 .dir = dir,
719                 .if_id = pol->if_id,
720         };
721         struct net *net = xp_net(pol);
722
723         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
724
725         write_pnet(&k.net, net);
726         bin = rhashtable_lookup_fast(&xfrm_policy_inexact_table, &k,
727                                      xfrm_pol_inexact_params);
728         if (bin)
729                 return bin;
730
731         bin = kzalloc(sizeof(*bin), GFP_ATOMIC);
732         if (!bin)
733                 return NULL;
734
735         bin->k = k;
736         INIT_HLIST_HEAD(&bin->hhead);
737         bin->root_d = RB_ROOT;
738         bin->root_s = RB_ROOT;
739         seqcount_spinlock_init(&bin->count, &net->xfrm.xfrm_policy_lock);
740
741         prev = rhashtable_lookup_get_insert_key(&xfrm_policy_inexact_table,
742                                                 &bin->k, &bin->head,
743                                                 xfrm_pol_inexact_params);
744         if (!prev) {
745                 list_add(&bin->inexact_bins, &net->xfrm.inexact_bins);
746                 return bin;
747         }
748
749         kfree(bin);
750
751         return IS_ERR(prev) ? NULL : prev;
752 }
753
754 static bool xfrm_pol_inexact_addr_use_any_list(const xfrm_address_t *addr,
755                                                int family, u8 prefixlen)
756 {
757         if (xfrm_addr_any(addr, family))
758                 return true;
759
760         if (family == AF_INET6 && prefixlen < INEXACT_PREFIXLEN_IPV6)
761                 return true;
762
763         if (family == AF_INET && prefixlen < INEXACT_PREFIXLEN_IPV4)
764                 return true;
765
766         return false;
767 }
768
769 static bool
770 xfrm_policy_inexact_insert_use_any_list(const struct xfrm_policy *policy)
771 {
772         const xfrm_address_t *addr;
773         bool saddr_any, daddr_any;
774         u8 prefixlen;
775
776         addr = &policy->selector.saddr;
777         prefixlen = policy->selector.prefixlen_s;
778
779         saddr_any = xfrm_pol_inexact_addr_use_any_list(addr,
780                                                        policy->family,
781                                                        prefixlen);
782         addr = &policy->selector.daddr;
783         prefixlen = policy->selector.prefixlen_d;
784         daddr_any = xfrm_pol_inexact_addr_use_any_list(addr,
785                                                        policy->family,
786                                                        prefixlen);
787         return saddr_any && daddr_any;
788 }
789
790 static void xfrm_pol_inexact_node_init(struct xfrm_pol_inexact_node *node,
791                                        const xfrm_address_t *addr, u8 prefixlen)
792 {
793         node->addr = *addr;
794         node->prefixlen = prefixlen;
795 }
796
797 static struct xfrm_pol_inexact_node *
798 xfrm_pol_inexact_node_alloc(const xfrm_address_t *addr, u8 prefixlen)
799 {
800         struct xfrm_pol_inexact_node *node;
801
802         node = kzalloc(sizeof(*node), GFP_ATOMIC);
803         if (node)
804                 xfrm_pol_inexact_node_init(node, addr, prefixlen);
805
806         return node;
807 }
808
809 static int xfrm_policy_addr_delta(const xfrm_address_t *a,
810                                   const xfrm_address_t *b,
811                                   u8 prefixlen, u16 family)
812 {
813         u32 ma, mb, mask;
814         unsigned int pdw, pbi;
815         int delta = 0;
816
817         switch (family) {
818         case AF_INET:
819                 if (prefixlen == 0)
820                         return 0;
821                 mask = ~0U << (32 - prefixlen);
822                 ma = ntohl(a->a4) & mask;
823                 mb = ntohl(b->a4) & mask;
824                 if (ma < mb)
825                         delta = -1;
826                 else if (ma > mb)
827                         delta = 1;
828                 break;
829         case AF_INET6:
830                 pdw = prefixlen >> 5;
831                 pbi = prefixlen & 0x1f;
832
833                 if (pdw) {
834                         delta = memcmp(a->a6, b->a6, pdw << 2);
835                         if (delta)
836                                 return delta;
837                 }
838                 if (pbi) {
839                         mask = ~0U << (32 - pbi);
840                         ma = ntohl(a->a6[pdw]) & mask;
841                         mb = ntohl(b->a6[pdw]) & mask;
842                         if (ma < mb)
843                                 delta = -1;
844                         else if (ma > mb)
845                                 delta = 1;
846                 }
847                 break;
848         default:
849                 break;
850         }
851
852         return delta;
853 }
854
855 static void xfrm_policy_inexact_list_reinsert(struct net *net,
856                                               struct xfrm_pol_inexact_node *n,
857                                               u16 family)
858 {
859         unsigned int matched_s, matched_d;
860         struct xfrm_policy *policy, *p;
861
862         matched_s = 0;
863         matched_d = 0;
864
865         list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) {
866                 struct hlist_node *newpos = NULL;
867                 bool matches_s, matches_d;
868
869                 if (policy->walk.dead || !policy->bydst_reinsert)
870                         continue;
871
872                 WARN_ON_ONCE(policy->family != family);
873
874                 policy->bydst_reinsert = false;
875                 hlist_for_each_entry(p, &n->hhead, bydst) {
876                         if (policy->priority > p->priority)
877                                 newpos = &p->bydst;
878                         else if (policy->priority == p->priority &&
879                                  policy->pos > p->pos)
880                                 newpos = &p->bydst;
881                         else
882                                 break;
883                 }
884
885                 if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
886                         hlist_add_behind_rcu(&policy->bydst, newpos);
887                 else
888                         hlist_add_head_rcu(&policy->bydst, &n->hhead);
889
890                 /* paranoia checks follow.
891                  * Check that the reinserted policy matches at least
892                  * saddr or daddr for current node prefix.
893                  *
894                  * Matching both is fine, matching saddr in one policy
895                  * (but not daddr) and then matching only daddr in another
896                  * is a bug.
897                  */
898                 matches_s = xfrm_policy_addr_delta(&policy->selector.saddr,
899                                                    &n->addr,
900                                                    n->prefixlen,
901                                                    family) == 0;
902                 matches_d = xfrm_policy_addr_delta(&policy->selector.daddr,
903                                                    &n->addr,
904                                                    n->prefixlen,
905                                                    family) == 0;
906                 if (matches_s && matches_d)
907                         continue;
908
909                 WARN_ON_ONCE(!matches_s && !matches_d);
910                 if (matches_s)
911                         matched_s++;
912                 if (matches_d)
913                         matched_d++;
914                 WARN_ON_ONCE(matched_s && matched_d);
915         }
916 }
917
918 static void xfrm_policy_inexact_node_reinsert(struct net *net,
919                                               struct xfrm_pol_inexact_node *n,
920                                               struct rb_root *new,
921                                               u16 family)
922 {
923         struct xfrm_pol_inexact_node *node;
924         struct rb_node **p, *parent;
925
926         /* we should not have another subtree here */
927         WARN_ON_ONCE(!RB_EMPTY_ROOT(&n->root));
928 restart:
929         parent = NULL;
930         p = &new->rb_node;
931         while (*p) {
932                 u8 prefixlen;
933                 int delta;
934
935                 parent = *p;
936                 node = rb_entry(*p, struct xfrm_pol_inexact_node, node);
937
938                 prefixlen = min(node->prefixlen, n->prefixlen);
939
940                 delta = xfrm_policy_addr_delta(&n->addr, &node->addr,
941                                                prefixlen, family);
942                 if (delta < 0) {
943                         p = &parent->rb_left;
944                 } else if (delta > 0) {
945                         p = &parent->rb_right;
946                 } else {
947                         bool same_prefixlen = node->prefixlen == n->prefixlen;
948                         struct xfrm_policy *tmp;
949
950                         hlist_for_each_entry(tmp, &n->hhead, bydst) {
951                                 tmp->bydst_reinsert = true;
952                                 hlist_del_rcu(&tmp->bydst);
953                         }
954
955                         node->prefixlen = prefixlen;
956
957                         xfrm_policy_inexact_list_reinsert(net, node, family);
958
959                         if (same_prefixlen) {
960                                 kfree_rcu(n, rcu);
961                                 return;
962                         }
963
964                         rb_erase(*p, new);
965                         kfree_rcu(n, rcu);
966                         n = node;
967                         goto restart;
968                 }
969         }
970
971         rb_link_node_rcu(&n->node, parent, p);
972         rb_insert_color(&n->node, new);
973 }
974
975 /* merge nodes v and n */
976 static void xfrm_policy_inexact_node_merge(struct net *net,
977                                            struct xfrm_pol_inexact_node *v,
978                                            struct xfrm_pol_inexact_node *n,
979                                            u16 family)
980 {
981         struct xfrm_pol_inexact_node *node;
982         struct xfrm_policy *tmp;
983         struct rb_node *rnode;
984
985         /* To-be-merged node v has a subtree.
986          *
987          * Dismantle it and insert its nodes to n->root.
988          */
989         while ((rnode = rb_first(&v->root)) != NULL) {
990                 node = rb_entry(rnode, struct xfrm_pol_inexact_node, node);
991                 rb_erase(&node->node, &v->root);
992                 xfrm_policy_inexact_node_reinsert(net, node, &n->root,
993                                                   family);
994         }
995
996         hlist_for_each_entry(tmp, &v->hhead, bydst) {
997                 tmp->bydst_reinsert = true;
998                 hlist_del_rcu(&tmp->bydst);
999         }
1000
1001         xfrm_policy_inexact_list_reinsert(net, n, family);
1002 }
1003
1004 static struct xfrm_pol_inexact_node *
1005 xfrm_policy_inexact_insert_node(struct net *net,
1006                                 struct rb_root *root,
1007                                 xfrm_address_t *addr,
1008                                 u16 family, u8 prefixlen, u8 dir)
1009 {
1010         struct xfrm_pol_inexact_node *cached = NULL;
1011         struct rb_node **p, *parent = NULL;
1012         struct xfrm_pol_inexact_node *node;
1013
1014         p = &root->rb_node;
1015         while (*p) {
1016                 int delta;
1017
1018                 parent = *p;
1019                 node = rb_entry(*p, struct xfrm_pol_inexact_node, node);
1020
1021                 delta = xfrm_policy_addr_delta(addr, &node->addr,
1022                                                node->prefixlen,
1023                                                family);
1024                 if (delta == 0 && prefixlen >= node->prefixlen) {
1025                         WARN_ON_ONCE(cached); /* ipsec policies got lost */
1026                         return node;
1027                 }
1028
1029                 if (delta < 0)
1030                         p = &parent->rb_left;
1031                 else
1032                         p = &parent->rb_right;
1033
1034                 if (prefixlen < node->prefixlen) {
1035                         delta = xfrm_policy_addr_delta(addr, &node->addr,
1036                                                        prefixlen,
1037                                                        family);
1038                         if (delta)
1039                                 continue;
1040
1041                         /* This node is a subnet of the new prefix. It needs
1042                          * to be removed and re-inserted with the smaller
1043                          * prefix and all nodes that are now also covered
1044                          * by the reduced prefixlen.
1045                          */
1046                         rb_erase(&node->node, root);
1047
1048                         if (!cached) {
1049                                 xfrm_pol_inexact_node_init(node, addr,
1050                                                            prefixlen);
1051                                 cached = node;
1052                         } else {
1053                                 /* This node also falls within the new
1054                                  * prefixlen. Merge the to-be-reinserted
1055                                  * node and this one.
1056                                  */
1057                                 xfrm_policy_inexact_node_merge(net, node,
1058                                                                cached, family);
1059                                 kfree_rcu(node, rcu);
1060                         }
1061
1062                         /* restart */
1063                         p = &root->rb_node;
1064                         parent = NULL;
1065                 }
1066         }
1067
1068         node = cached;
1069         if (!node) {
1070                 node = xfrm_pol_inexact_node_alloc(addr, prefixlen);
1071                 if (!node)
1072                         return NULL;
1073         }
1074
1075         rb_link_node_rcu(&node->node, parent, p);
1076         rb_insert_color(&node->node, root);
1077
1078         return node;
1079 }
1080
1081 static void xfrm_policy_inexact_gc_tree(struct rb_root *r, bool rm)
1082 {
1083         struct xfrm_pol_inexact_node *node;
1084         struct rb_node *rn = rb_first(r);
1085
1086         while (rn) {
1087                 node = rb_entry(rn, struct xfrm_pol_inexact_node, node);
1088
1089                 xfrm_policy_inexact_gc_tree(&node->root, rm);
1090                 rn = rb_next(rn);
1091
1092                 if (!hlist_empty(&node->hhead) || !RB_EMPTY_ROOT(&node->root)) {
1093                         WARN_ON_ONCE(rm);
1094                         continue;
1095                 }
1096
1097                 rb_erase(&node->node, r);
1098                 kfree_rcu(node, rcu);
1099         }
1100 }
1101
1102 static void __xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b, bool net_exit)
1103 {
1104         write_seqcount_begin(&b->count);
1105         xfrm_policy_inexact_gc_tree(&b->root_d, net_exit);
1106         xfrm_policy_inexact_gc_tree(&b->root_s, net_exit);
1107         write_seqcount_end(&b->count);
1108
1109         if (!RB_EMPTY_ROOT(&b->root_d) || !RB_EMPTY_ROOT(&b->root_s) ||
1110             !hlist_empty(&b->hhead)) {
1111                 WARN_ON_ONCE(net_exit);
1112                 return;
1113         }
1114
1115         if (rhashtable_remove_fast(&xfrm_policy_inexact_table, &b->head,
1116                                    xfrm_pol_inexact_params) == 0) {
1117                 list_del(&b->inexact_bins);
1118                 kfree_rcu(b, rcu);
1119         }
1120 }
1121
1122 static void xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b)
1123 {
1124         struct net *net = read_pnet(&b->k.net);
1125
1126         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1127         __xfrm_policy_inexact_prune_bin(b, false);
1128         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1129 }
1130
1131 static void __xfrm_policy_inexact_flush(struct net *net)
1132 {
1133         struct xfrm_pol_inexact_bin *bin, *t;
1134
1135         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1136
1137         list_for_each_entry_safe(bin, t, &net->xfrm.inexact_bins, inexact_bins)
1138                 __xfrm_policy_inexact_prune_bin(bin, false);
1139 }
1140
1141 static struct hlist_head *
1142 xfrm_policy_inexact_alloc_chain(struct xfrm_pol_inexact_bin *bin,
1143                                 struct xfrm_policy *policy, u8 dir)
1144 {
1145         struct xfrm_pol_inexact_node *n;
1146         struct net *net;
1147
1148         net = xp_net(policy);
1149         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1150
1151         if (xfrm_policy_inexact_insert_use_any_list(policy))
1152                 return &bin->hhead;
1153
1154         if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.daddr,
1155                                                policy->family,
1156                                                policy->selector.prefixlen_d)) {
1157                 write_seqcount_begin(&bin->count);
1158                 n = xfrm_policy_inexact_insert_node(net,
1159                                                     &bin->root_s,
1160                                                     &policy->selector.saddr,
1161                                                     policy->family,
1162                                                     policy->selector.prefixlen_s,
1163                                                     dir);
1164                 write_seqcount_end(&bin->count);
1165                 if (!n)
1166                         return NULL;
1167
1168                 return &n->hhead;
1169         }
1170
1171         /* daddr is fixed */
1172         write_seqcount_begin(&bin->count);
1173         n = xfrm_policy_inexact_insert_node(net,
1174                                             &bin->root_d,
1175                                             &policy->selector.daddr,
1176                                             policy->family,
1177                                             policy->selector.prefixlen_d, dir);
1178         write_seqcount_end(&bin->count);
1179         if (!n)
1180                 return NULL;
1181
1182         /* saddr is wildcard */
1183         if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.saddr,
1184                                                policy->family,
1185                                                policy->selector.prefixlen_s))
1186                 return &n->hhead;
1187
1188         write_seqcount_begin(&bin->count);
1189         n = xfrm_policy_inexact_insert_node(net,
1190                                             &n->root,
1191                                             &policy->selector.saddr,
1192                                             policy->family,
1193                                             policy->selector.prefixlen_s, dir);
1194         write_seqcount_end(&bin->count);
1195         if (!n)
1196                 return NULL;
1197
1198         return &n->hhead;
1199 }
1200
1201 static struct xfrm_policy *
1202 xfrm_policy_inexact_insert(struct xfrm_policy *policy, u8 dir, int excl)
1203 {
1204         struct xfrm_pol_inexact_bin *bin;
1205         struct xfrm_policy *delpol;
1206         struct hlist_head *chain;
1207         struct net *net;
1208
1209         bin = xfrm_policy_inexact_alloc_bin(policy, dir);
1210         if (!bin)
1211                 return ERR_PTR(-ENOMEM);
1212
1213         net = xp_net(policy);
1214         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1215
1216         chain = xfrm_policy_inexact_alloc_chain(bin, policy, dir);
1217         if (!chain) {
1218                 __xfrm_policy_inexact_prune_bin(bin, false);
1219                 return ERR_PTR(-ENOMEM);
1220         }
1221
1222         delpol = xfrm_policy_insert_list(chain, policy, excl);
1223         if (delpol && excl) {
1224                 __xfrm_policy_inexact_prune_bin(bin, false);
1225                 return ERR_PTR(-EEXIST);
1226         }
1227
1228         chain = &net->xfrm.policy_inexact[dir];
1229         xfrm_policy_insert_inexact_list(chain, policy);
1230
1231         if (delpol)
1232                 __xfrm_policy_inexact_prune_bin(bin, false);
1233
1234         return delpol;
1235 }
1236
1237 static void xfrm_hash_rebuild(struct work_struct *work)
1238 {
1239         struct net *net = container_of(work, struct net,
1240                                        xfrm.policy_hthresh.work);
1241         unsigned int hmask;
1242         struct xfrm_policy *pol;
1243         struct xfrm_policy *policy;
1244         struct hlist_head *chain;
1245         struct hlist_head *odst;
1246         struct hlist_node *newpos;
1247         int i;
1248         int dir;
1249         unsigned seq;
1250         u8 lbits4, rbits4, lbits6, rbits6;
1251
1252         mutex_lock(&hash_resize_mutex);
1253
1254         /* read selector prefixlen thresholds */
1255         do {
1256                 seq = read_seqbegin(&net->xfrm.policy_hthresh.lock);
1257
1258                 lbits4 = net->xfrm.policy_hthresh.lbits4;
1259                 rbits4 = net->xfrm.policy_hthresh.rbits4;
1260                 lbits6 = net->xfrm.policy_hthresh.lbits6;
1261                 rbits6 = net->xfrm.policy_hthresh.rbits6;
1262         } while (read_seqretry(&net->xfrm.policy_hthresh.lock, seq));
1263
1264         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1265         write_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation);
1266
1267         /* make sure that we can insert the indirect policies again before
1268          * we start with destructive action.
1269          */
1270         list_for_each_entry(policy, &net->xfrm.policy_all, walk.all) {
1271                 struct xfrm_pol_inexact_bin *bin;
1272                 u8 dbits, sbits;
1273
1274                 if (policy->walk.dead)
1275                         continue;
1276
1277                 dir = xfrm_policy_id2dir(policy->index);
1278                 if (dir >= XFRM_POLICY_MAX)
1279                         continue;
1280
1281                 if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) {
1282                         if (policy->family == AF_INET) {
1283                                 dbits = rbits4;
1284                                 sbits = lbits4;
1285                         } else {
1286                                 dbits = rbits6;
1287                                 sbits = lbits6;
1288                         }
1289                 } else {
1290                         if (policy->family == AF_INET) {
1291                                 dbits = lbits4;
1292                                 sbits = rbits4;
1293                         } else {
1294                                 dbits = lbits6;
1295                                 sbits = rbits6;
1296                         }
1297                 }
1298
1299                 if (policy->selector.prefixlen_d < dbits ||
1300                     policy->selector.prefixlen_s < sbits)
1301                         continue;
1302
1303                 bin = xfrm_policy_inexact_alloc_bin(policy, dir);
1304                 if (!bin)
1305                         goto out_unlock;
1306
1307                 if (!xfrm_policy_inexact_alloc_chain(bin, policy, dir))
1308                         goto out_unlock;
1309         }
1310
1311         /* reset the bydst and inexact table in all directions */
1312         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
1313                 struct hlist_node *n;
1314
1315                 hlist_for_each_entry_safe(policy, n,
1316                                           &net->xfrm.policy_inexact[dir],
1317                                           bydst_inexact_list) {
1318                         hlist_del_rcu(&policy->bydst);
1319                         hlist_del_init(&policy->bydst_inexact_list);
1320                 }
1321
1322                 hmask = net->xfrm.policy_bydst[dir].hmask;
1323                 odst = net->xfrm.policy_bydst[dir].table;
1324                 for (i = hmask; i >= 0; i--) {
1325                         hlist_for_each_entry_safe(policy, n, odst + i, bydst)
1326                                 hlist_del_rcu(&policy->bydst);
1327                 }
1328                 if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) {
1329                         /* dir out => dst = remote, src = local */
1330                         net->xfrm.policy_bydst[dir].dbits4 = rbits4;
1331                         net->xfrm.policy_bydst[dir].sbits4 = lbits4;
1332                         net->xfrm.policy_bydst[dir].dbits6 = rbits6;
1333                         net->xfrm.policy_bydst[dir].sbits6 = lbits6;
1334                 } else {
1335                         /* dir in/fwd => dst = local, src = remote */
1336                         net->xfrm.policy_bydst[dir].dbits4 = lbits4;
1337                         net->xfrm.policy_bydst[dir].sbits4 = rbits4;
1338                         net->xfrm.policy_bydst[dir].dbits6 = lbits6;
1339                         net->xfrm.policy_bydst[dir].sbits6 = rbits6;
1340                 }
1341         }
1342
1343         /* re-insert all policies by order of creation */
1344         list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) {
1345                 if (policy->walk.dead)
1346                         continue;
1347                 dir = xfrm_policy_id2dir(policy->index);
1348                 if (dir >= XFRM_POLICY_MAX) {
1349                         /* skip socket policies */
1350                         continue;
1351                 }
1352                 newpos = NULL;
1353                 chain = policy_hash_bysel(net, &policy->selector,
1354                                           policy->family, dir);
1355
1356                 if (!chain) {
1357                         void *p = xfrm_policy_inexact_insert(policy, dir, 0);
1358
1359                         WARN_ONCE(IS_ERR(p), "reinsert: %ld\n", PTR_ERR(p));
1360                         continue;
1361                 }
1362
1363                 hlist_for_each_entry(pol, chain, bydst) {
1364                         if (policy->priority >= pol->priority)
1365                                 newpos = &pol->bydst;
1366                         else
1367                                 break;
1368                 }
1369                 if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
1370                         hlist_add_behind_rcu(&policy->bydst, newpos);
1371                 else
1372                         hlist_add_head_rcu(&policy->bydst, chain);
1373         }
1374
1375 out_unlock:
1376         __xfrm_policy_inexact_flush(net);
1377         write_seqcount_end(&net->xfrm.xfrm_policy_hash_generation);
1378         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1379
1380         mutex_unlock(&hash_resize_mutex);
1381 }
1382
1383 void xfrm_policy_hash_rebuild(struct net *net)
1384 {
1385         schedule_work(&net->xfrm.policy_hthresh.work);
1386 }
1387 EXPORT_SYMBOL(xfrm_policy_hash_rebuild);
1388
1389 /* Generate new index... KAME seems to generate them ordered by cost
1390  * of an absolute inpredictability of ordering of rules. This will not pass. */
1391 static u32 xfrm_gen_index(struct net *net, int dir, u32 index)
1392 {
1393         for (;;) {
1394                 struct hlist_head *list;
1395                 struct xfrm_policy *p;
1396                 u32 idx;
1397                 int found;
1398
1399                 if (!index) {
1400                         idx = (net->xfrm.idx_generator | dir);
1401                         net->xfrm.idx_generator += 8;
1402                 } else {
1403                         idx = index;
1404                         index = 0;
1405                 }
1406
1407                 if (idx == 0)
1408                         idx = 8;
1409                 list = net->xfrm.policy_byidx + idx_hash(net, idx);
1410                 found = 0;
1411                 hlist_for_each_entry(p, list, byidx) {
1412                         if (p->index == idx) {
1413                                 found = 1;
1414                                 break;
1415                         }
1416                 }
1417                 if (!found)
1418                         return idx;
1419         }
1420 }
1421
1422 static inline int selector_cmp(struct xfrm_selector *s1, struct xfrm_selector *s2)
1423 {
1424         u32 *p1 = (u32 *) s1;
1425         u32 *p2 = (u32 *) s2;
1426         int len = sizeof(struct xfrm_selector) / sizeof(u32);
1427         int i;
1428
1429         for (i = 0; i < len; i++) {
1430                 if (p1[i] != p2[i])
1431                         return 1;
1432         }
1433
1434         return 0;
1435 }
1436
1437 static void xfrm_policy_requeue(struct xfrm_policy *old,
1438                                 struct xfrm_policy *new)
1439 {
1440         struct xfrm_policy_queue *pq = &old->polq;
1441         struct sk_buff_head list;
1442
1443         if (skb_queue_empty(&pq->hold_queue))
1444                 return;
1445
1446         __skb_queue_head_init(&list);
1447
1448         spin_lock_bh(&pq->hold_queue.lock);
1449         skb_queue_splice_init(&pq->hold_queue, &list);
1450         if (del_timer(&pq->hold_timer))
1451                 xfrm_pol_put(old);
1452         spin_unlock_bh(&pq->hold_queue.lock);
1453
1454         pq = &new->polq;
1455
1456         spin_lock_bh(&pq->hold_queue.lock);
1457         skb_queue_splice(&list, &pq->hold_queue);
1458         pq->timeout = XFRM_QUEUE_TMO_MIN;
1459         if (!mod_timer(&pq->hold_timer, jiffies))
1460                 xfrm_pol_hold(new);
1461         spin_unlock_bh(&pq->hold_queue.lock);
1462 }
1463
1464 static inline bool xfrm_policy_mark_match(const struct xfrm_mark *mark,
1465                                           struct xfrm_policy *pol)
1466 {
1467         return mark->v == pol->mark.v && mark->m == pol->mark.m;
1468 }
1469
1470 static u32 xfrm_pol_bin_key(const void *data, u32 len, u32 seed)
1471 {
1472         const struct xfrm_pol_inexact_key *k = data;
1473         u32 a = k->type << 24 | k->dir << 16 | k->family;
1474
1475         return jhash_3words(a, k->if_id, net_hash_mix(read_pnet(&k->net)),
1476                             seed);
1477 }
1478
1479 static u32 xfrm_pol_bin_obj(const void *data, u32 len, u32 seed)
1480 {
1481         const struct xfrm_pol_inexact_bin *b = data;
1482
1483         return xfrm_pol_bin_key(&b->k, 0, seed);
1484 }
1485
1486 static int xfrm_pol_bin_cmp(struct rhashtable_compare_arg *arg,
1487                             const void *ptr)
1488 {
1489         const struct xfrm_pol_inexact_key *key = arg->key;
1490         const struct xfrm_pol_inexact_bin *b = ptr;
1491         int ret;
1492
1493         if (!net_eq(read_pnet(&b->k.net), read_pnet(&key->net)))
1494                 return -1;
1495
1496         ret = b->k.dir ^ key->dir;
1497         if (ret)
1498                 return ret;
1499
1500         ret = b->k.type ^ key->type;
1501         if (ret)
1502                 return ret;
1503
1504         ret = b->k.family ^ key->family;
1505         if (ret)
1506                 return ret;
1507
1508         return b->k.if_id ^ key->if_id;
1509 }
1510
1511 static const struct rhashtable_params xfrm_pol_inexact_params = {
1512         .head_offset            = offsetof(struct xfrm_pol_inexact_bin, head),
1513         .hashfn                 = xfrm_pol_bin_key,
1514         .obj_hashfn             = xfrm_pol_bin_obj,
1515         .obj_cmpfn              = xfrm_pol_bin_cmp,
1516         .automatic_shrinking    = true,
1517 };
1518
1519 static void xfrm_policy_insert_inexact_list(struct hlist_head *chain,
1520                                             struct xfrm_policy *policy)
1521 {
1522         struct xfrm_policy *pol, *delpol = NULL;
1523         struct hlist_node *newpos = NULL;
1524         int i = 0;
1525
1526         hlist_for_each_entry(pol, chain, bydst_inexact_list) {
1527                 if (pol->type == policy->type &&
1528                     pol->if_id == policy->if_id &&
1529                     !selector_cmp(&pol->selector, &policy->selector) &&
1530                     xfrm_policy_mark_match(&policy->mark, pol) &&
1531                     xfrm_sec_ctx_match(pol->security, policy->security) &&
1532                     !WARN_ON(delpol)) {
1533                         delpol = pol;
1534                         if (policy->priority > pol->priority)
1535                                 continue;
1536                 } else if (policy->priority >= pol->priority) {
1537                         newpos = &pol->bydst_inexact_list;
1538                         continue;
1539                 }
1540                 if (delpol)
1541                         break;
1542         }
1543
1544         if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
1545                 hlist_add_behind_rcu(&policy->bydst_inexact_list, newpos);
1546         else
1547                 hlist_add_head_rcu(&policy->bydst_inexact_list, chain);
1548
1549         hlist_for_each_entry(pol, chain, bydst_inexact_list) {
1550                 pol->pos = i;
1551                 i++;
1552         }
1553 }
1554
1555 static struct xfrm_policy *xfrm_policy_insert_list(struct hlist_head *chain,
1556                                                    struct xfrm_policy *policy,
1557                                                    bool excl)
1558 {
1559         struct xfrm_policy *pol, *newpos = NULL, *delpol = NULL;
1560
1561         hlist_for_each_entry(pol, chain, bydst) {
1562                 if (pol->type == policy->type &&
1563                     pol->if_id == policy->if_id &&
1564                     !selector_cmp(&pol->selector, &policy->selector) &&
1565                     xfrm_policy_mark_match(&policy->mark, pol) &&
1566                     xfrm_sec_ctx_match(pol->security, policy->security) &&
1567                     !WARN_ON(delpol)) {
1568                         if (excl)
1569                                 return ERR_PTR(-EEXIST);
1570                         delpol = pol;
1571                         if (policy->priority > pol->priority)
1572                                 continue;
1573                 } else if (policy->priority >= pol->priority) {
1574                         newpos = pol;
1575                         continue;
1576                 }
1577                 if (delpol)
1578                         break;
1579         }
1580
1581         if (newpos && policy->xdo.type != XFRM_DEV_OFFLOAD_PACKET)
1582                 hlist_add_behind_rcu(&policy->bydst, &newpos->bydst);
1583         else
1584                 /* Packet offload policies enter to the head
1585                  * to speed-up lookups.
1586                  */
1587                 hlist_add_head_rcu(&policy->bydst, chain);
1588
1589         return delpol;
1590 }
1591
1592 int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
1593 {
1594         struct net *net = xp_net(policy);
1595         struct xfrm_policy *delpol;
1596         struct hlist_head *chain;
1597
1598         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1599         chain = policy_hash_bysel(net, &policy->selector, policy->family, dir);
1600         if (chain)
1601                 delpol = xfrm_policy_insert_list(chain, policy, excl);
1602         else
1603                 delpol = xfrm_policy_inexact_insert(policy, dir, excl);
1604
1605         if (IS_ERR(delpol)) {
1606                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1607                 return PTR_ERR(delpol);
1608         }
1609
1610         __xfrm_policy_link(policy, dir);
1611
1612         /* After previous checking, family can either be AF_INET or AF_INET6 */
1613         if (policy->family == AF_INET)
1614                 rt_genid_bump_ipv4(net);
1615         else
1616                 rt_genid_bump_ipv6(net);
1617
1618         if (delpol) {
1619                 xfrm_policy_requeue(delpol, policy);
1620                 __xfrm_policy_unlink(delpol, dir);
1621         }
1622         policy->index = delpol ? delpol->index : xfrm_gen_index(net, dir, policy->index);
1623         hlist_add_head(&policy->byidx, net->xfrm.policy_byidx+idx_hash(net, policy->index));
1624         policy->curlft.add_time = ktime_get_real_seconds();
1625         policy->curlft.use_time = 0;
1626         if (!mod_timer(&policy->timer, jiffies + HZ))
1627                 xfrm_pol_hold(policy);
1628         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1629
1630         if (delpol)
1631                 xfrm_policy_kill(delpol);
1632         else if (xfrm_bydst_should_resize(net, dir, NULL))
1633                 schedule_work(&net->xfrm.policy_hash_work);
1634
1635         return 0;
1636 }
1637 EXPORT_SYMBOL(xfrm_policy_insert);
1638
1639 static struct xfrm_policy *
1640 __xfrm_policy_bysel_ctx(struct hlist_head *chain, const struct xfrm_mark *mark,
1641                         u32 if_id, u8 type, int dir, struct xfrm_selector *sel,
1642                         struct xfrm_sec_ctx *ctx)
1643 {
1644         struct xfrm_policy *pol;
1645
1646         if (!chain)
1647                 return NULL;
1648
1649         hlist_for_each_entry(pol, chain, bydst) {
1650                 if (pol->type == type &&
1651                     pol->if_id == if_id &&
1652                     xfrm_policy_mark_match(mark, pol) &&
1653                     !selector_cmp(sel, &pol->selector) &&
1654                     xfrm_sec_ctx_match(ctx, pol->security))
1655                         return pol;
1656         }
1657
1658         return NULL;
1659 }
1660
1661 struct xfrm_policy *
1662 xfrm_policy_bysel_ctx(struct net *net, const struct xfrm_mark *mark, u32 if_id,
1663                       u8 type, int dir, struct xfrm_selector *sel,
1664                       struct xfrm_sec_ctx *ctx, int delete, int *err)
1665 {
1666         struct xfrm_pol_inexact_bin *bin = NULL;
1667         struct xfrm_policy *pol, *ret = NULL;
1668         struct hlist_head *chain;
1669
1670         *err = 0;
1671         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1672         chain = policy_hash_bysel(net, sel, sel->family, dir);
1673         if (!chain) {
1674                 struct xfrm_pol_inexact_candidates cand;
1675                 int i;
1676
1677                 bin = xfrm_policy_inexact_lookup(net, type,
1678                                                  sel->family, dir, if_id);
1679                 if (!bin) {
1680                         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1681                         return NULL;
1682                 }
1683
1684                 if (!xfrm_policy_find_inexact_candidates(&cand, bin,
1685                                                          &sel->saddr,
1686                                                          &sel->daddr)) {
1687                         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1688                         return NULL;
1689                 }
1690
1691                 pol = NULL;
1692                 for (i = 0; i < ARRAY_SIZE(cand.res); i++) {
1693                         struct xfrm_policy *tmp;
1694
1695                         tmp = __xfrm_policy_bysel_ctx(cand.res[i], mark,
1696                                                       if_id, type, dir,
1697                                                       sel, ctx);
1698                         if (!tmp)
1699                                 continue;
1700
1701                         if (!pol || tmp->pos < pol->pos)
1702                                 pol = tmp;
1703                 }
1704         } else {
1705                 pol = __xfrm_policy_bysel_ctx(chain, mark, if_id, type, dir,
1706                                               sel, ctx);
1707         }
1708
1709         if (pol) {
1710                 xfrm_pol_hold(pol);
1711                 if (delete) {
1712                         *err = security_xfrm_policy_delete(pol->security);
1713                         if (*err) {
1714                                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1715                                 return pol;
1716                         }
1717                         __xfrm_policy_unlink(pol, dir);
1718                 }
1719                 ret = pol;
1720         }
1721         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1722
1723         if (ret && delete)
1724                 xfrm_policy_kill(ret);
1725         if (bin && delete)
1726                 xfrm_policy_inexact_prune_bin(bin);
1727         return ret;
1728 }
1729 EXPORT_SYMBOL(xfrm_policy_bysel_ctx);
1730
1731 struct xfrm_policy *
1732 xfrm_policy_byid(struct net *net, const struct xfrm_mark *mark, u32 if_id,
1733                  u8 type, int dir, u32 id, int delete, int *err)
1734 {
1735         struct xfrm_policy *pol, *ret;
1736         struct hlist_head *chain;
1737
1738         *err = -ENOENT;
1739         if (xfrm_policy_id2dir(id) != dir)
1740                 return NULL;
1741
1742         *err = 0;
1743         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1744         chain = net->xfrm.policy_byidx + idx_hash(net, id);
1745         ret = NULL;
1746         hlist_for_each_entry(pol, chain, byidx) {
1747                 if (pol->type == type && pol->index == id &&
1748                     pol->if_id == if_id && xfrm_policy_mark_match(mark, pol)) {
1749                         xfrm_pol_hold(pol);
1750                         if (delete) {
1751                                 *err = security_xfrm_policy_delete(
1752                                                                 pol->security);
1753                                 if (*err) {
1754                                         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1755                                         return pol;
1756                                 }
1757                                 __xfrm_policy_unlink(pol, dir);
1758                         }
1759                         ret = pol;
1760                         break;
1761                 }
1762         }
1763         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1764
1765         if (ret && delete)
1766                 xfrm_policy_kill(ret);
1767         return ret;
1768 }
1769 EXPORT_SYMBOL(xfrm_policy_byid);
1770
1771 #ifdef CONFIG_SECURITY_NETWORK_XFRM
1772 static inline int
1773 xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid)
1774 {
1775         struct xfrm_policy *pol;
1776         int err = 0;
1777
1778         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1779                 if (pol->walk.dead ||
1780                     xfrm_policy_id2dir(pol->index) >= XFRM_POLICY_MAX ||
1781                     pol->type != type)
1782                         continue;
1783
1784                 err = security_xfrm_policy_delete(pol->security);
1785                 if (err) {
1786                         xfrm_audit_policy_delete(pol, 0, task_valid);
1787                         return err;
1788                 }
1789         }
1790         return err;
1791 }
1792
1793 static inline int xfrm_dev_policy_flush_secctx_check(struct net *net,
1794                                                      struct net_device *dev,
1795                                                      bool task_valid)
1796 {
1797         struct xfrm_policy *pol;
1798         int err = 0;
1799
1800         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1801                 if (pol->walk.dead ||
1802                     xfrm_policy_id2dir(pol->index) >= XFRM_POLICY_MAX ||
1803                     pol->xdo.dev != dev)
1804                         continue;
1805
1806                 err = security_xfrm_policy_delete(pol->security);
1807                 if (err) {
1808                         xfrm_audit_policy_delete(pol, 0, task_valid);
1809                         return err;
1810                 }
1811         }
1812         return err;
1813 }
1814 #else
1815 static inline int
1816 xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid)
1817 {
1818         return 0;
1819 }
1820
1821 static inline int xfrm_dev_policy_flush_secctx_check(struct net *net,
1822                                                      struct net_device *dev,
1823                                                      bool task_valid)
1824 {
1825         return 0;
1826 }
1827 #endif
1828
1829 int xfrm_policy_flush(struct net *net, u8 type, bool task_valid)
1830 {
1831         int dir, err = 0, cnt = 0;
1832         struct xfrm_policy *pol;
1833
1834         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1835
1836         err = xfrm_policy_flush_secctx_check(net, type, task_valid);
1837         if (err)
1838                 goto out;
1839
1840 again:
1841         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1842                 if (pol->walk.dead)
1843                         continue;
1844
1845                 dir = xfrm_policy_id2dir(pol->index);
1846                 if (dir >= XFRM_POLICY_MAX ||
1847                     pol->type != type)
1848                         continue;
1849
1850                 __xfrm_policy_unlink(pol, dir);
1851                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1852                 xfrm_dev_policy_delete(pol);
1853                 cnt++;
1854                 xfrm_audit_policy_delete(pol, 1, task_valid);
1855                 xfrm_policy_kill(pol);
1856                 spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1857                 goto again;
1858         }
1859         if (cnt)
1860                 __xfrm_policy_inexact_flush(net);
1861         else
1862                 err = -ESRCH;
1863 out:
1864         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1865         return err;
1866 }
1867 EXPORT_SYMBOL(xfrm_policy_flush);
1868
1869 int xfrm_dev_policy_flush(struct net *net, struct net_device *dev,
1870                           bool task_valid)
1871 {
1872         int dir, err = 0, cnt = 0;
1873         struct xfrm_policy *pol;
1874
1875         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1876
1877         err = xfrm_dev_policy_flush_secctx_check(net, dev, task_valid);
1878         if (err)
1879                 goto out;
1880
1881 again:
1882         list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1883                 if (pol->walk.dead)
1884                         continue;
1885
1886                 dir = xfrm_policy_id2dir(pol->index);
1887                 if (dir >= XFRM_POLICY_MAX ||
1888                     pol->xdo.dev != dev)
1889                         continue;
1890
1891                 __xfrm_policy_unlink(pol, dir);
1892                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1893                 xfrm_dev_policy_delete(pol);
1894                 cnt++;
1895                 xfrm_audit_policy_delete(pol, 1, task_valid);
1896                 xfrm_policy_kill(pol);
1897                 spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1898                 goto again;
1899         }
1900         if (cnt)
1901                 __xfrm_policy_inexact_flush(net);
1902         else
1903                 err = -ESRCH;
1904 out:
1905         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1906         return err;
1907 }
1908 EXPORT_SYMBOL(xfrm_dev_policy_flush);
1909
1910 int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk,
1911                      int (*func)(struct xfrm_policy *, int, int, void*),
1912                      void *data)
1913 {
1914         struct xfrm_policy *pol;
1915         struct xfrm_policy_walk_entry *x;
1916         int error = 0;
1917
1918         if (walk->type >= XFRM_POLICY_TYPE_MAX &&
1919             walk->type != XFRM_POLICY_TYPE_ANY)
1920                 return -EINVAL;
1921
1922         if (list_empty(&walk->walk.all) && walk->seq != 0)
1923                 return 0;
1924
1925         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1926         if (list_empty(&walk->walk.all))
1927                 x = list_first_entry(&net->xfrm.policy_all, struct xfrm_policy_walk_entry, all);
1928         else
1929                 x = list_first_entry(&walk->walk.all,
1930                                      struct xfrm_policy_walk_entry, all);
1931
1932         list_for_each_entry_from(x, &net->xfrm.policy_all, all) {
1933                 if (x->dead)
1934                         continue;
1935                 pol = container_of(x, struct xfrm_policy, walk);
1936                 if (walk->type != XFRM_POLICY_TYPE_ANY &&
1937                     walk->type != pol->type)
1938                         continue;
1939                 error = func(pol, xfrm_policy_id2dir(pol->index),
1940                              walk->seq, data);
1941                 if (error) {
1942                         list_move_tail(&walk->walk.all, &x->all);
1943                         goto out;
1944                 }
1945                 walk->seq++;
1946         }
1947         if (walk->seq == 0) {
1948                 error = -ENOENT;
1949                 goto out;
1950         }
1951         list_del_init(&walk->walk.all);
1952 out:
1953         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1954         return error;
1955 }
1956 EXPORT_SYMBOL(xfrm_policy_walk);
1957
1958 void xfrm_policy_walk_init(struct xfrm_policy_walk *walk, u8 type)
1959 {
1960         INIT_LIST_HEAD(&walk->walk.all);
1961         walk->walk.dead = 1;
1962         walk->type = type;
1963         walk->seq = 0;
1964 }
1965 EXPORT_SYMBOL(xfrm_policy_walk_init);
1966
1967 void xfrm_policy_walk_done(struct xfrm_policy_walk *walk, struct net *net)
1968 {
1969         if (list_empty(&walk->walk.all))
1970                 return;
1971
1972         spin_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */
1973         list_del(&walk->walk.all);
1974         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1975 }
1976 EXPORT_SYMBOL(xfrm_policy_walk_done);
1977
1978 /*
1979  * Find policy to apply to this flow.
1980  *
1981  * Returns 0 if policy found, else an -errno.
1982  */
1983 static int xfrm_policy_match(const struct xfrm_policy *pol,
1984                              const struct flowi *fl,
1985                              u8 type, u16 family, u32 if_id)
1986 {
1987         const struct xfrm_selector *sel = &pol->selector;
1988         int ret = -ESRCH;
1989         bool match;
1990
1991         if (pol->family != family ||
1992             pol->if_id != if_id ||
1993             (fl->flowi_mark & pol->mark.m) != pol->mark.v ||
1994             pol->type != type)
1995                 return ret;
1996
1997         match = xfrm_selector_match(sel, fl, family);
1998         if (match)
1999                 ret = security_xfrm_policy_lookup(pol->security, fl->flowi_secid);
2000         return ret;
2001 }
2002
2003 static struct xfrm_pol_inexact_node *
2004 xfrm_policy_lookup_inexact_addr(const struct rb_root *r,
2005                                 seqcount_spinlock_t *count,
2006                                 const xfrm_address_t *addr, u16 family)
2007 {
2008         const struct rb_node *parent;
2009         int seq;
2010
2011 again:
2012         seq = read_seqcount_begin(count);
2013
2014         parent = rcu_dereference_raw(r->rb_node);
2015         while (parent) {
2016                 struct xfrm_pol_inexact_node *node;
2017                 int delta;
2018
2019                 node = rb_entry(parent, struct xfrm_pol_inexact_node, node);
2020
2021                 delta = xfrm_policy_addr_delta(addr, &node->addr,
2022                                                node->prefixlen, family);
2023                 if (delta < 0) {
2024                         parent = rcu_dereference_raw(parent->rb_left);
2025                         continue;
2026                 } else if (delta > 0) {
2027                         parent = rcu_dereference_raw(parent->rb_right);
2028                         continue;
2029                 }
2030
2031                 return node;
2032         }
2033
2034         if (read_seqcount_retry(count, seq))
2035                 goto again;
2036
2037         return NULL;
2038 }
2039
2040 static bool
2041 xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand,
2042                                     struct xfrm_pol_inexact_bin *b,
2043                                     const xfrm_address_t *saddr,
2044                                     const xfrm_address_t *daddr)
2045 {
2046         struct xfrm_pol_inexact_node *n;
2047         u16 family;
2048
2049         if (!b)
2050                 return false;
2051
2052         family = b->k.family;
2053         memset(cand, 0, sizeof(*cand));
2054         cand->res[XFRM_POL_CAND_ANY] = &b->hhead;
2055
2056         n = xfrm_policy_lookup_inexact_addr(&b->root_d, &b->count, daddr,
2057                                             family);
2058         if (n) {
2059                 cand->res[XFRM_POL_CAND_DADDR] = &n->hhead;
2060                 n = xfrm_policy_lookup_inexact_addr(&n->root, &b->count, saddr,
2061                                                     family);
2062                 if (n)
2063                         cand->res[XFRM_POL_CAND_BOTH] = &n->hhead;
2064         }
2065
2066         n = xfrm_policy_lookup_inexact_addr(&b->root_s, &b->count, saddr,
2067                                             family);
2068         if (n)
2069                 cand->res[XFRM_POL_CAND_SADDR] = &n->hhead;
2070
2071         return true;
2072 }
2073
2074 static struct xfrm_pol_inexact_bin *
2075 xfrm_policy_inexact_lookup_rcu(struct net *net, u8 type, u16 family,
2076                                u8 dir, u32 if_id)
2077 {
2078         struct xfrm_pol_inexact_key k = {
2079                 .family = family,
2080                 .type = type,
2081                 .dir = dir,
2082                 .if_id = if_id,
2083         };
2084
2085         write_pnet(&k.net, net);
2086
2087         return rhashtable_lookup(&xfrm_policy_inexact_table, &k,
2088                                  xfrm_pol_inexact_params);
2089 }
2090
2091 static struct xfrm_pol_inexact_bin *
2092 xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family,
2093                            u8 dir, u32 if_id)
2094 {
2095         struct xfrm_pol_inexact_bin *bin;
2096
2097         lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
2098
2099         rcu_read_lock();
2100         bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id);
2101         rcu_read_unlock();
2102
2103         return bin;
2104 }
2105
2106 static struct xfrm_policy *
2107 __xfrm_policy_eval_candidates(struct hlist_head *chain,
2108                               struct xfrm_policy *prefer,
2109                               const struct flowi *fl,
2110                               u8 type, u16 family, u32 if_id)
2111 {
2112         u32 priority = prefer ? prefer->priority : ~0u;
2113         struct xfrm_policy *pol;
2114
2115         if (!chain)
2116                 return NULL;
2117
2118         hlist_for_each_entry_rcu(pol, chain, bydst) {
2119                 int err;
2120
2121                 if (pol->priority > priority)
2122                         break;
2123
2124                 err = xfrm_policy_match(pol, fl, type, family, if_id);
2125                 if (err) {
2126                         if (err != -ESRCH)
2127                                 return ERR_PTR(err);
2128
2129                         continue;
2130                 }
2131
2132                 if (prefer) {
2133                         /* matches.  Is it older than *prefer? */
2134                         if (pol->priority == priority &&
2135                             prefer->pos < pol->pos)
2136                                 return prefer;
2137                 }
2138
2139                 return pol;
2140         }
2141
2142         return NULL;
2143 }
2144
2145 static struct xfrm_policy *
2146 xfrm_policy_eval_candidates(struct xfrm_pol_inexact_candidates *cand,
2147                             struct xfrm_policy *prefer,
2148                             const struct flowi *fl,
2149                             u8 type, u16 family, u32 if_id)
2150 {
2151         struct xfrm_policy *tmp;
2152         int i;
2153
2154         for (i = 0; i < ARRAY_SIZE(cand->res); i++) {
2155                 tmp = __xfrm_policy_eval_candidates(cand->res[i],
2156                                                     prefer,
2157                                                     fl, type, family, if_id);
2158                 if (!tmp)
2159                         continue;
2160
2161                 if (IS_ERR(tmp))
2162                         return tmp;
2163                 prefer = tmp;
2164         }
2165
2166         return prefer;
2167 }
2168
2169 static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type,
2170                                                      const struct flowi *fl,
2171                                                      u16 family, u8 dir,
2172                                                      u32 if_id)
2173 {
2174         struct xfrm_pol_inexact_candidates cand;
2175         const xfrm_address_t *daddr, *saddr;
2176         struct xfrm_pol_inexact_bin *bin;
2177         struct xfrm_policy *pol, *ret;
2178         struct hlist_head *chain;
2179         unsigned int sequence;
2180         int err;
2181
2182         daddr = xfrm_flowi_daddr(fl, family);
2183         saddr = xfrm_flowi_saddr(fl, family);
2184         if (unlikely(!daddr || !saddr))
2185                 return NULL;
2186
2187         rcu_read_lock();
2188  retry:
2189         do {
2190                 sequence = read_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation);
2191                 chain = policy_hash_direct(net, daddr, saddr, family, dir);
2192         } while (read_seqcount_retry(&net->xfrm.xfrm_policy_hash_generation, sequence));
2193
2194         ret = NULL;
2195         hlist_for_each_entry_rcu(pol, chain, bydst) {
2196                 err = xfrm_policy_match(pol, fl, type, family, if_id);
2197                 if (err) {
2198                         if (err == -ESRCH)
2199                                 continue;
2200                         else {
2201                                 ret = ERR_PTR(err);
2202                                 goto fail;
2203                         }
2204                 } else {
2205                         ret = pol;
2206                         break;
2207                 }
2208         }
2209         if (ret && ret->xdo.type == XFRM_DEV_OFFLOAD_PACKET)
2210                 goto skip_inexact;
2211
2212         bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id);
2213         if (!bin || !xfrm_policy_find_inexact_candidates(&cand, bin, saddr,
2214                                                          daddr))
2215                 goto skip_inexact;
2216
2217         pol = xfrm_policy_eval_candidates(&cand, ret, fl, type,
2218                                           family, if_id);
2219         if (pol) {
2220                 ret = pol;
2221                 if (IS_ERR(pol))
2222                         goto fail;
2223         }
2224
2225 skip_inexact:
2226         if (read_seqcount_retry(&net->xfrm.xfrm_policy_hash_generation, sequence))
2227                 goto retry;
2228
2229         if (ret && !xfrm_pol_hold_rcu(ret))
2230                 goto retry;
2231 fail:
2232         rcu_read_unlock();
2233
2234         return ret;
2235 }
2236
2237 static struct xfrm_policy *xfrm_policy_lookup(struct net *net,
2238                                               const struct flowi *fl,
2239                                               u16 family, u8 dir, u32 if_id)
2240 {
2241 #ifdef CONFIG_XFRM_SUB_POLICY
2242         struct xfrm_policy *pol;
2243
2244         pol = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_SUB, fl, family,
2245                                         dir, if_id);
2246         if (pol != NULL)
2247                 return pol;
2248 #endif
2249         return xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN, fl, family,
2250                                          dir, if_id);
2251 }
2252
2253 static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir,
2254                                                  const struct flowi *fl,
2255                                                  u16 family, u32 if_id)
2256 {
2257         struct xfrm_policy *pol;
2258
2259         rcu_read_lock();
2260  again:
2261         pol = rcu_dereference(sk->sk_policy[dir]);
2262         if (pol != NULL) {
2263                 bool match;
2264                 int err = 0;
2265
2266                 if (pol->family != family) {
2267                         pol = NULL;
2268                         goto out;
2269                 }
2270
2271                 match = xfrm_selector_match(&pol->selector, fl, family);
2272                 if (match) {
2273                         if ((READ_ONCE(sk->sk_mark) & pol->mark.m) != pol->mark.v ||
2274                             pol->if_id != if_id) {
2275                                 pol = NULL;
2276                                 goto out;
2277                         }
2278                         err = security_xfrm_policy_lookup(pol->security,
2279                                                       fl->flowi_secid);
2280                         if (!err) {
2281                                 if (!xfrm_pol_hold_rcu(pol))
2282                                         goto again;
2283                         } else if (err == -ESRCH) {
2284                                 pol = NULL;
2285                         } else {
2286                                 pol = ERR_PTR(err);
2287                         }
2288                 } else
2289                         pol = NULL;
2290         }
2291 out:
2292         rcu_read_unlock();
2293         return pol;
2294 }
2295
2296 static void __xfrm_policy_link(struct xfrm_policy *pol, int dir)
2297 {
2298         struct net *net = xp_net(pol);
2299
2300         list_add(&pol->walk.all, &net->xfrm.policy_all);
2301         net->xfrm.policy_count[dir]++;
2302         xfrm_pol_hold(pol);
2303 }
2304
2305 static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
2306                                                 int dir)
2307 {
2308         struct net *net = xp_net(pol);
2309
2310         if (list_empty(&pol->walk.all))
2311                 return NULL;
2312
2313         /* Socket policies are not hashed. */
2314         if (!hlist_unhashed(&pol->bydst)) {
2315                 hlist_del_rcu(&pol->bydst);
2316                 hlist_del_init(&pol->bydst_inexact_list);
2317                 hlist_del(&pol->byidx);
2318         }
2319
2320         list_del_init(&pol->walk.all);
2321         net->xfrm.policy_count[dir]--;
2322
2323         return pol;
2324 }
2325
2326 static void xfrm_sk_policy_link(struct xfrm_policy *pol, int dir)
2327 {
2328         __xfrm_policy_link(pol, XFRM_POLICY_MAX + dir);
2329 }
2330
2331 static void xfrm_sk_policy_unlink(struct xfrm_policy *pol, int dir)
2332 {
2333         __xfrm_policy_unlink(pol, XFRM_POLICY_MAX + dir);
2334 }
2335
2336 int xfrm_policy_delete(struct xfrm_policy *pol, int dir)
2337 {
2338         struct net *net = xp_net(pol);
2339
2340         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2341         pol = __xfrm_policy_unlink(pol, dir);
2342         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2343         if (pol) {
2344                 xfrm_dev_policy_delete(pol);
2345                 xfrm_policy_kill(pol);
2346                 return 0;
2347         }
2348         return -ENOENT;
2349 }
2350 EXPORT_SYMBOL(xfrm_policy_delete);
2351
2352 int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol)
2353 {
2354         struct net *net = sock_net(sk);
2355         struct xfrm_policy *old_pol;
2356
2357 #ifdef CONFIG_XFRM_SUB_POLICY
2358         if (pol && pol->type != XFRM_POLICY_TYPE_MAIN)
2359                 return -EINVAL;
2360 #endif
2361
2362         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2363         old_pol = rcu_dereference_protected(sk->sk_policy[dir],
2364                                 lockdep_is_held(&net->xfrm.xfrm_policy_lock));
2365         if (pol) {
2366                 pol->curlft.add_time = ktime_get_real_seconds();
2367                 pol->index = xfrm_gen_index(net, XFRM_POLICY_MAX+dir, 0);
2368                 xfrm_sk_policy_link(pol, dir);
2369         }
2370         rcu_assign_pointer(sk->sk_policy[dir], pol);
2371         if (old_pol) {
2372                 if (pol)
2373                         xfrm_policy_requeue(old_pol, pol);
2374
2375                 /* Unlinking succeeds always. This is the only function
2376                  * allowed to delete or replace socket policy.
2377                  */
2378                 xfrm_sk_policy_unlink(old_pol, dir);
2379         }
2380         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2381
2382         if (old_pol) {
2383                 xfrm_policy_kill(old_pol);
2384         }
2385         return 0;
2386 }
2387
2388 static struct xfrm_policy *clone_policy(const struct xfrm_policy *old, int dir)
2389 {
2390         struct xfrm_policy *newp = xfrm_policy_alloc(xp_net(old), GFP_ATOMIC);
2391         struct net *net = xp_net(old);
2392
2393         if (newp) {
2394                 newp->selector = old->selector;
2395                 if (security_xfrm_policy_clone(old->security,
2396                                                &newp->security)) {
2397                         kfree(newp);
2398                         return NULL;  /* ENOMEM */
2399                 }
2400                 newp->lft = old->lft;
2401                 newp->curlft = old->curlft;
2402                 newp->mark = old->mark;
2403                 newp->if_id = old->if_id;
2404                 newp->action = old->action;
2405                 newp->flags = old->flags;
2406                 newp->xfrm_nr = old->xfrm_nr;
2407                 newp->index = old->index;
2408                 newp->type = old->type;
2409                 newp->family = old->family;
2410                 memcpy(newp->xfrm_vec, old->xfrm_vec,
2411                        newp->xfrm_nr*sizeof(struct xfrm_tmpl));
2412                 spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2413                 xfrm_sk_policy_link(newp, dir);
2414                 spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2415                 xfrm_pol_put(newp);
2416         }
2417         return newp;
2418 }
2419
2420 int __xfrm_sk_clone_policy(struct sock *sk, const struct sock *osk)
2421 {
2422         const struct xfrm_policy *p;
2423         struct xfrm_policy *np;
2424         int i, ret = 0;
2425
2426         rcu_read_lock();
2427         for (i = 0; i < 2; i++) {
2428                 p = rcu_dereference(osk->sk_policy[i]);
2429                 if (p) {
2430                         np = clone_policy(p, i);
2431                         if (unlikely(!np)) {
2432                                 ret = -ENOMEM;
2433                                 break;
2434                         }
2435                         rcu_assign_pointer(sk->sk_policy[i], np);
2436                 }
2437         }
2438         rcu_read_unlock();
2439         return ret;
2440 }
2441
2442 static int
2443 xfrm_get_saddr(struct net *net, int oif, xfrm_address_t *local,
2444                xfrm_address_t *remote, unsigned short family, u32 mark)
2445 {
2446         int err;
2447         const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
2448
2449         if (unlikely(afinfo == NULL))
2450                 return -EINVAL;
2451         err = afinfo->get_saddr(net, oif, local, remote, mark);
2452         rcu_read_unlock();
2453         return err;
2454 }
2455
2456 /* Resolve list of templates for the flow, given policy. */
2457
2458 static int
2459 xfrm_tmpl_resolve_one(struct xfrm_policy *policy, const struct flowi *fl,
2460                       struct xfrm_state **xfrm, unsigned short family)
2461 {
2462         struct net *net = xp_net(policy);
2463         int nx;
2464         int i, error;
2465         xfrm_address_t *daddr = xfrm_flowi_daddr(fl, family);
2466         xfrm_address_t *saddr = xfrm_flowi_saddr(fl, family);
2467         xfrm_address_t tmp;
2468
2469         for (nx = 0, i = 0; i < policy->xfrm_nr; i++) {
2470                 struct xfrm_state *x;
2471                 xfrm_address_t *remote = daddr;
2472                 xfrm_address_t *local  = saddr;
2473                 struct xfrm_tmpl *tmpl = &policy->xfrm_vec[i];
2474
2475                 if (tmpl->mode == XFRM_MODE_TUNNEL ||
2476                     tmpl->mode == XFRM_MODE_BEET) {
2477                         remote = &tmpl->id.daddr;
2478                         local = &tmpl->saddr;
2479                         if (xfrm_addr_any(local, tmpl->encap_family)) {
2480                                 error = xfrm_get_saddr(net, fl->flowi_oif,
2481                                                        &tmp, remote,
2482                                                        tmpl->encap_family, 0);
2483                                 if (error)
2484                                         goto fail;
2485                                 local = &tmp;
2486                         }
2487                 }
2488
2489                 x = xfrm_state_find(remote, local, fl, tmpl, policy, &error,
2490                                     family, policy->if_id);
2491
2492                 if (x && x->km.state == XFRM_STATE_VALID) {
2493                         xfrm[nx++] = x;
2494                         daddr = remote;
2495                         saddr = local;
2496                         continue;
2497                 }
2498                 if (x) {
2499                         error = (x->km.state == XFRM_STATE_ERROR ?
2500                                  -EINVAL : -EAGAIN);
2501                         xfrm_state_put(x);
2502                 } else if (error == -ESRCH) {
2503                         error = -EAGAIN;
2504                 }
2505
2506                 if (!tmpl->optional)
2507                         goto fail;
2508         }
2509         return nx;
2510
2511 fail:
2512         for (nx--; nx >= 0; nx--)
2513                 xfrm_state_put(xfrm[nx]);
2514         return error;
2515 }
2516
2517 static int
2518 xfrm_tmpl_resolve(struct xfrm_policy **pols, int npols, const struct flowi *fl,
2519                   struct xfrm_state **xfrm, unsigned short family)
2520 {
2521         struct xfrm_state *tp[XFRM_MAX_DEPTH];
2522         struct xfrm_state **tpp = (npols > 1) ? tp : xfrm;
2523         int cnx = 0;
2524         int error;
2525         int ret;
2526         int i;
2527
2528         for (i = 0; i < npols; i++) {
2529                 if (cnx + pols[i]->xfrm_nr >= XFRM_MAX_DEPTH) {
2530                         error = -ENOBUFS;
2531                         goto fail;
2532                 }
2533
2534                 ret = xfrm_tmpl_resolve_one(pols[i], fl, &tpp[cnx], family);
2535                 if (ret < 0) {
2536                         error = ret;
2537                         goto fail;
2538                 } else
2539                         cnx += ret;
2540         }
2541
2542         /* found states are sorted for outbound processing */
2543         if (npols > 1)
2544                 xfrm_state_sort(xfrm, tpp, cnx, family);
2545
2546         return cnx;
2547
2548  fail:
2549         for (cnx--; cnx >= 0; cnx--)
2550                 xfrm_state_put(tpp[cnx]);
2551         return error;
2552
2553 }
2554
2555 static int xfrm_get_tos(const struct flowi *fl, int family)
2556 {
2557         if (family == AF_INET)
2558                 return IPTOS_RT_MASK & fl->u.ip4.flowi4_tos;
2559
2560         return 0;
2561 }
2562
2563 static inline struct xfrm_dst *xfrm_alloc_dst(struct net *net, int family)
2564 {
2565         const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
2566         struct dst_ops *dst_ops;
2567         struct xfrm_dst *xdst;
2568
2569         if (!afinfo)
2570                 return ERR_PTR(-EINVAL);
2571
2572         switch (family) {
2573         case AF_INET:
2574                 dst_ops = &net->xfrm.xfrm4_dst_ops;
2575                 break;
2576 #if IS_ENABLED(CONFIG_IPV6)
2577         case AF_INET6:
2578                 dst_ops = &net->xfrm.xfrm6_dst_ops;
2579                 break;
2580 #endif
2581         default:
2582                 BUG();
2583         }
2584         xdst = dst_alloc(dst_ops, NULL, DST_OBSOLETE_NONE, 0);
2585
2586         if (likely(xdst)) {
2587                 memset_after(xdst, 0, u.dst);
2588         } else
2589                 xdst = ERR_PTR(-ENOBUFS);
2590
2591         rcu_read_unlock();
2592
2593         return xdst;
2594 }
2595
2596 static void xfrm_init_path(struct xfrm_dst *path, struct dst_entry *dst,
2597                            int nfheader_len)
2598 {
2599         if (dst->ops->family == AF_INET6) {
2600                 struct rt6_info *rt = (struct rt6_info *)dst;
2601                 path->path_cookie = rt6_get_cookie(rt);
2602                 path->u.rt6.rt6i_nfheader_len = nfheader_len;
2603         }
2604 }
2605
2606 static inline int xfrm_fill_dst(struct xfrm_dst *xdst, struct net_device *dev,
2607                                 const struct flowi *fl)
2608 {
2609         const struct xfrm_policy_afinfo *afinfo =
2610                 xfrm_policy_get_afinfo(xdst->u.dst.ops->family);
2611         int err;
2612
2613         if (!afinfo)
2614                 return -EINVAL;
2615
2616         err = afinfo->fill_dst(xdst, dev, fl);
2617
2618         rcu_read_unlock();
2619
2620         return err;
2621 }
2622
2623
2624 /* Allocate chain of dst_entry's, attach known xfrm's, calculate
2625  * all the metrics... Shortly, bundle a bundle.
2626  */
2627
2628 static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy,
2629                                             struct xfrm_state **xfrm,
2630                                             struct xfrm_dst **bundle,
2631                                             int nx,
2632                                             const struct flowi *fl,
2633                                             struct dst_entry *dst)
2634 {
2635         const struct xfrm_state_afinfo *afinfo;
2636         const struct xfrm_mode *inner_mode;
2637         struct net *net = xp_net(policy);
2638         unsigned long now = jiffies;
2639         struct net_device *dev;
2640         struct xfrm_dst *xdst_prev = NULL;
2641         struct xfrm_dst *xdst0 = NULL;
2642         int i = 0;
2643         int err;
2644         int header_len = 0;
2645         int nfheader_len = 0;
2646         int trailer_len = 0;
2647         int tos;
2648         int family = policy->selector.family;
2649         xfrm_address_t saddr, daddr;
2650
2651         xfrm_flowi_addr_get(fl, &saddr, &daddr, family);
2652
2653         tos = xfrm_get_tos(fl, family);
2654
2655         dst_hold(dst);
2656
2657         for (; i < nx; i++) {
2658                 struct xfrm_dst *xdst = xfrm_alloc_dst(net, family);
2659                 struct dst_entry *dst1 = &xdst->u.dst;
2660
2661                 err = PTR_ERR(xdst);
2662                 if (IS_ERR(xdst)) {
2663                         dst_release(dst);
2664                         goto put_states;
2665                 }
2666
2667                 bundle[i] = xdst;
2668                 if (!xdst_prev)
2669                         xdst0 = xdst;
2670                 else
2671                         /* Ref count is taken during xfrm_alloc_dst()
2672                          * No need to do dst_clone() on dst1
2673                          */
2674                         xfrm_dst_set_child(xdst_prev, &xdst->u.dst);
2675
2676                 if (xfrm[i]->sel.family == AF_UNSPEC) {
2677                         inner_mode = xfrm_ip2inner_mode(xfrm[i],
2678                                                         xfrm_af2proto(family));
2679                         if (!inner_mode) {
2680                                 err = -EAFNOSUPPORT;
2681                                 dst_release(dst);
2682                                 goto put_states;
2683                         }
2684                 } else
2685                         inner_mode = &xfrm[i]->inner_mode;
2686
2687                 xdst->route = dst;
2688                 dst_copy_metrics(dst1, dst);
2689
2690                 if (xfrm[i]->props.mode != XFRM_MODE_TRANSPORT) {
2691                         __u32 mark = 0;
2692                         int oif;
2693
2694                         if (xfrm[i]->props.smark.v || xfrm[i]->props.smark.m)
2695                                 mark = xfrm_smark_get(fl->flowi_mark, xfrm[i]);
2696
2697                         family = xfrm[i]->props.family;
2698                         oif = fl->flowi_oif ? : fl->flowi_l3mdev;
2699                         dst = xfrm_dst_lookup(xfrm[i], tos, oif,
2700                                               &saddr, &daddr, family, mark);
2701                         err = PTR_ERR(dst);
2702                         if (IS_ERR(dst))
2703                                 goto put_states;
2704                 } else
2705                         dst_hold(dst);
2706
2707                 dst1->xfrm = xfrm[i];
2708                 xdst->xfrm_genid = xfrm[i]->genid;
2709
2710                 dst1->obsolete = DST_OBSOLETE_FORCE_CHK;
2711                 dst1->lastuse = now;
2712
2713                 dst1->input = dst_discard;
2714
2715                 rcu_read_lock();
2716                 afinfo = xfrm_state_afinfo_get_rcu(inner_mode->family);
2717                 if (likely(afinfo))
2718                         dst1->output = afinfo->output;
2719                 else
2720                         dst1->output = dst_discard_out;
2721                 rcu_read_unlock();
2722
2723                 xdst_prev = xdst;
2724
2725                 header_len += xfrm[i]->props.header_len;
2726                 if (xfrm[i]->type->flags & XFRM_TYPE_NON_FRAGMENT)
2727                         nfheader_len += xfrm[i]->props.header_len;
2728                 trailer_len += xfrm[i]->props.trailer_len;
2729         }
2730
2731         xfrm_dst_set_child(xdst_prev, dst);
2732         xdst0->path = dst;
2733
2734         err = -ENODEV;
2735         dev = dst->dev;
2736         if (!dev)
2737                 goto free_dst;
2738
2739         xfrm_init_path(xdst0, dst, nfheader_len);
2740         xfrm_init_pmtu(bundle, nx);
2741
2742         for (xdst_prev = xdst0; xdst_prev != (struct xfrm_dst *)dst;
2743              xdst_prev = (struct xfrm_dst *) xfrm_dst_child(&xdst_prev->u.dst)) {
2744                 err = xfrm_fill_dst(xdst_prev, dev, fl);
2745                 if (err)
2746                         goto free_dst;
2747
2748                 xdst_prev->u.dst.header_len = header_len;
2749                 xdst_prev->u.dst.trailer_len = trailer_len;
2750                 header_len -= xdst_prev->u.dst.xfrm->props.header_len;
2751                 trailer_len -= xdst_prev->u.dst.xfrm->props.trailer_len;
2752         }
2753
2754         return &xdst0->u.dst;
2755
2756 put_states:
2757         for (; i < nx; i++)
2758                 xfrm_state_put(xfrm[i]);
2759 free_dst:
2760         if (xdst0)
2761                 dst_release_immediate(&xdst0->u.dst);
2762
2763         return ERR_PTR(err);
2764 }
2765
2766 static int xfrm_expand_policies(const struct flowi *fl, u16 family,
2767                                 struct xfrm_policy **pols,
2768                                 int *num_pols, int *num_xfrms)
2769 {
2770         int i;
2771
2772         if (*num_pols == 0 || !pols[0]) {
2773                 *num_pols = 0;
2774                 *num_xfrms = 0;
2775                 return 0;
2776         }
2777         if (IS_ERR(pols[0])) {
2778                 *num_pols = 0;
2779                 return PTR_ERR(pols[0]);
2780         }
2781
2782         *num_xfrms = pols[0]->xfrm_nr;
2783
2784 #ifdef CONFIG_XFRM_SUB_POLICY
2785         if (pols[0]->action == XFRM_POLICY_ALLOW &&
2786             pols[0]->type != XFRM_POLICY_TYPE_MAIN) {
2787                 pols[1] = xfrm_policy_lookup_bytype(xp_net(pols[0]),
2788                                                     XFRM_POLICY_TYPE_MAIN,
2789                                                     fl, family,
2790                                                     XFRM_POLICY_OUT,
2791                                                     pols[0]->if_id);
2792                 if (pols[1]) {
2793                         if (IS_ERR(pols[1])) {
2794                                 xfrm_pols_put(pols, *num_pols);
2795                                 *num_pols = 0;
2796                                 return PTR_ERR(pols[1]);
2797                         }
2798                         (*num_pols)++;
2799                         (*num_xfrms) += pols[1]->xfrm_nr;
2800                 }
2801         }
2802 #endif
2803         for (i = 0; i < *num_pols; i++) {
2804                 if (pols[i]->action != XFRM_POLICY_ALLOW) {
2805                         *num_xfrms = -1;
2806                         break;
2807                 }
2808         }
2809
2810         return 0;
2811
2812 }
2813
2814 static struct xfrm_dst *
2815 xfrm_resolve_and_create_bundle(struct xfrm_policy **pols, int num_pols,
2816                                const struct flowi *fl, u16 family,
2817                                struct dst_entry *dst_orig)
2818 {
2819         struct net *net = xp_net(pols[0]);
2820         struct xfrm_state *xfrm[XFRM_MAX_DEPTH];
2821         struct xfrm_dst *bundle[XFRM_MAX_DEPTH];
2822         struct xfrm_dst *xdst;
2823         struct dst_entry *dst;
2824         int err;
2825
2826         /* Try to instantiate a bundle */
2827         err = xfrm_tmpl_resolve(pols, num_pols, fl, xfrm, family);
2828         if (err <= 0) {
2829                 if (err == 0)
2830                         return NULL;
2831
2832                 if (err != -EAGAIN)
2833                         XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR);
2834                 return ERR_PTR(err);
2835         }
2836
2837         dst = xfrm_bundle_create(pols[0], xfrm, bundle, err, fl, dst_orig);
2838         if (IS_ERR(dst)) {
2839                 XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTBUNDLEGENERROR);
2840                 return ERR_CAST(dst);
2841         }
2842
2843         xdst = (struct xfrm_dst *)dst;
2844         xdst->num_xfrms = err;
2845         xdst->num_pols = num_pols;
2846         memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols);
2847         xdst->policy_genid = atomic_read(&pols[0]->genid);
2848
2849         return xdst;
2850 }
2851
2852 static void xfrm_policy_queue_process(struct timer_list *t)
2853 {
2854         struct sk_buff *skb;
2855         struct sock *sk;
2856         struct dst_entry *dst;
2857         struct xfrm_policy *pol = from_timer(pol, t, polq.hold_timer);
2858         struct net *net = xp_net(pol);
2859         struct xfrm_policy_queue *pq = &pol->polq;
2860         struct flowi fl;
2861         struct sk_buff_head list;
2862         __u32 skb_mark;
2863
2864         spin_lock(&pq->hold_queue.lock);
2865         skb = skb_peek(&pq->hold_queue);
2866         if (!skb) {
2867                 spin_unlock(&pq->hold_queue.lock);
2868                 goto out;
2869         }
2870         dst = skb_dst(skb);
2871         sk = skb->sk;
2872
2873         /* Fixup the mark to support VTI. */
2874         skb_mark = skb->mark;
2875         skb->mark = pol->mark.v;
2876         xfrm_decode_session(net, skb, &fl, dst->ops->family);
2877         skb->mark = skb_mark;
2878         spin_unlock(&pq->hold_queue.lock);
2879
2880         dst_hold(xfrm_dst_path(dst));
2881         dst = xfrm_lookup(net, xfrm_dst_path(dst), &fl, sk, XFRM_LOOKUP_QUEUE);
2882         if (IS_ERR(dst))
2883                 goto purge_queue;
2884
2885         if (dst->flags & DST_XFRM_QUEUE) {
2886                 dst_release(dst);
2887
2888                 if (pq->timeout >= XFRM_QUEUE_TMO_MAX)
2889                         goto purge_queue;
2890
2891                 pq->timeout = pq->timeout << 1;
2892                 if (!mod_timer(&pq->hold_timer, jiffies + pq->timeout))
2893                         xfrm_pol_hold(pol);
2894                 goto out;
2895         }
2896
2897         dst_release(dst);
2898
2899         __skb_queue_head_init(&list);
2900
2901         spin_lock(&pq->hold_queue.lock);
2902         pq->timeout = 0;
2903         skb_queue_splice_init(&pq->hold_queue, &list);
2904         spin_unlock(&pq->hold_queue.lock);
2905
2906         while (!skb_queue_empty(&list)) {
2907                 skb = __skb_dequeue(&list);
2908
2909                 /* Fixup the mark to support VTI. */
2910                 skb_mark = skb->mark;
2911                 skb->mark = pol->mark.v;
2912                 xfrm_decode_session(net, skb, &fl, skb_dst(skb)->ops->family);
2913                 skb->mark = skb_mark;
2914
2915                 dst_hold(xfrm_dst_path(skb_dst(skb)));
2916                 dst = xfrm_lookup(net, xfrm_dst_path(skb_dst(skb)), &fl, skb->sk, 0);
2917                 if (IS_ERR(dst)) {
2918                         kfree_skb(skb);
2919                         continue;
2920                 }
2921
2922                 nf_reset_ct(skb);
2923                 skb_dst_drop(skb);
2924                 skb_dst_set(skb, dst);
2925
2926                 dst_output(net, skb->sk, skb);
2927         }
2928
2929 out:
2930         xfrm_pol_put(pol);
2931         return;
2932
2933 purge_queue:
2934         pq->timeout = 0;
2935         skb_queue_purge(&pq->hold_queue);
2936         xfrm_pol_put(pol);
2937 }
2938
2939 static int xdst_queue_output(struct net *net, struct sock *sk, struct sk_buff *skb)
2940 {
2941         unsigned long sched_next;
2942         struct dst_entry *dst = skb_dst(skb);
2943         struct xfrm_dst *xdst = (struct xfrm_dst *) dst;
2944         struct xfrm_policy *pol = xdst->pols[0];
2945         struct xfrm_policy_queue *pq = &pol->polq;
2946
2947         if (unlikely(skb_fclone_busy(sk, skb))) {
2948                 kfree_skb(skb);
2949                 return 0;
2950         }
2951
2952         if (pq->hold_queue.qlen > XFRM_MAX_QUEUE_LEN) {
2953                 kfree_skb(skb);
2954                 return -EAGAIN;
2955         }
2956
2957         skb_dst_force(skb);
2958
2959         spin_lock_bh(&pq->hold_queue.lock);
2960
2961         if (!pq->timeout)
2962                 pq->timeout = XFRM_QUEUE_TMO_MIN;
2963
2964         sched_next = jiffies + pq->timeout;
2965
2966         if (del_timer(&pq->hold_timer)) {
2967                 if (time_before(pq->hold_timer.expires, sched_next))
2968                         sched_next = pq->hold_timer.expires;
2969                 xfrm_pol_put(pol);
2970         }
2971
2972         __skb_queue_tail(&pq->hold_queue, skb);
2973         if (!mod_timer(&pq->hold_timer, sched_next))
2974                 xfrm_pol_hold(pol);
2975
2976         spin_unlock_bh(&pq->hold_queue.lock);
2977
2978         return 0;
2979 }
2980
2981 static struct xfrm_dst *xfrm_create_dummy_bundle(struct net *net,
2982                                                  struct xfrm_flo *xflo,
2983                                                  const struct flowi *fl,
2984                                                  int num_xfrms,
2985                                                  u16 family)
2986 {
2987         int err;
2988         struct net_device *dev;
2989         struct dst_entry *dst;
2990         struct dst_entry *dst1;
2991         struct xfrm_dst *xdst;
2992
2993         xdst = xfrm_alloc_dst(net, family);
2994         if (IS_ERR(xdst))
2995                 return xdst;
2996
2997         if (!(xflo->flags & XFRM_LOOKUP_QUEUE) ||
2998             net->xfrm.sysctl_larval_drop ||
2999             num_xfrms <= 0)
3000                 return xdst;
3001
3002         dst = xflo->dst_orig;
3003         dst1 = &xdst->u.dst;
3004         dst_hold(dst);
3005         xdst->route = dst;
3006
3007         dst_copy_metrics(dst1, dst);
3008
3009         dst1->obsolete = DST_OBSOLETE_FORCE_CHK;
3010         dst1->flags |= DST_XFRM_QUEUE;
3011         dst1->lastuse = jiffies;
3012
3013         dst1->input = dst_discard;
3014         dst1->output = xdst_queue_output;
3015
3016         dst_hold(dst);
3017         xfrm_dst_set_child(xdst, dst);
3018         xdst->path = dst;
3019
3020         xfrm_init_path((struct xfrm_dst *)dst1, dst, 0);
3021
3022         err = -ENODEV;
3023         dev = dst->dev;
3024         if (!dev)
3025                 goto free_dst;
3026
3027         err = xfrm_fill_dst(xdst, dev, fl);
3028         if (err)
3029                 goto free_dst;
3030
3031 out:
3032         return xdst;
3033
3034 free_dst:
3035         dst_release(dst1);
3036         xdst = ERR_PTR(err);
3037         goto out;
3038 }
3039
3040 static struct xfrm_dst *xfrm_bundle_lookup(struct net *net,
3041                                            const struct flowi *fl,
3042                                            u16 family, u8 dir,
3043                                            struct xfrm_flo *xflo, u32 if_id)
3044 {
3045         struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3046         int num_pols = 0, num_xfrms = 0, err;
3047         struct xfrm_dst *xdst;
3048
3049         /* Resolve policies to use if we couldn't get them from
3050          * previous cache entry */
3051         num_pols = 1;
3052         pols[0] = xfrm_policy_lookup(net, fl, family, dir, if_id);
3053         err = xfrm_expand_policies(fl, family, pols,
3054                                            &num_pols, &num_xfrms);
3055         if (err < 0)
3056                 goto inc_error;
3057         if (num_pols == 0)
3058                 return NULL;
3059         if (num_xfrms <= 0)
3060                 goto make_dummy_bundle;
3061
3062         xdst = xfrm_resolve_and_create_bundle(pols, num_pols, fl, family,
3063                                               xflo->dst_orig);
3064         if (IS_ERR(xdst)) {
3065                 err = PTR_ERR(xdst);
3066                 if (err == -EREMOTE) {
3067                         xfrm_pols_put(pols, num_pols);
3068                         return NULL;
3069                 }
3070
3071                 if (err != -EAGAIN)
3072                         goto error;
3073                 goto make_dummy_bundle;
3074         } else if (xdst == NULL) {
3075                 num_xfrms = 0;
3076                 goto make_dummy_bundle;
3077         }
3078
3079         return xdst;
3080
3081 make_dummy_bundle:
3082         /* We found policies, but there's no bundles to instantiate:
3083          * either because the policy blocks, has no transformations or
3084          * we could not build template (no xfrm_states).*/
3085         xdst = xfrm_create_dummy_bundle(net, xflo, fl, num_xfrms, family);
3086         if (IS_ERR(xdst)) {
3087                 xfrm_pols_put(pols, num_pols);
3088                 return ERR_CAST(xdst);
3089         }
3090         xdst->num_pols = num_pols;
3091         xdst->num_xfrms = num_xfrms;
3092         memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols);
3093
3094         return xdst;
3095
3096 inc_error:
3097         XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR);
3098 error:
3099         xfrm_pols_put(pols, num_pols);
3100         return ERR_PTR(err);
3101 }
3102
3103 static struct dst_entry *make_blackhole(struct net *net, u16 family,
3104                                         struct dst_entry *dst_orig)
3105 {
3106         const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
3107         struct dst_entry *ret;
3108
3109         if (!afinfo) {
3110                 dst_release(dst_orig);
3111                 return ERR_PTR(-EINVAL);
3112         } else {
3113                 ret = afinfo->blackhole_route(net, dst_orig);
3114         }
3115         rcu_read_unlock();
3116
3117         return ret;
3118 }
3119
3120 /* Finds/creates a bundle for given flow and if_id
3121  *
3122  * At the moment we eat a raw IP route. Mostly to speed up lookups
3123  * on interfaces with disabled IPsec.
3124  *
3125  * xfrm_lookup uses an if_id of 0 by default, and is provided for
3126  * compatibility
3127  */
3128 struct dst_entry *xfrm_lookup_with_ifid(struct net *net,
3129                                         struct dst_entry *dst_orig,
3130                                         const struct flowi *fl,
3131                                         const struct sock *sk,
3132                                         int flags, u32 if_id)
3133 {
3134         struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3135         struct xfrm_dst *xdst;
3136         struct dst_entry *dst, *route;
3137         u16 family = dst_orig->ops->family;
3138         u8 dir = XFRM_POLICY_OUT;
3139         int i, err, num_pols, num_xfrms = 0, drop_pols = 0;
3140
3141         dst = NULL;
3142         xdst = NULL;
3143         route = NULL;
3144
3145         sk = sk_const_to_full_sk(sk);
3146         if (sk && sk->sk_policy[XFRM_POLICY_OUT]) {
3147                 num_pols = 1;
3148                 pols[0] = xfrm_sk_policy_lookup(sk, XFRM_POLICY_OUT, fl, family,
3149                                                 if_id);
3150                 err = xfrm_expand_policies(fl, family, pols,
3151                                            &num_pols, &num_xfrms);
3152                 if (err < 0)
3153                         goto dropdst;
3154
3155                 if (num_pols) {
3156                         if (num_xfrms <= 0) {
3157                                 drop_pols = num_pols;
3158                                 goto no_transform;
3159                         }
3160
3161                         xdst = xfrm_resolve_and_create_bundle(
3162                                         pols, num_pols, fl,
3163                                         family, dst_orig);
3164
3165                         if (IS_ERR(xdst)) {
3166                                 xfrm_pols_put(pols, num_pols);
3167                                 err = PTR_ERR(xdst);
3168                                 if (err == -EREMOTE)
3169                                         goto nopol;
3170
3171                                 goto dropdst;
3172                         } else if (xdst == NULL) {
3173                                 num_xfrms = 0;
3174                                 drop_pols = num_pols;
3175                                 goto no_transform;
3176                         }
3177
3178                         route = xdst->route;
3179                 }
3180         }
3181
3182         if (xdst == NULL) {
3183                 struct xfrm_flo xflo;
3184
3185                 xflo.dst_orig = dst_orig;
3186                 xflo.flags = flags;
3187
3188                 /* To accelerate a bit...  */
3189                 if (!if_id && ((dst_orig->flags & DST_NOXFRM) ||
3190                                !net->xfrm.policy_count[XFRM_POLICY_OUT]))
3191                         goto nopol;
3192
3193                 xdst = xfrm_bundle_lookup(net, fl, family, dir, &xflo, if_id);
3194                 if (xdst == NULL)
3195                         goto nopol;
3196                 if (IS_ERR(xdst)) {
3197                         err = PTR_ERR(xdst);
3198                         goto dropdst;
3199                 }
3200
3201                 num_pols = xdst->num_pols;
3202                 num_xfrms = xdst->num_xfrms;
3203                 memcpy(pols, xdst->pols, sizeof(struct xfrm_policy *) * num_pols);
3204                 route = xdst->route;
3205         }
3206
3207         dst = &xdst->u.dst;
3208         if (route == NULL && num_xfrms > 0) {
3209                 /* The only case when xfrm_bundle_lookup() returns a
3210                  * bundle with null route, is when the template could
3211                  * not be resolved. It means policies are there, but
3212                  * bundle could not be created, since we don't yet
3213                  * have the xfrm_state's. We need to wait for KM to
3214                  * negotiate new SA's or bail out with error.*/
3215                 if (net->xfrm.sysctl_larval_drop) {
3216                         XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES);
3217                         err = -EREMOTE;
3218                         goto error;
3219                 }
3220
3221                 err = -EAGAIN;
3222
3223                 XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES);
3224                 goto error;
3225         }
3226
3227 no_transform:
3228         if (num_pols == 0)
3229                 goto nopol;
3230
3231         if ((flags & XFRM_LOOKUP_ICMP) &&
3232             !(pols[0]->flags & XFRM_POLICY_ICMP)) {
3233                 err = -ENOENT;
3234                 goto error;
3235         }
3236
3237         for (i = 0; i < num_pols; i++)
3238                 WRITE_ONCE(pols[i]->curlft.use_time, ktime_get_real_seconds());
3239
3240         if (num_xfrms < 0) {
3241                 /* Prohibit the flow */
3242                 XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLBLOCK);
3243                 err = -EPERM;
3244                 goto error;
3245         } else if (num_xfrms > 0) {
3246                 /* Flow transformed */
3247                 dst_release(dst_orig);
3248         } else {
3249                 /* Flow passes untransformed */
3250                 dst_release(dst);
3251                 dst = dst_orig;
3252         }
3253 ok:
3254         xfrm_pols_put(pols, drop_pols);
3255         if (dst && dst->xfrm &&
3256             dst->xfrm->props.mode == XFRM_MODE_TUNNEL)
3257                 dst->flags |= DST_XFRM_TUNNEL;
3258         return dst;
3259
3260 nopol:
3261         if ((!dst_orig->dev || !(dst_orig->dev->flags & IFF_LOOPBACK)) &&
3262             net->xfrm.policy_default[dir] == XFRM_USERPOLICY_BLOCK) {
3263                 err = -EPERM;
3264                 goto error;
3265         }
3266         if (!(flags & XFRM_LOOKUP_ICMP)) {
3267                 dst = dst_orig;
3268                 goto ok;
3269         }
3270         err = -ENOENT;
3271 error:
3272         dst_release(dst);
3273 dropdst:
3274         if (!(flags & XFRM_LOOKUP_KEEP_DST_REF))
3275                 dst_release(dst_orig);
3276         xfrm_pols_put(pols, drop_pols);
3277         return ERR_PTR(err);
3278 }
3279 EXPORT_SYMBOL(xfrm_lookup_with_ifid);
3280
3281 /* Main function: finds/creates a bundle for given flow.
3282  *
3283  * At the moment we eat a raw IP route. Mostly to speed up lookups
3284  * on interfaces with disabled IPsec.
3285  */
3286 struct dst_entry *xfrm_lookup(struct net *net, struct dst_entry *dst_orig,
3287                               const struct flowi *fl, const struct sock *sk,
3288                               int flags)
3289 {
3290         return xfrm_lookup_with_ifid(net, dst_orig, fl, sk, flags, 0);
3291 }
3292 EXPORT_SYMBOL(xfrm_lookup);
3293
3294 /* Callers of xfrm_lookup_route() must ensure a call to dst_output().
3295  * Otherwise we may send out blackholed packets.
3296  */
3297 struct dst_entry *xfrm_lookup_route(struct net *net, struct dst_entry *dst_orig,
3298                                     const struct flowi *fl,
3299                                     const struct sock *sk, int flags)
3300 {
3301         struct dst_entry *dst = xfrm_lookup(net, dst_orig, fl, sk,
3302                                             flags | XFRM_LOOKUP_QUEUE |
3303                                             XFRM_LOOKUP_KEEP_DST_REF);
3304
3305         if (PTR_ERR(dst) == -EREMOTE)
3306                 return make_blackhole(net, dst_orig->ops->family, dst_orig);
3307
3308         if (IS_ERR(dst))
3309                 dst_release(dst_orig);
3310
3311         return dst;
3312 }
3313 EXPORT_SYMBOL(xfrm_lookup_route);
3314
3315 static inline int
3316 xfrm_secpath_reject(int idx, struct sk_buff *skb, const struct flowi *fl)
3317 {
3318         struct sec_path *sp = skb_sec_path(skb);
3319         struct xfrm_state *x;
3320
3321         if (!sp || idx < 0 || idx >= sp->len)
3322                 return 0;
3323         x = sp->xvec[idx];
3324         if (!x->type->reject)
3325                 return 0;
3326         return x->type->reject(x, skb, fl);
3327 }
3328
3329 /* When skb is transformed back to its "native" form, we have to
3330  * check policy restrictions. At the moment we make this in maximally
3331  * stupid way. Shame on me. :-) Of course, connected sockets must
3332  * have policy cached at them.
3333  */
3334
3335 static inline int
3336 xfrm_state_ok(const struct xfrm_tmpl *tmpl, const struct xfrm_state *x,
3337               unsigned short family, u32 if_id)
3338 {
3339         if (xfrm_state_kern(x))
3340                 return tmpl->optional && !xfrm_state_addr_cmp(tmpl, x, tmpl->encap_family);
3341         return  x->id.proto == tmpl->id.proto &&
3342                 (x->id.spi == tmpl->id.spi || !tmpl->id.spi) &&
3343                 (x->props.reqid == tmpl->reqid || !tmpl->reqid) &&
3344                 x->props.mode == tmpl->mode &&
3345                 (tmpl->allalgs || (tmpl->aalgos & (1<<x->props.aalgo)) ||
3346                  !(xfrm_id_proto_match(tmpl->id.proto, IPSEC_PROTO_ANY))) &&
3347                 !(x->props.mode != XFRM_MODE_TRANSPORT &&
3348                   xfrm_state_addr_cmp(tmpl, x, family)) &&
3349                 (if_id == 0 || if_id == x->if_id);
3350 }
3351
3352 /*
3353  * 0 or more than 0 is returned when validation is succeeded (either bypass
3354  * because of optional transport mode, or next index of the matched secpath
3355  * state with the template.
3356  * -1 is returned when no matching template is found.
3357  * Otherwise "-2 - errored_index" is returned.
3358  */
3359 static inline int
3360 xfrm_policy_ok(const struct xfrm_tmpl *tmpl, const struct sec_path *sp, int start,
3361                unsigned short family, u32 if_id)
3362 {
3363         int idx = start;
3364
3365         if (tmpl->optional) {
3366                 if (tmpl->mode == XFRM_MODE_TRANSPORT)
3367                         return start;
3368         } else
3369                 start = -1;
3370         for (; idx < sp->len; idx++) {
3371                 if (xfrm_state_ok(tmpl, sp->xvec[idx], family, if_id))
3372                         return ++idx;
3373                 if (sp->xvec[idx]->props.mode != XFRM_MODE_TRANSPORT) {
3374                         if (idx < sp->verified_cnt) {
3375                                 /* Secpath entry previously verified, consider optional and
3376                                  * continue searching
3377                                  */
3378                                 continue;
3379                         }
3380
3381                         if (start == -1)
3382                                 start = -2-idx;
3383                         break;
3384                 }
3385         }
3386         return start;
3387 }
3388
3389 static void
3390 decode_session4(const struct xfrm_flow_keys *flkeys, struct flowi *fl, bool reverse)
3391 {
3392         struct flowi4 *fl4 = &fl->u.ip4;
3393
3394         memset(fl4, 0, sizeof(struct flowi4));
3395
3396         if (reverse) {
3397                 fl4->saddr = flkeys->addrs.ipv4.dst;
3398                 fl4->daddr = flkeys->addrs.ipv4.src;
3399                 fl4->fl4_sport = flkeys->ports.dst;
3400                 fl4->fl4_dport = flkeys->ports.src;
3401         } else {
3402                 fl4->saddr = flkeys->addrs.ipv4.src;
3403                 fl4->daddr = flkeys->addrs.ipv4.dst;
3404                 fl4->fl4_sport = flkeys->ports.src;
3405                 fl4->fl4_dport = flkeys->ports.dst;
3406         }
3407
3408         switch (flkeys->basic.ip_proto) {
3409         case IPPROTO_GRE:
3410                 fl4->fl4_gre_key = flkeys->gre.keyid;
3411                 break;
3412         case IPPROTO_ICMP:
3413                 fl4->fl4_icmp_type = flkeys->icmp.type;
3414                 fl4->fl4_icmp_code = flkeys->icmp.code;
3415                 break;
3416         }
3417
3418         fl4->flowi4_proto = flkeys->basic.ip_proto;
3419         fl4->flowi4_tos = flkeys->ip.tos;
3420 }
3421
3422 #if IS_ENABLED(CONFIG_IPV6)
3423 static void
3424 decode_session6(const struct xfrm_flow_keys *flkeys, struct flowi *fl, bool reverse)
3425 {
3426         struct flowi6 *fl6 = &fl->u.ip6;
3427
3428         memset(fl6, 0, sizeof(struct flowi6));
3429
3430         if (reverse) {
3431                 fl6->saddr = flkeys->addrs.ipv6.dst;
3432                 fl6->daddr = flkeys->addrs.ipv6.src;
3433                 fl6->fl6_sport = flkeys->ports.dst;
3434                 fl6->fl6_dport = flkeys->ports.src;
3435         } else {
3436                 fl6->saddr = flkeys->addrs.ipv6.src;
3437                 fl6->daddr = flkeys->addrs.ipv6.dst;
3438                 fl6->fl6_sport = flkeys->ports.src;
3439                 fl6->fl6_dport = flkeys->ports.dst;
3440         }
3441
3442         switch (flkeys->basic.ip_proto) {
3443         case IPPROTO_GRE:
3444                 fl6->fl6_gre_key = flkeys->gre.keyid;
3445                 break;
3446         case IPPROTO_ICMPV6:
3447                 fl6->fl6_icmp_type = flkeys->icmp.type;
3448                 fl6->fl6_icmp_code = flkeys->icmp.code;
3449                 break;
3450         }
3451
3452         fl6->flowi6_proto = flkeys->basic.ip_proto;
3453 }
3454 #endif
3455
3456 int __xfrm_decode_session(struct net *net, struct sk_buff *skb, struct flowi *fl,
3457                           unsigned int family, int reverse)
3458 {
3459         struct xfrm_flow_keys flkeys;
3460
3461         memset(&flkeys, 0, sizeof(flkeys));
3462         __skb_flow_dissect(net, skb, &xfrm_session_dissector, &flkeys,
3463                            NULL, 0, 0, 0, FLOW_DISSECTOR_F_STOP_AT_ENCAP);
3464
3465         switch (family) {
3466         case AF_INET:
3467                 decode_session4(&flkeys, fl, reverse);
3468                 break;
3469 #if IS_ENABLED(CONFIG_IPV6)
3470         case AF_INET6:
3471                 decode_session6(&flkeys, fl, reverse);
3472                 break;
3473 #endif
3474         default:
3475                 return -EAFNOSUPPORT;
3476         }
3477
3478         fl->flowi_mark = skb->mark;
3479         if (reverse) {
3480                 fl->flowi_oif = skb->skb_iif;
3481         } else {
3482                 int oif = 0;
3483
3484                 if (skb_dst(skb) && skb_dst(skb)->dev)
3485                         oif = skb_dst(skb)->dev->ifindex;
3486
3487                 fl->flowi_oif = oif;
3488         }
3489
3490         return security_xfrm_decode_session(skb, &fl->flowi_secid);
3491 }
3492 EXPORT_SYMBOL(__xfrm_decode_session);
3493
3494 static inline int secpath_has_nontransport(const struct sec_path *sp, int k, int *idxp)
3495 {
3496         for (; k < sp->len; k++) {
3497                 if (sp->xvec[k]->props.mode != XFRM_MODE_TRANSPORT) {
3498                         *idxp = k;
3499                         return 1;
3500                 }
3501         }
3502
3503         return 0;
3504 }
3505
3506 int __xfrm_policy_check(struct sock *sk, int dir, struct sk_buff *skb,
3507                         unsigned short family)
3508 {
3509         struct net *net = dev_net(skb->dev);
3510         struct xfrm_policy *pol;
3511         struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3512         int npols = 0;
3513         int xfrm_nr;
3514         int pi;
3515         int reverse;
3516         struct flowi fl;
3517         int xerr_idx = -1;
3518         const struct xfrm_if_cb *ifcb;
3519         struct sec_path *sp;
3520         u32 if_id = 0;
3521
3522         rcu_read_lock();
3523         ifcb = xfrm_if_get_cb();
3524
3525         if (ifcb) {
3526                 struct xfrm_if_decode_session_result r;
3527
3528                 if (ifcb->decode_session(skb, family, &r)) {
3529                         if_id = r.if_id;
3530                         net = r.net;
3531                 }
3532         }
3533         rcu_read_unlock();
3534
3535         reverse = dir & ~XFRM_POLICY_MASK;
3536         dir &= XFRM_POLICY_MASK;
3537
3538         if (__xfrm_decode_session(net, skb, &fl, family, reverse) < 0) {
3539                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR);
3540                 return 0;
3541         }
3542
3543         nf_nat_decode_session(skb, &fl, family);
3544
3545         /* First, check used SA against their selectors. */
3546         sp = skb_sec_path(skb);
3547         if (sp) {
3548                 int i;
3549
3550                 for (i = sp->len - 1; i >= 0; i--) {
3551                         struct xfrm_state *x = sp->xvec[i];
3552                         if (!xfrm_selector_match(&x->sel, &fl, family)) {
3553                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMISMATCH);
3554                                 return 0;
3555                         }
3556                 }
3557         }
3558
3559         pol = NULL;
3560         sk = sk_to_full_sk(sk);
3561         if (sk && sk->sk_policy[dir]) {
3562                 pol = xfrm_sk_policy_lookup(sk, dir, &fl, family, if_id);
3563                 if (IS_ERR(pol)) {
3564                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3565                         return 0;
3566                 }
3567         }
3568
3569         if (!pol)
3570                 pol = xfrm_policy_lookup(net, &fl, family, dir, if_id);
3571
3572         if (IS_ERR(pol)) {
3573                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3574                 return 0;
3575         }
3576
3577         if (!pol) {
3578                 if (net->xfrm.policy_default[dir] == XFRM_USERPOLICY_BLOCK) {
3579                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOPOLS);
3580                         return 0;
3581                 }
3582
3583                 if (sp && secpath_has_nontransport(sp, 0, &xerr_idx)) {
3584                         xfrm_secpath_reject(xerr_idx, skb, &fl);
3585                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOPOLS);
3586                         return 0;
3587                 }
3588                 return 1;
3589         }
3590
3591         /* This lockless write can happen from different cpus. */
3592         WRITE_ONCE(pol->curlft.use_time, ktime_get_real_seconds());
3593
3594         pols[0] = pol;
3595         npols++;
3596 #ifdef CONFIG_XFRM_SUB_POLICY
3597         if (pols[0]->type != XFRM_POLICY_TYPE_MAIN) {
3598                 pols[1] = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN,
3599                                                     &fl, family,
3600                                                     XFRM_POLICY_IN, if_id);
3601                 if (pols[1]) {
3602                         if (IS_ERR(pols[1])) {
3603                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3604                                 xfrm_pol_put(pols[0]);
3605                                 return 0;
3606                         }
3607                         /* This write can happen from different cpus. */
3608                         WRITE_ONCE(pols[1]->curlft.use_time,
3609                                    ktime_get_real_seconds());
3610                         npols++;
3611                 }
3612         }
3613 #endif
3614
3615         if (pol->action == XFRM_POLICY_ALLOW) {
3616                 static struct sec_path dummy;
3617                 struct xfrm_tmpl *tp[XFRM_MAX_DEPTH];
3618                 struct xfrm_tmpl *stp[XFRM_MAX_DEPTH];
3619                 struct xfrm_tmpl **tpp = tp;
3620                 int ti = 0;
3621                 int i, k;
3622
3623                 sp = skb_sec_path(skb);
3624                 if (!sp)
3625                         sp = &dummy;
3626
3627                 for (pi = 0; pi < npols; pi++) {
3628                         if (pols[pi] != pol &&
3629                             pols[pi]->action != XFRM_POLICY_ALLOW) {
3630                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK);
3631                                 goto reject;
3632                         }
3633                         if (ti + pols[pi]->xfrm_nr >= XFRM_MAX_DEPTH) {
3634                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINBUFFERERROR);
3635                                 goto reject_error;
3636                         }
3637                         for (i = 0; i < pols[pi]->xfrm_nr; i++)
3638                                 tpp[ti++] = &pols[pi]->xfrm_vec[i];
3639                 }
3640                 xfrm_nr = ti;
3641
3642                 if (npols > 1) {
3643                         xfrm_tmpl_sort(stp, tpp, xfrm_nr, family);
3644                         tpp = stp;
3645                 }
3646
3647                 /* For each tunnel xfrm, find the first matching tmpl.
3648                  * For each tmpl before that, find corresponding xfrm.
3649                  * Order is _important_. Later we will implement
3650                  * some barriers, but at the moment barriers
3651                  * are implied between each two transformations.
3652                  * Upon success, marks secpath entries as having been
3653                  * verified to allow them to be skipped in future policy
3654                  * checks (e.g. nested tunnels).
3655                  */
3656                 for (i = xfrm_nr-1, k = 0; i >= 0; i--) {
3657                         k = xfrm_policy_ok(tpp[i], sp, k, family, if_id);
3658                         if (k < 0) {
3659                                 if (k < -1)
3660                                         /* "-2 - errored_index" returned */
3661                                         xerr_idx = -(2+k);
3662                                 XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH);
3663                                 goto reject;
3664                         }
3665                 }
3666
3667                 if (secpath_has_nontransport(sp, k, &xerr_idx)) {
3668                         XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH);
3669                         goto reject;
3670                 }
3671
3672                 xfrm_pols_put(pols, npols);
3673                 sp->verified_cnt = k;
3674
3675                 return 1;
3676         }
3677         XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK);
3678
3679 reject:
3680         xfrm_secpath_reject(xerr_idx, skb, &fl);
3681 reject_error:
3682         xfrm_pols_put(pols, npols);
3683         return 0;
3684 }
3685 EXPORT_SYMBOL(__xfrm_policy_check);
3686
3687 int __xfrm_route_forward(struct sk_buff *skb, unsigned short family)
3688 {
3689         struct net *net = dev_net(skb->dev);
3690         struct flowi fl;
3691         struct dst_entry *dst;
3692         int res = 1;
3693
3694         if (xfrm_decode_session(net, skb, &fl, family) < 0) {
3695                 XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR);
3696                 return 0;
3697         }
3698
3699         skb_dst_force(skb);
3700         if (!skb_dst(skb)) {
3701                 XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR);
3702                 return 0;
3703         }
3704
3705         dst = xfrm_lookup(net, skb_dst(skb), &fl, NULL, XFRM_LOOKUP_QUEUE);
3706         if (IS_ERR(dst)) {
3707                 res = 0;
3708                 dst = NULL;
3709         }
3710         skb_dst_set(skb, dst);
3711         return res;
3712 }
3713 EXPORT_SYMBOL(__xfrm_route_forward);
3714
3715 /* Optimize later using cookies and generation ids. */
3716
3717 static struct dst_entry *xfrm_dst_check(struct dst_entry *dst, u32 cookie)
3718 {
3719         /* Code (such as __xfrm4_bundle_create()) sets dst->obsolete
3720          * to DST_OBSOLETE_FORCE_CHK to force all XFRM destinations to
3721          * get validated by dst_ops->check on every use.  We do this
3722          * because when a normal route referenced by an XFRM dst is
3723          * obsoleted we do not go looking around for all parent
3724          * referencing XFRM dsts so that we can invalidate them.  It
3725          * is just too much work.  Instead we make the checks here on
3726          * every use.  For example:
3727          *
3728          *      XFRM dst A --> IPv4 dst X
3729          *
3730          * X is the "xdst->route" of A (X is also the "dst->path" of A
3731          * in this example).  If X is marked obsolete, "A" will not
3732          * notice.  That's what we are validating here via the
3733          * stale_bundle() check.
3734          *
3735          * When a dst is removed from the fib tree, DST_OBSOLETE_DEAD will
3736          * be marked on it.
3737          * This will force stale_bundle() to fail on any xdst bundle with
3738          * this dst linked in it.
3739          */
3740         if (dst->obsolete < 0 && !stale_bundle(dst))
3741                 return dst;
3742
3743         return NULL;
3744 }
3745
3746 static int stale_bundle(struct dst_entry *dst)
3747 {
3748         return !xfrm_bundle_ok((struct xfrm_dst *)dst);
3749 }
3750
3751 void xfrm_dst_ifdown(struct dst_entry *dst, struct net_device *dev)
3752 {
3753         while ((dst = xfrm_dst_child(dst)) && dst->xfrm && dst->dev == dev) {
3754                 dst->dev = blackhole_netdev;
3755                 dev_hold(dst->dev);
3756                 dev_put(dev);
3757         }
3758 }
3759 EXPORT_SYMBOL(xfrm_dst_ifdown);
3760
3761 static void xfrm_link_failure(struct sk_buff *skb)
3762 {
3763         /* Impossible. Such dst must be popped before reaches point of failure. */
3764 }
3765
3766 static struct dst_entry *xfrm_negative_advice(struct dst_entry *dst)
3767 {
3768         if (dst) {
3769                 if (dst->obsolete) {
3770                         dst_release(dst);
3771                         dst = NULL;
3772                 }
3773         }
3774         return dst;
3775 }
3776
3777 static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr)
3778 {
3779         while (nr--) {
3780                 struct xfrm_dst *xdst = bundle[nr];
3781                 u32 pmtu, route_mtu_cached;
3782                 struct dst_entry *dst;
3783
3784                 dst = &xdst->u.dst;
3785                 pmtu = dst_mtu(xfrm_dst_child(dst));
3786                 xdst->child_mtu_cached = pmtu;
3787
3788                 pmtu = xfrm_state_mtu(dst->xfrm, pmtu);
3789
3790                 route_mtu_cached = dst_mtu(xdst->route);
3791                 xdst->route_mtu_cached = route_mtu_cached;
3792
3793                 if (pmtu > route_mtu_cached)
3794                         pmtu = route_mtu_cached;
3795
3796                 dst_metric_set(dst, RTAX_MTU, pmtu);
3797         }
3798 }
3799
3800 /* Check that the bundle accepts the flow and its components are
3801  * still valid.
3802  */
3803
3804 static int xfrm_bundle_ok(struct xfrm_dst *first)
3805 {
3806         struct xfrm_dst *bundle[XFRM_MAX_DEPTH];
3807         struct dst_entry *dst = &first->u.dst;
3808         struct xfrm_dst *xdst;
3809         int start_from, nr;
3810         u32 mtu;
3811
3812         if (!dst_check(xfrm_dst_path(dst), ((struct xfrm_dst *)dst)->path_cookie) ||
3813             (dst->dev && !netif_running(dst->dev)))
3814                 return 0;
3815
3816         if (dst->flags & DST_XFRM_QUEUE)
3817                 return 1;
3818
3819         start_from = nr = 0;
3820         do {
3821                 struct xfrm_dst *xdst = (struct xfrm_dst *)dst;
3822
3823                 if (dst->xfrm->km.state != XFRM_STATE_VALID)
3824                         return 0;
3825                 if (xdst->xfrm_genid != dst->xfrm->genid)
3826                         return 0;
3827                 if (xdst->num_pols > 0 &&
3828                     xdst->policy_genid != atomic_read(&xdst->pols[0]->genid))
3829                         return 0;
3830
3831                 bundle[nr++] = xdst;
3832
3833                 mtu = dst_mtu(xfrm_dst_child(dst));
3834                 if (xdst->child_mtu_cached != mtu) {
3835                         start_from = nr;
3836                         xdst->child_mtu_cached = mtu;
3837                 }
3838
3839                 if (!dst_check(xdst->route, xdst->route_cookie))
3840                         return 0;
3841                 mtu = dst_mtu(xdst->route);
3842                 if (xdst->route_mtu_cached != mtu) {
3843                         start_from = nr;
3844                         xdst->route_mtu_cached = mtu;
3845                 }
3846
3847                 dst = xfrm_dst_child(dst);
3848         } while (dst->xfrm);
3849
3850         if (likely(!start_from))
3851                 return 1;
3852
3853         xdst = bundle[start_from - 1];
3854         mtu = xdst->child_mtu_cached;
3855         while (start_from--) {
3856                 dst = &xdst->u.dst;
3857
3858                 mtu = xfrm_state_mtu(dst->xfrm, mtu);
3859                 if (mtu > xdst->route_mtu_cached)
3860                         mtu = xdst->route_mtu_cached;
3861                 dst_metric_set(dst, RTAX_MTU, mtu);
3862                 if (!start_from)
3863                         break;
3864
3865                 xdst = bundle[start_from - 1];
3866                 xdst->child_mtu_cached = mtu;
3867         }
3868
3869         return 1;
3870 }
3871
3872 static unsigned int xfrm_default_advmss(const struct dst_entry *dst)
3873 {
3874         return dst_metric_advmss(xfrm_dst_path(dst));
3875 }
3876
3877 static unsigned int xfrm_mtu(const struct dst_entry *dst)
3878 {
3879         unsigned int mtu = dst_metric_raw(dst, RTAX_MTU);
3880
3881         return mtu ? : dst_mtu(xfrm_dst_path(dst));
3882 }
3883
3884 static const void *xfrm_get_dst_nexthop(const struct dst_entry *dst,
3885                                         const void *daddr)
3886 {
3887         while (dst->xfrm) {
3888                 const struct xfrm_state *xfrm = dst->xfrm;
3889
3890                 dst = xfrm_dst_child(dst);
3891
3892                 if (xfrm->props.mode == XFRM_MODE_TRANSPORT)
3893                         continue;
3894                 if (xfrm->type->flags & XFRM_TYPE_REMOTE_COADDR)
3895                         daddr = xfrm->coaddr;
3896                 else if (!(xfrm->type->flags & XFRM_TYPE_LOCAL_COADDR))
3897                         daddr = &xfrm->id.daddr;
3898         }
3899         return daddr;
3900 }
3901
3902 static struct neighbour *xfrm_neigh_lookup(const struct dst_entry *dst,
3903                                            struct sk_buff *skb,
3904                                            const void *daddr)
3905 {
3906         const struct dst_entry *path = xfrm_dst_path(dst);
3907
3908         if (!skb)
3909                 daddr = xfrm_get_dst_nexthop(dst, daddr);
3910         return path->ops->neigh_lookup(path, skb, daddr);
3911 }
3912
3913 static void xfrm_confirm_neigh(const struct dst_entry *dst, const void *daddr)
3914 {
3915         const struct dst_entry *path = xfrm_dst_path(dst);
3916
3917         daddr = xfrm_get_dst_nexthop(dst, daddr);
3918         path->ops->confirm_neigh(path, daddr);
3919 }
3920
3921 int xfrm_policy_register_afinfo(const struct xfrm_policy_afinfo *afinfo, int family)
3922 {
3923         int err = 0;
3924
3925         if (WARN_ON(family >= ARRAY_SIZE(xfrm_policy_afinfo)))
3926                 return -EAFNOSUPPORT;
3927
3928         spin_lock(&xfrm_policy_afinfo_lock);
3929         if (unlikely(xfrm_policy_afinfo[family] != NULL))
3930                 err = -EEXIST;
3931         else {
3932                 struct dst_ops *dst_ops = afinfo->dst_ops;
3933                 if (likely(dst_ops->kmem_cachep == NULL))
3934                         dst_ops->kmem_cachep = xfrm_dst_cache;
3935                 if (likely(dst_ops->check == NULL))
3936                         dst_ops->check = xfrm_dst_check;
3937                 if (likely(dst_ops->default_advmss == NULL))
3938                         dst_ops->default_advmss = xfrm_default_advmss;
3939                 if (likely(dst_ops->mtu == NULL))
3940                         dst_ops->mtu = xfrm_mtu;
3941                 if (likely(dst_ops->negative_advice == NULL))
3942                         dst_ops->negative_advice = xfrm_negative_advice;
3943                 if (likely(dst_ops->link_failure == NULL))
3944                         dst_ops->link_failure = xfrm_link_failure;
3945                 if (likely(dst_ops->neigh_lookup == NULL))
3946                         dst_ops->neigh_lookup = xfrm_neigh_lookup;
3947                 if (likely(!dst_ops->confirm_neigh))
3948                         dst_ops->confirm_neigh = xfrm_confirm_neigh;
3949                 rcu_assign_pointer(xfrm_policy_afinfo[family], afinfo);
3950         }
3951         spin_unlock(&xfrm_policy_afinfo_lock);
3952
3953         return err;
3954 }
3955 EXPORT_SYMBOL(xfrm_policy_register_afinfo);
3956
3957 void xfrm_policy_unregister_afinfo(const struct xfrm_policy_afinfo *afinfo)
3958 {
3959         struct dst_ops *dst_ops = afinfo->dst_ops;
3960         int i;
3961
3962         for (i = 0; i < ARRAY_SIZE(xfrm_policy_afinfo); i++) {
3963                 if (xfrm_policy_afinfo[i] != afinfo)
3964                         continue;
3965                 RCU_INIT_POINTER(xfrm_policy_afinfo[i], NULL);
3966                 break;
3967         }
3968
3969         synchronize_rcu();
3970
3971         dst_ops->kmem_cachep = NULL;
3972         dst_ops->check = NULL;
3973         dst_ops->negative_advice = NULL;
3974         dst_ops->link_failure = NULL;
3975 }
3976 EXPORT_SYMBOL(xfrm_policy_unregister_afinfo);
3977
3978 void xfrm_if_register_cb(const struct xfrm_if_cb *ifcb)
3979 {
3980         spin_lock(&xfrm_if_cb_lock);
3981         rcu_assign_pointer(xfrm_if_cb, ifcb);
3982         spin_unlock(&xfrm_if_cb_lock);
3983 }
3984 EXPORT_SYMBOL(xfrm_if_register_cb);
3985
3986 void xfrm_if_unregister_cb(void)
3987 {
3988         RCU_INIT_POINTER(xfrm_if_cb, NULL);
3989         synchronize_rcu();
3990 }
3991 EXPORT_SYMBOL(xfrm_if_unregister_cb);
3992
3993 #ifdef CONFIG_XFRM_STATISTICS
3994 static int __net_init xfrm_statistics_init(struct net *net)
3995 {
3996         int rv;
3997         net->mib.xfrm_statistics = alloc_percpu(struct linux_xfrm_mib);
3998         if (!net->mib.xfrm_statistics)
3999                 return -ENOMEM;
4000         rv = xfrm_proc_init(net);
4001         if (rv < 0)
4002                 free_percpu(net->mib.xfrm_statistics);
4003         return rv;
4004 }
4005
4006 static void xfrm_statistics_fini(struct net *net)
4007 {
4008         xfrm_proc_fini(net);
4009         free_percpu(net->mib.xfrm_statistics);
4010 }
4011 #else
4012 static int __net_init xfrm_statistics_init(struct net *net)
4013 {
4014         return 0;
4015 }
4016
4017 static void xfrm_statistics_fini(struct net *net)
4018 {
4019 }
4020 #endif
4021
4022 static int __net_init xfrm_policy_init(struct net *net)
4023 {
4024         unsigned int hmask, sz;
4025         int dir, err;
4026
4027         if (net_eq(net, &init_net)) {
4028                 xfrm_dst_cache = kmem_cache_create("xfrm_dst_cache",
4029                                            sizeof(struct xfrm_dst),
4030                                            0, SLAB_HWCACHE_ALIGN|SLAB_PANIC,
4031                                            NULL);
4032                 err = rhashtable_init(&xfrm_policy_inexact_table,
4033                                       &xfrm_pol_inexact_params);
4034                 BUG_ON(err);
4035         }
4036
4037         hmask = 8 - 1;
4038         sz = (hmask+1) * sizeof(struct hlist_head);
4039
4040         net->xfrm.policy_byidx = xfrm_hash_alloc(sz);
4041         if (!net->xfrm.policy_byidx)
4042                 goto out_byidx;
4043         net->xfrm.policy_idx_hmask = hmask;
4044
4045         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
4046                 struct xfrm_policy_hash *htab;
4047
4048                 net->xfrm.policy_count[dir] = 0;
4049                 net->xfrm.policy_count[XFRM_POLICY_MAX + dir] = 0;
4050                 INIT_HLIST_HEAD(&net->xfrm.policy_inexact[dir]);
4051
4052                 htab = &net->xfrm.policy_bydst[dir];
4053                 htab->table = xfrm_hash_alloc(sz);
4054                 if (!htab->table)
4055                         goto out_bydst;
4056                 htab->hmask = hmask;
4057                 htab->dbits4 = 32;
4058                 htab->sbits4 = 32;
4059                 htab->dbits6 = 128;
4060                 htab->sbits6 = 128;
4061         }
4062         net->xfrm.policy_hthresh.lbits4 = 32;
4063         net->xfrm.policy_hthresh.rbits4 = 32;
4064         net->xfrm.policy_hthresh.lbits6 = 128;
4065         net->xfrm.policy_hthresh.rbits6 = 128;
4066
4067         seqlock_init(&net->xfrm.policy_hthresh.lock);
4068
4069         INIT_LIST_HEAD(&net->xfrm.policy_all);
4070         INIT_LIST_HEAD(&net->xfrm.inexact_bins);
4071         INIT_WORK(&net->xfrm.policy_hash_work, xfrm_hash_resize);
4072         INIT_WORK(&net->xfrm.policy_hthresh.work, xfrm_hash_rebuild);
4073         return 0;
4074
4075 out_bydst:
4076         for (dir--; dir >= 0; dir--) {
4077                 struct xfrm_policy_hash *htab;
4078
4079                 htab = &net->xfrm.policy_bydst[dir];
4080                 xfrm_hash_free(htab->table, sz);
4081         }
4082         xfrm_hash_free(net->xfrm.policy_byidx, sz);
4083 out_byidx:
4084         return -ENOMEM;
4085 }
4086
4087 static void xfrm_policy_fini(struct net *net)
4088 {
4089         struct xfrm_pol_inexact_bin *b, *t;
4090         unsigned int sz;
4091         int dir;
4092
4093         flush_work(&net->xfrm.policy_hash_work);
4094 #ifdef CONFIG_XFRM_SUB_POLICY
4095         xfrm_policy_flush(net, XFRM_POLICY_TYPE_SUB, false);
4096 #endif
4097         xfrm_policy_flush(net, XFRM_POLICY_TYPE_MAIN, false);
4098
4099         WARN_ON(!list_empty(&net->xfrm.policy_all));
4100
4101         for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
4102                 struct xfrm_policy_hash *htab;
4103
4104                 WARN_ON(!hlist_empty(&net->xfrm.policy_inexact[dir]));
4105
4106                 htab = &net->xfrm.policy_bydst[dir];
4107                 sz = (htab->hmask + 1) * sizeof(struct hlist_head);
4108                 WARN_ON(!hlist_empty(htab->table));
4109                 xfrm_hash_free(htab->table, sz);
4110         }
4111
4112         sz = (net->xfrm.policy_idx_hmask + 1) * sizeof(struct hlist_head);
4113         WARN_ON(!hlist_empty(net->xfrm.policy_byidx));
4114         xfrm_hash_free(net->xfrm.policy_byidx, sz);
4115
4116         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
4117         list_for_each_entry_safe(b, t, &net->xfrm.inexact_bins, inexact_bins)
4118                 __xfrm_policy_inexact_prune_bin(b, true);
4119         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
4120 }
4121
4122 static int __net_init xfrm_net_init(struct net *net)
4123 {
4124         int rv;
4125
4126         /* Initialize the per-net locks here */
4127         spin_lock_init(&net->xfrm.xfrm_state_lock);
4128         spin_lock_init(&net->xfrm.xfrm_policy_lock);
4129         seqcount_spinlock_init(&net->xfrm.xfrm_policy_hash_generation, &net->xfrm.xfrm_policy_lock);
4130         mutex_init(&net->xfrm.xfrm_cfg_mutex);
4131         net->xfrm.policy_default[XFRM_POLICY_IN] = XFRM_USERPOLICY_ACCEPT;
4132         net->xfrm.policy_default[XFRM_POLICY_FWD] = XFRM_USERPOLICY_ACCEPT;
4133         net->xfrm.policy_default[XFRM_POLICY_OUT] = XFRM_USERPOLICY_ACCEPT;
4134
4135         rv = xfrm_statistics_init(net);
4136         if (rv < 0)
4137                 goto out_statistics;
4138         rv = xfrm_state_init(net);
4139         if (rv < 0)
4140                 goto out_state;
4141         rv = xfrm_policy_init(net);
4142         if (rv < 0)
4143                 goto out_policy;
4144         rv = xfrm_sysctl_init(net);
4145         if (rv < 0)
4146                 goto out_sysctl;
4147
4148         return 0;
4149
4150 out_sysctl:
4151         xfrm_policy_fini(net);
4152 out_policy:
4153         xfrm_state_fini(net);
4154 out_state:
4155         xfrm_statistics_fini(net);
4156 out_statistics:
4157         return rv;
4158 }
4159
4160 static void __net_exit xfrm_net_exit(struct net *net)
4161 {
4162         xfrm_sysctl_fini(net);
4163         xfrm_policy_fini(net);
4164         xfrm_state_fini(net);
4165         xfrm_statistics_fini(net);
4166 }
4167
4168 static struct pernet_operations __net_initdata xfrm_net_ops = {
4169         .init = xfrm_net_init,
4170         .exit = xfrm_net_exit,
4171 };
4172
4173 static const struct flow_dissector_key xfrm_flow_dissector_keys[] = {
4174         {
4175                 .key_id = FLOW_DISSECTOR_KEY_CONTROL,
4176                 .offset = offsetof(struct xfrm_flow_keys, control),
4177         },
4178         {
4179                 .key_id = FLOW_DISSECTOR_KEY_BASIC,
4180                 .offset = offsetof(struct xfrm_flow_keys, basic),
4181         },
4182         {
4183                 .key_id = FLOW_DISSECTOR_KEY_IPV4_ADDRS,
4184                 .offset = offsetof(struct xfrm_flow_keys, addrs.ipv4),
4185         },
4186         {
4187                 .key_id = FLOW_DISSECTOR_KEY_IPV6_ADDRS,
4188                 .offset = offsetof(struct xfrm_flow_keys, addrs.ipv6),
4189         },
4190         {
4191                 .key_id = FLOW_DISSECTOR_KEY_PORTS,
4192                 .offset = offsetof(struct xfrm_flow_keys, ports),
4193         },
4194         {
4195                 .key_id = FLOW_DISSECTOR_KEY_GRE_KEYID,
4196                 .offset = offsetof(struct xfrm_flow_keys, gre),
4197         },
4198         {
4199                 .key_id = FLOW_DISSECTOR_KEY_IP,
4200                 .offset = offsetof(struct xfrm_flow_keys, ip),
4201         },
4202         {
4203                 .key_id = FLOW_DISSECTOR_KEY_ICMP,
4204                 .offset = offsetof(struct xfrm_flow_keys, icmp),
4205         },
4206 };
4207
4208 void __init xfrm_init(void)
4209 {
4210         skb_flow_dissector_init(&xfrm_session_dissector,
4211                                 xfrm_flow_dissector_keys,
4212                                 ARRAY_SIZE(xfrm_flow_dissector_keys));
4213
4214         register_pernet_subsys(&xfrm_net_ops);
4215         xfrm_dev_init();
4216         xfrm_input_init();
4217
4218 #ifdef CONFIG_XFRM_ESPINTCP
4219         espintcp_init();
4220 #endif
4221 }
4222
4223 #ifdef CONFIG_AUDITSYSCALL
4224 static void xfrm_audit_common_policyinfo(struct xfrm_policy *xp,
4225                                          struct audit_buffer *audit_buf)
4226 {
4227         struct xfrm_sec_ctx *ctx = xp->security;
4228         struct xfrm_selector *sel = &xp->selector;
4229
4230         if (ctx)
4231                 audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s",
4232                                  ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str);
4233
4234         switch (sel->family) {
4235         case AF_INET:
4236                 audit_log_format(audit_buf, " src=%pI4", &sel->saddr.a4);
4237                 if (sel->prefixlen_s != 32)
4238                         audit_log_format(audit_buf, " src_prefixlen=%d",
4239                                          sel->prefixlen_s);
4240                 audit_log_format(audit_buf, " dst=%pI4", &sel->daddr.a4);
4241                 if (sel->prefixlen_d != 32)
4242                         audit_log_format(audit_buf, " dst_prefixlen=%d",
4243                                          sel->prefixlen_d);
4244                 break;
4245         case AF_INET6:
4246                 audit_log_format(audit_buf, " src=%pI6", sel->saddr.a6);
4247                 if (sel->prefixlen_s != 128)
4248                         audit_log_format(audit_buf, " src_prefixlen=%d",
4249                                          sel->prefixlen_s);
4250                 audit_log_format(audit_buf, " dst=%pI6", sel->daddr.a6);
4251                 if (sel->prefixlen_d != 128)
4252                         audit_log_format(audit_buf, " dst_prefixlen=%d",
4253                                          sel->prefixlen_d);
4254                 break;
4255         }
4256 }
4257
4258 void xfrm_audit_policy_add(struct xfrm_policy *xp, int result, bool task_valid)
4259 {
4260         struct audit_buffer *audit_buf;
4261
4262         audit_buf = xfrm_audit_start("SPD-add");
4263         if (audit_buf == NULL)
4264                 return;
4265         xfrm_audit_helper_usrinfo(task_valid, audit_buf);
4266         audit_log_format(audit_buf, " res=%u", result);
4267         xfrm_audit_common_policyinfo(xp, audit_buf);
4268         audit_log_end(audit_buf);
4269 }
4270 EXPORT_SYMBOL_GPL(xfrm_audit_policy_add);
4271
4272 void xfrm_audit_policy_delete(struct xfrm_policy *xp, int result,
4273                               bool task_valid)
4274 {
4275         struct audit_buffer *audit_buf;
4276
4277         audit_buf = xfrm_audit_start("SPD-delete");
4278         if (audit_buf == NULL)
4279                 return;
4280         xfrm_audit_helper_usrinfo(task_valid, audit_buf);
4281         audit_log_format(audit_buf, " res=%u", result);
4282         xfrm_audit_common_policyinfo(xp, audit_buf);
4283         audit_log_end(audit_buf);
4284 }
4285 EXPORT_SYMBOL_GPL(xfrm_audit_policy_delete);
4286 #endif
4287
4288 #ifdef CONFIG_XFRM_MIGRATE
4289 static bool xfrm_migrate_selector_match(const struct xfrm_selector *sel_cmp,
4290                                         const struct xfrm_selector *sel_tgt)
4291 {
4292         if (sel_cmp->proto == IPSEC_ULPROTO_ANY) {
4293                 if (sel_tgt->family == sel_cmp->family &&
4294                     xfrm_addr_equal(&sel_tgt->daddr, &sel_cmp->daddr,
4295                                     sel_cmp->family) &&
4296                     xfrm_addr_equal(&sel_tgt->saddr, &sel_cmp->saddr,
4297                                     sel_cmp->family) &&
4298                     sel_tgt->prefixlen_d == sel_cmp->prefixlen_d &&
4299                     sel_tgt->prefixlen_s == sel_cmp->prefixlen_s) {
4300                         return true;
4301                 }
4302         } else {
4303                 if (memcmp(sel_tgt, sel_cmp, sizeof(*sel_tgt)) == 0) {
4304                         return true;
4305                 }
4306         }
4307         return false;
4308 }
4309
4310 static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector *sel,
4311                                                     u8 dir, u8 type, struct net *net, u32 if_id)
4312 {
4313         struct xfrm_policy *pol, *ret = NULL;
4314         struct hlist_head *chain;
4315         u32 priority = ~0U;
4316
4317         spin_lock_bh(&net->xfrm.xfrm_policy_lock);
4318         chain = policy_hash_direct(net, &sel->daddr, &sel->saddr, sel->family, dir);
4319         hlist_for_each_entry(pol, chain, bydst) {
4320                 if ((if_id == 0 || pol->if_id == if_id) &&
4321                     xfrm_migrate_selector_match(sel, &pol->selector) &&
4322                     pol->type == type) {
4323                         ret = pol;
4324                         priority = ret->priority;
4325                         break;
4326                 }
4327         }
4328         chain = &net->xfrm.policy_inexact[dir];
4329         hlist_for_each_entry(pol, chain, bydst_inexact_list) {
4330                 if ((pol->priority >= priority) && ret)
4331                         break;
4332
4333                 if ((if_id == 0 || pol->if_id == if_id) &&
4334                     xfrm_migrate_selector_match(sel, &pol->selector) &&
4335                     pol->type == type) {
4336                         ret = pol;
4337                         break;
4338                 }
4339         }
4340
4341         xfrm_pol_hold(ret);
4342
4343         spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
4344
4345         return ret;
4346 }
4347
4348 static int migrate_tmpl_match(const struct xfrm_migrate *m, const struct xfrm_tmpl *t)
4349 {
4350         int match = 0;
4351
4352         if (t->mode == m->mode && t->id.proto == m->proto &&
4353             (m->reqid == 0 || t->reqid == m->reqid)) {
4354                 switch (t->mode) {
4355                 case XFRM_MODE_TUNNEL:
4356                 case XFRM_MODE_BEET:
4357                         if (xfrm_addr_equal(&t->id.daddr, &m->old_daddr,
4358                                             m->old_family) &&
4359                             xfrm_addr_equal(&t->saddr, &m->old_saddr,
4360                                             m->old_family)) {
4361                                 match = 1;
4362                         }
4363                         break;
4364                 case XFRM_MODE_TRANSPORT:
4365                         /* in case of transport mode, template does not store
4366                            any IP addresses, hence we just compare mode and
4367                            protocol */
4368                         match = 1;
4369                         break;
4370                 default:
4371                         break;
4372                 }
4373         }
4374         return match;
4375 }
4376
4377 /* update endpoint address(es) of template(s) */
4378 static int xfrm_policy_migrate(struct xfrm_policy *pol,
4379                                struct xfrm_migrate *m, int num_migrate,
4380                                struct netlink_ext_ack *extack)
4381 {
4382         struct xfrm_migrate *mp;
4383         int i, j, n = 0;
4384
4385         write_lock_bh(&pol->lock);
4386         if (unlikely(pol->walk.dead)) {
4387                 /* target policy has been deleted */
4388                 NL_SET_ERR_MSG(extack, "Target policy not found");
4389                 write_unlock_bh(&pol->lock);
4390                 return -ENOENT;
4391         }
4392
4393         for (i = 0; i < pol->xfrm_nr; i++) {
4394                 for (j = 0, mp = m; j < num_migrate; j++, mp++) {
4395                         if (!migrate_tmpl_match(mp, &pol->xfrm_vec[i]))
4396                                 continue;
4397                         n++;
4398                         if (pol->xfrm_vec[i].mode != XFRM_MODE_TUNNEL &&
4399                             pol->xfrm_vec[i].mode != XFRM_MODE_BEET)
4400                                 continue;
4401                         /* update endpoints */
4402                         memcpy(&pol->xfrm_vec[i].id.daddr, &mp->new_daddr,
4403                                sizeof(pol->xfrm_vec[i].id.daddr));
4404                         memcpy(&pol->xfrm_vec[i].saddr, &mp->new_saddr,
4405                                sizeof(pol->xfrm_vec[i].saddr));
4406                         pol->xfrm_vec[i].encap_family = mp->new_family;
4407                         /* flush bundles */
4408                         atomic_inc(&pol->genid);
4409                 }
4410         }
4411
4412         write_unlock_bh(&pol->lock);
4413
4414         if (!n)
4415                 return -ENODATA;
4416
4417         return 0;
4418 }
4419
4420 static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate,
4421                               struct netlink_ext_ack *extack)
4422 {
4423         int i, j;
4424
4425         if (num_migrate < 1 || num_migrate > XFRM_MAX_DEPTH) {
4426                 NL_SET_ERR_MSG(extack, "Invalid number of SAs to migrate, must be 0 < num <= XFRM_MAX_DEPTH (6)");
4427                 return -EINVAL;
4428         }
4429
4430         for (i = 0; i < num_migrate; i++) {
4431                 if (xfrm_addr_any(&m[i].new_daddr, m[i].new_family) ||
4432                     xfrm_addr_any(&m[i].new_saddr, m[i].new_family)) {
4433                         NL_SET_ERR_MSG(extack, "Addresses in the MIGRATE attribute's list cannot be null");
4434                         return -EINVAL;
4435                 }
4436
4437                 /* check if there is any duplicated entry */
4438                 for (j = i + 1; j < num_migrate; j++) {
4439                         if (!memcmp(&m[i].old_daddr, &m[j].old_daddr,
4440                                     sizeof(m[i].old_daddr)) &&
4441                             !memcmp(&m[i].old_saddr, &m[j].old_saddr,
4442                                     sizeof(m[i].old_saddr)) &&
4443                             m[i].proto == m[j].proto &&
4444                             m[i].mode == m[j].mode &&
4445                             m[i].reqid == m[j].reqid &&
4446                             m[i].old_family == m[j].old_family) {
4447                                 NL_SET_ERR_MSG(extack, "Entries in the MIGRATE attribute's list must be unique");
4448                                 return -EINVAL;
4449                         }
4450                 }
4451         }
4452
4453         return 0;
4454 }
4455
4456 int xfrm_migrate(const struct xfrm_selector *sel, u8 dir, u8 type,
4457                  struct xfrm_migrate *m, int num_migrate,
4458                  struct xfrm_kmaddress *k, struct net *net,
4459                  struct xfrm_encap_tmpl *encap, u32 if_id,
4460                  struct netlink_ext_ack *extack)
4461 {
4462         int i, err, nx_cur = 0, nx_new = 0;
4463         struct xfrm_policy *pol = NULL;
4464         struct xfrm_state *x, *xc;
4465         struct xfrm_state *x_cur[XFRM_MAX_DEPTH];
4466         struct xfrm_state *x_new[XFRM_MAX_DEPTH];
4467         struct xfrm_migrate *mp;
4468
4469         /* Stage 0 - sanity checks */
4470         err = xfrm_migrate_check(m, num_migrate, extack);
4471         if (err < 0)
4472                 goto out;
4473
4474         if (dir >= XFRM_POLICY_MAX) {
4475                 NL_SET_ERR_MSG(extack, "Invalid policy direction");
4476                 err = -EINVAL;
4477                 goto out;
4478         }
4479
4480         /* Stage 1 - find policy */
4481         pol = xfrm_migrate_policy_find(sel, dir, type, net, if_id);
4482         if (!pol) {
4483                 NL_SET_ERR_MSG(extack, "Target policy not found");
4484                 err = -ENOENT;
4485                 goto out;
4486         }
4487
4488         /* Stage 2 - find and update state(s) */
4489         for (i = 0, mp = m; i < num_migrate; i++, mp++) {
4490                 if ((x = xfrm_migrate_state_find(mp, net, if_id))) {
4491                         x_cur[nx_cur] = x;
4492                         nx_cur++;
4493                         xc = xfrm_state_migrate(x, mp, encap);
4494                         if (xc) {
4495                                 x_new[nx_new] = xc;
4496                                 nx_new++;
4497                         } else {
4498                                 err = -ENODATA;
4499                                 goto restore_state;
4500                         }
4501                 }
4502         }
4503
4504         /* Stage 3 - update policy */
4505         err = xfrm_policy_migrate(pol, m, num_migrate, extack);
4506         if (err < 0)
4507                 goto restore_state;
4508
4509         /* Stage 4 - delete old state(s) */
4510         if (nx_cur) {
4511                 xfrm_states_put(x_cur, nx_cur);
4512                 xfrm_states_delete(x_cur, nx_cur);
4513         }
4514
4515         /* Stage 5 - announce */
4516         km_migrate(sel, dir, type, m, num_migrate, k, encap);
4517
4518         xfrm_pol_put(pol);
4519
4520         return 0;
4521 out:
4522         return err;
4523
4524 restore_state:
4525         if (pol)
4526                 xfrm_pol_put(pol);
4527         if (nx_cur)
4528                 xfrm_states_put(x_cur, nx_cur);
4529         if (nx_new)
4530                 xfrm_states_delete(x_new, nx_new);
4531
4532         return err;
4533 }
4534 EXPORT_SYMBOL(xfrm_migrate);
4535 #endif