GNU Linux-libre 5.15.137-gnu
[releases.git] / net / ipv6 / seg6_local.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *  SR-IPv6 implementation
4  *
5  *  Authors:
6  *  David Lebrun <david.lebrun@uclouvain.be>
7  *  eBPF support: Mathieu Xhonneux <m.xhonneux@gmail.com>
8  */
9
10 #include <linux/types.h>
11 #include <linux/skbuff.h>
12 #include <linux/net.h>
13 #include <linux/module.h>
14 #include <net/ip.h>
15 #include <net/lwtunnel.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #include <net/ip6_fib.h>
19 #include <net/route.h>
20 #include <net/seg6.h>
21 #include <linux/seg6.h>
22 #include <linux/seg6_local.h>
23 #include <net/addrconf.h>
24 #include <net/ip6_route.h>
25 #include <net/dst_cache.h>
26 #include <net/ip_tunnels.h>
27 #ifdef CONFIG_IPV6_SEG6_HMAC
28 #include <net/seg6_hmac.h>
29 #endif
30 #include <net/seg6_local.h>
31 #include <linux/etherdevice.h>
32 #include <linux/bpf.h>
33 #include <linux/netfilter.h>
34
35 #define SEG6_F_ATTR(i)          BIT(i)
36
37 struct seg6_local_lwt;
38
39 /* callbacks used for customizing the creation and destruction of a behavior */
40 struct seg6_local_lwtunnel_ops {
41         int (*build_state)(struct seg6_local_lwt *slwt, const void *cfg,
42                            struct netlink_ext_ack *extack);
43         void (*destroy_state)(struct seg6_local_lwt *slwt);
44 };
45
46 struct seg6_action_desc {
47         int action;
48         unsigned long attrs;
49
50         /* The optattrs field is used for specifying all the optional
51          * attributes supported by a specific behavior.
52          * It means that if one of these attributes is not provided in the
53          * netlink message during the behavior creation, no errors will be
54          * returned to the userspace.
55          *
56          * Each attribute can be only of two types (mutually exclusive):
57          * 1) required or 2) optional.
58          * Every user MUST obey to this rule! If you set an attribute as
59          * required the same attribute CANNOT be set as optional and vice
60          * versa.
61          */
62         unsigned long optattrs;
63
64         int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
65         int static_headroom;
66
67         struct seg6_local_lwtunnel_ops slwt_ops;
68 };
69
70 struct bpf_lwt_prog {
71         struct bpf_prog *prog;
72         char *name;
73 };
74
75 enum seg6_end_dt_mode {
76         DT_INVALID_MODE = -EINVAL,
77         DT_LEGACY_MODE  = 0,
78         DT_VRF_MODE     = 1,
79 };
80
81 struct seg6_end_dt_info {
82         enum seg6_end_dt_mode mode;
83
84         struct net *net;
85         /* VRF device associated to the routing table used by the SRv6
86          * End.DT4/DT6 behavior for routing IPv4/IPv6 packets.
87          */
88         int vrf_ifindex;
89         int vrf_table;
90
91         /* tunneled packet family (IPv4 or IPv6).
92          * Protocol and header length are inferred from family.
93          */
94         u16 family;
95 };
96
97 struct pcpu_seg6_local_counters {
98         u64_stats_t packets;
99         u64_stats_t bytes;
100         u64_stats_t errors;
101
102         struct u64_stats_sync syncp;
103 };
104
105 /* This struct groups all the SRv6 Behavior counters supported so far.
106  *
107  * put_nla_counters() makes use of this data structure to collect all counter
108  * values after the per-CPU counter evaluation has been performed.
109  * Finally, each counter value (in seg6_local_counters) is stored in the
110  * corresponding netlink attribute and sent to user space.
111  *
112  * NB: we don't want to expose this structure to user space!
113  */
114 struct seg6_local_counters {
115         __u64 packets;
116         __u64 bytes;
117         __u64 errors;
118 };
119
120 #define seg6_local_alloc_pcpu_counters(__gfp)                           \
121         __netdev_alloc_pcpu_stats(struct pcpu_seg6_local_counters,      \
122                                   ((__gfp) | __GFP_ZERO))
123
124 #define SEG6_F_LOCAL_COUNTERS   SEG6_F_ATTR(SEG6_LOCAL_COUNTERS)
125
126 struct seg6_local_lwt {
127         int action;
128         struct ipv6_sr_hdr *srh;
129         int table;
130         struct in_addr nh4;
131         struct in6_addr nh6;
132         int iif;
133         int oif;
134         struct bpf_lwt_prog bpf;
135 #ifdef CONFIG_NET_L3_MASTER_DEV
136         struct seg6_end_dt_info dt_info;
137 #endif
138         struct pcpu_seg6_local_counters __percpu *pcpu_counters;
139
140         int headroom;
141         struct seg6_action_desc *desc;
142         /* unlike the required attrs, we have to track the optional attributes
143          * that have been effectively parsed.
144          */
145         unsigned long parsed_optattrs;
146 };
147
148 static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt)
149 {
150         return (struct seg6_local_lwt *)lwt->data;
151 }
152
153 static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb)
154 {
155         struct ipv6_sr_hdr *srh;
156
157         srh = seg6_get_srh(skb, IP6_FH_F_SKIP_RH);
158         if (!srh)
159                 return NULL;
160
161 #ifdef CONFIG_IPV6_SEG6_HMAC
162         if (!seg6_hmac_validate_skb(skb))
163                 return NULL;
164 #endif
165
166         return srh;
167 }
168
169 static bool decap_and_validate(struct sk_buff *skb, int proto)
170 {
171         struct ipv6_sr_hdr *srh;
172         unsigned int off = 0;
173
174         srh = seg6_get_srh(skb, 0);
175         if (srh && srh->segments_left > 0)
176                 return false;
177
178 #ifdef CONFIG_IPV6_SEG6_HMAC
179         if (srh && !seg6_hmac_validate_skb(skb))
180                 return false;
181 #endif
182
183         if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0)
184                 return false;
185
186         if (!pskb_pull(skb, off))
187                 return false;
188
189         skb_postpull_rcsum(skb, skb_network_header(skb), off);
190
191         skb_reset_network_header(skb);
192         skb_reset_transport_header(skb);
193         if (iptunnel_pull_offloads(skb))
194                 return false;
195
196         return true;
197 }
198
199 static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr)
200 {
201         struct in6_addr *addr;
202
203         srh->segments_left--;
204         addr = srh->segments + srh->segments_left;
205         *daddr = *addr;
206 }
207
208 static int
209 seg6_lookup_any_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
210                         u32 tbl_id, bool local_delivery)
211 {
212         struct net *net = dev_net(skb->dev);
213         struct ipv6hdr *hdr = ipv6_hdr(skb);
214         int flags = RT6_LOOKUP_F_HAS_SADDR;
215         struct dst_entry *dst = NULL;
216         struct rt6_info *rt;
217         struct flowi6 fl6;
218         int dev_flags = 0;
219
220         fl6.flowi6_iif = skb->dev->ifindex;
221         fl6.daddr = nhaddr ? *nhaddr : hdr->daddr;
222         fl6.saddr = hdr->saddr;
223         fl6.flowlabel = ip6_flowinfo(hdr);
224         fl6.flowi6_mark = skb->mark;
225         fl6.flowi6_proto = hdr->nexthdr;
226
227         if (nhaddr)
228                 fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
229
230         if (!tbl_id) {
231                 dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
232         } else {
233                 struct fib6_table *table;
234
235                 table = fib6_get_table(net, tbl_id);
236                 if (!table)
237                         goto out;
238
239                 rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
240                 dst = &rt->dst;
241         }
242
243         /* we want to discard traffic destined for local packet processing,
244          * if @local_delivery is set to false.
245          */
246         if (!local_delivery)
247                 dev_flags |= IFF_LOOPBACK;
248
249         if (dst && (dst->dev->flags & dev_flags) && !dst->error) {
250                 dst_release(dst);
251                 dst = NULL;
252         }
253
254 out:
255         if (!dst) {
256                 rt = net->ipv6.ip6_blk_hole_entry;
257                 dst = &rt->dst;
258                 dst_hold(dst);
259         }
260
261         skb_dst_drop(skb);
262         skb_dst_set(skb, dst);
263         return dst->error;
264 }
265
266 int seg6_lookup_nexthop(struct sk_buff *skb,
267                         struct in6_addr *nhaddr, u32 tbl_id)
268 {
269         return seg6_lookup_any_nexthop(skb, nhaddr, tbl_id, false);
270 }
271
272 /* regular endpoint function */
273 static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt)
274 {
275         struct ipv6_sr_hdr *srh;
276
277         srh = get_and_validate_srh(skb);
278         if (!srh)
279                 goto drop;
280
281         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
282
283         seg6_lookup_nexthop(skb, NULL, 0);
284
285         return dst_input(skb);
286
287 drop:
288         kfree_skb(skb);
289         return -EINVAL;
290 }
291
292 /* regular endpoint, and forward to specified nexthop */
293 static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt)
294 {
295         struct ipv6_sr_hdr *srh;
296
297         srh = get_and_validate_srh(skb);
298         if (!srh)
299                 goto drop;
300
301         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
302
303         seg6_lookup_nexthop(skb, &slwt->nh6, 0);
304
305         return dst_input(skb);
306
307 drop:
308         kfree_skb(skb);
309         return -EINVAL;
310 }
311
312 static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt)
313 {
314         struct ipv6_sr_hdr *srh;
315
316         srh = get_and_validate_srh(skb);
317         if (!srh)
318                 goto drop;
319
320         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
321
322         seg6_lookup_nexthop(skb, NULL, slwt->table);
323
324         return dst_input(skb);
325
326 drop:
327         kfree_skb(skb);
328         return -EINVAL;
329 }
330
331 /* decapsulate and forward inner L2 frame on specified interface */
332 static int input_action_end_dx2(struct sk_buff *skb,
333                                 struct seg6_local_lwt *slwt)
334 {
335         struct net *net = dev_net(skb->dev);
336         struct net_device *odev;
337         struct ethhdr *eth;
338
339         if (!decap_and_validate(skb, IPPROTO_ETHERNET))
340                 goto drop;
341
342         if (!pskb_may_pull(skb, ETH_HLEN))
343                 goto drop;
344
345         skb_reset_mac_header(skb);
346         eth = (struct ethhdr *)skb->data;
347
348         /* To determine the frame's protocol, we assume it is 802.3. This avoids
349          * a call to eth_type_trans(), which is not really relevant for our
350          * use case.
351          */
352         if (!eth_proto_is_802_3(eth->h_proto))
353                 goto drop;
354
355         odev = dev_get_by_index_rcu(net, slwt->oif);
356         if (!odev)
357                 goto drop;
358
359         /* As we accept Ethernet frames, make sure the egress device is of
360          * the correct type.
361          */
362         if (odev->type != ARPHRD_ETHER)
363                 goto drop;
364
365         if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev))
366                 goto drop;
367
368         skb_orphan(skb);
369
370         if (skb_warn_if_lro(skb))
371                 goto drop;
372
373         skb_forward_csum(skb);
374
375         if (skb->len - ETH_HLEN > odev->mtu)
376                 goto drop;
377
378         skb->dev = odev;
379         skb->protocol = eth->h_proto;
380
381         return dev_queue_xmit(skb);
382
383 drop:
384         kfree_skb(skb);
385         return -EINVAL;
386 }
387
388 static int input_action_end_dx6_finish(struct net *net, struct sock *sk,
389                                        struct sk_buff *skb)
390 {
391         struct dst_entry *orig_dst = skb_dst(skb);
392         struct in6_addr *nhaddr = NULL;
393         struct seg6_local_lwt *slwt;
394
395         slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
396
397         /* The inner packet is not associated to any local interface,
398          * so we do not call netif_rx().
399          *
400          * If slwt->nh6 is set to ::, then lookup the nexthop for the
401          * inner packet's DA. Otherwise, use the specified nexthop.
402          */
403         if (!ipv6_addr_any(&slwt->nh6))
404                 nhaddr = &slwt->nh6;
405
406         seg6_lookup_nexthop(skb, nhaddr, 0);
407
408         return dst_input(skb);
409 }
410
411 /* decapsulate and forward to specified nexthop */
412 static int input_action_end_dx6(struct sk_buff *skb,
413                                 struct seg6_local_lwt *slwt)
414 {
415         /* this function accepts IPv6 encapsulated packets, with either
416          * an SRH with SL=0, or no SRH.
417          */
418
419         if (!decap_and_validate(skb, IPPROTO_IPV6))
420                 goto drop;
421
422         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
423                 goto drop;
424
425         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
426         nf_reset_ct(skb);
427
428         if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
429                 return NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING,
430                                dev_net(skb->dev), NULL, skb, NULL,
431                                skb_dst(skb)->dev, input_action_end_dx6_finish);
432
433         return input_action_end_dx6_finish(dev_net(skb->dev), NULL, skb);
434 drop:
435         kfree_skb(skb);
436         return -EINVAL;
437 }
438
439 static int input_action_end_dx4_finish(struct net *net, struct sock *sk,
440                                        struct sk_buff *skb)
441 {
442         struct dst_entry *orig_dst = skb_dst(skb);
443         struct seg6_local_lwt *slwt;
444         struct iphdr *iph;
445         __be32 nhaddr;
446         int err;
447
448         slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
449
450         iph = ip_hdr(skb);
451
452         nhaddr = slwt->nh4.s_addr ?: iph->daddr;
453
454         skb_dst_drop(skb);
455
456         err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev);
457         if (err) {
458                 kfree_skb(skb);
459                 return -EINVAL;
460         }
461
462         return dst_input(skb);
463 }
464
465 static int input_action_end_dx4(struct sk_buff *skb,
466                                 struct seg6_local_lwt *slwt)
467 {
468         if (!decap_and_validate(skb, IPPROTO_IPIP))
469                 goto drop;
470
471         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
472                 goto drop;
473
474         skb->protocol = htons(ETH_P_IP);
475         skb_set_transport_header(skb, sizeof(struct iphdr));
476         nf_reset_ct(skb);
477
478         if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
479                 return NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING,
480                                dev_net(skb->dev), NULL, skb, NULL,
481                                skb_dst(skb)->dev, input_action_end_dx4_finish);
482
483         return input_action_end_dx4_finish(dev_net(skb->dev), NULL, skb);
484 drop:
485         kfree_skb(skb);
486         return -EINVAL;
487 }
488
489 #ifdef CONFIG_NET_L3_MASTER_DEV
490 static struct net *fib6_config_get_net(const struct fib6_config *fib6_cfg)
491 {
492         const struct nl_info *nli = &fib6_cfg->fc_nlinfo;
493
494         return nli->nl_net;
495 }
496
497 static int __seg6_end_dt_vrf_build(struct seg6_local_lwt *slwt, const void *cfg,
498                                    u16 family, struct netlink_ext_ack *extack)
499 {
500         struct seg6_end_dt_info *info = &slwt->dt_info;
501         int vrf_ifindex;
502         struct net *net;
503
504         net = fib6_config_get_net(cfg);
505
506         /* note that vrf_table was already set by parse_nla_vrftable() */
507         vrf_ifindex = l3mdev_ifindex_lookup_by_table_id(L3MDEV_TYPE_VRF, net,
508                                                         info->vrf_table);
509         if (vrf_ifindex < 0) {
510                 if (vrf_ifindex == -EPERM) {
511                         NL_SET_ERR_MSG(extack,
512                                        "Strict mode for VRF is disabled");
513                 } else if (vrf_ifindex == -ENODEV) {
514                         NL_SET_ERR_MSG(extack,
515                                        "Table has no associated VRF device");
516                 } else {
517                         pr_debug("seg6local: SRv6 End.DT* creation error=%d\n",
518                                  vrf_ifindex);
519                 }
520
521                 return vrf_ifindex;
522         }
523
524         info->net = net;
525         info->vrf_ifindex = vrf_ifindex;
526
527         info->family = family;
528         info->mode = DT_VRF_MODE;
529
530         return 0;
531 }
532
533 /* The SRv6 End.DT4/DT6 behavior extracts the inner (IPv4/IPv6) packet and
534  * routes the IPv4/IPv6 packet by looking at the configured routing table.
535  *
536  * In the SRv6 End.DT4/DT6 use case, we can receive traffic (IPv6+Segment
537  * Routing Header packets) from several interfaces and the outer IPv6
538  * destination address (DA) is used for retrieving the specific instance of the
539  * End.DT4/DT6 behavior that should process the packets.
540  *
541  * However, the inner IPv4/IPv6 packet is not really bound to any receiving
542  * interface and thus the End.DT4/DT6 sets the VRF (associated with the
543  * corresponding routing table) as the *receiving* interface.
544  * In other words, the End.DT4/DT6 processes a packet as if it has been received
545  * directly by the VRF (and not by one of its slave devices, if any).
546  * In this way, the VRF interface is used for routing the IPv4/IPv6 packet in
547  * according to the routing table configured by the End.DT4/DT6 instance.
548  *
549  * This design allows you to get some interesting features like:
550  *  1) the statistics on rx packets;
551  *  2) the possibility to install a packet sniffer on the receiving interface
552  *     (the VRF one) for looking at the incoming packets;
553  *  3) the possibility to leverage the netfilter prerouting hook for the inner
554  *     IPv4 packet.
555  *
556  * This function returns:
557  *  - the sk_buff* when the VRF rcv handler has processed the packet correctly;
558  *  - NULL when the skb is consumed by the VRF rcv handler;
559  *  - a pointer which encodes a negative error number in case of error.
560  *    Note that in this case, the function takes care of freeing the skb.
561  */
562 static struct sk_buff *end_dt_vrf_rcv(struct sk_buff *skb, u16 family,
563                                       struct net_device *dev)
564 {
565         /* based on l3mdev_ip_rcv; we are only interested in the master */
566         if (unlikely(!netif_is_l3_master(dev) && !netif_has_l3_rx_handler(dev)))
567                 goto drop;
568
569         if (unlikely(!dev->l3mdev_ops->l3mdev_l3_rcv))
570                 goto drop;
571
572         /* the decap packet IPv4/IPv6 does not come with any mac header info.
573          * We must unset the mac header to allow the VRF device to rebuild it,
574          * just in case there is a sniffer attached on the device.
575          */
576         skb_unset_mac_header(skb);
577
578         skb = dev->l3mdev_ops->l3mdev_l3_rcv(dev, skb, family);
579         if (!skb)
580                 /* the skb buffer was consumed by the handler */
581                 return NULL;
582
583         /* when a packet is received by a VRF or by one of its slaves, the
584          * master device reference is set into the skb.
585          */
586         if (unlikely(skb->dev != dev || skb->skb_iif != dev->ifindex))
587                 goto drop;
588
589         return skb;
590
591 drop:
592         kfree_skb(skb);
593         return ERR_PTR(-EINVAL);
594 }
595
596 static struct net_device *end_dt_get_vrf_rcu(struct sk_buff *skb,
597                                              struct seg6_end_dt_info *info)
598 {
599         int vrf_ifindex = info->vrf_ifindex;
600         struct net *net = info->net;
601
602         if (unlikely(vrf_ifindex < 0))
603                 goto error;
604
605         if (unlikely(!net_eq(dev_net(skb->dev), net)))
606                 goto error;
607
608         return dev_get_by_index_rcu(net, vrf_ifindex);
609
610 error:
611         return NULL;
612 }
613
614 static struct sk_buff *end_dt_vrf_core(struct sk_buff *skb,
615                                        struct seg6_local_lwt *slwt, u16 family)
616 {
617         struct seg6_end_dt_info *info = &slwt->dt_info;
618         struct net_device *vrf;
619         __be16 protocol;
620         int hdrlen;
621
622         vrf = end_dt_get_vrf_rcu(skb, info);
623         if (unlikely(!vrf))
624                 goto drop;
625
626         switch (family) {
627         case AF_INET:
628                 protocol = htons(ETH_P_IP);
629                 hdrlen = sizeof(struct iphdr);
630                 break;
631         case AF_INET6:
632                 protocol = htons(ETH_P_IPV6);
633                 hdrlen = sizeof(struct ipv6hdr);
634                 break;
635         case AF_UNSPEC:
636                 fallthrough;
637         default:
638                 goto drop;
639         }
640
641         if (unlikely(info->family != AF_UNSPEC && info->family != family)) {
642                 pr_warn_once("seg6local: SRv6 End.DT* family mismatch");
643                 goto drop;
644         }
645
646         skb->protocol = protocol;
647
648         skb_dst_drop(skb);
649
650         skb_set_transport_header(skb, hdrlen);
651         nf_reset_ct(skb);
652
653         return end_dt_vrf_rcv(skb, family, vrf);
654
655 drop:
656         kfree_skb(skb);
657         return ERR_PTR(-EINVAL);
658 }
659
660 static int input_action_end_dt4(struct sk_buff *skb,
661                                 struct seg6_local_lwt *slwt)
662 {
663         struct iphdr *iph;
664         int err;
665
666         if (!decap_and_validate(skb, IPPROTO_IPIP))
667                 goto drop;
668
669         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
670                 goto drop;
671
672         skb = end_dt_vrf_core(skb, slwt, AF_INET);
673         if (!skb)
674                 /* packet has been processed and consumed by the VRF */
675                 return 0;
676
677         if (IS_ERR(skb))
678                 return PTR_ERR(skb);
679
680         iph = ip_hdr(skb);
681
682         err = ip_route_input(skb, iph->daddr, iph->saddr, 0, skb->dev);
683         if (unlikely(err))
684                 goto drop;
685
686         return dst_input(skb);
687
688 drop:
689         kfree_skb(skb);
690         return -EINVAL;
691 }
692
693 static int seg6_end_dt4_build(struct seg6_local_lwt *slwt, const void *cfg,
694                               struct netlink_ext_ack *extack)
695 {
696         return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET, extack);
697 }
698
699 static enum
700 seg6_end_dt_mode seg6_end_dt6_parse_mode(struct seg6_local_lwt *slwt)
701 {
702         unsigned long parsed_optattrs = slwt->parsed_optattrs;
703         bool legacy, vrfmode;
704
705         legacy  = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE));
706         vrfmode = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE));
707
708         if (!(legacy ^ vrfmode))
709                 /* both are absent or present: invalid DT6 mode */
710                 return DT_INVALID_MODE;
711
712         return legacy ? DT_LEGACY_MODE : DT_VRF_MODE;
713 }
714
715 static enum seg6_end_dt_mode seg6_end_dt6_get_mode(struct seg6_local_lwt *slwt)
716 {
717         struct seg6_end_dt_info *info = &slwt->dt_info;
718
719         return info->mode;
720 }
721
722 static int seg6_end_dt6_build(struct seg6_local_lwt *slwt, const void *cfg,
723                               struct netlink_ext_ack *extack)
724 {
725         enum seg6_end_dt_mode mode = seg6_end_dt6_parse_mode(slwt);
726         struct seg6_end_dt_info *info = &slwt->dt_info;
727
728         switch (mode) {
729         case DT_LEGACY_MODE:
730                 info->mode = DT_LEGACY_MODE;
731                 return 0;
732         case DT_VRF_MODE:
733                 return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET6, extack);
734         default:
735                 NL_SET_ERR_MSG(extack, "table or vrftable must be specified");
736                 return -EINVAL;
737         }
738 }
739 #endif
740
741 static int input_action_end_dt6(struct sk_buff *skb,
742                                 struct seg6_local_lwt *slwt)
743 {
744         if (!decap_and_validate(skb, IPPROTO_IPV6))
745                 goto drop;
746
747         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
748                 goto drop;
749
750 #ifdef CONFIG_NET_L3_MASTER_DEV
751         if (seg6_end_dt6_get_mode(slwt) == DT_LEGACY_MODE)
752                 goto legacy_mode;
753
754         /* DT6_VRF_MODE */
755         skb = end_dt_vrf_core(skb, slwt, AF_INET6);
756         if (!skb)
757                 /* packet has been processed and consumed by the VRF */
758                 return 0;
759
760         if (IS_ERR(skb))
761                 return PTR_ERR(skb);
762
763         /* note: this time we do not need to specify the table because the VRF
764          * takes care of selecting the correct table.
765          */
766         seg6_lookup_any_nexthop(skb, NULL, 0, true);
767
768         return dst_input(skb);
769
770 legacy_mode:
771 #endif
772         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
773
774         seg6_lookup_any_nexthop(skb, NULL, slwt->table, true);
775
776         return dst_input(skb);
777
778 drop:
779         kfree_skb(skb);
780         return -EINVAL;
781 }
782
783 #ifdef CONFIG_NET_L3_MASTER_DEV
784 static int seg6_end_dt46_build(struct seg6_local_lwt *slwt, const void *cfg,
785                                struct netlink_ext_ack *extack)
786 {
787         return __seg6_end_dt_vrf_build(slwt, cfg, AF_UNSPEC, extack);
788 }
789
790 static int input_action_end_dt46(struct sk_buff *skb,
791                                  struct seg6_local_lwt *slwt)
792 {
793         unsigned int off = 0;
794         int nexthdr;
795
796         nexthdr = ipv6_find_hdr(skb, &off, -1, NULL, NULL);
797         if (unlikely(nexthdr < 0))
798                 goto drop;
799
800         switch (nexthdr) {
801         case IPPROTO_IPIP:
802                 return input_action_end_dt4(skb, slwt);
803         case IPPROTO_IPV6:
804                 return input_action_end_dt6(skb, slwt);
805         }
806
807 drop:
808         kfree_skb(skb);
809         return -EINVAL;
810 }
811 #endif
812
813 /* push an SRH on top of the current one */
814 static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
815 {
816         struct ipv6_sr_hdr *srh;
817         int err = -EINVAL;
818
819         srh = get_and_validate_srh(skb);
820         if (!srh)
821                 goto drop;
822
823         err = seg6_do_srh_inline(skb, slwt->srh);
824         if (err)
825                 goto drop;
826
827         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
828
829         seg6_lookup_nexthop(skb, NULL, 0);
830
831         return dst_input(skb);
832
833 drop:
834         kfree_skb(skb);
835         return err;
836 }
837
838 /* encapsulate within an outer IPv6 header and a specified SRH */
839 static int input_action_end_b6_encap(struct sk_buff *skb,
840                                      struct seg6_local_lwt *slwt)
841 {
842         struct ipv6_sr_hdr *srh;
843         int err = -EINVAL;
844
845         srh = get_and_validate_srh(skb);
846         if (!srh)
847                 goto drop;
848
849         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
850
851         skb_reset_inner_headers(skb);
852         skb->encapsulation = 1;
853
854         err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6);
855         if (err)
856                 goto drop;
857
858         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
859
860         seg6_lookup_nexthop(skb, NULL, 0);
861
862         return dst_input(skb);
863
864 drop:
865         kfree_skb(skb);
866         return err;
867 }
868
869 DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states);
870
871 bool seg6_bpf_has_valid_srh(struct sk_buff *skb)
872 {
873         struct seg6_bpf_srh_state *srh_state =
874                 this_cpu_ptr(&seg6_bpf_srh_states);
875         struct ipv6_sr_hdr *srh = srh_state->srh;
876
877         if (unlikely(srh == NULL))
878                 return false;
879
880         if (unlikely(!srh_state->valid)) {
881                 if ((srh_state->hdrlen & 7) != 0)
882                         return false;
883
884                 srh->hdrlen = (u8)(srh_state->hdrlen >> 3);
885                 if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3, true))
886                         return false;
887
888                 srh_state->valid = true;
889         }
890
891         return true;
892 }
893
894 static int input_action_end_bpf(struct sk_buff *skb,
895                                 struct seg6_local_lwt *slwt)
896 {
897         struct seg6_bpf_srh_state *srh_state =
898                 this_cpu_ptr(&seg6_bpf_srh_states);
899         struct ipv6_sr_hdr *srh;
900         int ret;
901
902         srh = get_and_validate_srh(skb);
903         if (!srh) {
904                 kfree_skb(skb);
905                 return -EINVAL;
906         }
907         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
908
909         /* preempt_disable is needed to protect the per-CPU buffer srh_state,
910          * which is also accessed by the bpf_lwt_seg6_* helpers
911          */
912         preempt_disable();
913         srh_state->srh = srh;
914         srh_state->hdrlen = srh->hdrlen << 3;
915         srh_state->valid = true;
916
917         rcu_read_lock();
918         bpf_compute_data_pointers(skb);
919         ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb);
920         rcu_read_unlock();
921
922         switch (ret) {
923         case BPF_OK:
924         case BPF_REDIRECT:
925                 break;
926         case BPF_DROP:
927                 goto drop;
928         default:
929                 pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret);
930                 goto drop;
931         }
932
933         if (srh_state->srh && !seg6_bpf_has_valid_srh(skb))
934                 goto drop;
935
936         preempt_enable();
937         if (ret != BPF_REDIRECT)
938                 seg6_lookup_nexthop(skb, NULL, 0);
939
940         return dst_input(skb);
941
942 drop:
943         preempt_enable();
944         kfree_skb(skb);
945         return -EINVAL;
946 }
947
948 static struct seg6_action_desc seg6_action_table[] = {
949         {
950                 .action         = SEG6_LOCAL_ACTION_END,
951                 .attrs          = 0,
952                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
953                 .input          = input_action_end,
954         },
955         {
956                 .action         = SEG6_LOCAL_ACTION_END_X,
957                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
958                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
959                 .input          = input_action_end_x,
960         },
961         {
962                 .action         = SEG6_LOCAL_ACTION_END_T,
963                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
964                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
965                 .input          = input_action_end_t,
966         },
967         {
968                 .action         = SEG6_LOCAL_ACTION_END_DX2,
969                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_OIF),
970                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
971                 .input          = input_action_end_dx2,
972         },
973         {
974                 .action         = SEG6_LOCAL_ACTION_END_DX6,
975                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH6),
976                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
977                 .input          = input_action_end_dx6,
978         },
979         {
980                 .action         = SEG6_LOCAL_ACTION_END_DX4,
981                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_NH4),
982                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
983                 .input          = input_action_end_dx4,
984         },
985         {
986                 .action         = SEG6_LOCAL_ACTION_END_DT4,
987                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
988                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
989 #ifdef CONFIG_NET_L3_MASTER_DEV
990                 .input          = input_action_end_dt4,
991                 .slwt_ops       = {
992                                         .build_state = seg6_end_dt4_build,
993                                   },
994 #endif
995         },
996         {
997                 .action         = SEG6_LOCAL_ACTION_END_DT6,
998 #ifdef CONFIG_NET_L3_MASTER_DEV
999                 .attrs          = 0,
1000                 .optattrs       = SEG6_F_LOCAL_COUNTERS         |
1001                                   SEG6_F_ATTR(SEG6_LOCAL_TABLE) |
1002                                   SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
1003                 .slwt_ops       = {
1004                                         .build_state = seg6_end_dt6_build,
1005                                   },
1006 #else
1007                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_TABLE),
1008                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1009 #endif
1010                 .input          = input_action_end_dt6,
1011         },
1012         {
1013                 .action         = SEG6_LOCAL_ACTION_END_DT46,
1014                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE),
1015                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1016 #ifdef CONFIG_NET_L3_MASTER_DEV
1017                 .input          = input_action_end_dt46,
1018                 .slwt_ops       = {
1019                                         .build_state = seg6_end_dt46_build,
1020                                   },
1021 #endif
1022         },
1023         {
1024                 .action         = SEG6_LOCAL_ACTION_END_B6,
1025                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
1026                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1027                 .input          = input_action_end_b6,
1028         },
1029         {
1030                 .action         = SEG6_LOCAL_ACTION_END_B6_ENCAP,
1031                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_SRH),
1032                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1033                 .input          = input_action_end_b6_encap,
1034                 .static_headroom        = sizeof(struct ipv6hdr),
1035         },
1036         {
1037                 .action         = SEG6_LOCAL_ACTION_END_BPF,
1038                 .attrs          = SEG6_F_ATTR(SEG6_LOCAL_BPF),
1039                 .optattrs       = SEG6_F_LOCAL_COUNTERS,
1040                 .input          = input_action_end_bpf,
1041         },
1042
1043 };
1044
1045 static struct seg6_action_desc *__get_action_desc(int action)
1046 {
1047         struct seg6_action_desc *desc;
1048         int i, count;
1049
1050         count = ARRAY_SIZE(seg6_action_table);
1051         for (i = 0; i < count; i++) {
1052                 desc = &seg6_action_table[i];
1053                 if (desc->action == action)
1054                         return desc;
1055         }
1056
1057         return NULL;
1058 }
1059
1060 static bool seg6_lwtunnel_counters_enabled(struct seg6_local_lwt *slwt)
1061 {
1062         return slwt->parsed_optattrs & SEG6_F_LOCAL_COUNTERS;
1063 }
1064
1065 static void seg6_local_update_counters(struct seg6_local_lwt *slwt,
1066                                        unsigned int len, int err)
1067 {
1068         struct pcpu_seg6_local_counters *pcounters;
1069
1070         pcounters = this_cpu_ptr(slwt->pcpu_counters);
1071         u64_stats_update_begin(&pcounters->syncp);
1072
1073         if (likely(!err)) {
1074                 u64_stats_inc(&pcounters->packets);
1075                 u64_stats_add(&pcounters->bytes, len);
1076         } else {
1077                 u64_stats_inc(&pcounters->errors);
1078         }
1079
1080         u64_stats_update_end(&pcounters->syncp);
1081 }
1082
1083 static int seg6_local_input_core(struct net *net, struct sock *sk,
1084                                  struct sk_buff *skb)
1085 {
1086         struct dst_entry *orig_dst = skb_dst(skb);
1087         struct seg6_action_desc *desc;
1088         struct seg6_local_lwt *slwt;
1089         unsigned int len = skb->len;
1090         int rc;
1091
1092         slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
1093         desc = slwt->desc;
1094
1095         rc = desc->input(skb, slwt);
1096
1097         if (!seg6_lwtunnel_counters_enabled(slwt))
1098                 return rc;
1099
1100         seg6_local_update_counters(slwt, len, rc);
1101
1102         return rc;
1103 }
1104
1105 static int seg6_local_input(struct sk_buff *skb)
1106 {
1107         if (skb->protocol != htons(ETH_P_IPV6)) {
1108                 kfree_skb(skb);
1109                 return -EINVAL;
1110         }
1111
1112         if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
1113                 return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_IN,
1114                                dev_net(skb->dev), NULL, skb, skb->dev, NULL,
1115                                seg6_local_input_core);
1116
1117         return seg6_local_input_core(dev_net(skb->dev), NULL, skb);
1118 }
1119
1120 static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
1121         [SEG6_LOCAL_ACTION]     = { .type = NLA_U32 },
1122         [SEG6_LOCAL_SRH]        = { .type = NLA_BINARY },
1123         [SEG6_LOCAL_TABLE]      = { .type = NLA_U32 },
1124         [SEG6_LOCAL_VRFTABLE]   = { .type = NLA_U32 },
1125         [SEG6_LOCAL_NH4]        = { .type = NLA_BINARY,
1126                                     .len = sizeof(struct in_addr) },
1127         [SEG6_LOCAL_NH6]        = { .type = NLA_BINARY,
1128                                     .len = sizeof(struct in6_addr) },
1129         [SEG6_LOCAL_IIF]        = { .type = NLA_U32 },
1130         [SEG6_LOCAL_OIF]        = { .type = NLA_U32 },
1131         [SEG6_LOCAL_BPF]        = { .type = NLA_NESTED },
1132         [SEG6_LOCAL_COUNTERS]   = { .type = NLA_NESTED },
1133 };
1134
1135 static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1136 {
1137         struct ipv6_sr_hdr *srh;
1138         int len;
1139
1140         srh = nla_data(attrs[SEG6_LOCAL_SRH]);
1141         len = nla_len(attrs[SEG6_LOCAL_SRH]);
1142
1143         /* SRH must contain at least one segment */
1144         if (len < sizeof(*srh) + sizeof(struct in6_addr))
1145                 return -EINVAL;
1146
1147         if (!seg6_validate_srh(srh, len, false))
1148                 return -EINVAL;
1149
1150         slwt->srh = kmemdup(srh, len, GFP_KERNEL);
1151         if (!slwt->srh)
1152                 return -ENOMEM;
1153
1154         slwt->headroom += len;
1155
1156         return 0;
1157 }
1158
1159 static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1160 {
1161         struct ipv6_sr_hdr *srh;
1162         struct nlattr *nla;
1163         int len;
1164
1165         srh = slwt->srh;
1166         len = (srh->hdrlen + 1) << 3;
1167
1168         nla = nla_reserve(skb, SEG6_LOCAL_SRH, len);
1169         if (!nla)
1170                 return -EMSGSIZE;
1171
1172         memcpy(nla_data(nla), srh, len);
1173
1174         return 0;
1175 }
1176
1177 static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1178 {
1179         int len = (a->srh->hdrlen + 1) << 3;
1180
1181         if (len != ((b->srh->hdrlen + 1) << 3))
1182                 return 1;
1183
1184         return memcmp(a->srh, b->srh, len);
1185 }
1186
1187 static void destroy_attr_srh(struct seg6_local_lwt *slwt)
1188 {
1189         kfree(slwt->srh);
1190 }
1191
1192 static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1193 {
1194         slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]);
1195
1196         return 0;
1197 }
1198
1199 static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1200 {
1201         if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table))
1202                 return -EMSGSIZE;
1203
1204         return 0;
1205 }
1206
1207 static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1208 {
1209         if (a->table != b->table)
1210                 return 1;
1211
1212         return 0;
1213 }
1214
1215 static struct
1216 seg6_end_dt_info *seg6_possible_end_dt_info(struct seg6_local_lwt *slwt)
1217 {
1218 #ifdef CONFIG_NET_L3_MASTER_DEV
1219         return &slwt->dt_info;
1220 #else
1221         return ERR_PTR(-EOPNOTSUPP);
1222 #endif
1223 }
1224
1225 static int parse_nla_vrftable(struct nlattr **attrs,
1226                               struct seg6_local_lwt *slwt)
1227 {
1228         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1229
1230         if (IS_ERR(info))
1231                 return PTR_ERR(info);
1232
1233         info->vrf_table = nla_get_u32(attrs[SEG6_LOCAL_VRFTABLE]);
1234
1235         return 0;
1236 }
1237
1238 static int put_nla_vrftable(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1239 {
1240         struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt);
1241
1242         if (IS_ERR(info))
1243                 return PTR_ERR(info);
1244
1245         if (nla_put_u32(skb, SEG6_LOCAL_VRFTABLE, info->vrf_table))
1246                 return -EMSGSIZE;
1247
1248         return 0;
1249 }
1250
1251 static int cmp_nla_vrftable(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1252 {
1253         struct seg6_end_dt_info *info_a = seg6_possible_end_dt_info(a);
1254         struct seg6_end_dt_info *info_b = seg6_possible_end_dt_info(b);
1255
1256         if (info_a->vrf_table != info_b->vrf_table)
1257                 return 1;
1258
1259         return 0;
1260 }
1261
1262 static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1263 {
1264         memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]),
1265                sizeof(struct in_addr));
1266
1267         return 0;
1268 }
1269
1270 static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1271 {
1272         struct nlattr *nla;
1273
1274         nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr));
1275         if (!nla)
1276                 return -EMSGSIZE;
1277
1278         memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr));
1279
1280         return 0;
1281 }
1282
1283 static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1284 {
1285         return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr));
1286 }
1287
1288 static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1289 {
1290         memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]),
1291                sizeof(struct in6_addr));
1292
1293         return 0;
1294 }
1295
1296 static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1297 {
1298         struct nlattr *nla;
1299
1300         nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr));
1301         if (!nla)
1302                 return -EMSGSIZE;
1303
1304         memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr));
1305
1306         return 0;
1307 }
1308
1309 static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1310 {
1311         return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr));
1312 }
1313
1314 static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1315 {
1316         slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]);
1317
1318         return 0;
1319 }
1320
1321 static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1322 {
1323         if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif))
1324                 return -EMSGSIZE;
1325
1326         return 0;
1327 }
1328
1329 static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1330 {
1331         if (a->iif != b->iif)
1332                 return 1;
1333
1334         return 0;
1335 }
1336
1337 static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1338 {
1339         slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]);
1340
1341         return 0;
1342 }
1343
1344 static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1345 {
1346         if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif))
1347                 return -EMSGSIZE;
1348
1349         return 0;
1350 }
1351
1352 static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1353 {
1354         if (a->oif != b->oif)
1355                 return 1;
1356
1357         return 0;
1358 }
1359
1360 #define MAX_PROG_NAME 256
1361 static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = {
1362         [SEG6_LOCAL_BPF_PROG]      = { .type = NLA_U32, },
1363         [SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
1364                                        .len = MAX_PROG_NAME },
1365 };
1366
1367 static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1368 {
1369         struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1];
1370         struct bpf_prog *p;
1371         int ret;
1372         u32 fd;
1373
1374         ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_BPF_PROG_MAX,
1375                                           attrs[SEG6_LOCAL_BPF],
1376                                           bpf_prog_policy, NULL);
1377         if (ret < 0)
1378                 return ret;
1379
1380         if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME])
1381                 return -EINVAL;
1382
1383         slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL);
1384         if (!slwt->bpf.name)
1385                 return -ENOMEM;
1386
1387         fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]);
1388         p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL);
1389         if (IS_ERR(p)) {
1390                 kfree(slwt->bpf.name);
1391                 return PTR_ERR(p);
1392         }
1393
1394         slwt->bpf.prog = p;
1395         return 0;
1396 }
1397
1398 static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1399 {
1400         struct nlattr *nest;
1401
1402         if (!slwt->bpf.prog)
1403                 return 0;
1404
1405         nest = nla_nest_start_noflag(skb, SEG6_LOCAL_BPF);
1406         if (!nest)
1407                 return -EMSGSIZE;
1408
1409         if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id))
1410                 return -EMSGSIZE;
1411
1412         if (slwt->bpf.name &&
1413             nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name))
1414                 return -EMSGSIZE;
1415
1416         return nla_nest_end(skb, nest);
1417 }
1418
1419 static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1420 {
1421         if (!a->bpf.name && !b->bpf.name)
1422                 return 0;
1423
1424         if (!a->bpf.name || !b->bpf.name)
1425                 return 1;
1426
1427         return strcmp(a->bpf.name, b->bpf.name);
1428 }
1429
1430 static void destroy_attr_bpf(struct seg6_local_lwt *slwt)
1431 {
1432         kfree(slwt->bpf.name);
1433         if (slwt->bpf.prog)
1434                 bpf_prog_put(slwt->bpf.prog);
1435 }
1436
1437 static const struct
1438 nla_policy seg6_local_counters_policy[SEG6_LOCAL_CNT_MAX + 1] = {
1439         [SEG6_LOCAL_CNT_PACKETS]        = { .type = NLA_U64 },
1440         [SEG6_LOCAL_CNT_BYTES]          = { .type = NLA_U64 },
1441         [SEG6_LOCAL_CNT_ERRORS]         = { .type = NLA_U64 },
1442 };
1443
1444 static int parse_nla_counters(struct nlattr **attrs,
1445                               struct seg6_local_lwt *slwt)
1446 {
1447         struct pcpu_seg6_local_counters __percpu *pcounters;
1448         struct nlattr *tb[SEG6_LOCAL_CNT_MAX + 1];
1449         int ret;
1450
1451         ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_CNT_MAX,
1452                                           attrs[SEG6_LOCAL_COUNTERS],
1453                                           seg6_local_counters_policy, NULL);
1454         if (ret < 0)
1455                 return ret;
1456
1457         /* basic support for SRv6 Behavior counters requires at least:
1458          * packets, bytes and errors.
1459          */
1460         if (!tb[SEG6_LOCAL_CNT_PACKETS] || !tb[SEG6_LOCAL_CNT_BYTES] ||
1461             !tb[SEG6_LOCAL_CNT_ERRORS])
1462                 return -EINVAL;
1463
1464         /* counters are always zero initialized */
1465         pcounters = seg6_local_alloc_pcpu_counters(GFP_KERNEL);
1466         if (!pcounters)
1467                 return -ENOMEM;
1468
1469         slwt->pcpu_counters = pcounters;
1470
1471         return 0;
1472 }
1473
1474 static int seg6_local_fill_nla_counters(struct sk_buff *skb,
1475                                         struct seg6_local_counters *counters)
1476 {
1477         if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_PACKETS, counters->packets,
1478                               SEG6_LOCAL_CNT_PAD))
1479                 return -EMSGSIZE;
1480
1481         if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_BYTES, counters->bytes,
1482                               SEG6_LOCAL_CNT_PAD))
1483                 return -EMSGSIZE;
1484
1485         if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_ERRORS, counters->errors,
1486                               SEG6_LOCAL_CNT_PAD))
1487                 return -EMSGSIZE;
1488
1489         return 0;
1490 }
1491
1492 static int put_nla_counters(struct sk_buff *skb, struct seg6_local_lwt *slwt)
1493 {
1494         struct seg6_local_counters counters = { 0, 0, 0 };
1495         struct nlattr *nest;
1496         int rc, i;
1497
1498         nest = nla_nest_start(skb, SEG6_LOCAL_COUNTERS);
1499         if (!nest)
1500                 return -EMSGSIZE;
1501
1502         for_each_possible_cpu(i) {
1503                 struct pcpu_seg6_local_counters *pcounters;
1504                 u64 packets, bytes, errors;
1505                 unsigned int start;
1506
1507                 pcounters = per_cpu_ptr(slwt->pcpu_counters, i);
1508                 do {
1509                         start = u64_stats_fetch_begin_irq(&pcounters->syncp);
1510
1511                         packets = u64_stats_read(&pcounters->packets);
1512                         bytes = u64_stats_read(&pcounters->bytes);
1513                         errors = u64_stats_read(&pcounters->errors);
1514
1515                 } while (u64_stats_fetch_retry_irq(&pcounters->syncp, start));
1516
1517                 counters.packets += packets;
1518                 counters.bytes += bytes;
1519                 counters.errors += errors;
1520         }
1521
1522         rc = seg6_local_fill_nla_counters(skb, &counters);
1523         if (rc < 0) {
1524                 nla_nest_cancel(skb, nest);
1525                 return rc;
1526         }
1527
1528         return nla_nest_end(skb, nest);
1529 }
1530
1531 static int cmp_nla_counters(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
1532 {
1533         /* a and b are equal if both have pcpu_counters set or not */
1534         return (!!((unsigned long)a->pcpu_counters)) ^
1535                 (!!((unsigned long)b->pcpu_counters));
1536 }
1537
1538 static void destroy_attr_counters(struct seg6_local_lwt *slwt)
1539 {
1540         free_percpu(slwt->pcpu_counters);
1541 }
1542
1543 struct seg6_action_param {
1544         int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
1545         int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
1546         int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b);
1547
1548         /* optional destroy() callback useful for releasing resources which
1549          * have been previously acquired in the corresponding parse()
1550          * function.
1551          */
1552         void (*destroy)(struct seg6_local_lwt *slwt);
1553 };
1554
1555 static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
1556         [SEG6_LOCAL_SRH]        = { .parse = parse_nla_srh,
1557                                     .put = put_nla_srh,
1558                                     .cmp = cmp_nla_srh,
1559                                     .destroy = destroy_attr_srh },
1560
1561         [SEG6_LOCAL_TABLE]      = { .parse = parse_nla_table,
1562                                     .put = put_nla_table,
1563                                     .cmp = cmp_nla_table },
1564
1565         [SEG6_LOCAL_NH4]        = { .parse = parse_nla_nh4,
1566                                     .put = put_nla_nh4,
1567                                     .cmp = cmp_nla_nh4 },
1568
1569         [SEG6_LOCAL_NH6]        = { .parse = parse_nla_nh6,
1570                                     .put = put_nla_nh6,
1571                                     .cmp = cmp_nla_nh6 },
1572
1573         [SEG6_LOCAL_IIF]        = { .parse = parse_nla_iif,
1574                                     .put = put_nla_iif,
1575                                     .cmp = cmp_nla_iif },
1576
1577         [SEG6_LOCAL_OIF]        = { .parse = parse_nla_oif,
1578                                     .put = put_nla_oif,
1579                                     .cmp = cmp_nla_oif },
1580
1581         [SEG6_LOCAL_BPF]        = { .parse = parse_nla_bpf,
1582                                     .put = put_nla_bpf,
1583                                     .cmp = cmp_nla_bpf,
1584                                     .destroy = destroy_attr_bpf },
1585
1586         [SEG6_LOCAL_VRFTABLE]   = { .parse = parse_nla_vrftable,
1587                                     .put = put_nla_vrftable,
1588                                     .cmp = cmp_nla_vrftable },
1589
1590         [SEG6_LOCAL_COUNTERS]   = { .parse = parse_nla_counters,
1591                                     .put = put_nla_counters,
1592                                     .cmp = cmp_nla_counters,
1593                                     .destroy = destroy_attr_counters },
1594 };
1595
1596 /* call the destroy() callback (if available) for each set attribute in
1597  * @parsed_attrs, starting from the first attribute up to the @max_parsed
1598  * (excluded) attribute.
1599  */
1600 static void __destroy_attrs(unsigned long parsed_attrs, int max_parsed,
1601                             struct seg6_local_lwt *slwt)
1602 {
1603         struct seg6_action_param *param;
1604         int i;
1605
1606         /* Every required seg6local attribute is identified by an ID which is
1607          * encoded as a flag (i.e: 1 << ID) in the 'attrs' bitmask;
1608          *
1609          * We scan the 'parsed_attrs' bitmask, starting from the first attribute
1610          * up to the @max_parsed (excluded) attribute.
1611          * For each set attribute, we retrieve the corresponding destroy()
1612          * callback. If the callback is not available, then we skip to the next
1613          * attribute; otherwise, we call the destroy() callback.
1614          */
1615         for (i = 0; i < max_parsed; ++i) {
1616                 if (!(parsed_attrs & SEG6_F_ATTR(i)))
1617                         continue;
1618
1619                 param = &seg6_action_params[i];
1620
1621                 if (param->destroy)
1622                         param->destroy(slwt);
1623         }
1624 }
1625
1626 /* release all the resources that may have been acquired during parsing
1627  * operations.
1628  */
1629 static void destroy_attrs(struct seg6_local_lwt *slwt)
1630 {
1631         unsigned long attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1632
1633         __destroy_attrs(attrs, SEG6_LOCAL_MAX + 1, slwt);
1634 }
1635
1636 static int parse_nla_optional_attrs(struct nlattr **attrs,
1637                                     struct seg6_local_lwt *slwt)
1638 {
1639         struct seg6_action_desc *desc = slwt->desc;
1640         unsigned long parsed_optattrs = 0;
1641         struct seg6_action_param *param;
1642         int err, i;
1643
1644         for (i = 0; i < SEG6_LOCAL_MAX + 1; ++i) {
1645                 if (!(desc->optattrs & SEG6_F_ATTR(i)) || !attrs[i])
1646                         continue;
1647
1648                 /* once here, the i-th attribute is provided by the
1649                  * userspace AND it is identified optional as well.
1650                  */
1651                 param = &seg6_action_params[i];
1652
1653                 err = param->parse(attrs, slwt);
1654                 if (err < 0)
1655                         goto parse_optattrs_err;
1656
1657                 /* current attribute has been correctly parsed */
1658                 parsed_optattrs |= SEG6_F_ATTR(i);
1659         }
1660
1661         /* store in the tunnel state all the optional attributed successfully
1662          * parsed.
1663          */
1664         slwt->parsed_optattrs = parsed_optattrs;
1665
1666         return 0;
1667
1668 parse_optattrs_err:
1669         __destroy_attrs(parsed_optattrs, i, slwt);
1670
1671         return err;
1672 }
1673
1674 /* call the custom constructor of the behavior during its initialization phase
1675  * and after that all its attributes have been parsed successfully.
1676  */
1677 static int
1678 seg6_local_lwtunnel_build_state(struct seg6_local_lwt *slwt, const void *cfg,
1679                                 struct netlink_ext_ack *extack)
1680 {
1681         struct seg6_action_desc *desc = slwt->desc;
1682         struct seg6_local_lwtunnel_ops *ops;
1683
1684         ops = &desc->slwt_ops;
1685         if (!ops->build_state)
1686                 return 0;
1687
1688         return ops->build_state(slwt, cfg, extack);
1689 }
1690
1691 /* call the custom destructor of the behavior which is invoked before the
1692  * tunnel is going to be destroyed.
1693  */
1694 static void seg6_local_lwtunnel_destroy_state(struct seg6_local_lwt *slwt)
1695 {
1696         struct seg6_action_desc *desc = slwt->desc;
1697         struct seg6_local_lwtunnel_ops *ops;
1698
1699         ops = &desc->slwt_ops;
1700         if (!ops->destroy_state)
1701                 return;
1702
1703         ops->destroy_state(slwt);
1704 }
1705
1706 static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt)
1707 {
1708         struct seg6_action_param *param;
1709         struct seg6_action_desc *desc;
1710         unsigned long invalid_attrs;
1711         int i, err;
1712
1713         desc = __get_action_desc(slwt->action);
1714         if (!desc)
1715                 return -EINVAL;
1716
1717         if (!desc->input)
1718                 return -EOPNOTSUPP;
1719
1720         slwt->desc = desc;
1721         slwt->headroom += desc->static_headroom;
1722
1723         /* Forcing the desc->optattrs *set* and the desc->attrs *set* to be
1724          * disjoined, this allow us to release acquired resources by optional
1725          * attributes and by required attributes independently from each other
1726          * without any interference.
1727          * In other terms, we are sure that we do not release some the acquired
1728          * resources twice.
1729          *
1730          * Note that if an attribute is configured both as required and as
1731          * optional, it means that the user has messed something up in the
1732          * seg6_action_table. Therefore, this check is required for SRv6
1733          * behaviors to work properly.
1734          */
1735         invalid_attrs = desc->attrs & desc->optattrs;
1736         if (invalid_attrs) {
1737                 WARN_ONCE(1,
1738                           "An attribute cannot be both required AND optional");
1739                 return -EINVAL;
1740         }
1741
1742         /* parse the required attributes */
1743         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1744                 if (desc->attrs & SEG6_F_ATTR(i)) {
1745                         if (!attrs[i])
1746                                 return -EINVAL;
1747
1748                         param = &seg6_action_params[i];
1749
1750                         err = param->parse(attrs, slwt);
1751                         if (err < 0)
1752                                 goto parse_attrs_err;
1753                 }
1754         }
1755
1756         /* parse the optional attributes, if any */
1757         err = parse_nla_optional_attrs(attrs, slwt);
1758         if (err < 0)
1759                 goto parse_attrs_err;
1760
1761         return 0;
1762
1763 parse_attrs_err:
1764         /* release any resource that may have been acquired during the i-1
1765          * parse() operations.
1766          */
1767         __destroy_attrs(desc->attrs, i, slwt);
1768
1769         return err;
1770 }
1771
1772 static int seg6_local_build_state(struct net *net, struct nlattr *nla,
1773                                   unsigned int family, const void *cfg,
1774                                   struct lwtunnel_state **ts,
1775                                   struct netlink_ext_ack *extack)
1776 {
1777         struct nlattr *tb[SEG6_LOCAL_MAX + 1];
1778         struct lwtunnel_state *newts;
1779         struct seg6_local_lwt *slwt;
1780         int err;
1781
1782         if (family != AF_INET6)
1783                 return -EINVAL;
1784
1785         err = nla_parse_nested_deprecated(tb, SEG6_LOCAL_MAX, nla,
1786                                           seg6_local_policy, extack);
1787
1788         if (err < 0)
1789                 return err;
1790
1791         if (!tb[SEG6_LOCAL_ACTION])
1792                 return -EINVAL;
1793
1794         newts = lwtunnel_state_alloc(sizeof(*slwt));
1795         if (!newts)
1796                 return -ENOMEM;
1797
1798         slwt = seg6_local_lwtunnel(newts);
1799         slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]);
1800
1801         err = parse_nla_action(tb, slwt);
1802         if (err < 0)
1803                 goto out_free;
1804
1805         err = seg6_local_lwtunnel_build_state(slwt, cfg, extack);
1806         if (err < 0)
1807                 goto out_destroy_attrs;
1808
1809         newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL;
1810         newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT;
1811         newts->headroom = slwt->headroom;
1812
1813         *ts = newts;
1814
1815         return 0;
1816
1817 out_destroy_attrs:
1818         destroy_attrs(slwt);
1819 out_free:
1820         kfree(newts);
1821         return err;
1822 }
1823
1824 static void seg6_local_destroy_state(struct lwtunnel_state *lwt)
1825 {
1826         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1827
1828         seg6_local_lwtunnel_destroy_state(slwt);
1829
1830         destroy_attrs(slwt);
1831
1832         return;
1833 }
1834
1835 static int seg6_local_fill_encap(struct sk_buff *skb,
1836                                  struct lwtunnel_state *lwt)
1837 {
1838         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1839         struct seg6_action_param *param;
1840         unsigned long attrs;
1841         int i, err;
1842
1843         if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action))
1844                 return -EMSGSIZE;
1845
1846         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1847
1848         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1849                 if (attrs & SEG6_F_ATTR(i)) {
1850                         param = &seg6_action_params[i];
1851                         err = param->put(skb, slwt);
1852                         if (err < 0)
1853                                 return err;
1854                 }
1855         }
1856
1857         return 0;
1858 }
1859
1860 static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
1861 {
1862         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1863         unsigned long attrs;
1864         int nlsize;
1865
1866         nlsize = nla_total_size(4); /* action */
1867
1868         attrs = slwt->desc->attrs | slwt->parsed_optattrs;
1869
1870         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_SRH))
1871                 nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3);
1872
1873         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE))
1874                 nlsize += nla_total_size(4);
1875
1876         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH4))
1877                 nlsize += nla_total_size(4);
1878
1879         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH6))
1880                 nlsize += nla_total_size(16);
1881
1882         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_IIF))
1883                 nlsize += nla_total_size(4);
1884
1885         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_OIF))
1886                 nlsize += nla_total_size(4);
1887
1888         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_BPF))
1889                 nlsize += nla_total_size(sizeof(struct nlattr)) +
1890                        nla_total_size(MAX_PROG_NAME) +
1891                        nla_total_size(4);
1892
1893         if (attrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE))
1894                 nlsize += nla_total_size(4);
1895
1896         if (attrs & SEG6_F_LOCAL_COUNTERS)
1897                 nlsize += nla_total_size(0) + /* nest SEG6_LOCAL_COUNTERS */
1898                           /* SEG6_LOCAL_CNT_PACKETS */
1899                           nla_total_size_64bit(sizeof(__u64)) +
1900                           /* SEG6_LOCAL_CNT_BYTES */
1901                           nla_total_size_64bit(sizeof(__u64)) +
1902                           /* SEG6_LOCAL_CNT_ERRORS */
1903                           nla_total_size_64bit(sizeof(__u64));
1904
1905         return nlsize;
1906 }
1907
1908 static int seg6_local_cmp_encap(struct lwtunnel_state *a,
1909                                 struct lwtunnel_state *b)
1910 {
1911         struct seg6_local_lwt *slwt_a, *slwt_b;
1912         struct seg6_action_param *param;
1913         unsigned long attrs_a, attrs_b;
1914         int i;
1915
1916         slwt_a = seg6_local_lwtunnel(a);
1917         slwt_b = seg6_local_lwtunnel(b);
1918
1919         if (slwt_a->action != slwt_b->action)
1920                 return 1;
1921
1922         attrs_a = slwt_a->desc->attrs | slwt_a->parsed_optattrs;
1923         attrs_b = slwt_b->desc->attrs | slwt_b->parsed_optattrs;
1924
1925         if (attrs_a != attrs_b)
1926                 return 1;
1927
1928         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1929                 if (attrs_a & SEG6_F_ATTR(i)) {
1930                         param = &seg6_action_params[i];
1931                         if (param->cmp(slwt_a, slwt_b))
1932                                 return 1;
1933                 }
1934         }
1935
1936         return 0;
1937 }
1938
1939 static const struct lwtunnel_encap_ops seg6_local_ops = {
1940         .build_state    = seg6_local_build_state,
1941         .destroy_state  = seg6_local_destroy_state,
1942         .input          = seg6_local_input,
1943         .fill_encap     = seg6_local_fill_encap,
1944         .get_encap_size = seg6_local_get_encap_size,
1945         .cmp_encap      = seg6_local_cmp_encap,
1946         .owner          = THIS_MODULE,
1947 };
1948
1949 int __init seg6_local_init(void)
1950 {
1951         /* If the max total number of defined attributes is reached, then your
1952          * kernel build stops here.
1953          *
1954          * This check is required to avoid arithmetic overflows when processing
1955          * behavior attributes and the maximum number of defined attributes
1956          * exceeds the allowed value.
1957          */
1958         BUILD_BUG_ON(SEG6_LOCAL_MAX + 1 > BITS_PER_TYPE(unsigned long));
1959
1960         return lwtunnel_encap_add_ops(&seg6_local_ops,
1961                                       LWTUNNEL_ENCAP_SEG6_LOCAL);
1962 }
1963
1964 void seg6_local_exit(void)
1965 {
1966         lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL);
1967 }