GNU Linux-libre 4.19.242-gnu1
[releases.git] / net / ipv6 / seg6_local.c
1 /*
2  *  SR-IPv6 implementation
3  *
4  *  Authors:
5  *  David Lebrun <david.lebrun@uclouvain.be>
6  *  eBPF support: Mathieu Xhonneux <m.xhonneux@gmail.com>
7  *
8  *
9  *  This program is free software; you can redistribute it and/or
10  *        modify it under the terms of the GNU General Public License
11  *        as published by the Free Software Foundation; either version
12  *        2 of the License, or (at your option) any later version.
13  */
14
15 #include <linux/types.h>
16 #include <linux/skbuff.h>
17 #include <linux/net.h>
18 #include <linux/module.h>
19 #include <net/ip.h>
20 #include <net/lwtunnel.h>
21 #include <net/netevent.h>
22 #include <net/netns/generic.h>
23 #include <net/ip6_fib.h>
24 #include <net/route.h>
25 #include <net/seg6.h>
26 #include <linux/seg6.h>
27 #include <linux/seg6_local.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/dst_cache.h>
31 #include <net/ip_tunnels.h>
32 #ifdef CONFIG_IPV6_SEG6_HMAC
33 #include <net/seg6_hmac.h>
34 #endif
35 #include <net/seg6_local.h>
36 #include <linux/etherdevice.h>
37 #include <linux/bpf.h>
38
39 struct seg6_local_lwt;
40
41 struct seg6_action_desc {
42         int action;
43         unsigned long attrs;
44         int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
45         int static_headroom;
46 };
47
48 struct bpf_lwt_prog {
49         struct bpf_prog *prog;
50         char *name;
51 };
52
53 struct seg6_local_lwt {
54         int action;
55         struct ipv6_sr_hdr *srh;
56         int table;
57         struct in_addr nh4;
58         struct in6_addr nh6;
59         int iif;
60         int oif;
61         struct bpf_lwt_prog bpf;
62
63         int headroom;
64         struct seg6_action_desc *desc;
65 };
66
67 static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt)
68 {
69         return (struct seg6_local_lwt *)lwt->data;
70 }
71
72 static struct ipv6_sr_hdr *get_srh(struct sk_buff *skb)
73 {
74         struct ipv6_sr_hdr *srh;
75         int len, srhoff = 0;
76
77         if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0)
78                 return NULL;
79
80         if (!pskb_may_pull(skb, srhoff + sizeof(*srh)))
81                 return NULL;
82
83         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
84
85         len = (srh->hdrlen + 1) << 3;
86
87         if (!pskb_may_pull(skb, srhoff + len))
88                 return NULL;
89
90         if (!seg6_validate_srh(srh, len))
91                 return NULL;
92
93         return srh;
94 }
95
96 static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb)
97 {
98         struct ipv6_sr_hdr *srh;
99
100         srh = get_srh(skb);
101         if (!srh)
102                 return NULL;
103
104         if (srh->segments_left == 0)
105                 return NULL;
106
107 #ifdef CONFIG_IPV6_SEG6_HMAC
108         if (!seg6_hmac_validate_skb(skb))
109                 return NULL;
110 #endif
111
112         return srh;
113 }
114
115 static bool decap_and_validate(struct sk_buff *skb, int proto)
116 {
117         struct ipv6_sr_hdr *srh;
118         unsigned int off = 0;
119
120         srh = get_srh(skb);
121         if (srh && srh->segments_left > 0)
122                 return false;
123
124 #ifdef CONFIG_IPV6_SEG6_HMAC
125         if (srh && !seg6_hmac_validate_skb(skb))
126                 return false;
127 #endif
128
129         if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0)
130                 return false;
131
132         if (!pskb_pull(skb, off))
133                 return false;
134
135         skb_postpull_rcsum(skb, skb_network_header(skb), off);
136
137         skb_reset_network_header(skb);
138         skb_reset_transport_header(skb);
139         if (iptunnel_pull_offloads(skb))
140                 return false;
141
142         return true;
143 }
144
145 static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr)
146 {
147         struct in6_addr *addr;
148
149         srh->segments_left--;
150         addr = srh->segments + srh->segments_left;
151         *daddr = *addr;
152 }
153
154 int seg6_lookup_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr,
155                         u32 tbl_id)
156 {
157         struct net *net = dev_net(skb->dev);
158         struct ipv6hdr *hdr = ipv6_hdr(skb);
159         int flags = RT6_LOOKUP_F_HAS_SADDR;
160         struct dst_entry *dst = NULL;
161         struct rt6_info *rt;
162         struct flowi6 fl6;
163
164         fl6.flowi6_iif = skb->dev->ifindex;
165         fl6.daddr = nhaddr ? *nhaddr : hdr->daddr;
166         fl6.saddr = hdr->saddr;
167         fl6.flowlabel = ip6_flowinfo(hdr);
168         fl6.flowi6_mark = skb->mark;
169         fl6.flowi6_proto = hdr->nexthdr;
170
171         if (nhaddr)
172                 fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH;
173
174         if (!tbl_id) {
175                 dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags);
176         } else {
177                 struct fib6_table *table;
178
179                 table = fib6_get_table(net, tbl_id);
180                 if (!table)
181                         goto out;
182
183                 rt = ip6_pol_route(net, table, 0, &fl6, skb, flags);
184                 dst = &rt->dst;
185         }
186
187         if (dst && dst->dev->flags & IFF_LOOPBACK && !dst->error) {
188                 dst_release(dst);
189                 dst = NULL;
190         }
191
192 out:
193         if (!dst) {
194                 rt = net->ipv6.ip6_blk_hole_entry;
195                 dst = &rt->dst;
196                 dst_hold(dst);
197         }
198
199         skb_dst_drop(skb);
200         skb_dst_set(skb, dst);
201         return dst->error;
202 }
203
204 /* regular endpoint function */
205 static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt)
206 {
207         struct ipv6_sr_hdr *srh;
208
209         srh = get_and_validate_srh(skb);
210         if (!srh)
211                 goto drop;
212
213         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
214
215         seg6_lookup_nexthop(skb, NULL, 0);
216
217         return dst_input(skb);
218
219 drop:
220         kfree_skb(skb);
221         return -EINVAL;
222 }
223
224 /* regular endpoint, and forward to specified nexthop */
225 static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt)
226 {
227         struct ipv6_sr_hdr *srh;
228
229         srh = get_and_validate_srh(skb);
230         if (!srh)
231                 goto drop;
232
233         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
234
235         seg6_lookup_nexthop(skb, &slwt->nh6, 0);
236
237         return dst_input(skb);
238
239 drop:
240         kfree_skb(skb);
241         return -EINVAL;
242 }
243
244 static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt)
245 {
246         struct ipv6_sr_hdr *srh;
247
248         srh = get_and_validate_srh(skb);
249         if (!srh)
250                 goto drop;
251
252         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
253
254         seg6_lookup_nexthop(skb, NULL, slwt->table);
255
256         return dst_input(skb);
257
258 drop:
259         kfree_skb(skb);
260         return -EINVAL;
261 }
262
263 /* decapsulate and forward inner L2 frame on specified interface */
264 static int input_action_end_dx2(struct sk_buff *skb,
265                                 struct seg6_local_lwt *slwt)
266 {
267         struct net *net = dev_net(skb->dev);
268         struct net_device *odev;
269         struct ethhdr *eth;
270
271         if (!decap_and_validate(skb, NEXTHDR_NONE))
272                 goto drop;
273
274         if (!pskb_may_pull(skb, ETH_HLEN))
275                 goto drop;
276
277         skb_reset_mac_header(skb);
278         eth = (struct ethhdr *)skb->data;
279
280         /* To determine the frame's protocol, we assume it is 802.3. This avoids
281          * a call to eth_type_trans(), which is not really relevant for our
282          * use case.
283          */
284         if (!eth_proto_is_802_3(eth->h_proto))
285                 goto drop;
286
287         odev = dev_get_by_index_rcu(net, slwt->oif);
288         if (!odev)
289                 goto drop;
290
291         /* As we accept Ethernet frames, make sure the egress device is of
292          * the correct type.
293          */
294         if (odev->type != ARPHRD_ETHER)
295                 goto drop;
296
297         if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev))
298                 goto drop;
299
300         skb_orphan(skb);
301
302         if (skb_warn_if_lro(skb))
303                 goto drop;
304
305         skb_forward_csum(skb);
306
307         if (skb->len - ETH_HLEN > odev->mtu)
308                 goto drop;
309
310         skb->dev = odev;
311         skb->protocol = eth->h_proto;
312
313         return dev_queue_xmit(skb);
314
315 drop:
316         kfree_skb(skb);
317         return -EINVAL;
318 }
319
320 /* decapsulate and forward to specified nexthop */
321 static int input_action_end_dx6(struct sk_buff *skb,
322                                 struct seg6_local_lwt *slwt)
323 {
324         struct in6_addr *nhaddr = NULL;
325
326         /* this function accepts IPv6 encapsulated packets, with either
327          * an SRH with SL=0, or no SRH.
328          */
329
330         if (!decap_and_validate(skb, IPPROTO_IPV6))
331                 goto drop;
332
333         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
334                 goto drop;
335
336         /* The inner packet is not associated to any local interface,
337          * so we do not call netif_rx().
338          *
339          * If slwt->nh6 is set to ::, then lookup the nexthop for the
340          * inner packet's DA. Otherwise, use the specified nexthop.
341          */
342
343         if (!ipv6_addr_any(&slwt->nh6))
344                 nhaddr = &slwt->nh6;
345
346         seg6_lookup_nexthop(skb, nhaddr, 0);
347
348         return dst_input(skb);
349 drop:
350         kfree_skb(skb);
351         return -EINVAL;
352 }
353
354 static int input_action_end_dx4(struct sk_buff *skb,
355                                 struct seg6_local_lwt *slwt)
356 {
357         struct iphdr *iph;
358         __be32 nhaddr;
359         int err;
360
361         if (!decap_and_validate(skb, IPPROTO_IPIP))
362                 goto drop;
363
364         if (!pskb_may_pull(skb, sizeof(struct iphdr)))
365                 goto drop;
366
367         skb->protocol = htons(ETH_P_IP);
368
369         iph = ip_hdr(skb);
370
371         nhaddr = slwt->nh4.s_addr ?: iph->daddr;
372
373         skb_dst_drop(skb);
374
375         err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev);
376         if (err)
377                 goto drop;
378
379         return dst_input(skb);
380
381 drop:
382         kfree_skb(skb);
383         return -EINVAL;
384 }
385
386 static int input_action_end_dt6(struct sk_buff *skb,
387                                 struct seg6_local_lwt *slwt)
388 {
389         if (!decap_and_validate(skb, IPPROTO_IPV6))
390                 goto drop;
391
392         if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
393                 goto drop;
394
395         seg6_lookup_nexthop(skb, NULL, slwt->table);
396
397         return dst_input(skb);
398
399 drop:
400         kfree_skb(skb);
401         return -EINVAL;
402 }
403
404 /* push an SRH on top of the current one */
405 static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
406 {
407         struct ipv6_sr_hdr *srh;
408         int err = -EINVAL;
409
410         srh = get_and_validate_srh(skb);
411         if (!srh)
412                 goto drop;
413
414         err = seg6_do_srh_inline(skb, slwt->srh);
415         if (err)
416                 goto drop;
417
418         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
419         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
420
421         seg6_lookup_nexthop(skb, NULL, 0);
422
423         return dst_input(skb);
424
425 drop:
426         kfree_skb(skb);
427         return err;
428 }
429
430 /* encapsulate within an outer IPv6 header and a specified SRH */
431 static int input_action_end_b6_encap(struct sk_buff *skb,
432                                      struct seg6_local_lwt *slwt)
433 {
434         struct ipv6_sr_hdr *srh;
435         int err = -EINVAL;
436
437         srh = get_and_validate_srh(skb);
438         if (!srh)
439                 goto drop;
440
441         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
442
443         skb_reset_inner_headers(skb);
444         skb->encapsulation = 1;
445
446         err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6);
447         if (err)
448                 goto drop;
449
450         ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
451         skb_set_transport_header(skb, sizeof(struct ipv6hdr));
452
453         seg6_lookup_nexthop(skb, NULL, 0);
454
455         return dst_input(skb);
456
457 drop:
458         kfree_skb(skb);
459         return err;
460 }
461
462 DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states);
463
464 bool seg6_bpf_has_valid_srh(struct sk_buff *skb)
465 {
466         struct seg6_bpf_srh_state *srh_state =
467                 this_cpu_ptr(&seg6_bpf_srh_states);
468         struct ipv6_sr_hdr *srh = srh_state->srh;
469
470         if (unlikely(srh == NULL))
471                 return false;
472
473         if (unlikely(!srh_state->valid)) {
474                 if ((srh_state->hdrlen & 7) != 0)
475                         return false;
476
477                 srh->hdrlen = (u8)(srh_state->hdrlen >> 3);
478                 if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3))
479                         return false;
480
481                 srh_state->valid = true;
482         }
483
484         return true;
485 }
486
487 static int input_action_end_bpf(struct sk_buff *skb,
488                                 struct seg6_local_lwt *slwt)
489 {
490         struct seg6_bpf_srh_state *srh_state =
491                 this_cpu_ptr(&seg6_bpf_srh_states);
492         struct ipv6_sr_hdr *srh;
493         int ret;
494
495         srh = get_and_validate_srh(skb);
496         if (!srh) {
497                 kfree_skb(skb);
498                 return -EINVAL;
499         }
500         advance_nextseg(srh, &ipv6_hdr(skb)->daddr);
501
502         /* preempt_disable is needed to protect the per-CPU buffer srh_state,
503          * which is also accessed by the bpf_lwt_seg6_* helpers
504          */
505         preempt_disable();
506         srh_state->srh = srh;
507         srh_state->hdrlen = srh->hdrlen << 3;
508         srh_state->valid = true;
509
510         rcu_read_lock();
511         bpf_compute_data_pointers(skb);
512         ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb);
513         rcu_read_unlock();
514
515         switch (ret) {
516         case BPF_OK:
517         case BPF_REDIRECT:
518                 break;
519         case BPF_DROP:
520                 goto drop;
521         default:
522                 pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret);
523                 goto drop;
524         }
525
526         if (srh_state->srh && !seg6_bpf_has_valid_srh(skb))
527                 goto drop;
528
529         preempt_enable();
530         if (ret != BPF_REDIRECT)
531                 seg6_lookup_nexthop(skb, NULL, 0);
532
533         return dst_input(skb);
534
535 drop:
536         preempt_enable();
537         kfree_skb(skb);
538         return -EINVAL;
539 }
540
541 static struct seg6_action_desc seg6_action_table[] = {
542         {
543                 .action         = SEG6_LOCAL_ACTION_END,
544                 .attrs          = 0,
545                 .input          = input_action_end,
546         },
547         {
548                 .action         = SEG6_LOCAL_ACTION_END_X,
549                 .attrs          = (1 << SEG6_LOCAL_NH6),
550                 .input          = input_action_end_x,
551         },
552         {
553                 .action         = SEG6_LOCAL_ACTION_END_T,
554                 .attrs          = (1 << SEG6_LOCAL_TABLE),
555                 .input          = input_action_end_t,
556         },
557         {
558                 .action         = SEG6_LOCAL_ACTION_END_DX2,
559                 .attrs          = (1 << SEG6_LOCAL_OIF),
560                 .input          = input_action_end_dx2,
561         },
562         {
563                 .action         = SEG6_LOCAL_ACTION_END_DX6,
564                 .attrs          = (1 << SEG6_LOCAL_NH6),
565                 .input          = input_action_end_dx6,
566         },
567         {
568                 .action         = SEG6_LOCAL_ACTION_END_DX4,
569                 .attrs          = (1 << SEG6_LOCAL_NH4),
570                 .input          = input_action_end_dx4,
571         },
572         {
573                 .action         = SEG6_LOCAL_ACTION_END_DT6,
574                 .attrs          = (1 << SEG6_LOCAL_TABLE),
575                 .input          = input_action_end_dt6,
576         },
577         {
578                 .action         = SEG6_LOCAL_ACTION_END_B6,
579                 .attrs          = (1 << SEG6_LOCAL_SRH),
580                 .input          = input_action_end_b6,
581         },
582         {
583                 .action         = SEG6_LOCAL_ACTION_END_B6_ENCAP,
584                 .attrs          = (1 << SEG6_LOCAL_SRH),
585                 .input          = input_action_end_b6_encap,
586                 .static_headroom        = sizeof(struct ipv6hdr),
587         },
588         {
589                 .action         = SEG6_LOCAL_ACTION_END_BPF,
590                 .attrs          = (1 << SEG6_LOCAL_BPF),
591                 .input          = input_action_end_bpf,
592         },
593
594 };
595
596 static struct seg6_action_desc *__get_action_desc(int action)
597 {
598         struct seg6_action_desc *desc;
599         int i, count;
600
601         count = ARRAY_SIZE(seg6_action_table);
602         for (i = 0; i < count; i++) {
603                 desc = &seg6_action_table[i];
604                 if (desc->action == action)
605                         return desc;
606         }
607
608         return NULL;
609 }
610
611 static int seg6_local_input(struct sk_buff *skb)
612 {
613         struct dst_entry *orig_dst = skb_dst(skb);
614         struct seg6_action_desc *desc;
615         struct seg6_local_lwt *slwt;
616
617         if (skb->protocol != htons(ETH_P_IPV6)) {
618                 kfree_skb(skb);
619                 return -EINVAL;
620         }
621
622         slwt = seg6_local_lwtunnel(orig_dst->lwtstate);
623         desc = slwt->desc;
624
625         return desc->input(skb, slwt);
626 }
627
628 static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = {
629         [SEG6_LOCAL_ACTION]     = { .type = NLA_U32 },
630         [SEG6_LOCAL_SRH]        = { .type = NLA_BINARY },
631         [SEG6_LOCAL_TABLE]      = { .type = NLA_U32 },
632         [SEG6_LOCAL_NH4]        = { .type = NLA_BINARY,
633                                     .len = sizeof(struct in_addr) },
634         [SEG6_LOCAL_NH6]        = { .type = NLA_BINARY,
635                                     .len = sizeof(struct in6_addr) },
636         [SEG6_LOCAL_IIF]        = { .type = NLA_U32 },
637         [SEG6_LOCAL_OIF]        = { .type = NLA_U32 },
638         [SEG6_LOCAL_BPF]        = { .type = NLA_NESTED },
639 };
640
641 static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt)
642 {
643         struct ipv6_sr_hdr *srh;
644         int len;
645
646         srh = nla_data(attrs[SEG6_LOCAL_SRH]);
647         len = nla_len(attrs[SEG6_LOCAL_SRH]);
648
649         /* SRH must contain at least one segment */
650         if (len < sizeof(*srh) + sizeof(struct in6_addr))
651                 return -EINVAL;
652
653         if (!seg6_validate_srh(srh, len))
654                 return -EINVAL;
655
656         slwt->srh = kmemdup(srh, len, GFP_KERNEL);
657         if (!slwt->srh)
658                 return -ENOMEM;
659
660         slwt->headroom += len;
661
662         return 0;
663 }
664
665 static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt)
666 {
667         struct ipv6_sr_hdr *srh;
668         struct nlattr *nla;
669         int len;
670
671         srh = slwt->srh;
672         len = (srh->hdrlen + 1) << 3;
673
674         nla = nla_reserve(skb, SEG6_LOCAL_SRH, len);
675         if (!nla)
676                 return -EMSGSIZE;
677
678         memcpy(nla_data(nla), srh, len);
679
680         return 0;
681 }
682
683 static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
684 {
685         int len = (a->srh->hdrlen + 1) << 3;
686
687         if (len != ((b->srh->hdrlen + 1) << 3))
688                 return 1;
689
690         return memcmp(a->srh, b->srh, len);
691 }
692
693 static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt)
694 {
695         slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]);
696
697         return 0;
698 }
699
700 static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt)
701 {
702         if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table))
703                 return -EMSGSIZE;
704
705         return 0;
706 }
707
708 static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
709 {
710         if (a->table != b->table)
711                 return 1;
712
713         return 0;
714 }
715
716 static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt)
717 {
718         memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]),
719                sizeof(struct in_addr));
720
721         return 0;
722 }
723
724 static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt)
725 {
726         struct nlattr *nla;
727
728         nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr));
729         if (!nla)
730                 return -EMSGSIZE;
731
732         memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr));
733
734         return 0;
735 }
736
737 static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
738 {
739         return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr));
740 }
741
742 static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt)
743 {
744         memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]),
745                sizeof(struct in6_addr));
746
747         return 0;
748 }
749
750 static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt)
751 {
752         struct nlattr *nla;
753
754         nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr));
755         if (!nla)
756                 return -EMSGSIZE;
757
758         memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr));
759
760         return 0;
761 }
762
763 static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
764 {
765         return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr));
766 }
767
768 static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
769 {
770         slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]);
771
772         return 0;
773 }
774
775 static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
776 {
777         if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif))
778                 return -EMSGSIZE;
779
780         return 0;
781 }
782
783 static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
784 {
785         if (a->iif != b->iif)
786                 return 1;
787
788         return 0;
789 }
790
791 static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt)
792 {
793         slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]);
794
795         return 0;
796 }
797
798 static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt)
799 {
800         if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif))
801                 return -EMSGSIZE;
802
803         return 0;
804 }
805
806 static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
807 {
808         if (a->oif != b->oif)
809                 return 1;
810
811         return 0;
812 }
813
814 #define MAX_PROG_NAME 256
815 static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = {
816         [SEG6_LOCAL_BPF_PROG]      = { .type = NLA_U32, },
817         [SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING,
818                                        .len = MAX_PROG_NAME },
819 };
820
821 static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt)
822 {
823         struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1];
824         struct bpf_prog *p;
825         int ret;
826         u32 fd;
827
828         ret = nla_parse_nested(tb, SEG6_LOCAL_BPF_PROG_MAX,
829                                attrs[SEG6_LOCAL_BPF], bpf_prog_policy, NULL);
830         if (ret < 0)
831                 return ret;
832
833         if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME])
834                 return -EINVAL;
835
836         slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL);
837         if (!slwt->bpf.name)
838                 return -ENOMEM;
839
840         fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]);
841         p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL);
842         if (IS_ERR(p)) {
843                 kfree(slwt->bpf.name);
844                 return PTR_ERR(p);
845         }
846
847         slwt->bpf.prog = p;
848         return 0;
849 }
850
851 static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt)
852 {
853         struct nlattr *nest;
854
855         if (!slwt->bpf.prog)
856                 return 0;
857
858         nest = nla_nest_start(skb, SEG6_LOCAL_BPF);
859         if (!nest)
860                 return -EMSGSIZE;
861
862         if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id))
863                 return -EMSGSIZE;
864
865         if (slwt->bpf.name &&
866             nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name))
867                 return -EMSGSIZE;
868
869         return nla_nest_end(skb, nest);
870 }
871
872 static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b)
873 {
874         if (!a->bpf.name && !b->bpf.name)
875                 return 0;
876
877         if (!a->bpf.name || !b->bpf.name)
878                 return 1;
879
880         return strcmp(a->bpf.name, b->bpf.name);
881 }
882
883 struct seg6_action_param {
884         int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt);
885         int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt);
886         int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b);
887 };
888
889 static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = {
890         [SEG6_LOCAL_SRH]        = { .parse = parse_nla_srh,
891                                     .put = put_nla_srh,
892                                     .cmp = cmp_nla_srh },
893
894         [SEG6_LOCAL_TABLE]      = { .parse = parse_nla_table,
895                                     .put = put_nla_table,
896                                     .cmp = cmp_nla_table },
897
898         [SEG6_LOCAL_NH4]        = { .parse = parse_nla_nh4,
899                                     .put = put_nla_nh4,
900                                     .cmp = cmp_nla_nh4 },
901
902         [SEG6_LOCAL_NH6]        = { .parse = parse_nla_nh6,
903                                     .put = put_nla_nh6,
904                                     .cmp = cmp_nla_nh6 },
905
906         [SEG6_LOCAL_IIF]        = { .parse = parse_nla_iif,
907                                     .put = put_nla_iif,
908                                     .cmp = cmp_nla_iif },
909
910         [SEG6_LOCAL_OIF]        = { .parse = parse_nla_oif,
911                                     .put = put_nla_oif,
912                                     .cmp = cmp_nla_oif },
913
914         [SEG6_LOCAL_BPF]        = { .parse = parse_nla_bpf,
915                                     .put = put_nla_bpf,
916                                     .cmp = cmp_nla_bpf },
917
918 };
919
920 static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt)
921 {
922         struct seg6_action_param *param;
923         struct seg6_action_desc *desc;
924         int i, err;
925
926         desc = __get_action_desc(slwt->action);
927         if (!desc)
928                 return -EINVAL;
929
930         if (!desc->input)
931                 return -EOPNOTSUPP;
932
933         slwt->desc = desc;
934         slwt->headroom += desc->static_headroom;
935
936         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
937                 if (desc->attrs & (1 << i)) {
938                         if (!attrs[i])
939                                 return -EINVAL;
940
941                         param = &seg6_action_params[i];
942
943                         err = param->parse(attrs, slwt);
944                         if (err < 0)
945                                 return err;
946                 }
947         }
948
949         return 0;
950 }
951
952 static int seg6_local_build_state(struct nlattr *nla, unsigned int family,
953                                   const void *cfg, struct lwtunnel_state **ts,
954                                   struct netlink_ext_ack *extack)
955 {
956         struct nlattr *tb[SEG6_LOCAL_MAX + 1];
957         struct lwtunnel_state *newts;
958         struct seg6_local_lwt *slwt;
959         int err;
960
961         if (family != AF_INET6)
962                 return -EINVAL;
963
964         err = nla_parse_nested(tb, SEG6_LOCAL_MAX, nla, seg6_local_policy,
965                                extack);
966
967         if (err < 0)
968                 return err;
969
970         if (!tb[SEG6_LOCAL_ACTION])
971                 return -EINVAL;
972
973         newts = lwtunnel_state_alloc(sizeof(*slwt));
974         if (!newts)
975                 return -ENOMEM;
976
977         slwt = seg6_local_lwtunnel(newts);
978         slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]);
979
980         err = parse_nla_action(tb, slwt);
981         if (err < 0)
982                 goto out_free;
983
984         newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL;
985         newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT;
986         newts->headroom = slwt->headroom;
987
988         *ts = newts;
989
990         return 0;
991
992 out_free:
993         kfree(slwt->srh);
994         kfree(newts);
995         return err;
996 }
997
998 static void seg6_local_destroy_state(struct lwtunnel_state *lwt)
999 {
1000         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1001
1002         kfree(slwt->srh);
1003
1004         if (slwt->desc->attrs & (1 << SEG6_LOCAL_BPF)) {
1005                 kfree(slwt->bpf.name);
1006                 bpf_prog_put(slwt->bpf.prog);
1007         }
1008
1009         return;
1010 }
1011
1012 static int seg6_local_fill_encap(struct sk_buff *skb,
1013                                  struct lwtunnel_state *lwt)
1014 {
1015         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1016         struct seg6_action_param *param;
1017         int i, err;
1018
1019         if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action))
1020                 return -EMSGSIZE;
1021
1022         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1023                 if (slwt->desc->attrs & (1 << i)) {
1024                         param = &seg6_action_params[i];
1025                         err = param->put(skb, slwt);
1026                         if (err < 0)
1027                                 return err;
1028                 }
1029         }
1030
1031         return 0;
1032 }
1033
1034 static int seg6_local_get_encap_size(struct lwtunnel_state *lwt)
1035 {
1036         struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt);
1037         unsigned long attrs;
1038         int nlsize;
1039
1040         nlsize = nla_total_size(4); /* action */
1041
1042         attrs = slwt->desc->attrs;
1043
1044         if (attrs & (1 << SEG6_LOCAL_SRH))
1045                 nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3);
1046
1047         if (attrs & (1 << SEG6_LOCAL_TABLE))
1048                 nlsize += nla_total_size(4);
1049
1050         if (attrs & (1 << SEG6_LOCAL_NH4))
1051                 nlsize += nla_total_size(4);
1052
1053         if (attrs & (1 << SEG6_LOCAL_NH6))
1054                 nlsize += nla_total_size(16);
1055
1056         if (attrs & (1 << SEG6_LOCAL_IIF))
1057                 nlsize += nla_total_size(4);
1058
1059         if (attrs & (1 << SEG6_LOCAL_OIF))
1060                 nlsize += nla_total_size(4);
1061
1062         if (attrs & (1 << SEG6_LOCAL_BPF))
1063                 nlsize += nla_total_size(sizeof(struct nlattr)) +
1064                        nla_total_size(MAX_PROG_NAME) +
1065                        nla_total_size(4);
1066
1067         return nlsize;
1068 }
1069
1070 static int seg6_local_cmp_encap(struct lwtunnel_state *a,
1071                                 struct lwtunnel_state *b)
1072 {
1073         struct seg6_local_lwt *slwt_a, *slwt_b;
1074         struct seg6_action_param *param;
1075         int i;
1076
1077         slwt_a = seg6_local_lwtunnel(a);
1078         slwt_b = seg6_local_lwtunnel(b);
1079
1080         if (slwt_a->action != slwt_b->action)
1081                 return 1;
1082
1083         if (slwt_a->desc->attrs != slwt_b->desc->attrs)
1084                 return 1;
1085
1086         for (i = 0; i < SEG6_LOCAL_MAX + 1; i++) {
1087                 if (slwt_a->desc->attrs & (1 << i)) {
1088                         param = &seg6_action_params[i];
1089                         if (param->cmp(slwt_a, slwt_b))
1090                                 return 1;
1091                 }
1092         }
1093
1094         return 0;
1095 }
1096
1097 static const struct lwtunnel_encap_ops seg6_local_ops = {
1098         .build_state    = seg6_local_build_state,
1099         .destroy_state  = seg6_local_destroy_state,
1100         .input          = seg6_local_input,
1101         .fill_encap     = seg6_local_fill_encap,
1102         .get_encap_size = seg6_local_get_encap_size,
1103         .cmp_encap      = seg6_local_cmp_encap,
1104         .owner          = THIS_MODULE,
1105 };
1106
1107 int __init seg6_local_init(void)
1108 {
1109         return lwtunnel_encap_add_ops(&seg6_local_ops,
1110                                       LWTUNNEL_ENCAP_SEG6_LOCAL);
1111 }
1112
1113 void seg6_local_exit(void)
1114 {
1115         lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL);
1116 }