Mention branches and keyring.
[releases.git] / sched / act_ct.c
1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /* -
3  * net/sched/act_ct.c  Connection Tracking action
4  *
5  * Authors:   Paul Blakey <paulb@mellanox.com>
6  *            Yossi Kuperman <yossiku@mellanox.com>
7  *            Marcelo Ricardo Leitner <marcelo.leitner@gmail.com>
8  */
9
10 #include <linux/module.h>
11 #include <linux/init.h>
12 #include <linux/kernel.h>
13 #include <linux/skbuff.h>
14 #include <linux/rtnetlink.h>
15 #include <linux/pkt_cls.h>
16 #include <linux/ip.h>
17 #include <linux/ipv6.h>
18 #include <linux/rhashtable.h>
19 #include <net/netlink.h>
20 #include <net/pkt_sched.h>
21 #include <net/pkt_cls.h>
22 #include <net/act_api.h>
23 #include <net/ip.h>
24 #include <net/ipv6_frag.h>
25 #include <uapi/linux/tc_act/tc_ct.h>
26 #include <net/tc_act/tc_ct.h>
27
28 #include <net/netfilter/nf_flow_table.h>
29 #include <net/netfilter/nf_conntrack.h>
30 #include <net/netfilter/nf_conntrack_core.h>
31 #include <net/netfilter/nf_conntrack_zones.h>
32 #include <net/netfilter/nf_conntrack_helper.h>
33 #include <net/netfilter/nf_conntrack_acct.h>
34 #include <net/netfilter/ipv6/nf_defrag_ipv6.h>
35 #include <net/netfilter/nf_conntrack_act_ct.h>
36 #include <uapi/linux/netfilter/nf_nat.h>
37
38 static struct workqueue_struct *act_ct_wq;
39 static struct rhashtable zones_ht;
40 static DEFINE_MUTEX(zones_mutex);
41
42 struct tcf_ct_flow_table {
43         struct rhash_head node; /* In zones tables */
44
45         struct rcu_work rwork;
46         struct nf_flowtable nf_ft;
47         refcount_t ref;
48         u16 zone;
49
50         bool dying;
51 };
52
53 static const struct rhashtable_params zones_params = {
54         .head_offset = offsetof(struct tcf_ct_flow_table, node),
55         .key_offset = offsetof(struct tcf_ct_flow_table, zone),
56         .key_len = sizeof_field(struct tcf_ct_flow_table, zone),
57         .automatic_shrinking = true,
58 };
59
60 static struct flow_action_entry *
61 tcf_ct_flow_table_flow_action_get_next(struct flow_action *flow_action)
62 {
63         int i = flow_action->num_entries++;
64
65         return &flow_action->entries[i];
66 }
67
68 static void tcf_ct_add_mangle_action(struct flow_action *action,
69                                      enum flow_action_mangle_base htype,
70                                      u32 offset,
71                                      u32 mask,
72                                      u32 val)
73 {
74         struct flow_action_entry *entry;
75
76         entry = tcf_ct_flow_table_flow_action_get_next(action);
77         entry->id = FLOW_ACTION_MANGLE;
78         entry->mangle.htype = htype;
79         entry->mangle.mask = ~mask;
80         entry->mangle.offset = offset;
81         entry->mangle.val = val;
82 }
83
84 /* The following nat helper functions check if the inverted reverse tuple
85  * (target) is different then the current dir tuple - meaning nat for ports
86  * and/or ip is needed, and add the relevant mangle actions.
87  */
88 static void
89 tcf_ct_flow_table_add_action_nat_ipv4(const struct nf_conntrack_tuple *tuple,
90                                       struct nf_conntrack_tuple target,
91                                       struct flow_action *action)
92 {
93         if (memcmp(&target.src.u3, &tuple->src.u3, sizeof(target.src.u3)))
94                 tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_IP4,
95                                          offsetof(struct iphdr, saddr),
96                                          0xFFFFFFFF,
97                                          be32_to_cpu(target.src.u3.ip));
98         if (memcmp(&target.dst.u3, &tuple->dst.u3, sizeof(target.dst.u3)))
99                 tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_IP4,
100                                          offsetof(struct iphdr, daddr),
101                                          0xFFFFFFFF,
102                                          be32_to_cpu(target.dst.u3.ip));
103 }
104
105 static void
106 tcf_ct_add_ipv6_addr_mangle_action(struct flow_action *action,
107                                    union nf_inet_addr *addr,
108                                    u32 offset)
109 {
110         int i;
111
112         for (i = 0; i < sizeof(struct in6_addr) / sizeof(u32); i++)
113                 tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_IP6,
114                                          i * sizeof(u32) + offset,
115                                          0xFFFFFFFF, be32_to_cpu(addr->ip6[i]));
116 }
117
118 static void
119 tcf_ct_flow_table_add_action_nat_ipv6(const struct nf_conntrack_tuple *tuple,
120                                       struct nf_conntrack_tuple target,
121                                       struct flow_action *action)
122 {
123         if (memcmp(&target.src.u3, &tuple->src.u3, sizeof(target.src.u3)))
124                 tcf_ct_add_ipv6_addr_mangle_action(action, &target.src.u3,
125                                                    offsetof(struct ipv6hdr,
126                                                             saddr));
127         if (memcmp(&target.dst.u3, &tuple->dst.u3, sizeof(target.dst.u3)))
128                 tcf_ct_add_ipv6_addr_mangle_action(action, &target.dst.u3,
129                                                    offsetof(struct ipv6hdr,
130                                                             daddr));
131 }
132
133 static void
134 tcf_ct_flow_table_add_action_nat_tcp(const struct nf_conntrack_tuple *tuple,
135                                      struct nf_conntrack_tuple target,
136                                      struct flow_action *action)
137 {
138         __be16 target_src = target.src.u.tcp.port;
139         __be16 target_dst = target.dst.u.tcp.port;
140
141         if (target_src != tuple->src.u.tcp.port)
142                 tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_TCP,
143                                          offsetof(struct tcphdr, source),
144                                          0xFFFF, be16_to_cpu(target_src));
145         if (target_dst != tuple->dst.u.tcp.port)
146                 tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_TCP,
147                                          offsetof(struct tcphdr, dest),
148                                          0xFFFF, be16_to_cpu(target_dst));
149 }
150
151 static void
152 tcf_ct_flow_table_add_action_nat_udp(const struct nf_conntrack_tuple *tuple,
153                                      struct nf_conntrack_tuple target,
154                                      struct flow_action *action)
155 {
156         __be16 target_src = target.src.u.udp.port;
157         __be16 target_dst = target.dst.u.udp.port;
158
159         if (target_src != tuple->src.u.udp.port)
160                 tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_UDP,
161                                          offsetof(struct udphdr, source),
162                                          0xFFFF, be16_to_cpu(target_src));
163         if (target_dst != tuple->dst.u.udp.port)
164                 tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_UDP,
165                                          offsetof(struct udphdr, dest),
166                                          0xFFFF, be16_to_cpu(target_dst));
167 }
168
169 static void tcf_ct_flow_table_add_action_meta(struct nf_conn *ct,
170                                               enum ip_conntrack_dir dir,
171                                               enum ip_conntrack_info ctinfo,
172                                               struct flow_action *action)
173 {
174         struct nf_conn_labels *ct_labels;
175         struct flow_action_entry *entry;
176         u32 *act_ct_labels;
177
178         entry = tcf_ct_flow_table_flow_action_get_next(action);
179         entry->id = FLOW_ACTION_CT_METADATA;
180 #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)
181         entry->ct_metadata.mark = READ_ONCE(ct->mark);
182 #endif
183         /* aligns with the CT reference on the SKB nf_ct_set */
184         entry->ct_metadata.cookie = (unsigned long)ct | ctinfo;
185         entry->ct_metadata.orig_dir = dir == IP_CT_DIR_ORIGINAL;
186
187         act_ct_labels = entry->ct_metadata.labels;
188         ct_labels = nf_ct_labels_find(ct);
189         if (ct_labels)
190                 memcpy(act_ct_labels, ct_labels->bits, NF_CT_LABELS_MAX_SIZE);
191         else
192                 memset(act_ct_labels, 0, NF_CT_LABELS_MAX_SIZE);
193 }
194
195 static int tcf_ct_flow_table_add_action_nat(struct net *net,
196                                             struct nf_conn *ct,
197                                             enum ip_conntrack_dir dir,
198                                             struct flow_action *action)
199 {
200         const struct nf_conntrack_tuple *tuple = &ct->tuplehash[dir].tuple;
201         struct nf_conntrack_tuple target;
202
203         if (!(ct->status & IPS_NAT_MASK))
204                 return 0;
205
206         nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple);
207
208         switch (tuple->src.l3num) {
209         case NFPROTO_IPV4:
210                 tcf_ct_flow_table_add_action_nat_ipv4(tuple, target,
211                                                       action);
212                 break;
213         case NFPROTO_IPV6:
214                 tcf_ct_flow_table_add_action_nat_ipv6(tuple, target,
215                                                       action);
216                 break;
217         default:
218                 return -EOPNOTSUPP;
219         }
220
221         switch (nf_ct_protonum(ct)) {
222         case IPPROTO_TCP:
223                 tcf_ct_flow_table_add_action_nat_tcp(tuple, target, action);
224                 break;
225         case IPPROTO_UDP:
226                 tcf_ct_flow_table_add_action_nat_udp(tuple, target, action);
227                 break;
228         default:
229                 return -EOPNOTSUPP;
230         }
231
232         return 0;
233 }
234
235 static int tcf_ct_flow_table_fill_actions(struct net *net,
236                                           struct flow_offload *flow,
237                                           enum flow_offload_tuple_dir tdir,
238                                           struct nf_flow_rule *flow_rule)
239 {
240         struct flow_action *action = &flow_rule->rule->action;
241         int num_entries = action->num_entries;
242         struct nf_conn *ct = flow->ct;
243         enum ip_conntrack_info ctinfo;
244         enum ip_conntrack_dir dir;
245         int i, err;
246
247         switch (tdir) {
248         case FLOW_OFFLOAD_DIR_ORIGINAL:
249                 dir = IP_CT_DIR_ORIGINAL;
250                 ctinfo = IP_CT_ESTABLISHED;
251                 set_bit(NF_FLOW_HW_ESTABLISHED, &flow->flags);
252                 break;
253         case FLOW_OFFLOAD_DIR_REPLY:
254                 dir = IP_CT_DIR_REPLY;
255                 ctinfo = IP_CT_ESTABLISHED_REPLY;
256                 break;
257         default:
258                 return -EOPNOTSUPP;
259         }
260
261         err = tcf_ct_flow_table_add_action_nat(net, ct, dir, action);
262         if (err)
263                 goto err_nat;
264
265         tcf_ct_flow_table_add_action_meta(ct, dir, ctinfo, action);
266         return 0;
267
268 err_nat:
269         /* Clear filled actions */
270         for (i = num_entries; i < action->num_entries; i++)
271                 memset(&action->entries[i], 0, sizeof(action->entries[i]));
272         action->num_entries = num_entries;
273
274         return err;
275 }
276
277 static bool tcf_ct_flow_is_outdated(const struct flow_offload *flow)
278 {
279         return test_bit(IPS_SEEN_REPLY_BIT, &flow->ct->status) &&
280                test_bit(IPS_HW_OFFLOAD_BIT, &flow->ct->status) &&
281                !test_bit(NF_FLOW_HW_PENDING, &flow->flags) &&
282                !test_bit(NF_FLOW_HW_ESTABLISHED, &flow->flags);
283 }
284
285 static void tcf_ct_flow_table_get_ref(struct tcf_ct_flow_table *ct_ft);
286
287 static void tcf_ct_nf_get(struct nf_flowtable *ft)
288 {
289         struct tcf_ct_flow_table *ct_ft =
290                 container_of(ft, struct tcf_ct_flow_table, nf_ft);
291
292         tcf_ct_flow_table_get_ref(ct_ft);
293 }
294
295 static void tcf_ct_flow_table_put(struct tcf_ct_flow_table *ct_ft);
296
297 static void tcf_ct_nf_put(struct nf_flowtable *ft)
298 {
299         struct tcf_ct_flow_table *ct_ft =
300                 container_of(ft, struct tcf_ct_flow_table, nf_ft);
301
302         tcf_ct_flow_table_put(ct_ft);
303 }
304
305 static struct nf_flowtable_type flowtable_ct = {
306         .gc             = tcf_ct_flow_is_outdated,
307         .action         = tcf_ct_flow_table_fill_actions,
308         .get            = tcf_ct_nf_get,
309         .put            = tcf_ct_nf_put,
310         .owner          = THIS_MODULE,
311 };
312
313 static int tcf_ct_flow_table_get(struct net *net, struct tcf_ct_params *params)
314 {
315         struct tcf_ct_flow_table *ct_ft;
316         int err = -ENOMEM;
317
318         mutex_lock(&zones_mutex);
319         ct_ft = rhashtable_lookup_fast(&zones_ht, &params->zone, zones_params);
320         if (ct_ft && refcount_inc_not_zero(&ct_ft->ref))
321                 goto out_unlock;
322
323         ct_ft = kzalloc(sizeof(*ct_ft), GFP_KERNEL);
324         if (!ct_ft)
325                 goto err_alloc;
326         refcount_set(&ct_ft->ref, 1);
327
328         ct_ft->zone = params->zone;
329         err = rhashtable_insert_fast(&zones_ht, &ct_ft->node, zones_params);
330         if (err)
331                 goto err_insert;
332
333         ct_ft->nf_ft.type = &flowtable_ct;
334         ct_ft->nf_ft.flags |= NF_FLOWTABLE_HW_OFFLOAD |
335                               NF_FLOWTABLE_COUNTER;
336         err = nf_flow_table_init(&ct_ft->nf_ft);
337         if (err)
338                 goto err_init;
339         write_pnet(&ct_ft->nf_ft.net, net);
340
341         __module_get(THIS_MODULE);
342 out_unlock:
343         params->ct_ft = ct_ft;
344         params->nf_ft = &ct_ft->nf_ft;
345         mutex_unlock(&zones_mutex);
346
347         return 0;
348
349 err_init:
350         rhashtable_remove_fast(&zones_ht, &ct_ft->node, zones_params);
351 err_insert:
352         kfree(ct_ft);
353 err_alloc:
354         mutex_unlock(&zones_mutex);
355         return err;
356 }
357
358 static void tcf_ct_flow_table_get_ref(struct tcf_ct_flow_table *ct_ft)
359 {
360         refcount_inc(&ct_ft->ref);
361 }
362
363 static void tcf_ct_flow_table_cleanup_work(struct work_struct *work)
364 {
365         struct tcf_ct_flow_table *ct_ft;
366         struct flow_block *block;
367
368         ct_ft = container_of(to_rcu_work(work), struct tcf_ct_flow_table,
369                              rwork);
370         nf_flow_table_free(&ct_ft->nf_ft);
371
372         block = &ct_ft->nf_ft.flow_block;
373         down_write(&ct_ft->nf_ft.flow_block_lock);
374         WARN_ON(!list_empty(&block->cb_list));
375         up_write(&ct_ft->nf_ft.flow_block_lock);
376         kfree(ct_ft);
377
378         module_put(THIS_MODULE);
379 }
380
381 static void tcf_ct_flow_table_put(struct tcf_ct_flow_table *ct_ft)
382 {
383         if (refcount_dec_and_test(&ct_ft->ref)) {
384                 rhashtable_remove_fast(&zones_ht, &ct_ft->node, zones_params);
385                 INIT_RCU_WORK(&ct_ft->rwork, tcf_ct_flow_table_cleanup_work);
386                 queue_rcu_work(act_ct_wq, &ct_ft->rwork);
387         }
388 }
389
390 static void tcf_ct_flow_tc_ifidx(struct flow_offload *entry,
391                                  struct nf_conn_act_ct_ext *act_ct_ext, u8 dir)
392 {
393         entry->tuplehash[dir].tuple.xmit_type = FLOW_OFFLOAD_XMIT_TC;
394         entry->tuplehash[dir].tuple.tc.iifidx = act_ct_ext->ifindex[dir];
395 }
396
397 static void tcf_ct_flow_ct_ext_ifidx_update(struct flow_offload *entry)
398 {
399         struct nf_conn_act_ct_ext *act_ct_ext;
400
401         act_ct_ext = nf_conn_act_ct_ext_find(entry->ct);
402         if (act_ct_ext) {
403                 tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_ORIGINAL);
404                 tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_REPLY);
405         }
406 }
407
408 static void tcf_ct_flow_table_add(struct tcf_ct_flow_table *ct_ft,
409                                   struct nf_conn *ct,
410                                   bool tcp, bool bidirectional)
411 {
412         struct nf_conn_act_ct_ext *act_ct_ext;
413         struct flow_offload *entry;
414         int err;
415
416         if (test_and_set_bit(IPS_OFFLOAD_BIT, &ct->status))
417                 return;
418
419         entry = flow_offload_alloc(ct);
420         if (!entry) {
421                 WARN_ON_ONCE(1);
422                 goto err_alloc;
423         }
424
425         if (tcp) {
426                 ct->proto.tcp.seen[0].flags |= IP_CT_TCP_FLAG_BE_LIBERAL;
427                 ct->proto.tcp.seen[1].flags |= IP_CT_TCP_FLAG_BE_LIBERAL;
428         }
429         if (bidirectional)
430                 __set_bit(NF_FLOW_HW_BIDIRECTIONAL, &entry->flags);
431
432         act_ct_ext = nf_conn_act_ct_ext_find(ct);
433         if (act_ct_ext) {
434                 tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_ORIGINAL);
435                 tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_REPLY);
436         }
437
438         err = flow_offload_add(&ct_ft->nf_ft, entry);
439         if (err)
440                 goto err_add;
441
442         return;
443
444 err_add:
445         flow_offload_free(entry);
446 err_alloc:
447         clear_bit(IPS_OFFLOAD_BIT, &ct->status);
448 }
449
450 static void tcf_ct_flow_table_process_conn(struct tcf_ct_flow_table *ct_ft,
451                                            struct nf_conn *ct,
452                                            enum ip_conntrack_info ctinfo)
453 {
454         bool tcp = false, bidirectional = true;
455
456         switch (nf_ct_protonum(ct)) {
457         case IPPROTO_TCP:
458                 if ((ctinfo != IP_CT_ESTABLISHED &&
459                      ctinfo != IP_CT_ESTABLISHED_REPLY) ||
460                     !test_bit(IPS_ASSURED_BIT, &ct->status) ||
461                     ct->proto.tcp.state != TCP_CONNTRACK_ESTABLISHED)
462                         return;
463
464                 tcp = true;
465                 break;
466         case IPPROTO_UDP:
467                 if (!nf_ct_is_confirmed(ct))
468                         return;
469                 if (!test_bit(IPS_ASSURED_BIT, &ct->status))
470                         bidirectional = false;
471                 break;
472 #ifdef CONFIG_NF_CT_PROTO_GRE
473         case IPPROTO_GRE: {
474                 struct nf_conntrack_tuple *tuple;
475
476                 if ((ctinfo != IP_CT_ESTABLISHED &&
477                      ctinfo != IP_CT_ESTABLISHED_REPLY) ||
478                     !test_bit(IPS_ASSURED_BIT, &ct->status) ||
479                     ct->status & IPS_NAT_MASK)
480                         return;
481
482                 tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple;
483                 /* No support for GRE v1 */
484                 if (tuple->src.u.gre.key || tuple->dst.u.gre.key)
485                         return;
486                 break;
487         }
488 #endif
489         default:
490                 return;
491         }
492
493         if (nf_ct_ext_exist(ct, NF_CT_EXT_HELPER) ||
494             ct->status & IPS_SEQ_ADJUST)
495                 return;
496
497         tcf_ct_flow_table_add(ct_ft, ct, tcp, bidirectional);
498 }
499
500 static bool
501 tcf_ct_flow_table_fill_tuple_ipv4(struct sk_buff *skb,
502                                   struct flow_offload_tuple *tuple,
503                                   struct tcphdr **tcph)
504 {
505         struct flow_ports *ports;
506         unsigned int thoff;
507         struct iphdr *iph;
508         size_t hdrsize;
509         u8 ipproto;
510
511         if (!pskb_network_may_pull(skb, sizeof(*iph)))
512                 return false;
513
514         iph = ip_hdr(skb);
515         thoff = iph->ihl * 4;
516
517         if (ip_is_fragment(iph) ||
518             unlikely(thoff != sizeof(struct iphdr)))
519                 return false;
520
521         ipproto = iph->protocol;
522         switch (ipproto) {
523         case IPPROTO_TCP:
524                 hdrsize = sizeof(struct tcphdr);
525                 break;
526         case IPPROTO_UDP:
527                 hdrsize = sizeof(*ports);
528                 break;
529 #ifdef CONFIG_NF_CT_PROTO_GRE
530         case IPPROTO_GRE:
531                 hdrsize = sizeof(struct gre_base_hdr);
532                 break;
533 #endif
534         default:
535                 return false;
536         }
537
538         if (iph->ttl <= 1)
539                 return false;
540
541         if (!pskb_network_may_pull(skb, thoff + hdrsize))
542                 return false;
543
544         switch (ipproto) {
545         case IPPROTO_TCP:
546                 *tcph = (void *)(skb_network_header(skb) + thoff);
547                 fallthrough;
548         case IPPROTO_UDP:
549                 ports = (struct flow_ports *)(skb_network_header(skb) + thoff);
550                 tuple->src_port = ports->source;
551                 tuple->dst_port = ports->dest;
552                 break;
553         case IPPROTO_GRE: {
554                 struct gre_base_hdr *greh;
555
556                 greh = (struct gre_base_hdr *)(skb_network_header(skb) + thoff);
557                 if ((greh->flags & GRE_VERSION) != GRE_VERSION_0)
558                         return false;
559                 break;
560         }
561         }
562
563         iph = ip_hdr(skb);
564
565         tuple->src_v4.s_addr = iph->saddr;
566         tuple->dst_v4.s_addr = iph->daddr;
567         tuple->l3proto = AF_INET;
568         tuple->l4proto = ipproto;
569
570         return true;
571 }
572
573 static bool
574 tcf_ct_flow_table_fill_tuple_ipv6(struct sk_buff *skb,
575                                   struct flow_offload_tuple *tuple,
576                                   struct tcphdr **tcph)
577 {
578         struct flow_ports *ports;
579         struct ipv6hdr *ip6h;
580         unsigned int thoff;
581         size_t hdrsize;
582         u8 nexthdr;
583
584         if (!pskb_network_may_pull(skb, sizeof(*ip6h)))
585                 return false;
586
587         ip6h = ipv6_hdr(skb);
588         thoff = sizeof(*ip6h);
589
590         nexthdr = ip6h->nexthdr;
591         switch (nexthdr) {
592         case IPPROTO_TCP:
593                 hdrsize = sizeof(struct tcphdr);
594                 break;
595         case IPPROTO_UDP:
596                 hdrsize = sizeof(*ports);
597                 break;
598 #ifdef CONFIG_NF_CT_PROTO_GRE
599         case IPPROTO_GRE:
600                 hdrsize = sizeof(struct gre_base_hdr);
601                 break;
602 #endif
603         default:
604                 return false;
605         }
606
607         if (ip6h->hop_limit <= 1)
608                 return false;
609
610         if (!pskb_network_may_pull(skb, thoff + hdrsize))
611                 return false;
612
613         switch (nexthdr) {
614         case IPPROTO_TCP:
615                 *tcph = (void *)(skb_network_header(skb) + thoff);
616                 fallthrough;
617         case IPPROTO_UDP:
618                 ports = (struct flow_ports *)(skb_network_header(skb) + thoff);
619                 tuple->src_port = ports->source;
620                 tuple->dst_port = ports->dest;
621                 break;
622         case IPPROTO_GRE: {
623                 struct gre_base_hdr *greh;
624
625                 greh = (struct gre_base_hdr *)(skb_network_header(skb) + thoff);
626                 if ((greh->flags & GRE_VERSION) != GRE_VERSION_0)
627                         return false;
628                 break;
629         }
630         }
631
632         ip6h = ipv6_hdr(skb);
633
634         tuple->src_v6 = ip6h->saddr;
635         tuple->dst_v6 = ip6h->daddr;
636         tuple->l3proto = AF_INET6;
637         tuple->l4proto = nexthdr;
638
639         return true;
640 }
641
642 static bool tcf_ct_flow_table_lookup(struct tcf_ct_params *p,
643                                      struct sk_buff *skb,
644                                      u8 family)
645 {
646         struct nf_flowtable *nf_ft = &p->ct_ft->nf_ft;
647         struct flow_offload_tuple_rhash *tuplehash;
648         struct flow_offload_tuple tuple = {};
649         enum ip_conntrack_info ctinfo;
650         struct tcphdr *tcph = NULL;
651         bool force_refresh = false;
652         struct flow_offload *flow;
653         struct nf_conn *ct;
654         u8 dir;
655
656         switch (family) {
657         case NFPROTO_IPV4:
658                 if (!tcf_ct_flow_table_fill_tuple_ipv4(skb, &tuple, &tcph))
659                         return false;
660                 break;
661         case NFPROTO_IPV6:
662                 if (!tcf_ct_flow_table_fill_tuple_ipv6(skb, &tuple, &tcph))
663                         return false;
664                 break;
665         default:
666                 return false;
667         }
668
669         tuplehash = flow_offload_lookup(nf_ft, &tuple);
670         if (!tuplehash)
671                 return false;
672
673         dir = tuplehash->tuple.dir;
674         flow = container_of(tuplehash, struct flow_offload, tuplehash[dir]);
675         ct = flow->ct;
676
677         if (dir == FLOW_OFFLOAD_DIR_REPLY &&
678             !test_bit(NF_FLOW_HW_BIDIRECTIONAL, &flow->flags)) {
679                 /* Only offload reply direction after connection became
680                  * assured.
681                  */
682                 if (test_bit(IPS_ASSURED_BIT, &ct->status))
683                         set_bit(NF_FLOW_HW_BIDIRECTIONAL, &flow->flags);
684                 else if (test_bit(NF_FLOW_HW_ESTABLISHED, &flow->flags))
685                         /* If flow_table flow has already been updated to the
686                          * established state, then don't refresh.
687                          */
688                         return false;
689                 force_refresh = true;
690         }
691
692         if (tcph && (unlikely(tcph->fin || tcph->rst))) {
693                 flow_offload_teardown(flow);
694                 return false;
695         }
696
697         if (dir == FLOW_OFFLOAD_DIR_ORIGINAL)
698                 ctinfo = test_bit(IPS_SEEN_REPLY_BIT, &ct->status) ?
699                         IP_CT_ESTABLISHED : IP_CT_NEW;
700         else
701                 ctinfo = IP_CT_ESTABLISHED_REPLY;
702
703         nf_conn_act_ct_ext_fill(skb, ct, ctinfo);
704         tcf_ct_flow_ct_ext_ifidx_update(flow);
705         flow_offload_refresh(nf_ft, flow, force_refresh);
706         if (!test_bit(IPS_ASSURED_BIT, &ct->status)) {
707                 /* Process this flow in SW to allow promoting to ASSURED */
708                 return false;
709         }
710
711         nf_conntrack_get(&ct->ct_general);
712         nf_ct_set(skb, ct, ctinfo);
713         if (nf_ft->flags & NF_FLOWTABLE_COUNTER)
714                 nf_ct_acct_update(ct, dir, skb->len);
715
716         return true;
717 }
718
719 static int tcf_ct_flow_tables_init(void)
720 {
721         return rhashtable_init(&zones_ht, &zones_params);
722 }
723
724 static void tcf_ct_flow_tables_uninit(void)
725 {
726         rhashtable_destroy(&zones_ht);
727 }
728
729 static struct tc_action_ops act_ct_ops;
730
731 struct tc_ct_action_net {
732         struct tc_action_net tn; /* Must be first */
733         bool labels;
734 };
735
736 /* Determine whether skb->_nfct is equal to the result of conntrack lookup. */
737 static bool tcf_ct_skb_nfct_cached(struct net *net, struct sk_buff *skb,
738                                    u16 zone_id, bool force)
739 {
740         enum ip_conntrack_info ctinfo;
741         struct nf_conn *ct;
742
743         ct = nf_ct_get(skb, &ctinfo);
744         if (!ct)
745                 return false;
746         if (!net_eq(net, read_pnet(&ct->ct_net)))
747                 goto drop_ct;
748         if (nf_ct_zone(ct)->id != zone_id)
749                 goto drop_ct;
750
751         /* Force conntrack entry direction. */
752         if (force && CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) {
753                 if (nf_ct_is_confirmed(ct))
754                         nf_ct_kill(ct);
755
756                 goto drop_ct;
757         }
758
759         return true;
760
761 drop_ct:
762         nf_ct_put(ct);
763         nf_ct_set(skb, NULL, IP_CT_UNTRACKED);
764
765         return false;
766 }
767
768 /* Trim the skb to the length specified by the IP/IPv6 header,
769  * removing any trailing lower-layer padding. This prepares the skb
770  * for higher-layer processing that assumes skb->len excludes padding
771  * (such as nf_ip_checksum). The caller needs to pull the skb to the
772  * network header, and ensure ip_hdr/ipv6_hdr points to valid data.
773  */
774 static int tcf_ct_skb_network_trim(struct sk_buff *skb, int family)
775 {
776         unsigned int len;
777
778         switch (family) {
779         case NFPROTO_IPV4:
780                 len = ntohs(ip_hdr(skb)->tot_len);
781                 break;
782         case NFPROTO_IPV6:
783                 len = sizeof(struct ipv6hdr)
784                         + ntohs(ipv6_hdr(skb)->payload_len);
785                 break;
786         default:
787                 len = skb->len;
788         }
789
790         return pskb_trim_rcsum(skb, len);
791 }
792
793 static u8 tcf_ct_skb_nf_family(struct sk_buff *skb)
794 {
795         u8 family = NFPROTO_UNSPEC;
796
797         switch (skb_protocol(skb, true)) {
798         case htons(ETH_P_IP):
799                 family = NFPROTO_IPV4;
800                 break;
801         case htons(ETH_P_IPV6):
802                 family = NFPROTO_IPV6;
803                 break;
804         default:
805                 break;
806         }
807
808         return family;
809 }
810
811 static int tcf_ct_ipv4_is_fragment(struct sk_buff *skb, bool *frag)
812 {
813         unsigned int len;
814
815         len =  skb_network_offset(skb) + sizeof(struct iphdr);
816         if (unlikely(skb->len < len))
817                 return -EINVAL;
818         if (unlikely(!pskb_may_pull(skb, len)))
819                 return -ENOMEM;
820
821         *frag = ip_is_fragment(ip_hdr(skb));
822         return 0;
823 }
824
825 static int tcf_ct_ipv6_is_fragment(struct sk_buff *skb, bool *frag)
826 {
827         unsigned int flags = 0, len, payload_ofs = 0;
828         unsigned short frag_off;
829         int nexthdr;
830
831         len =  skb_network_offset(skb) + sizeof(struct ipv6hdr);
832         if (unlikely(skb->len < len))
833                 return -EINVAL;
834         if (unlikely(!pskb_may_pull(skb, len)))
835                 return -ENOMEM;
836
837         nexthdr = ipv6_find_hdr(skb, &payload_ofs, -1, &frag_off, &flags);
838         if (unlikely(nexthdr < 0))
839                 return -EPROTO;
840
841         *frag = flags & IP6_FH_F_FRAG;
842         return 0;
843 }
844
845 static int tcf_ct_handle_fragments(struct net *net, struct sk_buff *skb,
846                                    u8 family, u16 zone, bool *defrag)
847 {
848         enum ip_conntrack_info ctinfo;
849         struct nf_conn *ct;
850         int err = 0;
851         bool frag;
852         u16 mru;
853
854         /* Previously seen (loopback)? Ignore. */
855         ct = nf_ct_get(skb, &ctinfo);
856         if ((ct && !nf_ct_is_template(ct)) || ctinfo == IP_CT_UNTRACKED)
857                 return 0;
858
859         if (family == NFPROTO_IPV4)
860                 err = tcf_ct_ipv4_is_fragment(skb, &frag);
861         else
862                 err = tcf_ct_ipv6_is_fragment(skb, &frag);
863         if (err || !frag)
864                 return err;
865
866         mru = tc_skb_cb(skb)->mru;
867
868         if (family == NFPROTO_IPV4) {
869                 enum ip_defrag_users user = IP_DEFRAG_CONNTRACK_IN + zone;
870
871                 memset(IPCB(skb), 0, sizeof(struct inet_skb_parm));
872                 local_bh_disable();
873                 err = ip_defrag(net, skb, user);
874                 local_bh_enable();
875                 if (err && err != -EINPROGRESS)
876                         return err;
877
878                 if (!err) {
879                         *defrag = true;
880                         mru = IPCB(skb)->frag_max_size;
881                 }
882         } else { /* NFPROTO_IPV6 */
883 #if IS_ENABLED(CONFIG_NF_DEFRAG_IPV6)
884                 enum ip6_defrag_users user = IP6_DEFRAG_CONNTRACK_IN + zone;
885
886                 memset(IP6CB(skb), 0, sizeof(struct inet6_skb_parm));
887                 err = nf_ct_frag6_gather(net, skb, user);
888                 if (err && err != -EINPROGRESS)
889                         goto out_free;
890
891                 if (!err) {
892                         *defrag = true;
893                         mru = IP6CB(skb)->frag_max_size;
894                 }
895 #else
896                 err = -EOPNOTSUPP;
897                 goto out_free;
898 #endif
899         }
900
901         if (err != -EINPROGRESS)
902                 tc_skb_cb(skb)->mru = mru;
903         skb_clear_hash(skb);
904         skb->ignore_df = 1;
905         return err;
906
907 out_free:
908         kfree_skb(skb);
909         return err;
910 }
911
912 static void tcf_ct_params_free(struct tcf_ct_params *params)
913 {
914         if (params->ct_ft)
915                 tcf_ct_flow_table_put(params->ct_ft);
916         if (params->tmpl)
917                 nf_ct_put(params->tmpl);
918         kfree(params);
919 }
920
921 static void tcf_ct_params_free_rcu(struct rcu_head *head)
922 {
923         struct tcf_ct_params *params;
924
925         params = container_of(head, struct tcf_ct_params, rcu);
926         tcf_ct_params_free(params);
927 }
928
929 #if IS_ENABLED(CONFIG_NF_NAT)
930 /* Modelled after nf_nat_ipv[46]_fn().
931  * range is only used for new, uninitialized NAT state.
932  * Returns either NF_ACCEPT or NF_DROP.
933  */
934 static int ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct,
935                           enum ip_conntrack_info ctinfo,
936                           const struct nf_nat_range2 *range,
937                           enum nf_nat_manip_type maniptype)
938 {
939         __be16 proto = skb_protocol(skb, true);
940         int hooknum, err = NF_ACCEPT;
941
942         /* See HOOK2MANIP(). */
943         if (maniptype == NF_NAT_MANIP_SRC)
944                 hooknum = NF_INET_LOCAL_IN; /* Source NAT */
945         else
946                 hooknum = NF_INET_LOCAL_OUT; /* Destination NAT */
947
948         switch (ctinfo) {
949         case IP_CT_RELATED:
950         case IP_CT_RELATED_REPLY:
951                 if (proto == htons(ETH_P_IP) &&
952                     ip_hdr(skb)->protocol == IPPROTO_ICMP) {
953                         if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo,
954                                                            hooknum))
955                                 err = NF_DROP;
956                         goto out;
957                 } else if (IS_ENABLED(CONFIG_IPV6) && proto == htons(ETH_P_IPV6)) {
958                         __be16 frag_off;
959                         u8 nexthdr = ipv6_hdr(skb)->nexthdr;
960                         int hdrlen = ipv6_skip_exthdr(skb,
961                                                       sizeof(struct ipv6hdr),
962                                                       &nexthdr, &frag_off);
963
964                         if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) {
965                                 if (!nf_nat_icmpv6_reply_translation(skb, ct,
966                                                                      ctinfo,
967                                                                      hooknum,
968                                                                      hdrlen))
969                                         err = NF_DROP;
970                                 goto out;
971                         }
972                 }
973                 /* Non-ICMP, fall thru to initialize if needed. */
974                 fallthrough;
975         case IP_CT_NEW:
976                 /* Seen it before?  This can happen for loopback, retrans,
977                  * or local packets.
978                  */
979                 if (!nf_nat_initialized(ct, maniptype)) {
980                         /* Initialize according to the NAT action. */
981                         err = (range && range->flags & NF_NAT_RANGE_MAP_IPS)
982                                 /* Action is set up to establish a new
983                                  * mapping.
984                                  */
985                                 ? nf_nat_setup_info(ct, range, maniptype)
986                                 : nf_nat_alloc_null_binding(ct, hooknum);
987                         if (err != NF_ACCEPT)
988                                 goto out;
989                 }
990                 break;
991
992         case IP_CT_ESTABLISHED:
993         case IP_CT_ESTABLISHED_REPLY:
994                 break;
995
996         default:
997                 err = NF_DROP;
998                 goto out;
999         }
1000
1001         err = nf_nat_packet(ct, ctinfo, hooknum, skb);
1002         if (err == NF_ACCEPT) {
1003                 if (maniptype == NF_NAT_MANIP_SRC)
1004                         tc_skb_cb(skb)->post_ct_snat = 1;
1005                 if (maniptype == NF_NAT_MANIP_DST)
1006                         tc_skb_cb(skb)->post_ct_dnat = 1;
1007         }
1008 out:
1009         return err;
1010 }
1011 #endif /* CONFIG_NF_NAT */
1012
1013 static void tcf_ct_act_set_mark(struct nf_conn *ct, u32 mark, u32 mask)
1014 {
1015 #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)
1016         u32 new_mark;
1017
1018         if (!mask)
1019                 return;
1020
1021         new_mark = mark | (READ_ONCE(ct->mark) & ~(mask));
1022         if (READ_ONCE(ct->mark) != new_mark) {
1023                 WRITE_ONCE(ct->mark, new_mark);
1024                 if (nf_ct_is_confirmed(ct))
1025                         nf_conntrack_event_cache(IPCT_MARK, ct);
1026         }
1027 #endif
1028 }
1029
1030 static void tcf_ct_act_set_labels(struct nf_conn *ct,
1031                                   u32 *labels,
1032                                   u32 *labels_m)
1033 {
1034 #if IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS)
1035         size_t labels_sz = sizeof_field(struct tcf_ct_params, labels);
1036
1037         if (!memchr_inv(labels_m, 0, labels_sz))
1038                 return;
1039
1040         nf_connlabels_replace(ct, labels, labels_m, 4);
1041 #endif
1042 }
1043
1044 static int tcf_ct_act_nat(struct sk_buff *skb,
1045                           struct nf_conn *ct,
1046                           enum ip_conntrack_info ctinfo,
1047                           int ct_action,
1048                           struct nf_nat_range2 *range,
1049                           bool commit)
1050 {
1051 #if IS_ENABLED(CONFIG_NF_NAT)
1052         int err;
1053         enum nf_nat_manip_type maniptype;
1054
1055         if (!(ct_action & TCA_CT_ACT_NAT))
1056                 return NF_ACCEPT;
1057
1058         /* Add NAT extension if not confirmed yet. */
1059         if (!nf_ct_is_confirmed(ct) && !nf_ct_nat_ext_add(ct))
1060                 return NF_DROP;   /* Can't NAT. */
1061
1062         if (ctinfo != IP_CT_NEW && (ct->status & IPS_NAT_MASK) &&
1063             (ctinfo != IP_CT_RELATED || commit)) {
1064                 /* NAT an established or related connection like before. */
1065                 if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY)
1066                         /* This is the REPLY direction for a connection
1067                          * for which NAT was applied in the forward
1068                          * direction.  Do the reverse NAT.
1069                          */
1070                         maniptype = ct->status & IPS_SRC_NAT
1071                                 ? NF_NAT_MANIP_DST : NF_NAT_MANIP_SRC;
1072                 else
1073                         maniptype = ct->status & IPS_SRC_NAT
1074                                 ? NF_NAT_MANIP_SRC : NF_NAT_MANIP_DST;
1075         } else if (ct_action & TCA_CT_ACT_NAT_SRC) {
1076                 maniptype = NF_NAT_MANIP_SRC;
1077         } else if (ct_action & TCA_CT_ACT_NAT_DST) {
1078                 maniptype = NF_NAT_MANIP_DST;
1079         } else {
1080                 return NF_ACCEPT;
1081         }
1082
1083         err = ct_nat_execute(skb, ct, ctinfo, range, maniptype);
1084         if (err == NF_ACCEPT && ct->status & IPS_DST_NAT) {
1085                 if (ct->status & IPS_SRC_NAT) {
1086                         if (maniptype == NF_NAT_MANIP_SRC)
1087                                 maniptype = NF_NAT_MANIP_DST;
1088                         else
1089                                 maniptype = NF_NAT_MANIP_SRC;
1090
1091                         err = ct_nat_execute(skb, ct, ctinfo, range,
1092                                              maniptype);
1093                 } else if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) {
1094                         err = ct_nat_execute(skb, ct, ctinfo, NULL,
1095                                              NF_NAT_MANIP_SRC);
1096                 }
1097         }
1098         return err;
1099 #else
1100         return NF_ACCEPT;
1101 #endif
1102 }
1103
1104 static int tcf_ct_act(struct sk_buff *skb, const struct tc_action *a,
1105                       struct tcf_result *res)
1106 {
1107         struct net *net = dev_net(skb->dev);
1108         bool cached, commit, clear, force;
1109         enum ip_conntrack_info ctinfo;
1110         struct tcf_ct *c = to_ct(a);
1111         struct nf_conn *tmpl = NULL;
1112         struct nf_hook_state state;
1113         int nh_ofs, err, retval;
1114         struct tcf_ct_params *p;
1115         bool skip_add = false;
1116         bool defrag = false;
1117         struct nf_conn *ct;
1118         u8 family;
1119
1120         p = rcu_dereference_bh(c->params);
1121
1122         retval = READ_ONCE(c->tcf_action);
1123         commit = p->ct_action & TCA_CT_ACT_COMMIT;
1124         clear = p->ct_action & TCA_CT_ACT_CLEAR;
1125         force = p->ct_action & TCA_CT_ACT_FORCE;
1126         tmpl = p->tmpl;
1127
1128         tcf_lastuse_update(&c->tcf_tm);
1129         tcf_action_update_bstats(&c->common, skb);
1130
1131         if (clear) {
1132                 tc_skb_cb(skb)->post_ct = false;
1133                 ct = nf_ct_get(skb, &ctinfo);
1134                 if (ct) {
1135                         nf_ct_put(ct);
1136                         nf_ct_set(skb, NULL, IP_CT_UNTRACKED);
1137                 }
1138
1139                 goto out_clear;
1140         }
1141
1142         family = tcf_ct_skb_nf_family(skb);
1143         if (family == NFPROTO_UNSPEC)
1144                 goto drop;
1145
1146         /* The conntrack module expects to be working at L3.
1147          * We also try to pull the IPv4/6 header to linear area
1148          */
1149         nh_ofs = skb_network_offset(skb);
1150         skb_pull_rcsum(skb, nh_ofs);
1151         err = tcf_ct_handle_fragments(net, skb, family, p->zone, &defrag);
1152         if (err)
1153                 goto out_frag;
1154
1155         err = tcf_ct_skb_network_trim(skb, family);
1156         if (err)
1157                 goto drop;
1158
1159         /* If we are recirculating packets to match on ct fields and
1160          * committing with a separate ct action, then we don't need to
1161          * actually run the packet through conntrack twice unless it's for a
1162          * different zone.
1163          */
1164         cached = tcf_ct_skb_nfct_cached(net, skb, p->zone, force);
1165         if (!cached) {
1166                 if (tcf_ct_flow_table_lookup(p, skb, family)) {
1167                         skip_add = true;
1168                         goto do_nat;
1169                 }
1170
1171                 /* Associate skb with specified zone. */
1172                 if (tmpl) {
1173                         nf_conntrack_put(skb_nfct(skb));
1174                         nf_conntrack_get(&tmpl->ct_general);
1175                         nf_ct_set(skb, tmpl, IP_CT_NEW);
1176                 }
1177
1178                 state.hook = NF_INET_PRE_ROUTING;
1179                 state.net = net;
1180                 state.pf = family;
1181                 err = nf_conntrack_in(skb, &state);
1182                 if (err != NF_ACCEPT)
1183                         goto out_push;
1184         }
1185
1186 do_nat:
1187         ct = nf_ct_get(skb, &ctinfo);
1188         if (!ct)
1189                 goto out_push;
1190         nf_ct_deliver_cached_events(ct);
1191         nf_conn_act_ct_ext_fill(skb, ct, ctinfo);
1192
1193         err = tcf_ct_act_nat(skb, ct, ctinfo, p->ct_action, &p->range, commit);
1194         if (err != NF_ACCEPT)
1195                 goto drop;
1196
1197         if (commit) {
1198                 tcf_ct_act_set_mark(ct, p->mark, p->mark_mask);
1199                 tcf_ct_act_set_labels(ct, p->labels, p->labels_mask);
1200
1201                 if (!nf_ct_is_confirmed(ct))
1202                         nf_conn_act_ct_ext_add(skb, ct, ctinfo);
1203
1204                 /* This will take care of sending queued events
1205                  * even if the connection is already confirmed.
1206                  */
1207                 if (nf_conntrack_confirm(skb) != NF_ACCEPT)
1208                         goto drop;
1209         }
1210
1211         if (!skip_add)
1212                 tcf_ct_flow_table_process_conn(p->ct_ft, ct, ctinfo);
1213
1214 out_push:
1215         skb_push_rcsum(skb, nh_ofs);
1216
1217         tc_skb_cb(skb)->post_ct = true;
1218         tc_skb_cb(skb)->zone = p->zone;
1219 out_clear:
1220         if (defrag)
1221                 qdisc_skb_cb(skb)->pkt_len = skb->len;
1222         return retval;
1223
1224 out_frag:
1225         if (err != -EINPROGRESS)
1226                 tcf_action_inc_drop_qstats(&c->common);
1227         return TC_ACT_CONSUMED;
1228
1229 drop:
1230         tcf_action_inc_drop_qstats(&c->common);
1231         return TC_ACT_SHOT;
1232 }
1233
1234 static const struct nla_policy ct_policy[TCA_CT_MAX + 1] = {
1235         [TCA_CT_ACTION] = { .type = NLA_U16 },
1236         [TCA_CT_PARMS] = NLA_POLICY_EXACT_LEN(sizeof(struct tc_ct)),
1237         [TCA_CT_ZONE] = { .type = NLA_U16 },
1238         [TCA_CT_MARK] = { .type = NLA_U32 },
1239         [TCA_CT_MARK_MASK] = { .type = NLA_U32 },
1240         [TCA_CT_LABELS] = { .type = NLA_BINARY,
1241                             .len = 128 / BITS_PER_BYTE },
1242         [TCA_CT_LABELS_MASK] = { .type = NLA_BINARY,
1243                                  .len = 128 / BITS_PER_BYTE },
1244         [TCA_CT_NAT_IPV4_MIN] = { .type = NLA_U32 },
1245         [TCA_CT_NAT_IPV4_MAX] = { .type = NLA_U32 },
1246         [TCA_CT_NAT_IPV6_MIN] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
1247         [TCA_CT_NAT_IPV6_MAX] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
1248         [TCA_CT_NAT_PORT_MIN] = { .type = NLA_U16 },
1249         [TCA_CT_NAT_PORT_MAX] = { .type = NLA_U16 },
1250 };
1251
1252 static int tcf_ct_fill_params_nat(struct tcf_ct_params *p,
1253                                   struct tc_ct *parm,
1254                                   struct nlattr **tb,
1255                                   struct netlink_ext_ack *extack)
1256 {
1257         struct nf_nat_range2 *range;
1258
1259         if (!(p->ct_action & TCA_CT_ACT_NAT))
1260                 return 0;
1261
1262         if (!IS_ENABLED(CONFIG_NF_NAT)) {
1263                 NL_SET_ERR_MSG_MOD(extack, "Netfilter nat isn't enabled in kernel");
1264                 return -EOPNOTSUPP;
1265         }
1266
1267         if (!(p->ct_action & (TCA_CT_ACT_NAT_SRC | TCA_CT_ACT_NAT_DST)))
1268                 return 0;
1269
1270         if ((p->ct_action & TCA_CT_ACT_NAT_SRC) &&
1271             (p->ct_action & TCA_CT_ACT_NAT_DST)) {
1272                 NL_SET_ERR_MSG_MOD(extack, "dnat and snat can't be enabled at the same time");
1273                 return -EOPNOTSUPP;
1274         }
1275
1276         range = &p->range;
1277         if (tb[TCA_CT_NAT_IPV4_MIN]) {
1278                 struct nlattr *max_attr = tb[TCA_CT_NAT_IPV4_MAX];
1279
1280                 p->ipv4_range = true;
1281                 range->flags |= NF_NAT_RANGE_MAP_IPS;
1282                 range->min_addr.ip =
1283                         nla_get_in_addr(tb[TCA_CT_NAT_IPV4_MIN]);
1284
1285                 range->max_addr.ip = max_attr ?
1286                                      nla_get_in_addr(max_attr) :
1287                                      range->min_addr.ip;
1288         } else if (tb[TCA_CT_NAT_IPV6_MIN]) {
1289                 struct nlattr *max_attr = tb[TCA_CT_NAT_IPV6_MAX];
1290
1291                 p->ipv4_range = false;
1292                 range->flags |= NF_NAT_RANGE_MAP_IPS;
1293                 range->min_addr.in6 =
1294                         nla_get_in6_addr(tb[TCA_CT_NAT_IPV6_MIN]);
1295
1296                 range->max_addr.in6 = max_attr ?
1297                                       nla_get_in6_addr(max_attr) :
1298                                       range->min_addr.in6;
1299         }
1300
1301         if (tb[TCA_CT_NAT_PORT_MIN]) {
1302                 range->flags |= NF_NAT_RANGE_PROTO_SPECIFIED;
1303                 range->min_proto.all = nla_get_be16(tb[TCA_CT_NAT_PORT_MIN]);
1304
1305                 range->max_proto.all = tb[TCA_CT_NAT_PORT_MAX] ?
1306                                        nla_get_be16(tb[TCA_CT_NAT_PORT_MAX]) :
1307                                        range->min_proto.all;
1308         }
1309
1310         return 0;
1311 }
1312
1313 static void tcf_ct_set_key_val(struct nlattr **tb,
1314                                void *val, int val_type,
1315                                void *mask, int mask_type,
1316                                int len)
1317 {
1318         if (!tb[val_type])
1319                 return;
1320         nla_memcpy(val, tb[val_type], len);
1321
1322         if (!mask)
1323                 return;
1324
1325         if (mask_type == TCA_CT_UNSPEC || !tb[mask_type])
1326                 memset(mask, 0xff, len);
1327         else
1328                 nla_memcpy(mask, tb[mask_type], len);
1329 }
1330
1331 static int tcf_ct_fill_params(struct net *net,
1332                               struct tcf_ct_params *p,
1333                               struct tc_ct *parm,
1334                               struct nlattr **tb,
1335                               struct netlink_ext_ack *extack)
1336 {
1337         struct tc_ct_action_net *tn = net_generic(net, act_ct_ops.net_id);
1338         struct nf_conntrack_zone zone;
1339         struct nf_conn *tmpl;
1340         int err;
1341
1342         p->zone = NF_CT_DEFAULT_ZONE_ID;
1343
1344         tcf_ct_set_key_val(tb,
1345                            &p->ct_action, TCA_CT_ACTION,
1346                            NULL, TCA_CT_UNSPEC,
1347                            sizeof(p->ct_action));
1348
1349         if (p->ct_action & TCA_CT_ACT_CLEAR)
1350                 return 0;
1351
1352         err = tcf_ct_fill_params_nat(p, parm, tb, extack);
1353         if (err)
1354                 return err;
1355
1356         if (tb[TCA_CT_MARK]) {
1357                 if (!IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)) {
1358                         NL_SET_ERR_MSG_MOD(extack, "Conntrack mark isn't enabled.");
1359                         return -EOPNOTSUPP;
1360                 }
1361                 tcf_ct_set_key_val(tb,
1362                                    &p->mark, TCA_CT_MARK,
1363                                    &p->mark_mask, TCA_CT_MARK_MASK,
1364                                    sizeof(p->mark));
1365         }
1366
1367         if (tb[TCA_CT_LABELS]) {
1368                 if (!IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS)) {
1369                         NL_SET_ERR_MSG_MOD(extack, "Conntrack labels isn't enabled.");
1370                         return -EOPNOTSUPP;
1371                 }
1372
1373                 if (!tn->labels) {
1374                         NL_SET_ERR_MSG_MOD(extack, "Failed to set connlabel length");
1375                         return -EOPNOTSUPP;
1376                 }
1377                 tcf_ct_set_key_val(tb,
1378                                    p->labels, TCA_CT_LABELS,
1379                                    p->labels_mask, TCA_CT_LABELS_MASK,
1380                                    sizeof(p->labels));
1381         }
1382
1383         if (tb[TCA_CT_ZONE]) {
1384                 if (!IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES)) {
1385                         NL_SET_ERR_MSG_MOD(extack, "Conntrack zones isn't enabled.");
1386                         return -EOPNOTSUPP;
1387                 }
1388
1389                 tcf_ct_set_key_val(tb,
1390                                    &p->zone, TCA_CT_ZONE,
1391                                    NULL, TCA_CT_UNSPEC,
1392                                    sizeof(p->zone));
1393         }
1394
1395         nf_ct_zone_init(&zone, p->zone, NF_CT_DEFAULT_ZONE_DIR, 0);
1396         tmpl = nf_ct_tmpl_alloc(net, &zone, GFP_KERNEL);
1397         if (!tmpl) {
1398                 NL_SET_ERR_MSG_MOD(extack, "Failed to allocate conntrack template");
1399                 return -ENOMEM;
1400         }
1401         __set_bit(IPS_CONFIRMED_BIT, &tmpl->status);
1402         p->tmpl = tmpl;
1403
1404         return 0;
1405 }
1406
1407 static int tcf_ct_init(struct net *net, struct nlattr *nla,
1408                        struct nlattr *est, struct tc_action **a,
1409                        struct tcf_proto *tp, u32 flags,
1410                        struct netlink_ext_ack *extack)
1411 {
1412         struct tc_action_net *tn = net_generic(net, act_ct_ops.net_id);
1413         bool bind = flags & TCA_ACT_FLAGS_BIND;
1414         struct tcf_ct_params *params = NULL;
1415         struct nlattr *tb[TCA_CT_MAX + 1];
1416         struct tcf_chain *goto_ch = NULL;
1417         struct tc_ct *parm;
1418         struct tcf_ct *c;
1419         int err, res = 0;
1420         u32 index;
1421
1422         if (!nla) {
1423                 NL_SET_ERR_MSG_MOD(extack, "Ct requires attributes to be passed");
1424                 return -EINVAL;
1425         }
1426
1427         err = nla_parse_nested(tb, TCA_CT_MAX, nla, ct_policy, extack);
1428         if (err < 0)
1429                 return err;
1430
1431         if (!tb[TCA_CT_PARMS]) {
1432                 NL_SET_ERR_MSG_MOD(extack, "Missing required ct parameters");
1433                 return -EINVAL;
1434         }
1435         parm = nla_data(tb[TCA_CT_PARMS]);
1436         index = parm->index;
1437         err = tcf_idr_check_alloc(tn, &index, a, bind);
1438         if (err < 0)
1439                 return err;
1440
1441         if (!err) {
1442                 err = tcf_idr_create_from_flags(tn, index, est, a,
1443                                                 &act_ct_ops, bind, flags);
1444                 if (err) {
1445                         tcf_idr_cleanup(tn, index);
1446                         return err;
1447                 }
1448                 res = ACT_P_CREATED;
1449         } else {
1450                 if (bind)
1451                         return 0;
1452
1453                 if (!(flags & TCA_ACT_FLAGS_REPLACE)) {
1454                         tcf_idr_release(*a, bind);
1455                         return -EEXIST;
1456                 }
1457         }
1458         err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack);
1459         if (err < 0)
1460                 goto cleanup;
1461
1462         c = to_ct(*a);
1463
1464         params = kzalloc(sizeof(*params), GFP_KERNEL);
1465         if (unlikely(!params)) {
1466                 err = -ENOMEM;
1467                 goto cleanup;
1468         }
1469
1470         err = tcf_ct_fill_params(net, params, parm, tb, extack);
1471         if (err)
1472                 goto cleanup;
1473
1474         err = tcf_ct_flow_table_get(net, params);
1475         if (err)
1476                 goto cleanup;
1477
1478         spin_lock_bh(&c->tcf_lock);
1479         goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch);
1480         params = rcu_replace_pointer(c->params, params,
1481                                      lockdep_is_held(&c->tcf_lock));
1482         spin_unlock_bh(&c->tcf_lock);
1483
1484         if (goto_ch)
1485                 tcf_chain_put_by_act(goto_ch);
1486         if (params)
1487                 call_rcu(&params->rcu, tcf_ct_params_free_rcu);
1488
1489         return res;
1490
1491 cleanup:
1492         if (goto_ch)
1493                 tcf_chain_put_by_act(goto_ch);
1494         if (params)
1495                 tcf_ct_params_free(params);
1496         tcf_idr_release(*a, bind);
1497         return err;
1498 }
1499
1500 static void tcf_ct_cleanup(struct tc_action *a)
1501 {
1502         struct tcf_ct_params *params;
1503         struct tcf_ct *c = to_ct(a);
1504
1505         params = rcu_dereference_protected(c->params, 1);
1506         if (params)
1507                 call_rcu(&params->rcu, tcf_ct_params_free_rcu);
1508 }
1509
1510 static int tcf_ct_dump_key_val(struct sk_buff *skb,
1511                                void *val, int val_type,
1512                                void *mask, int mask_type,
1513                                int len)
1514 {
1515         int err;
1516
1517         if (mask && !memchr_inv(mask, 0, len))
1518                 return 0;
1519
1520         err = nla_put(skb, val_type, len, val);
1521         if (err)
1522                 return err;
1523
1524         if (mask_type != TCA_CT_UNSPEC) {
1525                 err = nla_put(skb, mask_type, len, mask);
1526                 if (err)
1527                         return err;
1528         }
1529
1530         return 0;
1531 }
1532
1533 static int tcf_ct_dump_nat(struct sk_buff *skb, struct tcf_ct_params *p)
1534 {
1535         struct nf_nat_range2 *range = &p->range;
1536
1537         if (!(p->ct_action & TCA_CT_ACT_NAT))
1538                 return 0;
1539
1540         if (!(p->ct_action & (TCA_CT_ACT_NAT_SRC | TCA_CT_ACT_NAT_DST)))
1541                 return 0;
1542
1543         if (range->flags & NF_NAT_RANGE_MAP_IPS) {
1544                 if (p->ipv4_range) {
1545                         if (nla_put_in_addr(skb, TCA_CT_NAT_IPV4_MIN,
1546                                             range->min_addr.ip))
1547                                 return -1;
1548                         if (nla_put_in_addr(skb, TCA_CT_NAT_IPV4_MAX,
1549                                             range->max_addr.ip))
1550                                 return -1;
1551                 } else {
1552                         if (nla_put_in6_addr(skb, TCA_CT_NAT_IPV6_MIN,
1553                                              &range->min_addr.in6))
1554                                 return -1;
1555                         if (nla_put_in6_addr(skb, TCA_CT_NAT_IPV6_MAX,
1556                                              &range->max_addr.in6))
1557                                 return -1;
1558                 }
1559         }
1560
1561         if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) {
1562                 if (nla_put_be16(skb, TCA_CT_NAT_PORT_MIN,
1563                                  range->min_proto.all))
1564                         return -1;
1565                 if (nla_put_be16(skb, TCA_CT_NAT_PORT_MAX,
1566                                  range->max_proto.all))
1567                         return -1;
1568         }
1569
1570         return 0;
1571 }
1572
1573 static inline int tcf_ct_dump(struct sk_buff *skb, struct tc_action *a,
1574                               int bind, int ref)
1575 {
1576         unsigned char *b = skb_tail_pointer(skb);
1577         struct tcf_ct *c = to_ct(a);
1578         struct tcf_ct_params *p;
1579
1580         struct tc_ct opt = {
1581                 .index   = c->tcf_index,
1582                 .refcnt  = refcount_read(&c->tcf_refcnt) - ref,
1583                 .bindcnt = atomic_read(&c->tcf_bindcnt) - bind,
1584         };
1585         struct tcf_t t;
1586
1587         spin_lock_bh(&c->tcf_lock);
1588         p = rcu_dereference_protected(c->params,
1589                                       lockdep_is_held(&c->tcf_lock));
1590         opt.action = c->tcf_action;
1591
1592         if (tcf_ct_dump_key_val(skb,
1593                                 &p->ct_action, TCA_CT_ACTION,
1594                                 NULL, TCA_CT_UNSPEC,
1595                                 sizeof(p->ct_action)))
1596                 goto nla_put_failure;
1597
1598         if (p->ct_action & TCA_CT_ACT_CLEAR)
1599                 goto skip_dump;
1600
1601         if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) &&
1602             tcf_ct_dump_key_val(skb,
1603                                 &p->mark, TCA_CT_MARK,
1604                                 &p->mark_mask, TCA_CT_MARK_MASK,
1605                                 sizeof(p->mark)))
1606                 goto nla_put_failure;
1607
1608         if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) &&
1609             tcf_ct_dump_key_val(skb,
1610                                 p->labels, TCA_CT_LABELS,
1611                                 p->labels_mask, TCA_CT_LABELS_MASK,
1612                                 sizeof(p->labels)))
1613                 goto nla_put_failure;
1614
1615         if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) &&
1616             tcf_ct_dump_key_val(skb,
1617                                 &p->zone, TCA_CT_ZONE,
1618                                 NULL, TCA_CT_UNSPEC,
1619                                 sizeof(p->zone)))
1620                 goto nla_put_failure;
1621
1622         if (tcf_ct_dump_nat(skb, p))
1623                 goto nla_put_failure;
1624
1625 skip_dump:
1626         if (nla_put(skb, TCA_CT_PARMS, sizeof(opt), &opt))
1627                 goto nla_put_failure;
1628
1629         tcf_tm_dump(&t, &c->tcf_tm);
1630         if (nla_put_64bit(skb, TCA_CT_TM, sizeof(t), &t, TCA_CT_PAD))
1631                 goto nla_put_failure;
1632         spin_unlock_bh(&c->tcf_lock);
1633
1634         return skb->len;
1635 nla_put_failure:
1636         spin_unlock_bh(&c->tcf_lock);
1637         nlmsg_trim(skb, b);
1638         return -1;
1639 }
1640
1641 static void tcf_stats_update(struct tc_action *a, u64 bytes, u64 packets,
1642                              u64 drops, u64 lastuse, bool hw)
1643 {
1644         struct tcf_ct *c = to_ct(a);
1645
1646         tcf_action_update_stats(a, bytes, packets, drops, hw);
1647         c->tcf_tm.lastuse = max_t(u64, c->tcf_tm.lastuse, lastuse);
1648 }
1649
1650 static int tcf_ct_offload_act_setup(struct tc_action *act, void *entry_data,
1651                                     u32 *index_inc, bool bind,
1652                                     struct netlink_ext_ack *extack)
1653 {
1654         if (bind) {
1655                 struct flow_action_entry *entry = entry_data;
1656
1657                 entry->id = FLOW_ACTION_CT;
1658                 entry->ct.action = tcf_ct_action(act);
1659                 entry->ct.zone = tcf_ct_zone(act);
1660                 entry->ct.flow_table = tcf_ct_ft(act);
1661                 *index_inc = 1;
1662         } else {
1663                 struct flow_offload_action *fl_action = entry_data;
1664
1665                 fl_action->id = FLOW_ACTION_CT;
1666         }
1667
1668         return 0;
1669 }
1670
1671 static struct tc_action_ops act_ct_ops = {
1672         .kind           =       "ct",
1673         .id             =       TCA_ID_CT,
1674         .owner          =       THIS_MODULE,
1675         .act            =       tcf_ct_act,
1676         .dump           =       tcf_ct_dump,
1677         .init           =       tcf_ct_init,
1678         .cleanup        =       tcf_ct_cleanup,
1679         .stats_update   =       tcf_stats_update,
1680         .offload_act_setup =    tcf_ct_offload_act_setup,
1681         .size           =       sizeof(struct tcf_ct),
1682 };
1683
1684 static __net_init int ct_init_net(struct net *net)
1685 {
1686         unsigned int n_bits = sizeof_field(struct tcf_ct_params, labels) * 8;
1687         struct tc_ct_action_net *tn = net_generic(net, act_ct_ops.net_id);
1688
1689         if (nf_connlabels_get(net, n_bits - 1)) {
1690                 tn->labels = false;
1691                 pr_err("act_ct: Failed to set connlabels length");
1692         } else {
1693                 tn->labels = true;
1694         }
1695
1696         return tc_action_net_init(net, &tn->tn, &act_ct_ops);
1697 }
1698
1699 static void __net_exit ct_exit_net(struct list_head *net_list)
1700 {
1701         struct net *net;
1702
1703         rtnl_lock();
1704         list_for_each_entry(net, net_list, exit_list) {
1705                 struct tc_ct_action_net *tn = net_generic(net, act_ct_ops.net_id);
1706
1707                 if (tn->labels)
1708                         nf_connlabels_put(net);
1709         }
1710         rtnl_unlock();
1711
1712         tc_action_net_exit(net_list, act_ct_ops.net_id);
1713 }
1714
1715 static struct pernet_operations ct_net_ops = {
1716         .init = ct_init_net,
1717         .exit_batch = ct_exit_net,
1718         .id   = &act_ct_ops.net_id,
1719         .size = sizeof(struct tc_ct_action_net),
1720 };
1721
1722 static int __init ct_init_module(void)
1723 {
1724         int err;
1725
1726         act_ct_wq = alloc_ordered_workqueue("act_ct_workqueue", 0);
1727         if (!act_ct_wq)
1728                 return -ENOMEM;
1729
1730         err = tcf_ct_flow_tables_init();
1731         if (err)
1732                 goto err_tbl_init;
1733
1734         err = tcf_register_action(&act_ct_ops, &ct_net_ops);
1735         if (err)
1736                 goto err_register;
1737
1738         static_branch_inc(&tcf_frag_xmit_count);
1739
1740         return 0;
1741
1742 err_register:
1743         tcf_ct_flow_tables_uninit();
1744 err_tbl_init:
1745         destroy_workqueue(act_ct_wq);
1746         return err;
1747 }
1748
1749 static void __exit ct_cleanup_module(void)
1750 {
1751         static_branch_dec(&tcf_frag_xmit_count);
1752         tcf_unregister_action(&act_ct_ops, &ct_net_ops);
1753         tcf_ct_flow_tables_uninit();
1754         destroy_workqueue(act_ct_wq);
1755 }
1756
1757 module_init(ct_init_module);
1758 module_exit(ct_cleanup_module);
1759 MODULE_AUTHOR("Paul Blakey <paulb@mellanox.com>");
1760 MODULE_AUTHOR("Yossi Kuperman <yossiku@mellanox.com>");
1761 MODULE_AUTHOR("Marcelo Ricardo Leitner <marcelo.leitner@gmail.com>");
1762 MODULE_DESCRIPTION("Connection tracking action");
1763 MODULE_LICENSE("GPL v2");