GNU Linux-libre 4.9.304-gnu1
[releases.git] / net / ipv4 / fou.c
1 #include <linux/module.h>
2 #include <linux/errno.h>
3 #include <linux/socket.h>
4 #include <linux/skbuff.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <net/genetlink.h>
10 #include <net/gue.h>
11 #include <net/ip.h>
12 #include <net/protocol.h>
13 #include <net/udp.h>
14 #include <net/udp_tunnel.h>
15 #include <net/xfrm.h>
16 #include <uapi/linux/fou.h>
17 #include <uapi/linux/genetlink.h>
18
19 struct fou {
20         struct socket *sock;
21         u8 protocol;
22         u8 flags;
23         __be16 port;
24         u8 family;
25         u16 type;
26         struct list_head list;
27         struct rcu_head rcu;
28 };
29
30 #define FOU_F_REMCSUM_NOPARTIAL BIT(0)
31
32 struct fou_cfg {
33         u16 type;
34         u8 protocol;
35         u8 flags;
36         struct udp_port_cfg udp_config;
37 };
38
39 static unsigned int fou_net_id;
40
41 struct fou_net {
42         struct list_head fou_list;
43         struct mutex fou_lock;
44 };
45
46 static inline struct fou *fou_from_sock(struct sock *sk)
47 {
48         return sk->sk_user_data;
49 }
50
51 static int fou_recv_pull(struct sk_buff *skb, struct fou *fou, size_t len)
52 {
53         /* Remove 'len' bytes from the packet (UDP header and
54          * FOU header if present).
55          */
56         if (fou->family == AF_INET)
57                 ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
58         else
59                 ipv6_hdr(skb)->payload_len =
60                     htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
61
62         __skb_pull(skb, len);
63         skb_postpull_rcsum(skb, udp_hdr(skb), len);
64         skb_reset_transport_header(skb);
65         return iptunnel_pull_offloads(skb);
66 }
67
68 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
69 {
70         struct fou *fou = fou_from_sock(sk);
71
72         if (!fou)
73                 return 1;
74
75         if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
76                 goto drop;
77
78         return -fou->protocol;
79
80 drop:
81         kfree_skb(skb);
82         return 0;
83 }
84
85 static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
86                                   void *data, size_t hdrlen, u8 ipproto,
87                                   bool nopartial)
88 {
89         __be16 *pd = data;
90         size_t start = ntohs(pd[0]);
91         size_t offset = ntohs(pd[1]);
92         size_t plen = sizeof(struct udphdr) + hdrlen +
93             max_t(size_t, offset + sizeof(u16), start);
94
95         if (skb->remcsum_offload)
96                 return guehdr;
97
98         if (!pskb_may_pull(skb, plen))
99                 return NULL;
100         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
101
102         skb_remcsum_process(skb, (void *)guehdr + hdrlen,
103                             start, offset, nopartial);
104
105         return guehdr;
106 }
107
108 static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
109 {
110         /* No support yet */
111         kfree_skb(skb);
112         return 0;
113 }
114
115 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
116 {
117         struct fou *fou = fou_from_sock(sk);
118         size_t len, optlen, hdrlen;
119         struct guehdr *guehdr;
120         void *data;
121         u16 doffset = 0;
122         u8 proto_ctype;
123
124         if (!fou)
125                 return 1;
126
127         len = sizeof(struct udphdr) + sizeof(struct guehdr);
128         if (!pskb_may_pull(skb, len))
129                 goto drop;
130
131         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
132
133         switch (guehdr->version) {
134         case 0: /* Full GUE header present */
135                 break;
136
137         case 1: {
138                 /* Direct encasulation of IPv4 or IPv6 */
139
140                 int prot;
141
142                 switch (((struct iphdr *)guehdr)->version) {
143                 case 4:
144                         prot = IPPROTO_IPIP;
145                         break;
146                 case 6:
147                         prot = IPPROTO_IPV6;
148                         break;
149                 default:
150                         goto drop;
151                 }
152
153                 if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
154                         goto drop;
155
156                 return -prot;
157         }
158
159         default: /* Undefined version */
160                 goto drop;
161         }
162
163         optlen = guehdr->hlen << 2;
164         len += optlen;
165
166         if (!pskb_may_pull(skb, len))
167                 goto drop;
168
169         /* guehdr may change after pull */
170         guehdr = (struct guehdr *)&udp_hdr(skb)[1];
171
172         hdrlen = sizeof(struct guehdr) + optlen;
173
174         if (guehdr->version != 0 || validate_gue_flags(guehdr, optlen))
175                 goto drop;
176
177         hdrlen = sizeof(struct guehdr) + optlen;
178
179         if (fou->family == AF_INET)
180                 ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
181         else
182                 ipv6_hdr(skb)->payload_len =
183                     htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
184
185         /* Pull csum through the guehdr now . This can be used if
186          * there is a remote checksum offload.
187          */
188         skb_postpull_rcsum(skb, udp_hdr(skb), len);
189
190         data = &guehdr[1];
191
192         if (guehdr->flags & GUE_FLAG_PRIV) {
193                 __be32 flags = *(__be32 *)(data + doffset);
194
195                 doffset += GUE_LEN_PRIV;
196
197                 if (flags & GUE_PFLAG_REMCSUM) {
198                         guehdr = gue_remcsum(skb, guehdr, data + doffset,
199                                              hdrlen, guehdr->proto_ctype,
200                                              !!(fou->flags &
201                                                 FOU_F_REMCSUM_NOPARTIAL));
202                         if (!guehdr)
203                                 goto drop;
204
205                         data = &guehdr[1];
206
207                         doffset += GUE_PLEN_REMCSUM;
208                 }
209         }
210
211         if (unlikely(guehdr->control))
212                 return gue_control_message(skb, guehdr);
213
214         proto_ctype = guehdr->proto_ctype;
215         __skb_pull(skb, sizeof(struct udphdr) + hdrlen);
216         skb_reset_transport_header(skb);
217
218         if (iptunnel_pull_offloads(skb))
219                 goto drop;
220
221         return -proto_ctype;
222
223 drop:
224         kfree_skb(skb);
225         return 0;
226 }
227
228 static struct sk_buff **fou_gro_receive(struct sock *sk,
229                                         struct sk_buff **head,
230                                         struct sk_buff *skb)
231 {
232         const struct net_offload *ops;
233         struct sk_buff **pp = NULL;
234         u8 proto = fou_from_sock(sk)->protocol;
235         const struct net_offload **offloads;
236
237         /* We can clear the encap_mark for FOU as we are essentially doing
238          * one of two possible things.  We are either adding an L4 tunnel
239          * header to the outer L3 tunnel header, or we are are simply
240          * treating the GRE tunnel header as though it is a UDP protocol
241          * specific header such as VXLAN or GENEVE.
242          */
243         NAPI_GRO_CB(skb)->encap_mark = 0;
244
245         /* Flag this frame as already having an outer encap header */
246         NAPI_GRO_CB(skb)->is_fou = 1;
247
248         rcu_read_lock();
249         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
250         ops = rcu_dereference(offloads[proto]);
251         if (!ops || !ops->callbacks.gro_receive)
252                 goto out_unlock;
253
254         pp = call_gro_receive(ops->callbacks.gro_receive, head, skb);
255
256 out_unlock:
257         rcu_read_unlock();
258
259         return pp;
260 }
261
262 static int fou_gro_complete(struct sock *sk, struct sk_buff *skb,
263                             int nhoff)
264 {
265         const struct net_offload *ops;
266         u8 proto = fou_from_sock(sk)->protocol;
267         int err = -ENOSYS;
268         const struct net_offload **offloads;
269
270         rcu_read_lock();
271         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
272         ops = rcu_dereference(offloads[proto]);
273         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
274                 goto out_unlock;
275
276         err = ops->callbacks.gro_complete(skb, nhoff);
277
278         skb_set_inner_mac_header(skb, nhoff);
279
280 out_unlock:
281         rcu_read_unlock();
282
283         return err;
284 }
285
286 static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
287                                       struct guehdr *guehdr, void *data,
288                                       size_t hdrlen, struct gro_remcsum *grc,
289                                       bool nopartial)
290 {
291         __be16 *pd = data;
292         size_t start = ntohs(pd[0]);
293         size_t offset = ntohs(pd[1]);
294
295         if (skb->remcsum_offload)
296                 return guehdr;
297
298         if (!NAPI_GRO_CB(skb)->csum_valid)
299                 return NULL;
300
301         guehdr = skb_gro_remcsum_process(skb, (void *)guehdr, off, hdrlen,
302                                          start, offset, grc, nopartial);
303
304         skb->remcsum_offload = 1;
305
306         return guehdr;
307 }
308
309 static struct sk_buff **gue_gro_receive(struct sock *sk,
310                                         struct sk_buff **head,
311                                         struct sk_buff *skb)
312 {
313         const struct net_offload **offloads;
314         const struct net_offload *ops;
315         struct sk_buff **pp = NULL;
316         struct sk_buff *p;
317         struct guehdr *guehdr;
318         size_t len, optlen, hdrlen, off;
319         void *data;
320         u16 doffset = 0;
321         int flush = 1;
322         struct fou *fou = fou_from_sock(sk);
323         struct gro_remcsum grc;
324         u8 proto;
325
326         skb_gro_remcsum_init(&grc);
327
328         off = skb_gro_offset(skb);
329         len = off + sizeof(*guehdr);
330
331         guehdr = skb_gro_header_fast(skb, off);
332         if (skb_gro_header_hard(skb, len)) {
333                 guehdr = skb_gro_header_slow(skb, len, off);
334                 if (unlikely(!guehdr))
335                         goto out;
336         }
337
338         switch (guehdr->version) {
339         case 0:
340                 break;
341         case 1:
342                 switch (((struct iphdr *)guehdr)->version) {
343                 case 4:
344                         proto = IPPROTO_IPIP;
345                         break;
346                 case 6:
347                         proto = IPPROTO_IPV6;
348                         break;
349                 default:
350                         goto out;
351                 }
352                 goto next_proto;
353         default:
354                 goto out;
355         }
356
357         optlen = guehdr->hlen << 2;
358         len += optlen;
359
360         if (skb_gro_header_hard(skb, len)) {
361                 guehdr = skb_gro_header_slow(skb, len, off);
362                 if (unlikely(!guehdr))
363                         goto out;
364         }
365
366         if (unlikely(guehdr->control) || guehdr->version != 0 ||
367             validate_gue_flags(guehdr, optlen))
368                 goto out;
369
370         hdrlen = sizeof(*guehdr) + optlen;
371
372         /* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
373          * this is needed if there is a remote checkcsum offload.
374          */
375         skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
376
377         data = &guehdr[1];
378
379         if (guehdr->flags & GUE_FLAG_PRIV) {
380                 __be32 flags = *(__be32 *)(data + doffset);
381
382                 doffset += GUE_LEN_PRIV;
383
384                 if (flags & GUE_PFLAG_REMCSUM) {
385                         guehdr = gue_gro_remcsum(skb, off, guehdr,
386                                                  data + doffset, hdrlen, &grc,
387                                                  !!(fou->flags &
388                                                     FOU_F_REMCSUM_NOPARTIAL));
389
390                         if (!guehdr)
391                                 goto out;
392
393                         data = &guehdr[1];
394
395                         doffset += GUE_PLEN_REMCSUM;
396                 }
397         }
398
399         skb_gro_pull(skb, hdrlen);
400
401         for (p = *head; p; p = p->next) {
402                 const struct guehdr *guehdr2;
403
404                 if (!NAPI_GRO_CB(p)->same_flow)
405                         continue;
406
407                 guehdr2 = (struct guehdr *)(p->data + off);
408
409                 /* Compare base GUE header to be equal (covers
410                  * hlen, version, proto_ctype, and flags.
411                  */
412                 if (guehdr->word != guehdr2->word) {
413                         NAPI_GRO_CB(p)->same_flow = 0;
414                         continue;
415                 }
416
417                 /* Compare optional fields are the same. */
418                 if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
419                                            guehdr->hlen << 2)) {
420                         NAPI_GRO_CB(p)->same_flow = 0;
421                         continue;
422                 }
423         }
424
425         proto = guehdr->proto_ctype;
426
427 next_proto:
428
429         /* We can clear the encap_mark for GUE as we are essentially doing
430          * one of two possible things.  We are either adding an L4 tunnel
431          * header to the outer L3 tunnel header, or we are are simply
432          * treating the GRE tunnel header as though it is a UDP protocol
433          * specific header such as VXLAN or GENEVE.
434          */
435         NAPI_GRO_CB(skb)->encap_mark = 0;
436
437         /* Flag this frame as already having an outer encap header */
438         NAPI_GRO_CB(skb)->is_fou = 1;
439
440         rcu_read_lock();
441         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
442         ops = rcu_dereference(offloads[proto]);
443         if (WARN_ON_ONCE(!ops || !ops->callbacks.gro_receive))
444                 goto out_unlock;
445
446         pp = call_gro_receive(ops->callbacks.gro_receive, head, skb);
447         flush = 0;
448
449 out_unlock:
450         rcu_read_unlock();
451 out:
452         NAPI_GRO_CB(skb)->flush |= flush;
453         skb_gro_remcsum_cleanup(skb, &grc);
454
455         return pp;
456 }
457
458 static int gue_gro_complete(struct sock *sk, struct sk_buff *skb, int nhoff)
459 {
460         const struct net_offload **offloads;
461         struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
462         const struct net_offload *ops;
463         unsigned int guehlen = 0;
464         u8 proto;
465         int err = -ENOENT;
466
467         switch (guehdr->version) {
468         case 0:
469                 proto = guehdr->proto_ctype;
470                 guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
471                 break;
472         case 1:
473                 switch (((struct iphdr *)guehdr)->version) {
474                 case 4:
475                         proto = IPPROTO_IPIP;
476                         break;
477                 case 6:
478                         proto = IPPROTO_IPV6;
479                         break;
480                 default:
481                         return err;
482                 }
483                 break;
484         default:
485                 return err;
486         }
487
488         rcu_read_lock();
489         offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
490         ops = rcu_dereference(offloads[proto]);
491         if (WARN_ON(!ops || !ops->callbacks.gro_complete))
492                 goto out_unlock;
493
494         err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
495
496         skb_set_inner_mac_header(skb, nhoff + guehlen);
497
498 out_unlock:
499         rcu_read_unlock();
500         return err;
501 }
502
503 static int fou_add_to_port_list(struct net *net, struct fou *fou)
504 {
505         struct fou_net *fn = net_generic(net, fou_net_id);
506         struct fou *fout;
507
508         mutex_lock(&fn->fou_lock);
509         list_for_each_entry(fout, &fn->fou_list, list) {
510                 if (fou->port == fout->port &&
511                     fou->family == fout->family) {
512                         mutex_unlock(&fn->fou_lock);
513                         return -EALREADY;
514                 }
515         }
516
517         list_add(&fou->list, &fn->fou_list);
518         mutex_unlock(&fn->fou_lock);
519
520         return 0;
521 }
522
523 static void fou_release(struct fou *fou)
524 {
525         struct socket *sock = fou->sock;
526
527         list_del(&fou->list);
528         udp_tunnel_sock_release(sock);
529
530         kfree_rcu(fou, rcu);
531 }
532
533 static int fou_create(struct net *net, struct fou_cfg *cfg,
534                       struct socket **sockp)
535 {
536         struct socket *sock = NULL;
537         struct fou *fou = NULL;
538         struct sock *sk;
539         struct udp_tunnel_sock_cfg tunnel_cfg;
540         int err;
541
542         /* Open UDP socket */
543         err = udp_sock_create(net, &cfg->udp_config, &sock);
544         if (err < 0)
545                 goto error;
546
547         /* Allocate FOU port structure */
548         fou = kzalloc(sizeof(*fou), GFP_KERNEL);
549         if (!fou) {
550                 err = -ENOMEM;
551                 goto error;
552         }
553
554         sk = sock->sk;
555
556         fou->port = cfg->udp_config.local_udp_port;
557         fou->family = cfg->udp_config.family;
558         fou->flags = cfg->flags;
559         fou->type = cfg->type;
560         fou->sock = sock;
561
562         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
563         tunnel_cfg.encap_type = 1;
564         tunnel_cfg.sk_user_data = fou;
565         tunnel_cfg.encap_destroy = NULL;
566
567         /* Initial for fou type */
568         switch (cfg->type) {
569         case FOU_ENCAP_DIRECT:
570                 tunnel_cfg.encap_rcv = fou_udp_recv;
571                 tunnel_cfg.gro_receive = fou_gro_receive;
572                 tunnel_cfg.gro_complete = fou_gro_complete;
573                 fou->protocol = cfg->protocol;
574                 break;
575         case FOU_ENCAP_GUE:
576                 tunnel_cfg.encap_rcv = gue_udp_recv;
577                 tunnel_cfg.gro_receive = gue_gro_receive;
578                 tunnel_cfg.gro_complete = gue_gro_complete;
579                 break;
580         default:
581                 err = -EINVAL;
582                 goto error;
583         }
584
585         setup_udp_tunnel_sock(net, sock, &tunnel_cfg);
586
587         sk->sk_allocation = GFP_ATOMIC;
588
589         err = fou_add_to_port_list(net, fou);
590         if (err)
591                 goto error;
592
593         if (sockp)
594                 *sockp = sock;
595
596         return 0;
597
598 error:
599         kfree(fou);
600         if (sock)
601                 udp_tunnel_sock_release(sock);
602
603         return err;
604 }
605
606 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
607 {
608         struct fou_net *fn = net_generic(net, fou_net_id);
609         __be16 port = cfg->udp_config.local_udp_port;
610         u8 family = cfg->udp_config.family;
611         int err = -EINVAL;
612         struct fou *fou;
613
614         mutex_lock(&fn->fou_lock);
615         list_for_each_entry(fou, &fn->fou_list, list) {
616                 if (fou->port == port && fou->family == family) {
617                         fou_release(fou);
618                         err = 0;
619                         break;
620                 }
621         }
622         mutex_unlock(&fn->fou_lock);
623
624         return err;
625 }
626
627 static struct genl_family fou_nl_family = {
628         .id             = GENL_ID_GENERATE,
629         .hdrsize        = 0,
630         .name           = FOU_GENL_NAME,
631         .version        = FOU_GENL_VERSION,
632         .maxattr        = FOU_ATTR_MAX,
633         .netnsok        = true,
634 };
635
636 static const struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = {
637         [FOU_ATTR_PORT] = { .type = NLA_U16, },
638         [FOU_ATTR_AF] = { .type = NLA_U8, },
639         [FOU_ATTR_IPPROTO] = { .type = NLA_U8, },
640         [FOU_ATTR_TYPE] = { .type = NLA_U8, },
641         [FOU_ATTR_REMCSUM_NOPARTIAL] = { .type = NLA_FLAG, },
642 };
643
644 static int parse_nl_config(struct genl_info *info,
645                            struct fou_cfg *cfg)
646 {
647         memset(cfg, 0, sizeof(*cfg));
648
649         cfg->udp_config.family = AF_INET;
650
651         if (info->attrs[FOU_ATTR_AF]) {
652                 u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
653
654                 switch (family) {
655                 case AF_INET:
656                         break;
657                 case AF_INET6:
658                         cfg->udp_config.ipv6_v6only = 1;
659                         break;
660                 default:
661                         return -EAFNOSUPPORT;
662                 }
663
664                 cfg->udp_config.family = family;
665         }
666
667         if (info->attrs[FOU_ATTR_PORT]) {
668                 __be16 port = nla_get_be16(info->attrs[FOU_ATTR_PORT]);
669
670                 cfg->udp_config.local_udp_port = port;
671         }
672
673         if (info->attrs[FOU_ATTR_IPPROTO])
674                 cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
675
676         if (info->attrs[FOU_ATTR_TYPE])
677                 cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
678
679         if (info->attrs[FOU_ATTR_REMCSUM_NOPARTIAL])
680                 cfg->flags |= FOU_F_REMCSUM_NOPARTIAL;
681
682         return 0;
683 }
684
685 static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info)
686 {
687         struct net *net = genl_info_net(info);
688         struct fou_cfg cfg;
689         int err;
690
691         err = parse_nl_config(info, &cfg);
692         if (err)
693                 return err;
694
695         return fou_create(net, &cfg, NULL);
696 }
697
698 static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info)
699 {
700         struct net *net = genl_info_net(info);
701         struct fou_cfg cfg;
702         int err;
703
704         err = parse_nl_config(info, &cfg);
705         if (err)
706                 return err;
707
708         return fou_destroy(net, &cfg);
709 }
710
711 static int fou_fill_info(struct fou *fou, struct sk_buff *msg)
712 {
713         if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) ||
714             nla_put_be16(msg, FOU_ATTR_PORT, fou->port) ||
715             nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) ||
716             nla_put_u8(msg, FOU_ATTR_TYPE, fou->type))
717                 return -1;
718
719         if (fou->flags & FOU_F_REMCSUM_NOPARTIAL)
720                 if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL))
721                         return -1;
722         return 0;
723 }
724
725 static int fou_dump_info(struct fou *fou, u32 portid, u32 seq,
726                          u32 flags, struct sk_buff *skb, u8 cmd)
727 {
728         void *hdr;
729
730         hdr = genlmsg_put(skb, portid, seq, &fou_nl_family, flags, cmd);
731         if (!hdr)
732                 return -ENOMEM;
733
734         if (fou_fill_info(fou, skb) < 0)
735                 goto nla_put_failure;
736
737         genlmsg_end(skb, hdr);
738         return 0;
739
740 nla_put_failure:
741         genlmsg_cancel(skb, hdr);
742         return -EMSGSIZE;
743 }
744
745 static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info)
746 {
747         struct net *net = genl_info_net(info);
748         struct fou_net *fn = net_generic(net, fou_net_id);
749         struct sk_buff *msg;
750         struct fou_cfg cfg;
751         struct fou *fout;
752         __be16 port;
753         u8 family;
754         int ret;
755
756         ret = parse_nl_config(info, &cfg);
757         if (ret)
758                 return ret;
759         port = cfg.udp_config.local_udp_port;
760         if (port == 0)
761                 return -EINVAL;
762
763         family = cfg.udp_config.family;
764         if (family != AF_INET && family != AF_INET6)
765                 return -EINVAL;
766
767         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
768         if (!msg)
769                 return -ENOMEM;
770
771         ret = -ESRCH;
772         mutex_lock(&fn->fou_lock);
773         list_for_each_entry(fout, &fn->fou_list, list) {
774                 if (port == fout->port && family == fout->family) {
775                         ret = fou_dump_info(fout, info->snd_portid,
776                                             info->snd_seq, 0, msg,
777                                             info->genlhdr->cmd);
778                         break;
779                 }
780         }
781         mutex_unlock(&fn->fou_lock);
782         if (ret < 0)
783                 goto out_free;
784
785         return genlmsg_reply(msg, info);
786
787 out_free:
788         nlmsg_free(msg);
789         return ret;
790 }
791
792 static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb)
793 {
794         struct net *net = sock_net(skb->sk);
795         struct fou_net *fn = net_generic(net, fou_net_id);
796         struct fou *fout;
797         int idx = 0, ret;
798
799         mutex_lock(&fn->fou_lock);
800         list_for_each_entry(fout, &fn->fou_list, list) {
801                 if (idx++ < cb->args[0])
802                         continue;
803                 ret = fou_dump_info(fout, NETLINK_CB(cb->skb).portid,
804                                     cb->nlh->nlmsg_seq, NLM_F_MULTI,
805                                     skb, FOU_CMD_GET);
806                 if (ret)
807                         break;
808         }
809         mutex_unlock(&fn->fou_lock);
810
811         cb->args[0] = idx;
812         return skb->len;
813 }
814
815 static const struct genl_ops fou_nl_ops[] = {
816         {
817                 .cmd = FOU_CMD_ADD,
818                 .doit = fou_nl_cmd_add_port,
819                 .policy = fou_nl_policy,
820                 .flags = GENL_ADMIN_PERM,
821         },
822         {
823                 .cmd = FOU_CMD_DEL,
824                 .doit = fou_nl_cmd_rm_port,
825                 .policy = fou_nl_policy,
826                 .flags = GENL_ADMIN_PERM,
827         },
828         {
829                 .cmd = FOU_CMD_GET,
830                 .doit = fou_nl_cmd_get_port,
831                 .dumpit = fou_nl_dump,
832                 .policy = fou_nl_policy,
833         },
834 };
835
836 size_t fou_encap_hlen(struct ip_tunnel_encap *e)
837 {
838         return sizeof(struct udphdr);
839 }
840 EXPORT_SYMBOL(fou_encap_hlen);
841
842 size_t gue_encap_hlen(struct ip_tunnel_encap *e)
843 {
844         size_t len;
845         bool need_priv = false;
846
847         len = sizeof(struct udphdr) + sizeof(struct guehdr);
848
849         if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
850                 len += GUE_PLEN_REMCSUM;
851                 need_priv = true;
852         }
853
854         len += need_priv ? GUE_LEN_PRIV : 0;
855
856         return len;
857 }
858 EXPORT_SYMBOL(gue_encap_hlen);
859
860 static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
861                           struct flowi4 *fl4, u8 *protocol, __be16 sport)
862 {
863         struct udphdr *uh;
864
865         skb_push(skb, sizeof(struct udphdr));
866         skb_reset_transport_header(skb);
867
868         uh = udp_hdr(skb);
869
870         uh->dest = e->dport;
871         uh->source = sport;
872         uh->len = htons(skb->len);
873         udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
874                      fl4->saddr, fl4->daddr, skb->len);
875
876         *protocol = IPPROTO_UDP;
877 }
878
879 int __fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
880                        u8 *protocol, __be16 *sport, int type)
881 {
882         int err;
883
884         err = iptunnel_handle_offloads(skb, type);
885         if (err)
886                 return err;
887
888         *sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
889                                                 skb, 0, 0, false);
890
891         return 0;
892 }
893 EXPORT_SYMBOL(__fou_build_header);
894
895 int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
896                      u8 *protocol, struct flowi4 *fl4)
897 {
898         int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
899                                                        SKB_GSO_UDP_TUNNEL;
900         __be16 sport;
901         int err;
902
903         err = __fou_build_header(skb, e, protocol, &sport, type);
904         if (err)
905                 return err;
906
907         fou_build_udp(skb, e, fl4, protocol, sport);
908
909         return 0;
910 }
911 EXPORT_SYMBOL(fou_build_header);
912
913 int __gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
914                        u8 *protocol, __be16 *sport, int type)
915 {
916         struct guehdr *guehdr;
917         size_t hdrlen, optlen = 0;
918         void *data;
919         bool need_priv = false;
920         int err;
921
922         if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
923             skb->ip_summed == CHECKSUM_PARTIAL) {
924                 optlen += GUE_PLEN_REMCSUM;
925                 type |= SKB_GSO_TUNNEL_REMCSUM;
926                 need_priv = true;
927         }
928
929         optlen += need_priv ? GUE_LEN_PRIV : 0;
930
931         err = iptunnel_handle_offloads(skb, type);
932         if (err)
933                 return err;
934
935         /* Get source port (based on flow hash) before skb_push */
936         *sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
937                                                 skb, 0, 0, false);
938
939         hdrlen = sizeof(struct guehdr) + optlen;
940
941         skb_push(skb, hdrlen);
942
943         guehdr = (struct guehdr *)skb->data;
944
945         guehdr->control = 0;
946         guehdr->version = 0;
947         guehdr->hlen = optlen >> 2;
948         guehdr->flags = 0;
949         guehdr->proto_ctype = *protocol;
950
951         data = &guehdr[1];
952
953         if (need_priv) {
954                 __be32 *flags = data;
955
956                 guehdr->flags |= GUE_FLAG_PRIV;
957                 *flags = 0;
958                 data += GUE_LEN_PRIV;
959
960                 if (type & SKB_GSO_TUNNEL_REMCSUM) {
961                         u16 csum_start = skb_checksum_start_offset(skb);
962                         __be16 *pd = data;
963
964                         if (csum_start < hdrlen)
965                                 return -EINVAL;
966
967                         csum_start -= hdrlen;
968                         pd[0] = htons(csum_start);
969                         pd[1] = htons(csum_start + skb->csum_offset);
970
971                         if (!skb_is_gso(skb)) {
972                                 skb->ip_summed = CHECKSUM_NONE;
973                                 skb->encapsulation = 0;
974                         }
975
976                         *flags |= GUE_PFLAG_REMCSUM;
977                         data += GUE_PLEN_REMCSUM;
978                 }
979
980         }
981
982         return 0;
983 }
984 EXPORT_SYMBOL(__gue_build_header);
985
986 int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
987                      u8 *protocol, struct flowi4 *fl4)
988 {
989         int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
990                                                        SKB_GSO_UDP_TUNNEL;
991         __be16 sport;
992         int err;
993
994         err = __gue_build_header(skb, e, protocol, &sport, type);
995         if (err)
996                 return err;
997
998         fou_build_udp(skb, e, fl4, protocol, sport);
999
1000         return 0;
1001 }
1002 EXPORT_SYMBOL(gue_build_header);
1003
1004 #ifdef CONFIG_NET_FOU_IP_TUNNELS
1005
1006 static const struct ip_tunnel_encap_ops fou_iptun_ops = {
1007         .encap_hlen = fou_encap_hlen,
1008         .build_header = fou_build_header,
1009 };
1010
1011 static const struct ip_tunnel_encap_ops gue_iptun_ops = {
1012         .encap_hlen = gue_encap_hlen,
1013         .build_header = gue_build_header,
1014 };
1015
1016 static int ip_tunnel_encap_add_fou_ops(void)
1017 {
1018         int ret;
1019
1020         ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
1021         if (ret < 0) {
1022                 pr_err("can't add fou ops\n");
1023                 return ret;
1024         }
1025
1026         ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
1027         if (ret < 0) {
1028                 pr_err("can't add gue ops\n");
1029                 ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
1030                 return ret;
1031         }
1032
1033         return 0;
1034 }
1035
1036 static void ip_tunnel_encap_del_fou_ops(void)
1037 {
1038         ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
1039         ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
1040 }
1041
1042 #else
1043
1044 static int ip_tunnel_encap_add_fou_ops(void)
1045 {
1046         return 0;
1047 }
1048
1049 static void ip_tunnel_encap_del_fou_ops(void)
1050 {
1051 }
1052
1053 #endif
1054
1055 static __net_init int fou_init_net(struct net *net)
1056 {
1057         struct fou_net *fn = net_generic(net, fou_net_id);
1058
1059         INIT_LIST_HEAD(&fn->fou_list);
1060         mutex_init(&fn->fou_lock);
1061         return 0;
1062 }
1063
1064 static __net_exit void fou_exit_net(struct net *net)
1065 {
1066         struct fou_net *fn = net_generic(net, fou_net_id);
1067         struct fou *fou, *next;
1068
1069         /* Close all the FOU sockets */
1070         mutex_lock(&fn->fou_lock);
1071         list_for_each_entry_safe(fou, next, &fn->fou_list, list)
1072                 fou_release(fou);
1073         mutex_unlock(&fn->fou_lock);
1074 }
1075
1076 static struct pernet_operations fou_net_ops = {
1077         .init = fou_init_net,
1078         .exit = fou_exit_net,
1079         .id   = &fou_net_id,
1080         .size = sizeof(struct fou_net),
1081 };
1082
1083 static int __init fou_init(void)
1084 {
1085         int ret;
1086
1087         ret = register_pernet_device(&fou_net_ops);
1088         if (ret)
1089                 goto exit;
1090
1091         ret = genl_register_family_with_ops(&fou_nl_family,
1092                                             fou_nl_ops);
1093         if (ret < 0)
1094                 goto unregister;
1095
1096         ret = ip_tunnel_encap_add_fou_ops();
1097         if (ret == 0)
1098                 return 0;
1099
1100         genl_unregister_family(&fou_nl_family);
1101 unregister:
1102         unregister_pernet_device(&fou_net_ops);
1103 exit:
1104         return ret;
1105 }
1106
1107 static void __exit fou_fini(void)
1108 {
1109         ip_tunnel_encap_del_fou_ops();
1110         genl_unregister_family(&fou_nl_family);
1111         unregister_pernet_device(&fou_net_ops);
1112 }
1113
1114 module_init(fou_init);
1115 module_exit(fou_fini);
1116 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
1117 MODULE_LICENSE("GPL");