GNU Linux-libre 4.9.317-gnu1
[releases.git] / net / sched / act_csum.c
1 /*
2  * Checksum updating actions
3  *
4  * Copyright (c) 2010 Gregoire Baron <baronchon@n7mm.org>
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by the Free
8  * Software Foundation; either version 2 of the License, or (at your option)
9  * any later version.
10  *
11  */
12
13 #include <linux/types.h>
14 #include <linux/init.h>
15 #include <linux/kernel.h>
16 #include <linux/module.h>
17 #include <linux/spinlock.h>
18
19 #include <linux/netlink.h>
20 #include <net/netlink.h>
21 #include <linux/rtnetlink.h>
22
23 #include <linux/skbuff.h>
24
25 #include <net/ip.h>
26 #include <net/ipv6.h>
27 #include <net/icmp.h>
28 #include <linux/icmpv6.h>
29 #include <linux/igmp.h>
30 #include <net/tcp.h>
31 #include <net/udp.h>
32 #include <net/ip6_checksum.h>
33
34 #include <net/act_api.h>
35
36 #include <linux/tc_act/tc_csum.h>
37 #include <net/tc_act/tc_csum.h>
38
39 #define CSUM_TAB_MASK 15
40
41 static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = {
42         [TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), },
43 };
44
45 static int csum_net_id;
46 static struct tc_action_ops act_csum_ops;
47
48 static int tcf_csum_init(struct net *net, struct nlattr *nla,
49                          struct nlattr *est, struct tc_action **a, int ovr,
50                          int bind)
51 {
52         struct tc_action_net *tn = net_generic(net, csum_net_id);
53         struct nlattr *tb[TCA_CSUM_MAX + 1];
54         struct tc_csum *parm;
55         struct tcf_csum *p;
56         int ret = 0, err;
57
58         if (nla == NULL)
59                 return -EINVAL;
60
61         err = nla_parse_nested(tb, TCA_CSUM_MAX, nla, csum_policy);
62         if (err < 0)
63                 return err;
64
65         if (tb[TCA_CSUM_PARMS] == NULL)
66                 return -EINVAL;
67         parm = nla_data(tb[TCA_CSUM_PARMS]);
68
69         if (!tcf_hash_check(tn, parm->index, a, bind)) {
70                 ret = tcf_hash_create(tn, parm->index, est, a,
71                                       &act_csum_ops, bind, false);
72                 if (ret)
73                         return ret;
74                 ret = ACT_P_CREATED;
75         } else {
76                 if (bind)/* dont override defaults */
77                         return 0;
78                 tcf_hash_release(*a, bind);
79                 if (!ovr)
80                         return -EEXIST;
81         }
82
83         p = to_tcf_csum(*a);
84         spin_lock_bh(&p->tcf_lock);
85         p->tcf_action = parm->action;
86         p->update_flags = parm->update_flags;
87         spin_unlock_bh(&p->tcf_lock);
88
89         if (ret == ACT_P_CREATED)
90                 tcf_hash_insert(tn, *a);
91
92         return ret;
93 }
94
95 /**
96  * tcf_csum_skb_nextlayer - Get next layer pointer
97  * @skb: sk_buff to use
98  * @ihl: previous summed headers length
99  * @ipl: complete packet length
100  * @jhl: next header length
101  *
102  * Check the expected next layer availability in the specified sk_buff.
103  * Return the next layer pointer if pass, NULL otherwise.
104  */
105 static void *tcf_csum_skb_nextlayer(struct sk_buff *skb,
106                                     unsigned int ihl, unsigned int ipl,
107                                     unsigned int jhl)
108 {
109         int ntkoff = skb_network_offset(skb);
110         int hl = ihl + jhl;
111
112         if (!pskb_may_pull(skb, ipl + ntkoff) || (ipl < hl) ||
113             skb_try_make_writable(skb, hl + ntkoff))
114                 return NULL;
115         else
116                 return (void *)(skb_network_header(skb) + ihl);
117 }
118
119 static int tcf_csum_ipv4_icmp(struct sk_buff *skb, unsigned int ihl,
120                               unsigned int ipl)
121 {
122         struct icmphdr *icmph;
123
124         icmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmph));
125         if (icmph == NULL)
126                 return 0;
127
128         icmph->checksum = 0;
129         skb->csum = csum_partial(icmph, ipl - ihl, 0);
130         icmph->checksum = csum_fold(skb->csum);
131
132         skb->ip_summed = CHECKSUM_NONE;
133
134         return 1;
135 }
136
137 static int tcf_csum_ipv4_igmp(struct sk_buff *skb,
138                               unsigned int ihl, unsigned int ipl)
139 {
140         struct igmphdr *igmph;
141
142         igmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*igmph));
143         if (igmph == NULL)
144                 return 0;
145
146         igmph->csum = 0;
147         skb->csum = csum_partial(igmph, ipl - ihl, 0);
148         igmph->csum = csum_fold(skb->csum);
149
150         skb->ip_summed = CHECKSUM_NONE;
151
152         return 1;
153 }
154
155 static int tcf_csum_ipv6_icmp(struct sk_buff *skb, unsigned int ihl,
156                               unsigned int ipl)
157 {
158         struct icmp6hdr *icmp6h;
159         const struct ipv6hdr *ip6h;
160
161         icmp6h = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmp6h));
162         if (icmp6h == NULL)
163                 return 0;
164
165         ip6h = ipv6_hdr(skb);
166         icmp6h->icmp6_cksum = 0;
167         skb->csum = csum_partial(icmp6h, ipl - ihl, 0);
168         icmp6h->icmp6_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
169                                               ipl - ihl, IPPROTO_ICMPV6,
170                                               skb->csum);
171
172         skb->ip_summed = CHECKSUM_NONE;
173
174         return 1;
175 }
176
177 static int tcf_csum_ipv4_tcp(struct sk_buff *skb, unsigned int ihl,
178                              unsigned int ipl)
179 {
180         struct tcphdr *tcph;
181         const struct iphdr *iph;
182
183         if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4)
184                 return 1;
185
186         tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph));
187         if (tcph == NULL)
188                 return 0;
189
190         iph = ip_hdr(skb);
191         tcph->check = 0;
192         skb->csum = csum_partial(tcph, ipl - ihl, 0);
193         tcph->check = tcp_v4_check(ipl - ihl,
194                                    iph->saddr, iph->daddr, skb->csum);
195
196         skb->ip_summed = CHECKSUM_NONE;
197
198         return 1;
199 }
200
201 static int tcf_csum_ipv6_tcp(struct sk_buff *skb, unsigned int ihl,
202                              unsigned int ipl)
203 {
204         struct tcphdr *tcph;
205         const struct ipv6hdr *ip6h;
206
207         if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV6)
208                 return 1;
209
210         tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph));
211         if (tcph == NULL)
212                 return 0;
213
214         ip6h = ipv6_hdr(skb);
215         tcph->check = 0;
216         skb->csum = csum_partial(tcph, ipl - ihl, 0);
217         tcph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
218                                       ipl - ihl, IPPROTO_TCP,
219                                       skb->csum);
220
221         skb->ip_summed = CHECKSUM_NONE;
222
223         return 1;
224 }
225
226 static int tcf_csum_ipv4_udp(struct sk_buff *skb, unsigned int ihl,
227                              unsigned int ipl, int udplite)
228 {
229         struct udphdr *udph;
230         const struct iphdr *iph;
231         u16 ul;
232
233         if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP)
234                 return 1;
235
236         /*
237          * Support both UDP and UDPLITE checksum algorithms, Don't use
238          * udph->len to get the real length without any protocol check,
239          * UDPLITE uses udph->len for another thing,
240          * Use iph->tot_len, or just ipl.
241          */
242
243         udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph));
244         if (udph == NULL)
245                 return 0;
246
247         iph = ip_hdr(skb);
248         ul = ntohs(udph->len);
249
250         if (udplite || udph->check) {
251
252                 udph->check = 0;
253
254                 if (udplite) {
255                         if (ul == 0)
256                                 skb->csum = csum_partial(udph, ipl - ihl, 0);
257                         else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl))
258                                 skb->csum = csum_partial(udph, ul, 0);
259                         else
260                                 goto ignore_obscure_skb;
261                 } else {
262                         if (ul != ipl - ihl)
263                                 goto ignore_obscure_skb;
264
265                         skb->csum = csum_partial(udph, ul, 0);
266                 }
267
268                 udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr,
269                                                 ul, iph->protocol,
270                                                 skb->csum);
271
272                 if (!udph->check)
273                         udph->check = CSUM_MANGLED_0;
274         }
275
276         skb->ip_summed = CHECKSUM_NONE;
277
278 ignore_obscure_skb:
279         return 1;
280 }
281
282 static int tcf_csum_ipv6_udp(struct sk_buff *skb, unsigned int ihl,
283                              unsigned int ipl, int udplite)
284 {
285         struct udphdr *udph;
286         const struct ipv6hdr *ip6h;
287         u16 ul;
288
289         if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP)
290                 return 1;
291
292         /*
293          * Support both UDP and UDPLITE checksum algorithms, Don't use
294          * udph->len to get the real length without any protocol check,
295          * UDPLITE uses udph->len for another thing,
296          * Use ip6h->payload_len + sizeof(*ip6h) ... , or just ipl.
297          */
298
299         udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph));
300         if (udph == NULL)
301                 return 0;
302
303         ip6h = ipv6_hdr(skb);
304         ul = ntohs(udph->len);
305
306         udph->check = 0;
307
308         if (udplite) {
309                 if (ul == 0)
310                         skb->csum = csum_partial(udph, ipl - ihl, 0);
311
312                 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl))
313                         skb->csum = csum_partial(udph, ul, 0);
314
315                 else
316                         goto ignore_obscure_skb;
317         } else {
318                 if (ul != ipl - ihl)
319                         goto ignore_obscure_skb;
320
321                 skb->csum = csum_partial(udph, ul, 0);
322         }
323
324         udph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, ul,
325                                       udplite ? IPPROTO_UDPLITE : IPPROTO_UDP,
326                                       skb->csum);
327
328         if (!udph->check)
329                 udph->check = CSUM_MANGLED_0;
330
331         skb->ip_summed = CHECKSUM_NONE;
332
333 ignore_obscure_skb:
334         return 1;
335 }
336
337 static int tcf_csum_ipv4(struct sk_buff *skb, u32 update_flags)
338 {
339         const struct iphdr *iph;
340         int ntkoff;
341
342         ntkoff = skb_network_offset(skb);
343
344         if (!pskb_may_pull(skb, sizeof(*iph) + ntkoff))
345                 goto fail;
346
347         iph = ip_hdr(skb);
348
349         switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) {
350         case IPPROTO_ICMP:
351                 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP)
352                         if (!tcf_csum_ipv4_icmp(skb, iph->ihl * 4,
353                                                 ntohs(iph->tot_len)))
354                                 goto fail;
355                 break;
356         case IPPROTO_IGMP:
357                 if (update_flags & TCA_CSUM_UPDATE_FLAG_IGMP)
358                         if (!tcf_csum_ipv4_igmp(skb, iph->ihl * 4,
359                                                 ntohs(iph->tot_len)))
360                                 goto fail;
361                 break;
362         case IPPROTO_TCP:
363                 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP)
364                         if (!tcf_csum_ipv4_tcp(skb, iph->ihl * 4,
365                                                ntohs(iph->tot_len)))
366                                 goto fail;
367                 break;
368         case IPPROTO_UDP:
369                 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP)
370                         if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4,
371                                                ntohs(iph->tot_len), 0))
372                                 goto fail;
373                 break;
374         case IPPROTO_UDPLITE:
375                 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE)
376                         if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4,
377                                                ntohs(iph->tot_len), 1))
378                                 goto fail;
379                 break;
380         }
381
382         if (update_flags & TCA_CSUM_UPDATE_FLAG_IPV4HDR) {
383                 if (skb_try_make_writable(skb, sizeof(*iph) + ntkoff))
384                         goto fail;
385
386                 ip_send_check(ip_hdr(skb));
387         }
388
389         return 1;
390
391 fail:
392         return 0;
393 }
394
395 static int tcf_csum_ipv6_hopopts(struct ipv6_opt_hdr *ip6xh, unsigned int ixhl,
396                                  unsigned int *pl)
397 {
398         int off, len, optlen;
399         unsigned char *xh = (void *)ip6xh;
400
401         off = sizeof(*ip6xh);
402         len = ixhl - off;
403
404         while (len > 1) {
405                 switch (xh[off]) {
406                 case IPV6_TLV_PAD1:
407                         optlen = 1;
408                         break;
409                 case IPV6_TLV_JUMBO:
410                         optlen = xh[off + 1] + 2;
411                         if (optlen != 6 || len < 6 || (off & 3) != 2)
412                                 /* wrong jumbo option length/alignment */
413                                 return 0;
414                         *pl = ntohl(*(__be32 *)(xh + off + 2));
415                         goto done;
416                 default:
417                         optlen = xh[off + 1] + 2;
418                         if (optlen > len)
419                                 /* ignore obscure options */
420                                 goto done;
421                         break;
422                 }
423                 off += optlen;
424                 len -= optlen;
425         }
426
427 done:
428         return 1;
429 }
430
431 static int tcf_csum_ipv6(struct sk_buff *skb, u32 update_flags)
432 {
433         struct ipv6hdr *ip6h;
434         struct ipv6_opt_hdr *ip6xh;
435         unsigned int hl, ixhl;
436         unsigned int pl;
437         int ntkoff;
438         u8 nexthdr;
439
440         ntkoff = skb_network_offset(skb);
441
442         hl = sizeof(*ip6h);
443
444         if (!pskb_may_pull(skb, hl + ntkoff))
445                 goto fail;
446
447         ip6h = ipv6_hdr(skb);
448
449         pl = ntohs(ip6h->payload_len);
450         nexthdr = ip6h->nexthdr;
451
452         do {
453                 switch (nexthdr) {
454                 case NEXTHDR_FRAGMENT:
455                         goto ignore_skb;
456                 case NEXTHDR_ROUTING:
457                 case NEXTHDR_HOP:
458                 case NEXTHDR_DEST:
459                         if (!pskb_may_pull(skb, hl + sizeof(*ip6xh) + ntkoff))
460                                 goto fail;
461                         ip6xh = (void *)(skb_network_header(skb) + hl);
462                         ixhl = ipv6_optlen(ip6xh);
463                         if (!pskb_may_pull(skb, hl + ixhl + ntkoff))
464                                 goto fail;
465                         ip6xh = (void *)(skb_network_header(skb) + hl);
466                         if ((nexthdr == NEXTHDR_HOP) &&
467                             !(tcf_csum_ipv6_hopopts(ip6xh, ixhl, &pl)))
468                                 goto fail;
469                         nexthdr = ip6xh->nexthdr;
470                         hl += ixhl;
471                         break;
472                 case IPPROTO_ICMPV6:
473                         if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP)
474                                 if (!tcf_csum_ipv6_icmp(skb,
475                                                         hl, pl + sizeof(*ip6h)))
476                                         goto fail;
477                         goto done;
478                 case IPPROTO_TCP:
479                         if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP)
480                                 if (!tcf_csum_ipv6_tcp(skb,
481                                                        hl, pl + sizeof(*ip6h)))
482                                         goto fail;
483                         goto done;
484                 case IPPROTO_UDP:
485                         if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP)
486                                 if (!tcf_csum_ipv6_udp(skb, hl,
487                                                        pl + sizeof(*ip6h), 0))
488                                         goto fail;
489                         goto done;
490                 case IPPROTO_UDPLITE:
491                         if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE)
492                                 if (!tcf_csum_ipv6_udp(skb, hl,
493                                                        pl + sizeof(*ip6h), 1))
494                                         goto fail;
495                         goto done;
496                 default:
497                         goto ignore_skb;
498                 }
499         } while (pskb_may_pull(skb, hl + 1 + ntkoff));
500
501 done:
502 ignore_skb:
503         return 1;
504
505 fail:
506         return 0;
507 }
508
509 static int tcf_csum(struct sk_buff *skb, const struct tc_action *a,
510                     struct tcf_result *res)
511 {
512         struct tcf_csum *p = to_tcf_csum(a);
513         int action;
514         u32 update_flags;
515
516         spin_lock(&p->tcf_lock);
517         tcf_lastuse_update(&p->tcf_tm);
518         bstats_update(&p->tcf_bstats, skb);
519         action = p->tcf_action;
520         update_flags = p->update_flags;
521         spin_unlock(&p->tcf_lock);
522
523         if (unlikely(action == TC_ACT_SHOT))
524                 goto drop;
525
526         switch (tc_skb_protocol(skb)) {
527         case cpu_to_be16(ETH_P_IP):
528                 if (!tcf_csum_ipv4(skb, update_flags))
529                         goto drop;
530                 break;
531         case cpu_to_be16(ETH_P_IPV6):
532                 if (!tcf_csum_ipv6(skb, update_flags))
533                         goto drop;
534                 break;
535         }
536
537         return action;
538
539 drop:
540         spin_lock(&p->tcf_lock);
541         p->tcf_qstats.drops++;
542         spin_unlock(&p->tcf_lock);
543         return TC_ACT_SHOT;
544 }
545
546 static int tcf_csum_dump(struct sk_buff *skb, struct tc_action *a, int bind,
547                          int ref)
548 {
549         unsigned char *b = skb_tail_pointer(skb);
550         struct tcf_csum *p = to_tcf_csum(a);
551         struct tc_csum opt = {
552                 .update_flags = p->update_flags,
553                 .index   = p->tcf_index,
554                 .action  = p->tcf_action,
555                 .refcnt  = p->tcf_refcnt - ref,
556                 .bindcnt = p->tcf_bindcnt - bind,
557         };
558         struct tcf_t t;
559
560         if (nla_put(skb, TCA_CSUM_PARMS, sizeof(opt), &opt))
561                 goto nla_put_failure;
562
563         tcf_tm_dump(&t, &p->tcf_tm);
564         if (nla_put_64bit(skb, TCA_CSUM_TM, sizeof(t), &t, TCA_CSUM_PAD))
565                 goto nla_put_failure;
566
567         return skb->len;
568
569 nla_put_failure:
570         nlmsg_trim(skb, b);
571         return -1;
572 }
573
574 static int tcf_csum_walker(struct net *net, struct sk_buff *skb,
575                            struct netlink_callback *cb, int type,
576                            const struct tc_action_ops *ops)
577 {
578         struct tc_action_net *tn = net_generic(net, csum_net_id);
579
580         return tcf_generic_walker(tn, skb, cb, type, ops);
581 }
582
583 static int tcf_csum_search(struct net *net, struct tc_action **a, u32 index)
584 {
585         struct tc_action_net *tn = net_generic(net, csum_net_id);
586
587         return tcf_hash_search(tn, a, index);
588 }
589
590 static struct tc_action_ops act_csum_ops = {
591         .kind           = "csum",
592         .type           = TCA_ACT_CSUM,
593         .owner          = THIS_MODULE,
594         .act            = tcf_csum,
595         .dump           = tcf_csum_dump,
596         .init           = tcf_csum_init,
597         .walk           = tcf_csum_walker,
598         .lookup         = tcf_csum_search,
599         .size           = sizeof(struct tcf_csum),
600 };
601
602 static __net_init int csum_init_net(struct net *net)
603 {
604         struct tc_action_net *tn = net_generic(net, csum_net_id);
605
606         return tc_action_net_init(tn, &act_csum_ops, CSUM_TAB_MASK);
607 }
608
609 static void __net_exit csum_exit_net(struct net *net)
610 {
611         struct tc_action_net *tn = net_generic(net, csum_net_id);
612
613         tc_action_net_exit(tn);
614 }
615
616 static struct pernet_operations csum_net_ops = {
617         .init = csum_init_net,
618         .exit = csum_exit_net,
619         .id   = &csum_net_id,
620         .size = sizeof(struct tc_action_net),
621 };
622
623 MODULE_DESCRIPTION("Checksum updating actions");
624 MODULE_LICENSE("GPL");
625
626 static int __init csum_init_module(void)
627 {
628         return tcf_register_action(&act_csum_ops, &csum_net_ops);
629 }
630
631 static void __exit csum_cleanup_module(void)
632 {
633         tcf_unregister_action(&act_csum_ops, &csum_net_ops);
634 }
635
636 module_init(csum_init_module);
637 module_exit(csum_cleanup_module);