GNU Linux-libre 5.19-rc6-gnu
[releases.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/sch_generic.h>
15 #include <net/net_namespace.h>
16 #include <net/ip.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <net/icmp.h>
20 #include <net/mld.h>
21 #include <net/amt.h>
22 #include <uapi/linux/amt.h>
23 #include <linux/security.h>
24 #include <net/gro_cells.h>
25 #include <net/ipv6.h>
26 #include <net/if_inet6.h>
27 #include <net/ndisc.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/inet_common.h>
31 #include <net/ip6_checksum.h>
32
33 static struct workqueue_struct *amt_wq;
34
35 static HLIST_HEAD(source_gc_list);
36 /* Lock for source_gc_list */
37 static spinlock_t source_gc_lock;
38 static struct delayed_work source_gc_wq;
39 static char *status_str[] = {
40         "AMT_STATUS_INIT",
41         "AMT_STATUS_SENT_DISCOVERY",
42         "AMT_STATUS_RECEIVED_DISCOVERY",
43         "AMT_STATUS_SENT_ADVERTISEMENT",
44         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
45         "AMT_STATUS_SENT_REQUEST",
46         "AMT_STATUS_RECEIVED_REQUEST",
47         "AMT_STATUS_SENT_QUERY",
48         "AMT_STATUS_RECEIVED_QUERY",
49         "AMT_STATUS_SENT_UPDATE",
50         "AMT_STATUS_RECEIVED_UPDATE",
51 };
52
53 static char *type_str[] = {
54         "", /* Type 0 is not defined */
55         "AMT_MSG_DISCOVERY",
56         "AMT_MSG_ADVERTISEMENT",
57         "AMT_MSG_REQUEST",
58         "AMT_MSG_MEMBERSHIP_QUERY",
59         "AMT_MSG_MEMBERSHIP_UPDATE",
60         "AMT_MSG_MULTICAST_DATA",
61         "AMT_MSG_TEARDOWN",
62 };
63
64 static char *action_str[] = {
65         "AMT_ACT_GMI",
66         "AMT_ACT_GMI_ZERO",
67         "AMT_ACT_GT",
68         "AMT_ACT_STATUS_FWD_NEW",
69         "AMT_ACT_STATUS_D_FWD_NEW",
70         "AMT_ACT_STATUS_NONE_NEW",
71 };
72
73 static struct igmpv3_grec igmpv3_zero_grec;
74
75 #if IS_ENABLED(CONFIG_IPV6)
76 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
77 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
78 static struct mld2_grec mldv2_zero_grec;
79 #endif
80
81 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
82 {
83         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
84                      sizeof_field(struct sk_buff, cb));
85
86         return (struct amt_skb_cb *)((void *)skb->cb +
87                 sizeof(struct qdisc_skb_cb));
88 }
89
90 static void __amt_source_gc_work(void)
91 {
92         struct amt_source_node *snode;
93         struct hlist_head gc_list;
94         struct hlist_node *t;
95
96         spin_lock_bh(&source_gc_lock);
97         hlist_move_list(&source_gc_list, &gc_list);
98         spin_unlock_bh(&source_gc_lock);
99
100         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
101                 hlist_del_rcu(&snode->node);
102                 kfree_rcu(snode, rcu);
103         }
104 }
105
106 static void amt_source_gc_work(struct work_struct *work)
107 {
108         __amt_source_gc_work();
109
110         spin_lock_bh(&source_gc_lock);
111         mod_delayed_work(amt_wq, &source_gc_wq,
112                          msecs_to_jiffies(AMT_GC_INTERVAL));
113         spin_unlock_bh(&source_gc_lock);
114 }
115
116 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
117 {
118         return !memcmp(a, b, sizeof(union amt_addr));
119 }
120
121 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
122 {
123         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
124
125         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
126 }
127
128 static bool amt_status_filter(struct amt_source_node *snode,
129                               enum amt_filter filter)
130 {
131         bool rc = false;
132
133         switch (filter) {
134         case AMT_FILTER_FWD:
135                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
136                     snode->flags == AMT_SOURCE_OLD)
137                         rc = true;
138                 break;
139         case AMT_FILTER_D_FWD:
140                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
141                     snode->flags == AMT_SOURCE_OLD)
142                         rc = true;
143                 break;
144         case AMT_FILTER_FWD_NEW:
145                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
146                     snode->flags == AMT_SOURCE_NEW)
147                         rc = true;
148                 break;
149         case AMT_FILTER_D_FWD_NEW:
150                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
151                     snode->flags == AMT_SOURCE_NEW)
152                         rc = true;
153                 break;
154         case AMT_FILTER_ALL:
155                 rc = true;
156                 break;
157         case AMT_FILTER_NONE_NEW:
158                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
159                     snode->flags == AMT_SOURCE_NEW)
160                         rc = true;
161                 break;
162         case AMT_FILTER_BOTH:
163                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
164                      snode->status == AMT_SOURCE_STATUS_FWD) &&
165                     snode->flags == AMT_SOURCE_OLD)
166                         rc = true;
167                 break;
168         case AMT_FILTER_BOTH_NEW:
169                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
170                      snode->status == AMT_SOURCE_STATUS_FWD) &&
171                     snode->flags == AMT_SOURCE_NEW)
172                         rc = true;
173                 break;
174         default:
175                 WARN_ON_ONCE(1);
176                 break;
177         }
178
179         return rc;
180 }
181
182 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
183                                               struct amt_group_node *gnode,
184                                               enum amt_filter filter,
185                                               union amt_addr *src)
186 {
187         u32 hash = amt_source_hash(tunnel, src);
188         struct amt_source_node *snode;
189
190         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
191                 if (amt_status_filter(snode, filter) &&
192                     amt_addr_equal(&snode->source_addr, src))
193                         return snode;
194
195         return NULL;
196 }
197
198 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
199 {
200         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
201
202         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
203 }
204
205 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
206                                                union amt_addr *group,
207                                                union amt_addr *host,
208                                                bool v6)
209 {
210         u32 hash = amt_group_hash(tunnel, group);
211         struct amt_group_node *gnode;
212
213         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
214                 if (amt_addr_equal(&gnode->group_addr, group) &&
215                     amt_addr_equal(&gnode->host_addr, host) &&
216                     gnode->v6 == v6)
217                         return gnode;
218         }
219
220         return NULL;
221 }
222
223 static void amt_destroy_source(struct amt_source_node *snode)
224 {
225         struct amt_group_node *gnode = snode->gnode;
226         struct amt_tunnel_list *tunnel;
227
228         tunnel = gnode->tunnel_list;
229
230         if (!gnode->v6) {
231                 netdev_dbg(snode->gnode->amt->dev,
232                            "Delete source %pI4 from %pI4\n",
233                            &snode->source_addr.ip4,
234                            &gnode->group_addr.ip4);
235 #if IS_ENABLED(CONFIG_IPV6)
236         } else {
237                 netdev_dbg(snode->gnode->amt->dev,
238                            "Delete source %pI6 from %pI6\n",
239                            &snode->source_addr.ip6,
240                            &gnode->group_addr.ip6);
241 #endif
242         }
243
244         cancel_delayed_work(&snode->source_timer);
245         hlist_del_init_rcu(&snode->node);
246         tunnel->nr_sources--;
247         gnode->nr_sources--;
248         spin_lock_bh(&source_gc_lock);
249         hlist_add_head_rcu(&snode->node, &source_gc_list);
250         spin_unlock_bh(&source_gc_lock);
251 }
252
253 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
254 {
255         struct amt_source_node *snode;
256         struct hlist_node *t;
257         int i;
258
259         if (cancel_delayed_work(&gnode->group_timer))
260                 dev_put(amt->dev);
261         hlist_del_rcu(&gnode->node);
262         gnode->tunnel_list->nr_groups--;
263
264         if (!gnode->v6)
265                 netdev_dbg(amt->dev, "Leave group %pI4\n",
266                            &gnode->group_addr.ip4);
267 #if IS_ENABLED(CONFIG_IPV6)
268         else
269                 netdev_dbg(amt->dev, "Leave group %pI6\n",
270                            &gnode->group_addr.ip6);
271 #endif
272         for (i = 0; i < amt->hash_buckets; i++)
273                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
274                         amt_destroy_source(snode);
275
276         /* tunnel->lock was acquired outside of amt_del_group()
277          * But rcu_read_lock() was acquired too so It's safe.
278          */
279         kfree_rcu(gnode, rcu);
280 }
281
282 /* If a source timer expires with a router filter-mode for the group of
283  * INCLUDE, the router concludes that traffic from this particular
284  * source is no longer desired on the attached network, and deletes the
285  * associated source record.
286  */
287 static void amt_source_work(struct work_struct *work)
288 {
289         struct amt_source_node *snode = container_of(to_delayed_work(work),
290                                                      struct amt_source_node,
291                                                      source_timer);
292         struct amt_group_node *gnode = snode->gnode;
293         struct amt_dev *amt = gnode->amt;
294         struct amt_tunnel_list *tunnel;
295
296         tunnel = gnode->tunnel_list;
297         spin_lock_bh(&tunnel->lock);
298         rcu_read_lock();
299         if (gnode->filter_mode == MCAST_INCLUDE) {
300                 amt_destroy_source(snode);
301                 if (!gnode->nr_sources)
302                         amt_del_group(amt, gnode);
303         } else {
304                 /* When a router filter-mode for a group is EXCLUDE,
305                  * source records are only deleted when the group timer expires
306                  */
307                 snode->status = AMT_SOURCE_STATUS_D_FWD;
308         }
309         rcu_read_unlock();
310         spin_unlock_bh(&tunnel->lock);
311 }
312
313 static void amt_act_src(struct amt_tunnel_list *tunnel,
314                         struct amt_group_node *gnode,
315                         struct amt_source_node *snode,
316                         enum amt_act act)
317 {
318         struct amt_dev *amt = tunnel->amt;
319
320         switch (act) {
321         case AMT_ACT_GMI:
322                 mod_delayed_work(amt_wq, &snode->source_timer,
323                                  msecs_to_jiffies(amt_gmi(amt)));
324                 break;
325         case AMT_ACT_GMI_ZERO:
326                 cancel_delayed_work(&snode->source_timer);
327                 break;
328         case AMT_ACT_GT:
329                 mod_delayed_work(amt_wq, &snode->source_timer,
330                                  gnode->group_timer.timer.expires);
331                 break;
332         case AMT_ACT_STATUS_FWD_NEW:
333                 snode->status = AMT_SOURCE_STATUS_FWD;
334                 snode->flags = AMT_SOURCE_NEW;
335                 break;
336         case AMT_ACT_STATUS_D_FWD_NEW:
337                 snode->status = AMT_SOURCE_STATUS_D_FWD;
338                 snode->flags = AMT_SOURCE_NEW;
339                 break;
340         case AMT_ACT_STATUS_NONE_NEW:
341                 cancel_delayed_work(&snode->source_timer);
342                 snode->status = AMT_SOURCE_STATUS_NONE;
343                 snode->flags = AMT_SOURCE_NEW;
344                 break;
345         default:
346                 WARN_ON_ONCE(1);
347                 return;
348         }
349
350         if (!gnode->v6)
351                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
352                            &snode->source_addr.ip4,
353                            &gnode->group_addr.ip4,
354                            action_str[act]);
355 #if IS_ENABLED(CONFIG_IPV6)
356         else
357                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
358                            &snode->source_addr.ip6,
359                            &gnode->group_addr.ip6,
360                            action_str[act]);
361 #endif
362 }
363
364 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
365                                                union amt_addr *src)
366 {
367         struct amt_source_node *snode;
368
369         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
370         if (!snode)
371                 return NULL;
372
373         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
374         snode->gnode = gnode;
375         snode->status = AMT_SOURCE_STATUS_NONE;
376         snode->flags = AMT_SOURCE_NEW;
377         INIT_HLIST_NODE(&snode->node);
378         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
379
380         return snode;
381 }
382
383 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
384  *
385  *  Router Mode          Filter Timer         Actions/Comments
386  *  -----------       -----------------       ----------------
387  *
388  *    INCLUDE             Not Used            All listeners in
389  *                                            INCLUDE mode.
390  *
391  *    EXCLUDE             Timer > 0           At least one listener
392  *                                            in EXCLUDE mode.
393  *
394  *    EXCLUDE             Timer == 0          No more listeners in
395  *                                            EXCLUDE mode for the
396  *                                            multicast address.
397  *                                            If the Requested List
398  *                                            is empty, delete
399  *                                            Multicast Address
400  *                                            Record.  If not, switch
401  *                                            to INCLUDE filter mode;
402  *                                            the sources in the
403  *                                            Requested List are
404  *                                            moved to the Include
405  *                                            List, and the Exclude
406  *                                            List is deleted.
407  */
408 static void amt_group_work(struct work_struct *work)
409 {
410         struct amt_group_node *gnode = container_of(to_delayed_work(work),
411                                                     struct amt_group_node,
412                                                     group_timer);
413         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
414         struct amt_dev *amt = gnode->amt;
415         struct amt_source_node *snode;
416         bool delete_group = true;
417         struct hlist_node *t;
418         int i, buckets;
419
420         buckets = amt->hash_buckets;
421
422         spin_lock_bh(&tunnel->lock);
423         if (gnode->filter_mode == MCAST_INCLUDE) {
424                 /* Not Used */
425                 spin_unlock_bh(&tunnel->lock);
426                 goto out;
427         }
428
429         rcu_read_lock();
430         for (i = 0; i < buckets; i++) {
431                 hlist_for_each_entry_safe(snode, t,
432                                           &gnode->sources[i], node) {
433                         if (!delayed_work_pending(&snode->source_timer) ||
434                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
435                                 amt_destroy_source(snode);
436                         } else {
437                                 delete_group = false;
438                                 snode->status = AMT_SOURCE_STATUS_FWD;
439                         }
440                 }
441         }
442         if (delete_group)
443                 amt_del_group(amt, gnode);
444         else
445                 gnode->filter_mode = MCAST_INCLUDE;
446         rcu_read_unlock();
447         spin_unlock_bh(&tunnel->lock);
448 out:
449         dev_put(amt->dev);
450 }
451
452 /* Non-existant group is created as INCLUDE {empty}:
453  *
454  * RFC 3376 - 5.1. Action on Change of Interface State
455  *
456  * If no interface state existed for that multicast address before
457  * the change (i.e., the change consisted of creating a new
458  * per-interface record), or if no state exists after the change
459  * (i.e., the change consisted of deleting a per-interface record),
460  * then the "non-existent" state is considered to have a filter mode
461  * of INCLUDE and an empty source list.
462  */
463 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
464                                             struct amt_tunnel_list *tunnel,
465                                             union amt_addr *group,
466                                             union amt_addr *host,
467                                             bool v6)
468 {
469         struct amt_group_node *gnode;
470         u32 hash;
471         int i;
472
473         if (tunnel->nr_groups >= amt->max_groups)
474                 return ERR_PTR(-ENOSPC);
475
476         gnode = kzalloc(sizeof(*gnode) +
477                         (sizeof(struct hlist_head) * amt->hash_buckets),
478                         GFP_ATOMIC);
479         if (unlikely(!gnode))
480                 return ERR_PTR(-ENOMEM);
481
482         gnode->amt = amt;
483         gnode->group_addr = *group;
484         gnode->host_addr = *host;
485         gnode->v6 = v6;
486         gnode->tunnel_list = tunnel;
487         gnode->filter_mode = MCAST_INCLUDE;
488         INIT_HLIST_NODE(&gnode->node);
489         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
490         for (i = 0; i < amt->hash_buckets; i++)
491                 INIT_HLIST_HEAD(&gnode->sources[i]);
492
493         hash = amt_group_hash(tunnel, group);
494         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
495         tunnel->nr_groups++;
496
497         if (!gnode->v6)
498                 netdev_dbg(amt->dev, "Join group %pI4\n",
499                            &gnode->group_addr.ip4);
500 #if IS_ENABLED(CONFIG_IPV6)
501         else
502                 netdev_dbg(amt->dev, "Join group %pI6\n",
503                            &gnode->group_addr.ip6);
504 #endif
505
506         return gnode;
507 }
508
509 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
510 {
511         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
512         int hlen = LL_RESERVED_SPACE(amt->dev);
513         int tlen = amt->dev->needed_tailroom;
514         struct igmpv3_query *ihv3;
515         void *csum_start = NULL;
516         __sum16 *csum = NULL;
517         struct sk_buff *skb;
518         struct ethhdr *eth;
519         struct iphdr *iph;
520         unsigned int len;
521         int offset;
522
523         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
524         skb = netdev_alloc_skb_ip_align(amt->dev, len);
525         if (!skb)
526                 return NULL;
527
528         skb_reserve(skb, hlen);
529         skb_push(skb, sizeof(*eth));
530         skb->protocol = htons(ETH_P_IP);
531         skb_reset_mac_header(skb);
532         skb->priority = TC_PRIO_CONTROL;
533         skb_put(skb, sizeof(*iph));
534         skb_put_data(skb, ra, sizeof(ra));
535         skb_put(skb, sizeof(*ihv3));
536         skb_pull(skb, sizeof(*eth));
537         skb_reset_network_header(skb);
538
539         iph             = ip_hdr(skb);
540         iph->version    = 4;
541         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
542         iph->tos        = AMT_TOS;
543         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
544         iph->frag_off   = htons(IP_DF);
545         iph->ttl        = 1;
546         iph->id         = 0;
547         iph->protocol   = IPPROTO_IGMP;
548         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
549         iph->saddr      = htonl(INADDR_ANY);
550         ip_send_check(iph);
551
552         eth = eth_hdr(skb);
553         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
554         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
555         eth->h_proto = htons(ETH_P_IP);
556
557         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
558         skb_reset_transport_header(skb);
559         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
560         ihv3->code      = 1;
561         ihv3->group     = 0;
562         ihv3->qqic      = amt->qi;
563         ihv3->nsrcs     = 0;
564         ihv3->resv      = 0;
565         ihv3->suppress  = false;
566         ihv3->qrv       = amt->net->ipv4.sysctl_igmp_qrv;
567         ihv3->csum      = 0;
568         csum            = &ihv3->csum;
569         csum_start      = (void *)ihv3;
570         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
571         offset          = skb_transport_offset(skb);
572         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
573         skb->ip_summed  = CHECKSUM_NONE;
574
575         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
576
577         return skb;
578 }
579
580 static void __amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
581                                    bool validate)
582 {
583         if (validate && amt->status >= status)
584                 return;
585         netdev_dbg(amt->dev, "Update GW status %s -> %s",
586                    status_str[amt->status], status_str[status]);
587         amt->status = status;
588 }
589
590 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
591                                       enum amt_status status,
592                                       bool validate)
593 {
594         if (validate && tunnel->status >= status)
595                 return;
596         netdev_dbg(tunnel->amt->dev,
597                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
598                    &tunnel->ip4, ntohs(tunnel->source_port),
599                    status_str[tunnel->status], status_str[status]);
600         tunnel->status = status;
601 }
602
603 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
604                                  bool validate)
605 {
606         spin_lock_bh(&amt->lock);
607         __amt_update_gw_status(amt, status, validate);
608         spin_unlock_bh(&amt->lock);
609 }
610
611 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
612                                     enum amt_status status, bool validate)
613 {
614         spin_lock_bh(&tunnel->lock);
615         __amt_update_relay_status(tunnel, status, validate);
616         spin_unlock_bh(&tunnel->lock);
617 }
618
619 static void amt_send_discovery(struct amt_dev *amt)
620 {
621         struct amt_header_discovery *amtd;
622         int hlen, tlen, offset;
623         struct socket *sock;
624         struct udphdr *udph;
625         struct sk_buff *skb;
626         struct iphdr *iph;
627         struct rtable *rt;
628         struct flowi4 fl4;
629         u32 len;
630         int err;
631
632         rcu_read_lock();
633         sock = rcu_dereference(amt->sock);
634         if (!sock)
635                 goto out;
636
637         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
638                 goto out;
639
640         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
641                                    amt->discovery_ip, amt->local_ip,
642                                    amt->gw_port, amt->relay_port,
643                                    IPPROTO_UDP, 0,
644                                    amt->stream_dev->ifindex);
645         if (IS_ERR(rt)) {
646                 amt->dev->stats.tx_errors++;
647                 goto out;
648         }
649
650         hlen = LL_RESERVED_SPACE(amt->dev);
651         tlen = amt->dev->needed_tailroom;
652         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
653         skb = netdev_alloc_skb_ip_align(amt->dev, len);
654         if (!skb) {
655                 ip_rt_put(rt);
656                 amt->dev->stats.tx_errors++;
657                 goto out;
658         }
659
660         skb->priority = TC_PRIO_CONTROL;
661         skb_dst_set(skb, &rt->dst);
662
663         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
664         skb_reset_network_header(skb);
665         skb_put(skb, len);
666         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
667         amtd->version   = 0;
668         amtd->type      = AMT_MSG_DISCOVERY;
669         amtd->reserved  = 0;
670         amtd->nonce     = amt->nonce;
671         skb_push(skb, sizeof(*udph));
672         skb_reset_transport_header(skb);
673         udph            = udp_hdr(skb);
674         udph->source    = amt->gw_port;
675         udph->dest      = amt->relay_port;
676         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
677         udph->check     = 0;
678         offset = skb_transport_offset(skb);
679         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
680         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
681                                         sizeof(*udph) + sizeof(*amtd),
682                                         IPPROTO_UDP, skb->csum);
683
684         skb_push(skb, sizeof(*iph));
685         iph             = ip_hdr(skb);
686         iph->version    = 4;
687         iph->ihl        = (sizeof(struct iphdr)) >> 2;
688         iph->tos        = AMT_TOS;
689         iph->frag_off   = 0;
690         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
691         iph->daddr      = amt->discovery_ip;
692         iph->saddr      = amt->local_ip;
693         iph->protocol   = IPPROTO_UDP;
694         iph->tot_len    = htons(len);
695
696         skb->ip_summed = CHECKSUM_NONE;
697         ip_select_ident(amt->net, skb, NULL);
698         ip_send_check(iph);
699         err = ip_local_out(amt->net, sock->sk, skb);
700         if (unlikely(net_xmit_eval(err)))
701                 amt->dev->stats.tx_errors++;
702
703         spin_lock_bh(&amt->lock);
704         __amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
705         spin_unlock_bh(&amt->lock);
706 out:
707         rcu_read_unlock();
708 }
709
710 static void amt_send_request(struct amt_dev *amt, bool v6)
711 {
712         struct amt_header_request *amtrh;
713         int hlen, tlen, offset;
714         struct socket *sock;
715         struct udphdr *udph;
716         struct sk_buff *skb;
717         struct iphdr *iph;
718         struct rtable *rt;
719         struct flowi4 fl4;
720         u32 len;
721         int err;
722
723         rcu_read_lock();
724         sock = rcu_dereference(amt->sock);
725         if (!sock)
726                 goto out;
727
728         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
729                 goto out;
730
731         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
732                                    amt->remote_ip, amt->local_ip,
733                                    amt->gw_port, amt->relay_port,
734                                    IPPROTO_UDP, 0,
735                                    amt->stream_dev->ifindex);
736         if (IS_ERR(rt)) {
737                 amt->dev->stats.tx_errors++;
738                 goto out;
739         }
740
741         hlen = LL_RESERVED_SPACE(amt->dev);
742         tlen = amt->dev->needed_tailroom;
743         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
744         skb = netdev_alloc_skb_ip_align(amt->dev, len);
745         if (!skb) {
746                 ip_rt_put(rt);
747                 amt->dev->stats.tx_errors++;
748                 goto out;
749         }
750
751         skb->priority = TC_PRIO_CONTROL;
752         skb_dst_set(skb, &rt->dst);
753
754         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
755         skb_reset_network_header(skb);
756         skb_put(skb, len);
757         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
758         amtrh->version   = 0;
759         amtrh->type      = AMT_MSG_REQUEST;
760         amtrh->reserved1 = 0;
761         amtrh->p         = v6;
762         amtrh->reserved2 = 0;
763         amtrh->nonce     = amt->nonce;
764         skb_push(skb, sizeof(*udph));
765         skb_reset_transport_header(skb);
766         udph            = udp_hdr(skb);
767         udph->source    = amt->gw_port;
768         udph->dest      = amt->relay_port;
769         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
770         udph->check     = 0;
771         offset = skb_transport_offset(skb);
772         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
773         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
774                                         sizeof(*udph) + sizeof(*amtrh),
775                                         IPPROTO_UDP, skb->csum);
776
777         skb_push(skb, sizeof(*iph));
778         iph             = ip_hdr(skb);
779         iph->version    = 4;
780         iph->ihl        = (sizeof(struct iphdr)) >> 2;
781         iph->tos        = AMT_TOS;
782         iph->frag_off   = 0;
783         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
784         iph->daddr      = amt->remote_ip;
785         iph->saddr      = amt->local_ip;
786         iph->protocol   = IPPROTO_UDP;
787         iph->tot_len    = htons(len);
788
789         skb->ip_summed = CHECKSUM_NONE;
790         ip_select_ident(amt->net, skb, NULL);
791         ip_send_check(iph);
792         err = ip_local_out(amt->net, sock->sk, skb);
793         if (unlikely(net_xmit_eval(err)))
794                 amt->dev->stats.tx_errors++;
795
796 out:
797         rcu_read_unlock();
798 }
799
800 static void amt_send_igmp_gq(struct amt_dev *amt,
801                              struct amt_tunnel_list *tunnel)
802 {
803         struct sk_buff *skb;
804
805         skb = amt_build_igmp_gq(amt);
806         if (!skb)
807                 return;
808
809         amt_skb_cb(skb)->tunnel = tunnel;
810         dev_queue_xmit(skb);
811 }
812
813 #if IS_ENABLED(CONFIG_IPV6)
814 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
815 {
816         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
817                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
818         int hlen = LL_RESERVED_SPACE(amt->dev);
819         int tlen = amt->dev->needed_tailroom;
820         struct mld2_query *mld2q;
821         void *csum_start = NULL;
822         struct ipv6hdr *ip6h;
823         struct sk_buff *skb;
824         struct ethhdr *eth;
825         u32 len;
826
827         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
828         skb = netdev_alloc_skb_ip_align(amt->dev, len);
829         if (!skb)
830                 return NULL;
831
832         skb_reserve(skb, hlen);
833         skb_push(skb, sizeof(*eth));
834         skb_reset_mac_header(skb);
835         eth = eth_hdr(skb);
836         skb->priority = TC_PRIO_CONTROL;
837         skb->protocol = htons(ETH_P_IPV6);
838         skb_put_zero(skb, sizeof(*ip6h));
839         skb_put_data(skb, ra, sizeof(ra));
840         skb_put_zero(skb, sizeof(*mld2q));
841         skb_pull(skb, sizeof(*eth));
842         skb_reset_network_header(skb);
843         ip6h                    = ipv6_hdr(skb);
844         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
845         ip6h->nexthdr           = NEXTHDR_HOP;
846         ip6h->hop_limit         = 1;
847         ip6h->daddr             = mld2_all_node;
848         ip6_flow_hdr(ip6h, 0, 0);
849
850         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
851                                &ip6h->saddr)) {
852                 amt->dev->stats.tx_errors++;
853                 kfree_skb(skb);
854                 return NULL;
855         }
856
857         eth->h_proto = htons(ETH_P_IPV6);
858         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
859         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
860
861         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
862         skb_reset_transport_header(skb);
863         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
864         mld2q->mld2q_mrc        = htons(1);
865         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
866         mld2q->mld2q_code       = 0;
867         mld2q->mld2q_cksum      = 0;
868         mld2q->mld2q_resv1      = 0;
869         mld2q->mld2q_resv2      = 0;
870         mld2q->mld2q_suppress   = 0;
871         mld2q->mld2q_qrv        = amt->qrv;
872         mld2q->mld2q_nsrcs      = 0;
873         mld2q->mld2q_qqic       = amt->qi;
874         csum_start              = (void *)mld2q;
875         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
876                                              sizeof(*mld2q),
877                                              IPPROTO_ICMPV6,
878                                              csum_partial(csum_start,
879                                                           sizeof(*mld2q), 0));
880
881         skb->ip_summed = CHECKSUM_NONE;
882         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
883         return skb;
884 }
885
886 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
887 {
888         struct sk_buff *skb;
889
890         skb = amt_build_mld_gq(amt);
891         if (!skb)
892                 return;
893
894         amt_skb_cb(skb)->tunnel = tunnel;
895         dev_queue_xmit(skb);
896 }
897 #else
898 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
899 {
900 }
901 #endif
902
903 static void amt_secret_work(struct work_struct *work)
904 {
905         struct amt_dev *amt = container_of(to_delayed_work(work),
906                                            struct amt_dev,
907                                            secret_wq);
908
909         spin_lock_bh(&amt->lock);
910         get_random_bytes(&amt->key, sizeof(siphash_key_t));
911         spin_unlock_bh(&amt->lock);
912         mod_delayed_work(amt_wq, &amt->secret_wq,
913                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
914 }
915
916 static void amt_discovery_work(struct work_struct *work)
917 {
918         struct amt_dev *amt = container_of(to_delayed_work(work),
919                                            struct amt_dev,
920                                            discovery_wq);
921
922         spin_lock_bh(&amt->lock);
923         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
924                 goto out;
925         get_random_bytes(&amt->nonce, sizeof(__be32));
926         spin_unlock_bh(&amt->lock);
927
928         amt_send_discovery(amt);
929         spin_lock_bh(&amt->lock);
930 out:
931         mod_delayed_work(amt_wq, &amt->discovery_wq,
932                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
933         spin_unlock_bh(&amt->lock);
934 }
935
936 static void amt_req_work(struct work_struct *work)
937 {
938         struct amt_dev *amt = container_of(to_delayed_work(work),
939                                            struct amt_dev,
940                                            req_wq);
941         u32 exp;
942
943         spin_lock_bh(&amt->lock);
944         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
945                 goto out;
946
947         if (amt->req_cnt > AMT_MAX_REQ_COUNT) {
948                 netdev_dbg(amt->dev, "Gateway is not ready");
949                 amt->qi = AMT_INIT_REQ_TIMEOUT;
950                 amt->ready4 = false;
951                 amt->ready6 = false;
952                 amt->remote_ip = 0;
953                 __amt_update_gw_status(amt, AMT_STATUS_INIT, false);
954                 amt->req_cnt = 0;
955                 goto out;
956         }
957         spin_unlock_bh(&amt->lock);
958
959         amt_send_request(amt, false);
960         amt_send_request(amt, true);
961         spin_lock_bh(&amt->lock);
962         __amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
963         amt->req_cnt++;
964 out:
965         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
966         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
967         spin_unlock_bh(&amt->lock);
968 }
969
970 static bool amt_send_membership_update(struct amt_dev *amt,
971                                        struct sk_buff *skb,
972                                        bool v6)
973 {
974         struct amt_header_membership_update *amtmu;
975         struct socket *sock;
976         struct iphdr *iph;
977         struct flowi4 fl4;
978         struct rtable *rt;
979         int err;
980
981         sock = rcu_dereference_bh(amt->sock);
982         if (!sock)
983                 return true;
984
985         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
986                            sizeof(*iph) + sizeof(struct udphdr));
987         if (err)
988                 return true;
989
990         skb_reset_inner_headers(skb);
991         memset(&fl4, 0, sizeof(struct flowi4));
992         fl4.flowi4_oif         = amt->stream_dev->ifindex;
993         fl4.daddr              = amt->remote_ip;
994         fl4.saddr              = amt->local_ip;
995         fl4.flowi4_tos         = AMT_TOS;
996         fl4.flowi4_proto       = IPPROTO_UDP;
997         rt = ip_route_output_key(amt->net, &fl4);
998         if (IS_ERR(rt)) {
999                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
1000                 return true;
1001         }
1002
1003         amtmu                   = skb_push(skb, sizeof(*amtmu));
1004         amtmu->version          = 0;
1005         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1006         amtmu->reserved         = 0;
1007         amtmu->nonce            = amt->nonce;
1008         amtmu->response_mac     = amt->mac;
1009
1010         if (!v6)
1011                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1012         else
1013                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1014         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1015                             fl4.saddr,
1016                             fl4.daddr,
1017                             AMT_TOS,
1018                             ip4_dst_hoplimit(&rt->dst),
1019                             0,
1020                             amt->gw_port,
1021                             amt->relay_port,
1022                             false,
1023                             false);
1024         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1025         return false;
1026 }
1027
1028 static void amt_send_multicast_data(struct amt_dev *amt,
1029                                     const struct sk_buff *oskb,
1030                                     struct amt_tunnel_list *tunnel,
1031                                     bool v6)
1032 {
1033         struct amt_header_mcast_data *amtmd;
1034         struct socket *sock;
1035         struct sk_buff *skb;
1036         struct iphdr *iph;
1037         struct flowi4 fl4;
1038         struct rtable *rt;
1039
1040         sock = rcu_dereference_bh(amt->sock);
1041         if (!sock)
1042                 return;
1043
1044         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1045                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1046         if (!skb)
1047                 return;
1048
1049         skb_reset_inner_headers(skb);
1050         memset(&fl4, 0, sizeof(struct flowi4));
1051         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1052         fl4.daddr              = tunnel->ip4;
1053         fl4.saddr              = amt->local_ip;
1054         fl4.flowi4_proto       = IPPROTO_UDP;
1055         rt = ip_route_output_key(amt->net, &fl4);
1056         if (IS_ERR(rt)) {
1057                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1058                 kfree_skb(skb);
1059                 return;
1060         }
1061
1062         amtmd = skb_push(skb, sizeof(*amtmd));
1063         amtmd->version = 0;
1064         amtmd->reserved = 0;
1065         amtmd->type = AMT_MSG_MULTICAST_DATA;
1066
1067         if (!v6)
1068                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1069         else
1070                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1071         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1072                             fl4.saddr,
1073                             fl4.daddr,
1074                             AMT_TOS,
1075                             ip4_dst_hoplimit(&rt->dst),
1076                             0,
1077                             amt->relay_port,
1078                             tunnel->source_port,
1079                             false,
1080                             false);
1081 }
1082
1083 static bool amt_send_membership_query(struct amt_dev *amt,
1084                                       struct sk_buff *skb,
1085                                       struct amt_tunnel_list *tunnel,
1086                                       bool v6)
1087 {
1088         struct amt_header_membership_query *amtmq;
1089         struct socket *sock;
1090         struct rtable *rt;
1091         struct flowi4 fl4;
1092         int err;
1093
1094         sock = rcu_dereference_bh(amt->sock);
1095         if (!sock)
1096                 return true;
1097
1098         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1099                            sizeof(struct iphdr) + sizeof(struct udphdr));
1100         if (err)
1101                 return true;
1102
1103         skb_reset_inner_headers(skb);
1104         memset(&fl4, 0, sizeof(struct flowi4));
1105         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1106         fl4.daddr              = tunnel->ip4;
1107         fl4.saddr              = amt->local_ip;
1108         fl4.flowi4_tos         = AMT_TOS;
1109         fl4.flowi4_proto       = IPPROTO_UDP;
1110         rt = ip_route_output_key(amt->net, &fl4);
1111         if (IS_ERR(rt)) {
1112                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1113                 return true;
1114         }
1115
1116         amtmq           = skb_push(skb, sizeof(*amtmq));
1117         amtmq->version  = 0;
1118         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1119         amtmq->reserved = 0;
1120         amtmq->l        = 0;
1121         amtmq->g        = 0;
1122         amtmq->nonce    = tunnel->nonce;
1123         amtmq->response_mac = tunnel->mac;
1124
1125         if (!v6)
1126                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1127         else
1128                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1129         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1130                             fl4.saddr,
1131                             fl4.daddr,
1132                             AMT_TOS,
1133                             ip4_dst_hoplimit(&rt->dst),
1134                             0,
1135                             amt->relay_port,
1136                             tunnel->source_port,
1137                             false,
1138                             false);
1139         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1140         return false;
1141 }
1142
1143 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1144 {
1145         struct amt_dev *amt = netdev_priv(dev);
1146         struct amt_tunnel_list *tunnel;
1147         struct amt_group_node *gnode;
1148         union amt_addr group = {0,};
1149 #if IS_ENABLED(CONFIG_IPV6)
1150         struct ipv6hdr *ip6h;
1151         struct mld_msg *mld;
1152 #endif
1153         bool report = false;
1154         struct igmphdr *ih;
1155         bool query = false;
1156         struct iphdr *iph;
1157         bool data = false;
1158         bool v6 = false;
1159         u32 hash;
1160
1161         iph = ip_hdr(skb);
1162         if (iph->version == 4) {
1163                 if (!ipv4_is_multicast(iph->daddr))
1164                         goto free;
1165
1166                 if (!ip_mc_check_igmp(skb)) {
1167                         ih = igmp_hdr(skb);
1168                         switch (ih->type) {
1169                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1170                         case IGMP_HOST_MEMBERSHIP_REPORT:
1171                                 report = true;
1172                                 break;
1173                         case IGMP_HOST_MEMBERSHIP_QUERY:
1174                                 query = true;
1175                                 break;
1176                         default:
1177                                 goto free;
1178                         }
1179                 } else {
1180                         data = true;
1181                 }
1182                 v6 = false;
1183                 group.ip4 = iph->daddr;
1184 #if IS_ENABLED(CONFIG_IPV6)
1185         } else if (iph->version == 6) {
1186                 ip6h = ipv6_hdr(skb);
1187                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1188                         goto free;
1189
1190                 if (!ipv6_mc_check_mld(skb)) {
1191                         mld = (struct mld_msg *)skb_transport_header(skb);
1192                         switch (mld->mld_type) {
1193                         case ICMPV6_MGM_REPORT:
1194                         case ICMPV6_MLD2_REPORT:
1195                                 report = true;
1196                                 break;
1197                         case ICMPV6_MGM_QUERY:
1198                                 query = true;
1199                                 break;
1200                         default:
1201                                 goto free;
1202                         }
1203                 } else {
1204                         data = true;
1205                 }
1206                 v6 = true;
1207                 group.ip6 = ip6h->daddr;
1208 #endif
1209         } else {
1210                 dev->stats.tx_errors++;
1211                 goto free;
1212         }
1213
1214         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1215                 goto free;
1216
1217         skb_pull(skb, sizeof(struct ethhdr));
1218
1219         if (amt->mode == AMT_MODE_GATEWAY) {
1220                 /* Gateway only passes IGMP/MLD packets */
1221                 if (!report)
1222                         goto free;
1223                 if ((!v6 && !amt->ready4) || (v6 && !amt->ready6))
1224                         goto free;
1225                 if (amt_send_membership_update(amt, skb,  v6))
1226                         goto free;
1227                 goto unlock;
1228         } else if (amt->mode == AMT_MODE_RELAY) {
1229                 if (query) {
1230                         tunnel = amt_skb_cb(skb)->tunnel;
1231                         if (!tunnel) {
1232                                 WARN_ON(1);
1233                                 goto free;
1234                         }
1235
1236                         /* Do not forward unexpected query */
1237                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1238                                 goto free;
1239                         goto unlock;
1240                 }
1241
1242                 if (!data)
1243                         goto free;
1244                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1245                         hash = amt_group_hash(tunnel, &group);
1246                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1247                                                  node) {
1248                                 if (!v6) {
1249                                         if (gnode->group_addr.ip4 == iph->daddr)
1250                                                 goto found;
1251 #if IS_ENABLED(CONFIG_IPV6)
1252                                 } else {
1253                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1254                                                             &ip6h->daddr))
1255                                                 goto found;
1256 #endif
1257                                 }
1258                         }
1259                         continue;
1260 found:
1261                         amt_send_multicast_data(amt, skb, tunnel, v6);
1262                 }
1263         }
1264
1265         dev_kfree_skb(skb);
1266         return NETDEV_TX_OK;
1267 free:
1268         dev_kfree_skb(skb);
1269 unlock:
1270         dev->stats.tx_dropped++;
1271         return NETDEV_TX_OK;
1272 }
1273
1274 static int amt_parse_type(struct sk_buff *skb)
1275 {
1276         struct amt_header *amth;
1277
1278         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1279                            sizeof(struct amt_header)))
1280                 return -1;
1281
1282         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1283
1284         if (amth->version != 0)
1285                 return -1;
1286
1287         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1288                 return -1;
1289         return amth->type;
1290 }
1291
1292 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1293 {
1294         struct amt_dev *amt = tunnel->amt;
1295         struct amt_group_node *gnode;
1296         struct hlist_node *t;
1297         int i;
1298
1299         spin_lock_bh(&tunnel->lock);
1300         rcu_read_lock();
1301         for (i = 0; i < amt->hash_buckets; i++)
1302                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1303                         amt_del_group(amt, gnode);
1304         rcu_read_unlock();
1305         spin_unlock_bh(&tunnel->lock);
1306 }
1307
1308 static void amt_tunnel_expire(struct work_struct *work)
1309 {
1310         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1311                                                       struct amt_tunnel_list,
1312                                                       gc_wq);
1313         struct amt_dev *amt = tunnel->amt;
1314
1315         spin_lock_bh(&amt->lock);
1316         rcu_read_lock();
1317         list_del_rcu(&tunnel->list);
1318         amt->nr_tunnels--;
1319         amt_clear_groups(tunnel);
1320         rcu_read_unlock();
1321         spin_unlock_bh(&amt->lock);
1322         kfree_rcu(tunnel, rcu);
1323 }
1324
1325 static void amt_cleanup_srcs(struct amt_dev *amt,
1326                              struct amt_tunnel_list *tunnel,
1327                              struct amt_group_node *gnode)
1328 {
1329         struct amt_source_node *snode;
1330         struct hlist_node *t;
1331         int i;
1332
1333         /* Delete old sources */
1334         for (i = 0; i < amt->hash_buckets; i++) {
1335                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1336                         if (snode->flags == AMT_SOURCE_OLD)
1337                                 amt_destroy_source(snode);
1338                 }
1339         }
1340
1341         /* switch from new to old */
1342         for (i = 0; i < amt->hash_buckets; i++)  {
1343                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1344                         snode->flags = AMT_SOURCE_OLD;
1345                         if (!gnode->v6)
1346                                 netdev_dbg(snode->gnode->amt->dev,
1347                                            "Add source as OLD %pI4 from %pI4\n",
1348                                            &snode->source_addr.ip4,
1349                                            &gnode->group_addr.ip4);
1350 #if IS_ENABLED(CONFIG_IPV6)
1351                         else
1352                                 netdev_dbg(snode->gnode->amt->dev,
1353                                            "Add source as OLD %pI6 from %pI6\n",
1354                                            &snode->source_addr.ip6,
1355                                            &gnode->group_addr.ip6);
1356 #endif
1357                 }
1358         }
1359 }
1360
1361 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1362                          struct amt_group_node *gnode, void *grec,
1363                          bool v6)
1364 {
1365         struct igmpv3_grec *igmp_grec;
1366         struct amt_source_node *snode;
1367 #if IS_ENABLED(CONFIG_IPV6)
1368         struct mld2_grec *mld_grec;
1369 #endif
1370         union amt_addr src = {0,};
1371         u16 nsrcs;
1372         u32 hash;
1373         int i;
1374
1375         if (!v6) {
1376                 igmp_grec = (struct igmpv3_grec *)grec;
1377                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1378         } else {
1379 #if IS_ENABLED(CONFIG_IPV6)
1380                 mld_grec = (struct mld2_grec *)grec;
1381                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1382 #else
1383         return;
1384 #endif
1385         }
1386         for (i = 0; i < nsrcs; i++) {
1387                 if (tunnel->nr_sources >= amt->max_sources)
1388                         return;
1389                 if (!v6)
1390                         src.ip4 = igmp_grec->grec_src[i];
1391 #if IS_ENABLED(CONFIG_IPV6)
1392                 else
1393                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1394                                sizeof(struct in6_addr));
1395 #endif
1396                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1397                         continue;
1398
1399                 snode = amt_alloc_snode(gnode, &src);
1400                 if (snode) {
1401                         hash = amt_source_hash(tunnel, &snode->source_addr);
1402                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1403                         tunnel->nr_sources++;
1404                         gnode->nr_sources++;
1405
1406                         if (!gnode->v6)
1407                                 netdev_dbg(snode->gnode->amt->dev,
1408                                            "Add source as NEW %pI4 from %pI4\n",
1409                                            &snode->source_addr.ip4,
1410                                            &gnode->group_addr.ip4);
1411 #if IS_ENABLED(CONFIG_IPV6)
1412                         else
1413                                 netdev_dbg(snode->gnode->amt->dev,
1414                                            "Add source as NEW %pI6 from %pI6\n",
1415                                            &snode->source_addr.ip6,
1416                                            &gnode->group_addr.ip6);
1417 #endif
1418                 }
1419         }
1420 }
1421
1422 /* Router State   Report Rec'd New Router State
1423  * ------------   ------------ ----------------
1424  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1425  *
1426  * -----------+-----------+-----------+
1427  *            |    OLD    |    NEW    |
1428  * -----------+-----------+-----------+
1429  *    FWD     |     X     |    X+A    |
1430  * -----------+-----------+-----------+
1431  *    D_FWD   |     Y     |    Y-A    |
1432  * -----------+-----------+-----------+
1433  *    NONE    |           |     A     |
1434  * -----------+-----------+-----------+
1435  *
1436  * a) Received sources are NONE/NEW
1437  * b) All NONE will be deleted by amt_cleanup_srcs().
1438  * c) All OLD will be deleted by amt_cleanup_srcs().
1439  * d) After delete, NEW source will be switched to OLD.
1440  */
1441 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1442                                 struct amt_group_node *gnode,
1443                                 void *grec,
1444                                 enum amt_ops ops,
1445                                 enum amt_filter filter,
1446                                 enum amt_act act,
1447                                 bool v6)
1448 {
1449         struct amt_dev *amt = tunnel->amt;
1450         struct amt_source_node *snode;
1451         struct igmpv3_grec *igmp_grec;
1452 #if IS_ENABLED(CONFIG_IPV6)
1453         struct mld2_grec *mld_grec;
1454 #endif
1455         union amt_addr src = {0,};
1456         struct hlist_node *t;
1457         u16 nsrcs;
1458         int i, j;
1459
1460         if (!v6) {
1461                 igmp_grec = (struct igmpv3_grec *)grec;
1462                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1463         } else {
1464 #if IS_ENABLED(CONFIG_IPV6)
1465                 mld_grec = (struct mld2_grec *)grec;
1466                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1467 #else
1468         return;
1469 #endif
1470         }
1471
1472         memset(&src, 0, sizeof(union amt_addr));
1473         switch (ops) {
1474         case AMT_OPS_INT:
1475                 /* A*B */
1476                 for (i = 0; i < nsrcs; i++) {
1477                         if (!v6)
1478                                 src.ip4 = igmp_grec->grec_src[i];
1479 #if IS_ENABLED(CONFIG_IPV6)
1480                         else
1481                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1482                                        sizeof(struct in6_addr));
1483 #endif
1484                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1485                         if (!snode)
1486                                 continue;
1487                         amt_act_src(tunnel, gnode, snode, act);
1488                 }
1489                 break;
1490         case AMT_OPS_UNI:
1491                 /* A+B */
1492                 for (i = 0; i < amt->hash_buckets; i++) {
1493                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1494                                                   node) {
1495                                 if (amt_status_filter(snode, filter))
1496                                         amt_act_src(tunnel, gnode, snode, act);
1497                         }
1498                 }
1499                 for (i = 0; i < nsrcs; i++) {
1500                         if (!v6)
1501                                 src.ip4 = igmp_grec->grec_src[i];
1502 #if IS_ENABLED(CONFIG_IPV6)
1503                         else
1504                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1505                                        sizeof(struct in6_addr));
1506 #endif
1507                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1508                         if (!snode)
1509                                 continue;
1510                         amt_act_src(tunnel, gnode, snode, act);
1511                 }
1512                 break;
1513         case AMT_OPS_SUB:
1514                 /* A-B */
1515                 for (i = 0; i < amt->hash_buckets; i++) {
1516                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1517                                                   node) {
1518                                 if (!amt_status_filter(snode, filter))
1519                                         continue;
1520                                 for (j = 0; j < nsrcs; j++) {
1521                                         if (!v6)
1522                                                 src.ip4 = igmp_grec->grec_src[j];
1523 #if IS_ENABLED(CONFIG_IPV6)
1524                                         else
1525                                                 memcpy(&src.ip6,
1526                                                        &mld_grec->grec_src[j],
1527                                                        sizeof(struct in6_addr));
1528 #endif
1529                                         if (amt_addr_equal(&snode->source_addr,
1530                                                            &src))
1531                                                 goto out_sub;
1532                                 }
1533                                 amt_act_src(tunnel, gnode, snode, act);
1534                                 continue;
1535 out_sub:;
1536                         }
1537                 }
1538                 break;
1539         case AMT_OPS_SUB_REV:
1540                 /* B-A */
1541                 for (i = 0; i < nsrcs; i++) {
1542                         if (!v6)
1543                                 src.ip4 = igmp_grec->grec_src[i];
1544 #if IS_ENABLED(CONFIG_IPV6)
1545                         else
1546                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1547                                        sizeof(struct in6_addr));
1548 #endif
1549                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1550                                                &src);
1551                         if (!snode) {
1552                                 snode = amt_lookup_src(tunnel, gnode,
1553                                                        filter, &src);
1554                                 if (snode)
1555                                         amt_act_src(tunnel, gnode, snode, act);
1556                         }
1557                 }
1558                 break;
1559         default:
1560                 netdev_dbg(amt->dev, "Invalid type\n");
1561                 return;
1562         }
1563 }
1564
1565 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1566                                     struct amt_tunnel_list *tunnel,
1567                                     struct amt_group_node *gnode,
1568                                     void *grec, void *zero_grec, bool v6)
1569 {
1570         if (gnode->filter_mode == MCAST_INCLUDE) {
1571 /* Router State   Report Rec'd New Router State        Actions
1572  * ------------   ------------ ----------------        -------
1573  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1574  */
1575                 /* Update IS_IN (B) as FWD/NEW */
1576                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1577                                     AMT_FILTER_NONE_NEW,
1578                                     AMT_ACT_STATUS_FWD_NEW,
1579                                     v6);
1580                 /* Update INCLUDE (A) as NEW */
1581                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1582                                     AMT_FILTER_FWD,
1583                                     AMT_ACT_STATUS_FWD_NEW,
1584                                     v6);
1585                 /* (B)=GMI */
1586                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1587                                     AMT_FILTER_FWD_NEW,
1588                                     AMT_ACT_GMI,
1589                                     v6);
1590         } else {
1591 /* State        Actions
1592  * ------------   ------------ ----------------        -------
1593  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1594  */
1595                 /* Update (A) in (X, Y) as NONE/NEW */
1596                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1597                                     AMT_FILTER_BOTH,
1598                                     AMT_ACT_STATUS_NONE_NEW,
1599                                     v6);
1600                 /* Update FWD/OLD as FWD/NEW */
1601                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1602                                     AMT_FILTER_FWD,
1603                                     AMT_ACT_STATUS_FWD_NEW,
1604                                     v6);
1605                 /* Update IS_IN (A) as FWD/NEW */
1606                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1607                                     AMT_FILTER_NONE_NEW,
1608                                     AMT_ACT_STATUS_FWD_NEW,
1609                                     v6);
1610                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1611                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1612                                     AMT_FILTER_D_FWD,
1613                                     AMT_ACT_STATUS_D_FWD_NEW,
1614                                     v6);
1615         }
1616 }
1617
1618 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1619                                     struct amt_tunnel_list *tunnel,
1620                                     struct amt_group_node *gnode,
1621                                     void *grec, void *zero_grec, bool v6)
1622 {
1623         if (gnode->filter_mode == MCAST_INCLUDE) {
1624 /* Router State   Report Rec'd  New Router State         Actions
1625  * ------------   ------------  ----------------         -------
1626  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1627  *                                                       Delete (A-B)
1628  *                                                       Group Timer=GMI
1629  */
1630                 /* EXCLUDE(A*B, ) */
1631                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1632                                     AMT_FILTER_FWD,
1633                                     AMT_ACT_STATUS_FWD_NEW,
1634                                     v6);
1635                 /* EXCLUDE(, B-A) */
1636                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1637                                     AMT_FILTER_FWD,
1638                                     AMT_ACT_STATUS_D_FWD_NEW,
1639                                     v6);
1640                 /* (B-A)=0 */
1641                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1642                                     AMT_FILTER_D_FWD_NEW,
1643                                     AMT_ACT_GMI_ZERO,
1644                                     v6);
1645                 /* Group Timer=GMI */
1646                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1647                                       msecs_to_jiffies(amt_gmi(amt))))
1648                         dev_hold(amt->dev);
1649                 gnode->filter_mode = MCAST_EXCLUDE;
1650                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1651         } else {
1652 /* Router State   Report Rec'd  New Router State        Actions
1653  * ------------   ------------  ----------------        -------
1654  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1655  *                                                      Delete (X-A)
1656  *                                                      Delete (Y-A)
1657  *                                                      Group Timer=GMI
1658  */
1659                 /* EXCLUDE (A-Y, ) */
1660                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1661                                     AMT_FILTER_D_FWD,
1662                                     AMT_ACT_STATUS_FWD_NEW,
1663                                     v6);
1664                 /* EXCLUDE (, Y*A ) */
1665                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1666                                     AMT_FILTER_D_FWD,
1667                                     AMT_ACT_STATUS_D_FWD_NEW,
1668                                     v6);
1669                 /* (A-X-Y)=GMI */
1670                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1671                                     AMT_FILTER_BOTH_NEW,
1672                                     AMT_ACT_GMI,
1673                                     v6);
1674                 /* Group Timer=GMI */
1675                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1676                                       msecs_to_jiffies(amt_gmi(amt))))
1677                         dev_hold(amt->dev);
1678                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1679         }
1680 }
1681
1682 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1683                                     struct amt_tunnel_list *tunnel,
1684                                     struct amt_group_node *gnode,
1685                                     void *grec, void *zero_grec, bool v6)
1686 {
1687         if (gnode->filter_mode == MCAST_INCLUDE) {
1688 /* Router State   Report Rec'd New Router State        Actions
1689  * ------------   ------------ ----------------        -------
1690  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1691  *                                                     Send Q(G,A-B)
1692  */
1693                 /* Update TO_IN (B) sources as FWD/NEW */
1694                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1695                                     AMT_FILTER_NONE_NEW,
1696                                     AMT_ACT_STATUS_FWD_NEW,
1697                                     v6);
1698                 /* Update INCLUDE (A) sources as NEW */
1699                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1700                                     AMT_FILTER_FWD,
1701                                     AMT_ACT_STATUS_FWD_NEW,
1702                                     v6);
1703                 /* (B)=GMI */
1704                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1705                                     AMT_FILTER_FWD_NEW,
1706                                     AMT_ACT_GMI,
1707                                     v6);
1708         } else {
1709 /* Router State   Report Rec'd New Router State        Actions
1710  * ------------   ------------ ----------------        -------
1711  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1712  *                                                     Send Q(G,X-A)
1713  *                                                     Send Q(G)
1714  */
1715                 /* Update TO_IN (A) sources as FWD/NEW */
1716                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1717                                     AMT_FILTER_NONE_NEW,
1718                                     AMT_ACT_STATUS_FWD_NEW,
1719                                     v6);
1720                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1721                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1722                                     AMT_FILTER_FWD,
1723                                     AMT_ACT_STATUS_FWD_NEW,
1724                                     v6);
1725                 /* EXCLUDE (, Y-A)
1726                  * (A) are already switched to FWD_NEW.
1727                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1728                  */
1729                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1730                                     AMT_FILTER_D_FWD,
1731                                     AMT_ACT_STATUS_D_FWD_NEW,
1732                                     v6);
1733                 /* (A)=GMI
1734                  * Only FWD_NEW will have (A) sources.
1735                  */
1736                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1737                                     AMT_FILTER_FWD_NEW,
1738                                     AMT_ACT_GMI,
1739                                     v6);
1740         }
1741 }
1742
1743 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1744                                     struct amt_tunnel_list *tunnel,
1745                                     struct amt_group_node *gnode,
1746                                     void *grec, void *zero_grec, bool v6)
1747 {
1748         if (gnode->filter_mode == MCAST_INCLUDE) {
1749 /* Router State   Report Rec'd New Router State        Actions
1750  * ------------   ------------ ----------------        -------
1751  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1752  *                                                     Delete (A-B)
1753  *                                                     Send Q(G,A*B)
1754  *                                                     Group Timer=GMI
1755  */
1756                 /* EXCLUDE (A*B, ) */
1757                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1758                                     AMT_FILTER_FWD,
1759                                     AMT_ACT_STATUS_FWD_NEW,
1760                                     v6);
1761                 /* EXCLUDE (, B-A) */
1762                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1763                                     AMT_FILTER_FWD,
1764                                     AMT_ACT_STATUS_D_FWD_NEW,
1765                                     v6);
1766                 /* (B-A)=0 */
1767                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1768                                     AMT_FILTER_D_FWD_NEW,
1769                                     AMT_ACT_GMI_ZERO,
1770                                     v6);
1771                 /* Group Timer=GMI */
1772                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1773                                       msecs_to_jiffies(amt_gmi(amt))))
1774                         dev_hold(amt->dev);
1775                 gnode->filter_mode = MCAST_EXCLUDE;
1776                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1777         } else {
1778 /* Router State   Report Rec'd New Router State        Actions
1779  * ------------   ------------ ----------------        -------
1780  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1781  *                                                     Delete (X-A)
1782  *                                                     Delete (Y-A)
1783  *                                                     Send Q(G,A-Y)
1784  *                                                     Group Timer=GMI
1785  */
1786                 /* Update (A-X-Y) as NONE/OLD */
1787                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1788                                     AMT_FILTER_BOTH,
1789                                     AMT_ACT_GT,
1790                                     v6);
1791                 /* EXCLUDE (A-Y, ) */
1792                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1793                                     AMT_FILTER_D_FWD,
1794                                     AMT_ACT_STATUS_FWD_NEW,
1795                                     v6);
1796                 /* EXCLUDE (, Y*A) */
1797                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1798                                     AMT_FILTER_D_FWD,
1799                                     AMT_ACT_STATUS_D_FWD_NEW,
1800                                     v6);
1801                 /* Group Timer=GMI */
1802                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1803                                       msecs_to_jiffies(amt_gmi(amt))))
1804                         dev_hold(amt->dev);
1805                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1806         }
1807 }
1808
1809 static void amt_mcast_allow_handler(struct amt_dev *amt,
1810                                     struct amt_tunnel_list *tunnel,
1811                                     struct amt_group_node *gnode,
1812                                     void *grec, void *zero_grec, bool v6)
1813 {
1814         if (gnode->filter_mode == MCAST_INCLUDE) {
1815 /* Router State   Report Rec'd New Router State        Actions
1816  * ------------   ------------ ----------------        -------
1817  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1818  */
1819                 /* INCLUDE (A+B) */
1820                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1821                                     AMT_FILTER_FWD,
1822                                     AMT_ACT_STATUS_FWD_NEW,
1823                                     v6);
1824                 /* (B)=GMI */
1825                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1826                                     AMT_FILTER_FWD_NEW,
1827                                     AMT_ACT_GMI,
1828                                     v6);
1829         } else {
1830 /* Router State   Report Rec'd New Router State        Actions
1831  * ------------   ------------ ----------------        -------
1832  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1833  */
1834                 /* EXCLUDE (X+A, ) */
1835                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1836                                     AMT_FILTER_FWD,
1837                                     AMT_ACT_STATUS_FWD_NEW,
1838                                     v6);
1839                 /* EXCLUDE (, Y-A) */
1840                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1841                                     AMT_FILTER_D_FWD,
1842                                     AMT_ACT_STATUS_D_FWD_NEW,
1843                                     v6);
1844                 /* (A)=GMI
1845                  * All (A) source are now FWD/NEW status.
1846                  */
1847                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1848                                     AMT_FILTER_FWD_NEW,
1849                                     AMT_ACT_GMI,
1850                                     v6);
1851         }
1852 }
1853
1854 static void amt_mcast_block_handler(struct amt_dev *amt,
1855                                     struct amt_tunnel_list *tunnel,
1856                                     struct amt_group_node *gnode,
1857                                     void *grec, void *zero_grec, bool v6)
1858 {
1859         if (gnode->filter_mode == MCAST_INCLUDE) {
1860 /* Router State   Report Rec'd New Router State        Actions
1861  * ------------   ------------ ----------------        -------
1862  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1863  */
1864                 /* INCLUDE (A) */
1865                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1866                                     AMT_FILTER_FWD,
1867                                     AMT_ACT_STATUS_FWD_NEW,
1868                                     v6);
1869         } else {
1870 /* Router State   Report Rec'd New Router State        Actions
1871  * ------------   ------------ ----------------        -------
1872  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1873  *                                                     Send Q(G,A-Y)
1874  */
1875                 /* (A-X-Y)=Group Timer */
1876                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1877                                     AMT_FILTER_BOTH,
1878                                     AMT_ACT_GT,
1879                                     v6);
1880                 /* EXCLUDE (X, ) */
1881                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1882                                     AMT_FILTER_FWD,
1883                                     AMT_ACT_STATUS_FWD_NEW,
1884                                     v6);
1885                 /* EXCLUDE (X+(A-Y) */
1886                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1887                                     AMT_FILTER_D_FWD,
1888                                     AMT_ACT_STATUS_FWD_NEW,
1889                                     v6);
1890                 /* EXCLUDE (, Y) */
1891                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1892                                     AMT_FILTER_D_FWD,
1893                                     AMT_ACT_STATUS_D_FWD_NEW,
1894                                     v6);
1895         }
1896 }
1897
1898 /* RFC 3376
1899  * 7.3.2. In the Presence of Older Version Group Members
1900  *
1901  * When Group Compatibility Mode is IGMPv2, a router internally
1902  * translates the following IGMPv2 messages for that group to their
1903  * IGMPv3 equivalents:
1904  *
1905  * IGMPv2 Message                IGMPv3 Equivalent
1906  * --------------                -----------------
1907  * Report                        IS_EX( {} )
1908  * Leave                         TO_IN( {} )
1909  */
1910 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1911                                       struct amt_tunnel_list *tunnel)
1912 {
1913         struct igmphdr *ih = igmp_hdr(skb);
1914         struct iphdr *iph = ip_hdr(skb);
1915         struct amt_group_node *gnode;
1916         union amt_addr group, host;
1917
1918         memset(&group, 0, sizeof(union amt_addr));
1919         group.ip4 = ih->group;
1920         memset(&host, 0, sizeof(union amt_addr));
1921         host.ip4 = iph->saddr;
1922
1923         gnode = amt_lookup_group(tunnel, &group, &host, false);
1924         if (!gnode) {
1925                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1926                 if (!IS_ERR(gnode)) {
1927                         gnode->filter_mode = MCAST_EXCLUDE;
1928                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1929                                               msecs_to_jiffies(amt_gmi(amt))))
1930                                 dev_hold(amt->dev);
1931                 }
1932         }
1933 }
1934
1935 /* RFC 3376
1936  * 7.3.2. In the Presence of Older Version Group Members
1937  *
1938  * When Group Compatibility Mode is IGMPv2, a router internally
1939  * translates the following IGMPv2 messages for that group to their
1940  * IGMPv3 equivalents:
1941  *
1942  * IGMPv2 Message                IGMPv3 Equivalent
1943  * --------------                -----------------
1944  * Report                        IS_EX( {} )
1945  * Leave                         TO_IN( {} )
1946  */
1947 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1948                                      struct amt_tunnel_list *tunnel)
1949 {
1950         struct igmphdr *ih = igmp_hdr(skb);
1951         struct iphdr *iph = ip_hdr(skb);
1952         struct amt_group_node *gnode;
1953         union amt_addr group, host;
1954
1955         memset(&group, 0, sizeof(union amt_addr));
1956         group.ip4 = ih->group;
1957         memset(&host, 0, sizeof(union amt_addr));
1958         host.ip4 = iph->saddr;
1959
1960         gnode = amt_lookup_group(tunnel, &group, &host, false);
1961         if (gnode)
1962                 amt_del_group(amt, gnode);
1963 }
1964
1965 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1966                                       struct amt_tunnel_list *tunnel)
1967 {
1968         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
1969         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
1970         void *zero_grec = (void *)&igmpv3_zero_grec;
1971         struct iphdr *iph = ip_hdr(skb);
1972         struct amt_group_node *gnode;
1973         union amt_addr group, host;
1974         struct igmpv3_grec *grec;
1975         u16 nsrcs;
1976         int i;
1977
1978         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
1979                 len += sizeof(*grec);
1980                 if (!ip_mc_may_pull(skb, len))
1981                         break;
1982
1983                 grec = (void *)(skb->data + len - sizeof(*grec));
1984                 nsrcs = ntohs(grec->grec_nsrcs);
1985
1986                 len += nsrcs * sizeof(__be32);
1987                 if (!ip_mc_may_pull(skb, len))
1988                         break;
1989
1990                 memset(&group, 0, sizeof(union amt_addr));
1991                 group.ip4 = grec->grec_mca;
1992                 memset(&host, 0, sizeof(union amt_addr));
1993                 host.ip4 = iph->saddr;
1994                 gnode = amt_lookup_group(tunnel, &group, &host, false);
1995                 if (!gnode) {
1996                         gnode = amt_add_group(amt, tunnel, &group, &host,
1997                                               false);
1998                         if (IS_ERR(gnode))
1999                                 continue;
2000                 }
2001
2002                 amt_add_srcs(amt, tunnel, gnode, grec, false);
2003                 switch (grec->grec_type) {
2004                 case IGMPV3_MODE_IS_INCLUDE:
2005                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2006                                                 zero_grec, false);
2007                         break;
2008                 case IGMPV3_MODE_IS_EXCLUDE:
2009                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2010                                                 zero_grec, false);
2011                         break;
2012                 case IGMPV3_CHANGE_TO_INCLUDE:
2013                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2014                                                 zero_grec, false);
2015                         break;
2016                 case IGMPV3_CHANGE_TO_EXCLUDE:
2017                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2018                                                 zero_grec, false);
2019                         break;
2020                 case IGMPV3_ALLOW_NEW_SOURCES:
2021                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2022                                                 zero_grec, false);
2023                         break;
2024                 case IGMPV3_BLOCK_OLD_SOURCES:
2025                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2026                                                 zero_grec, false);
2027                         break;
2028                 default:
2029                         break;
2030                 }
2031                 amt_cleanup_srcs(amt, tunnel, gnode);
2032         }
2033 }
2034
2035 /* caller held tunnel->lock */
2036 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2037                                     struct amt_tunnel_list *tunnel)
2038 {
2039         struct igmphdr *ih = igmp_hdr(skb);
2040
2041         switch (ih->type) {
2042         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2043                 amt_igmpv3_report_handler(amt, skb, tunnel);
2044                 break;
2045         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2046                 amt_igmpv2_report_handler(amt, skb, tunnel);
2047                 break;
2048         case IGMP_HOST_LEAVE_MESSAGE:
2049                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2050                 break;
2051         default:
2052                 break;
2053         }
2054 }
2055
2056 #if IS_ENABLED(CONFIG_IPV6)
2057 /* RFC 3810
2058  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2059  *
2060  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2061  * using the MLDv2 protocol for that multicast address.  When Multicast
2062  * Address Compatibility Mode is MLDv1, a router internally translates
2063  * the following MLDv1 messages for that multicast address to their
2064  * MLDv2 equivalents:
2065  *
2066  * MLDv1 Message                 MLDv2 Equivalent
2067  * --------------                -----------------
2068  * Report                        IS_EX( {} )
2069  * Done                          TO_IN( {} )
2070  */
2071 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2072                                      struct amt_tunnel_list *tunnel)
2073 {
2074         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2075         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2076         struct amt_group_node *gnode;
2077         union amt_addr group, host;
2078
2079         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2080         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2081
2082         gnode = amt_lookup_group(tunnel, &group, &host, true);
2083         if (!gnode) {
2084                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2085                 if (!IS_ERR(gnode)) {
2086                         gnode->filter_mode = MCAST_EXCLUDE;
2087                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2088                                               msecs_to_jiffies(amt_gmi(amt))))
2089                                 dev_hold(amt->dev);
2090                 }
2091         }
2092 }
2093
2094 /* RFC 3810
2095  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2096  *
2097  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2098  * using the MLDv2 protocol for that multicast address.  When Multicast
2099  * Address Compatibility Mode is MLDv1, a router internally translates
2100  * the following MLDv1 messages for that multicast address to their
2101  * MLDv2 equivalents:
2102  *
2103  * MLDv1 Message                 MLDv2 Equivalent
2104  * --------------                -----------------
2105  * Report                        IS_EX( {} )
2106  * Done                          TO_IN( {} )
2107  */
2108 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2109                                     struct amt_tunnel_list *tunnel)
2110 {
2111         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2112         struct iphdr *iph = ip_hdr(skb);
2113         struct amt_group_node *gnode;
2114         union amt_addr group, host;
2115
2116         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2117         memset(&host, 0, sizeof(union amt_addr));
2118         host.ip4 = iph->saddr;
2119
2120         gnode = amt_lookup_group(tunnel, &group, &host, true);
2121         if (gnode) {
2122                 amt_del_group(amt, gnode);
2123                 return;
2124         }
2125 }
2126
2127 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2128                                      struct amt_tunnel_list *tunnel)
2129 {
2130         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2131         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2132         void *zero_grec = (void *)&mldv2_zero_grec;
2133         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2134         struct amt_group_node *gnode;
2135         union amt_addr group, host;
2136         struct mld2_grec *grec;
2137         u16 nsrcs;
2138         int i;
2139
2140         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2141                 len += sizeof(*grec);
2142                 if (!ipv6_mc_may_pull(skb, len))
2143                         break;
2144
2145                 grec = (void *)(skb->data + len - sizeof(*grec));
2146                 nsrcs = ntohs(grec->grec_nsrcs);
2147
2148                 len += nsrcs * sizeof(struct in6_addr);
2149                 if (!ipv6_mc_may_pull(skb, len))
2150                         break;
2151
2152                 memset(&group, 0, sizeof(union amt_addr));
2153                 group.ip6 = grec->grec_mca;
2154                 memset(&host, 0, sizeof(union amt_addr));
2155                 host.ip6 = ip6h->saddr;
2156                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2157                 if (!gnode) {
2158                         gnode = amt_add_group(amt, tunnel, &group, &host,
2159                                               ETH_P_IPV6);
2160                         if (IS_ERR(gnode))
2161                                 continue;
2162                 }
2163
2164                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2165                 switch (grec->grec_type) {
2166                 case MLD2_MODE_IS_INCLUDE:
2167                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2168                                                 zero_grec, true);
2169                         break;
2170                 case MLD2_MODE_IS_EXCLUDE:
2171                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2172                                                 zero_grec, true);
2173                         break;
2174                 case MLD2_CHANGE_TO_INCLUDE:
2175                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2176                                                 zero_grec, true);
2177                         break;
2178                 case MLD2_CHANGE_TO_EXCLUDE:
2179                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2180                                                 zero_grec, true);
2181                         break;
2182                 case MLD2_ALLOW_NEW_SOURCES:
2183                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2184                                                 zero_grec, true);
2185                         break;
2186                 case MLD2_BLOCK_OLD_SOURCES:
2187                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2188                                                 zero_grec, true);
2189                         break;
2190                 default:
2191                         break;
2192                 }
2193                 amt_cleanup_srcs(amt, tunnel, gnode);
2194         }
2195 }
2196
2197 /* caller held tunnel->lock */
2198 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2199                                    struct amt_tunnel_list *tunnel)
2200 {
2201         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2202
2203         switch (mld->mld_type) {
2204         case ICMPV6_MGM_REPORT:
2205                 amt_mldv1_report_handler(amt, skb, tunnel);
2206                 break;
2207         case ICMPV6_MLD2_REPORT:
2208                 amt_mldv2_report_handler(amt, skb, tunnel);
2209                 break;
2210         case ICMPV6_MGM_REDUCTION:
2211                 amt_mldv1_leave_handler(amt, skb, tunnel);
2212                 break;
2213         default:
2214                 break;
2215         }
2216 }
2217 #endif
2218
2219 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2220 {
2221         struct amt_header_advertisement *amta;
2222         int hdr_size;
2223
2224         hdr_size = sizeof(*amta) + sizeof(struct udphdr);
2225         if (!pskb_may_pull(skb, hdr_size))
2226                 return true;
2227
2228         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2229         if (!amta->ip4)
2230                 return true;
2231
2232         if (amta->reserved || amta->version)
2233                 return true;
2234
2235         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2236             ipv4_is_zeronet(amta->ip4))
2237                 return true;
2238
2239         amt->remote_ip = amta->ip4;
2240         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2241         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2242
2243         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2244         return false;
2245 }
2246
2247 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2248 {
2249         struct amt_header_mcast_data *amtmd;
2250         int hdr_size, len, err;
2251         struct ethhdr *eth;
2252         struct iphdr *iph;
2253
2254         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2255         if (!pskb_may_pull(skb, hdr_size))
2256                 return true;
2257
2258         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2259         if (amtmd->reserved || amtmd->version)
2260                 return true;
2261
2262         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2263                 return true;
2264
2265         skb_reset_network_header(skb);
2266         skb_push(skb, sizeof(*eth));
2267         skb_reset_mac_header(skb);
2268         skb_pull(skb, sizeof(*eth));
2269         eth = eth_hdr(skb);
2270
2271         if (!pskb_may_pull(skb, sizeof(*iph)))
2272                 return true;
2273         iph = ip_hdr(skb);
2274
2275         if (iph->version == 4) {
2276                 if (!ipv4_is_multicast(iph->daddr))
2277                         return true;
2278                 skb->protocol = htons(ETH_P_IP);
2279                 eth->h_proto = htons(ETH_P_IP);
2280                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2281 #if IS_ENABLED(CONFIG_IPV6)
2282         } else if (iph->version == 6) {
2283                 struct ipv6hdr *ip6h;
2284
2285                 if (!pskb_may_pull(skb, sizeof(*ip6h)))
2286                         return true;
2287
2288                 ip6h = ipv6_hdr(skb);
2289                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2290                         return true;
2291                 skb->protocol = htons(ETH_P_IPV6);
2292                 eth->h_proto = htons(ETH_P_IPV6);
2293                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2294 #endif
2295         } else {
2296                 return true;
2297         }
2298
2299         skb->pkt_type = PACKET_MULTICAST;
2300         skb->ip_summed = CHECKSUM_NONE;
2301         len = skb->len;
2302         err = gro_cells_receive(&amt->gro_cells, skb);
2303         if (likely(err == NET_RX_SUCCESS))
2304                 dev_sw_netstats_rx_add(amt->dev, len);
2305         else
2306                 amt->dev->stats.rx_dropped++;
2307
2308         return false;
2309 }
2310
2311 static bool amt_membership_query_handler(struct amt_dev *amt,
2312                                          struct sk_buff *skb)
2313 {
2314         struct amt_header_membership_query *amtmq;
2315         struct igmpv3_query *ihv3;
2316         struct ethhdr *eth, *oeth;
2317         struct iphdr *iph;
2318         int hdr_size, len;
2319
2320         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr);
2321         if (!pskb_may_pull(skb, hdr_size))
2322                 return true;
2323
2324         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2325         if (amtmq->reserved || amtmq->version)
2326                 return true;
2327
2328         hdr_size -= sizeof(*eth);
2329         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2330                 return true;
2331
2332         oeth = eth_hdr(skb);
2333         skb_reset_mac_header(skb);
2334         skb_pull(skb, sizeof(*eth));
2335         skb_reset_network_header(skb);
2336         eth = eth_hdr(skb);
2337         if (!pskb_may_pull(skb, sizeof(*iph)))
2338                 return true;
2339
2340         iph = ip_hdr(skb);
2341         if (iph->version == 4) {
2342                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2343                                    sizeof(*ihv3)))
2344                         return true;
2345
2346                 if (!ipv4_is_multicast(iph->daddr))
2347                         return true;
2348
2349                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2350                 skb_reset_transport_header(skb);
2351                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2352                 spin_lock_bh(&amt->lock);
2353                 amt->ready4 = true;
2354                 amt->mac = amtmq->response_mac;
2355                 amt->req_cnt = 0;
2356                 amt->qi = ihv3->qqic;
2357                 spin_unlock_bh(&amt->lock);
2358                 skb->protocol = htons(ETH_P_IP);
2359                 eth->h_proto = htons(ETH_P_IP);
2360                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2361 #if IS_ENABLED(CONFIG_IPV6)
2362         } else if (iph->version == 6) {
2363                 struct mld2_query *mld2q;
2364                 struct ipv6hdr *ip6h;
2365
2366                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2367                                    sizeof(*mld2q)))
2368                         return true;
2369
2370                 ip6h = ipv6_hdr(skb);
2371                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2372                         return true;
2373
2374                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2375                 skb_reset_transport_header(skb);
2376                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2377                 spin_lock_bh(&amt->lock);
2378                 amt->ready6 = true;
2379                 amt->mac = amtmq->response_mac;
2380                 amt->req_cnt = 0;
2381                 amt->qi = mld2q->mld2q_qqic;
2382                 spin_unlock_bh(&amt->lock);
2383                 skb->protocol = htons(ETH_P_IPV6);
2384                 eth->h_proto = htons(ETH_P_IPV6);
2385                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2386 #endif
2387         } else {
2388                 return true;
2389         }
2390
2391         ether_addr_copy(eth->h_source, oeth->h_source);
2392         skb->pkt_type = PACKET_MULTICAST;
2393         skb->ip_summed = CHECKSUM_NONE;
2394         len = skb->len;
2395         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2396                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2397                 dev_sw_netstats_rx_add(amt->dev, len);
2398         } else {
2399                 amt->dev->stats.rx_dropped++;
2400         }
2401
2402         return false;
2403 }
2404
2405 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2406 {
2407         struct amt_header_membership_update *amtmu;
2408         struct amt_tunnel_list *tunnel;
2409         struct ethhdr *eth;
2410         struct iphdr *iph;
2411         int len, hdr_size;
2412
2413         iph = ip_hdr(skb);
2414
2415         hdr_size = sizeof(*amtmu) + sizeof(struct udphdr);
2416         if (!pskb_may_pull(skb, hdr_size))
2417                 return true;
2418
2419         amtmu = (struct amt_header_membership_update *)(udp_hdr(skb) + 1);
2420         if (amtmu->reserved || amtmu->version)
2421                 return true;
2422
2423         if (iptunnel_pull_header(skb, hdr_size, skb->protocol, false))
2424                 return true;
2425
2426         skb_reset_network_header(skb);
2427
2428         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2429                 if (tunnel->ip4 == iph->saddr) {
2430                         if ((amtmu->nonce == tunnel->nonce &&
2431                              amtmu->response_mac == tunnel->mac)) {
2432                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2433                                                  msecs_to_jiffies(amt_gmi(amt))
2434                                                                   * 3);
2435                                 goto report;
2436                         } else {
2437                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2438                                 return true;
2439                         }
2440                 }
2441         }
2442
2443         return true;
2444
2445 report:
2446         if (!pskb_may_pull(skb, sizeof(*iph)))
2447                 return true;
2448
2449         iph = ip_hdr(skb);
2450         if (iph->version == 4) {
2451                 if (ip_mc_check_igmp(skb)) {
2452                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2453                         return true;
2454                 }
2455
2456                 spin_lock_bh(&tunnel->lock);
2457                 amt_igmp_report_handler(amt, skb, tunnel);
2458                 spin_unlock_bh(&tunnel->lock);
2459
2460                 skb_push(skb, sizeof(struct ethhdr));
2461                 skb_reset_mac_header(skb);
2462                 eth = eth_hdr(skb);
2463                 skb->protocol = htons(ETH_P_IP);
2464                 eth->h_proto = htons(ETH_P_IP);
2465                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2466 #if IS_ENABLED(CONFIG_IPV6)
2467         } else if (iph->version == 6) {
2468                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2469
2470                 if (ipv6_mc_check_mld(skb)) {
2471                         netdev_dbg(amt->dev, "Invalid MLD\n");
2472                         return true;
2473                 }
2474
2475                 spin_lock_bh(&tunnel->lock);
2476                 amt_mld_report_handler(amt, skb, tunnel);
2477                 spin_unlock_bh(&tunnel->lock);
2478
2479                 skb_push(skb, sizeof(struct ethhdr));
2480                 skb_reset_mac_header(skb);
2481                 eth = eth_hdr(skb);
2482                 skb->protocol = htons(ETH_P_IPV6);
2483                 eth->h_proto = htons(ETH_P_IPV6);
2484                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2485 #endif
2486         } else {
2487                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2488                 return true;
2489         }
2490
2491         skb_pull(skb, sizeof(struct ethhdr));
2492         skb->pkt_type = PACKET_MULTICAST;
2493         skb->ip_summed = CHECKSUM_NONE;
2494         len = skb->len;
2495         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2496                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2497                                         true);
2498                 dev_sw_netstats_rx_add(amt->dev, len);
2499         } else {
2500                 amt->dev->stats.rx_dropped++;
2501         }
2502
2503         return false;
2504 }
2505
2506 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2507                                    __be32 daddr, __be16 dport)
2508 {
2509         struct amt_header_advertisement *amta;
2510         int hlen, tlen, offset;
2511         struct socket *sock;
2512         struct udphdr *udph;
2513         struct sk_buff *skb;
2514         struct iphdr *iph;
2515         struct rtable *rt;
2516         struct flowi4 fl4;
2517         u32 len;
2518         int err;
2519
2520         rcu_read_lock();
2521         sock = rcu_dereference(amt->sock);
2522         if (!sock)
2523                 goto out;
2524
2525         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2526                 goto out;
2527
2528         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2529                                    daddr, amt->local_ip,
2530                                    dport, amt->relay_port,
2531                                    IPPROTO_UDP, 0,
2532                                    amt->stream_dev->ifindex);
2533         if (IS_ERR(rt)) {
2534                 amt->dev->stats.tx_errors++;
2535                 goto out;
2536         }
2537
2538         hlen = LL_RESERVED_SPACE(amt->dev);
2539         tlen = amt->dev->needed_tailroom;
2540         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2541         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2542         if (!skb) {
2543                 ip_rt_put(rt);
2544                 amt->dev->stats.tx_errors++;
2545                 goto out;
2546         }
2547
2548         skb->priority = TC_PRIO_CONTROL;
2549         skb_dst_set(skb, &rt->dst);
2550
2551         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2552         skb_reset_network_header(skb);
2553         skb_put(skb, len);
2554         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2555         amta->version   = 0;
2556         amta->type      = AMT_MSG_ADVERTISEMENT;
2557         amta->reserved  = 0;
2558         amta->nonce     = nonce;
2559         amta->ip4       = amt->local_ip;
2560         skb_push(skb, sizeof(*udph));
2561         skb_reset_transport_header(skb);
2562         udph            = udp_hdr(skb);
2563         udph->source    = amt->relay_port;
2564         udph->dest      = dport;
2565         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2566         udph->check     = 0;
2567         offset = skb_transport_offset(skb);
2568         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2569         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2570                                         sizeof(*udph) + sizeof(*amta),
2571                                         IPPROTO_UDP, skb->csum);
2572
2573         skb_push(skb, sizeof(*iph));
2574         iph             = ip_hdr(skb);
2575         iph->version    = 4;
2576         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2577         iph->tos        = AMT_TOS;
2578         iph->frag_off   = 0;
2579         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2580         iph->daddr      = daddr;
2581         iph->saddr      = amt->local_ip;
2582         iph->protocol   = IPPROTO_UDP;
2583         iph->tot_len    = htons(len);
2584
2585         skb->ip_summed = CHECKSUM_NONE;
2586         ip_select_ident(amt->net, skb, NULL);
2587         ip_send_check(iph);
2588         err = ip_local_out(amt->net, sock->sk, skb);
2589         if (unlikely(net_xmit_eval(err)))
2590                 amt->dev->stats.tx_errors++;
2591
2592 out:
2593         rcu_read_unlock();
2594 }
2595
2596 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2597 {
2598         struct amt_header_discovery *amtd;
2599         struct udphdr *udph;
2600         struct iphdr *iph;
2601
2602         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2603                 return true;
2604
2605         iph = ip_hdr(skb);
2606         udph = udp_hdr(skb);
2607         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2608
2609         if (amtd->reserved || amtd->version)
2610                 return true;
2611
2612         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2613
2614         return false;
2615 }
2616
2617 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2618 {
2619         struct amt_header_request *amtrh;
2620         struct amt_tunnel_list *tunnel;
2621         unsigned long long key;
2622         struct udphdr *udph;
2623         struct iphdr *iph;
2624         u64 mac;
2625         int i;
2626
2627         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2628                 return true;
2629
2630         iph = ip_hdr(skb);
2631         udph = udp_hdr(skb);
2632         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2633
2634         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2635                 return true;
2636
2637         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2638                 if (tunnel->ip4 == iph->saddr)
2639                         goto send;
2640
2641         if (amt->nr_tunnels >= amt->max_tunnels) {
2642                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2643                 return true;
2644         }
2645
2646         tunnel = kzalloc(sizeof(*tunnel) +
2647                          (sizeof(struct hlist_head) * amt->hash_buckets),
2648                          GFP_ATOMIC);
2649         if (!tunnel)
2650                 return true;
2651
2652         tunnel->source_port = udph->source;
2653         tunnel->ip4 = iph->saddr;
2654
2655         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2656         tunnel->amt = amt;
2657         spin_lock_init(&tunnel->lock);
2658         for (i = 0; i < amt->hash_buckets; i++)
2659                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2660
2661         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2662
2663         spin_lock_bh(&amt->lock);
2664         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2665         tunnel->key = amt->key;
2666         amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2667         amt->nr_tunnels++;
2668         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2669                          msecs_to_jiffies(amt_gmi(amt)));
2670         spin_unlock_bh(&amt->lock);
2671
2672 send:
2673         tunnel->nonce = amtrh->nonce;
2674         mac = siphash_3u32((__force u32)tunnel->ip4,
2675                            (__force u32)tunnel->source_port,
2676                            (__force u32)tunnel->nonce,
2677                            &tunnel->key);
2678         tunnel->mac = mac >> 16;
2679
2680         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2681                 return true;
2682
2683         if (!amtrh->p)
2684                 amt_send_igmp_gq(amt, tunnel);
2685         else
2686                 amt_send_mld_gq(amt, tunnel);
2687
2688         return false;
2689 }
2690
2691 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2692 {
2693         struct amt_dev *amt;
2694         struct iphdr *iph;
2695         int type;
2696         bool err;
2697
2698         rcu_read_lock_bh();
2699         amt = rcu_dereference_sk_user_data(sk);
2700         if (!amt) {
2701                 err = true;
2702                 kfree_skb(skb);
2703                 goto out;
2704         }
2705
2706         skb->dev = amt->dev;
2707         iph = ip_hdr(skb);
2708         type = amt_parse_type(skb);
2709         if (type == -1) {
2710                 err = true;
2711                 goto drop;
2712         }
2713
2714         if (amt->mode == AMT_MODE_GATEWAY) {
2715                 switch (type) {
2716                 case AMT_MSG_ADVERTISEMENT:
2717                         if (iph->saddr != amt->discovery_ip) {
2718                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2719                                 err = true;
2720                                 goto drop;
2721                         }
2722                         err = amt_advertisement_handler(amt, skb);
2723                         break;
2724                 case AMT_MSG_MULTICAST_DATA:
2725                         if (iph->saddr != amt->remote_ip) {
2726                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2727                                 err = true;
2728                                 goto drop;
2729                         }
2730                         err = amt_multicast_data_handler(amt, skb);
2731                         if (err)
2732                                 goto drop;
2733                         else
2734                                 goto out;
2735                 case AMT_MSG_MEMBERSHIP_QUERY:
2736                         if (iph->saddr != amt->remote_ip) {
2737                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2738                                 err = true;
2739                                 goto drop;
2740                         }
2741                         err = amt_membership_query_handler(amt, skb);
2742                         if (err)
2743                                 goto drop;
2744                         else
2745                                 goto out;
2746                 default:
2747                         err = true;
2748                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2749                         break;
2750                 }
2751         } else {
2752                 switch (type) {
2753                 case AMT_MSG_DISCOVERY:
2754                         err = amt_discovery_handler(amt, skb);
2755                         break;
2756                 case AMT_MSG_REQUEST:
2757                         err = amt_request_handler(amt, skb);
2758                         break;
2759                 case AMT_MSG_MEMBERSHIP_UPDATE:
2760                         err = amt_update_handler(amt, skb);
2761                         if (err)
2762                                 goto drop;
2763                         else
2764                                 goto out;
2765                 default:
2766                         err = true;
2767                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2768                         break;
2769                 }
2770         }
2771 drop:
2772         if (err) {
2773                 amt->dev->stats.rx_dropped++;
2774                 kfree_skb(skb);
2775         } else {
2776                 consume_skb(skb);
2777         }
2778 out:
2779         rcu_read_unlock_bh();
2780         return 0;
2781 }
2782
2783 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2784 {
2785         struct amt_dev *amt;
2786         int type;
2787
2788         rcu_read_lock_bh();
2789         amt = rcu_dereference_sk_user_data(sk);
2790         if (!amt)
2791                 goto out;
2792
2793         if (amt->mode != AMT_MODE_GATEWAY)
2794                 goto drop;
2795
2796         type = amt_parse_type(skb);
2797         if (type == -1)
2798                 goto drop;
2799
2800         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2801                    type_str[type]);
2802         switch (type) {
2803         case AMT_MSG_DISCOVERY:
2804                 break;
2805         case AMT_MSG_REQUEST:
2806         case AMT_MSG_MEMBERSHIP_UPDATE:
2807                 if (amt->status >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2808                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2809                 break;
2810         default:
2811                 goto drop;
2812         }
2813 out:
2814         rcu_read_unlock_bh();
2815         return 0;
2816 drop:
2817         rcu_read_unlock_bh();
2818         amt->dev->stats.rx_dropped++;
2819         return 0;
2820 }
2821
2822 static struct socket *amt_create_sock(struct net *net, __be16 port)
2823 {
2824         struct udp_port_cfg udp_conf;
2825         struct socket *sock;
2826         int err;
2827
2828         memset(&udp_conf, 0, sizeof(udp_conf));
2829         udp_conf.family = AF_INET;
2830         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2831
2832         udp_conf.local_udp_port = port;
2833
2834         err = udp_sock_create(net, &udp_conf, &sock);
2835         if (err < 0)
2836                 return ERR_PTR(err);
2837
2838         return sock;
2839 }
2840
2841 static int amt_socket_create(struct amt_dev *amt)
2842 {
2843         struct udp_tunnel_sock_cfg tunnel_cfg;
2844         struct socket *sock;
2845
2846         sock = amt_create_sock(amt->net, amt->relay_port);
2847         if (IS_ERR(sock))
2848                 return PTR_ERR(sock);
2849
2850         /* Mark socket as an encapsulation socket */
2851         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2852         tunnel_cfg.sk_user_data = amt;
2853         tunnel_cfg.encap_type = 1;
2854         tunnel_cfg.encap_rcv = amt_rcv;
2855         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2856         tunnel_cfg.encap_destroy = NULL;
2857         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2858
2859         rcu_assign_pointer(amt->sock, sock);
2860         return 0;
2861 }
2862
2863 static int amt_dev_open(struct net_device *dev)
2864 {
2865         struct amt_dev *amt = netdev_priv(dev);
2866         int err;
2867
2868         amt->ready4 = false;
2869         amt->ready6 = false;
2870
2871         err = amt_socket_create(amt);
2872         if (err)
2873                 return err;
2874
2875         amt->req_cnt = 0;
2876         amt->remote_ip = 0;
2877         get_random_bytes(&amt->key, sizeof(siphash_key_t));
2878
2879         amt->status = AMT_STATUS_INIT;
2880         if (amt->mode == AMT_MODE_GATEWAY) {
2881                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
2882                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
2883         } else if (amt->mode == AMT_MODE_RELAY) {
2884                 mod_delayed_work(amt_wq, &amt->secret_wq,
2885                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
2886         }
2887         return err;
2888 }
2889
2890 static int amt_dev_stop(struct net_device *dev)
2891 {
2892         struct amt_dev *amt = netdev_priv(dev);
2893         struct amt_tunnel_list *tunnel, *tmp;
2894         struct socket *sock;
2895
2896         cancel_delayed_work_sync(&amt->req_wq);
2897         cancel_delayed_work_sync(&amt->discovery_wq);
2898         cancel_delayed_work_sync(&amt->secret_wq);
2899
2900         /* shutdown */
2901         sock = rtnl_dereference(amt->sock);
2902         RCU_INIT_POINTER(amt->sock, NULL);
2903         synchronize_net();
2904         if (sock)
2905                 udp_tunnel_sock_release(sock);
2906
2907         amt->ready4 = false;
2908         amt->ready6 = false;
2909         amt->req_cnt = 0;
2910         amt->remote_ip = 0;
2911
2912         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
2913                 list_del_rcu(&tunnel->list);
2914                 amt->nr_tunnels--;
2915                 cancel_delayed_work_sync(&tunnel->gc_wq);
2916                 amt_clear_groups(tunnel);
2917                 kfree_rcu(tunnel, rcu);
2918         }
2919
2920         return 0;
2921 }
2922
2923 static const struct device_type amt_type = {
2924         .name = "amt",
2925 };
2926
2927 static int amt_dev_init(struct net_device *dev)
2928 {
2929         struct amt_dev *amt = netdev_priv(dev);
2930         int err;
2931
2932         amt->dev = dev;
2933         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
2934         if (!dev->tstats)
2935                 return -ENOMEM;
2936
2937         err = gro_cells_init(&amt->gro_cells, dev);
2938         if (err) {
2939                 free_percpu(dev->tstats);
2940                 return err;
2941         }
2942
2943         return 0;
2944 }
2945
2946 static void amt_dev_uninit(struct net_device *dev)
2947 {
2948         struct amt_dev *amt = netdev_priv(dev);
2949
2950         gro_cells_destroy(&amt->gro_cells);
2951         free_percpu(dev->tstats);
2952 }
2953
2954 static const struct net_device_ops amt_netdev_ops = {
2955         .ndo_init               = amt_dev_init,
2956         .ndo_uninit             = amt_dev_uninit,
2957         .ndo_open               = amt_dev_open,
2958         .ndo_stop               = amt_dev_stop,
2959         .ndo_start_xmit         = amt_dev_xmit,
2960         .ndo_get_stats64        = dev_get_tstats64,
2961 };
2962
2963 static void amt_link_setup(struct net_device *dev)
2964 {
2965         dev->netdev_ops         = &amt_netdev_ops;
2966         dev->needs_free_netdev  = true;
2967         SET_NETDEV_DEVTYPE(dev, &amt_type);
2968         dev->min_mtu            = ETH_MIN_MTU;
2969         dev->max_mtu            = ETH_MAX_MTU;
2970         dev->type               = ARPHRD_NONE;
2971         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
2972         dev->hard_header_len    = 0;
2973         dev->addr_len           = 0;
2974         dev->priv_flags         |= IFF_NO_QUEUE;
2975         dev->features           |= NETIF_F_LLTX;
2976         dev->features           |= NETIF_F_GSO_SOFTWARE;
2977         dev->features           |= NETIF_F_NETNS_LOCAL;
2978         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
2979         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
2980         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
2981         eth_hw_addr_random(dev);
2982         eth_zero_addr(dev->broadcast);
2983         ether_setup(dev);
2984 }
2985
2986 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
2987         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
2988         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
2989         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
2990         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
2991         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
2992         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
2993         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
2994         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
2995 };
2996
2997 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
2998                         struct netlink_ext_ack *extack)
2999 {
3000         if (!data)
3001                 return -EINVAL;
3002
3003         if (!data[IFLA_AMT_LINK]) {
3004                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
3005                                     "Link attribute is required");
3006                 return -EINVAL;
3007         }
3008
3009         if (!data[IFLA_AMT_MODE]) {
3010                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3011                                     "Mode attribute is required");
3012                 return -EINVAL;
3013         }
3014
3015         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
3016                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3017                                     "Mode attribute is not valid");
3018                 return -EINVAL;
3019         }
3020
3021         if (!data[IFLA_AMT_LOCAL_IP]) {
3022                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3023                                     "Local attribute is required");
3024                 return -EINVAL;
3025         }
3026
3027         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3028             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3029                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3030                                     "Discovery attribute is required");
3031                 return -EINVAL;
3032         }
3033
3034         return 0;
3035 }
3036
3037 static int amt_newlink(struct net *net, struct net_device *dev,
3038                        struct nlattr *tb[], struct nlattr *data[],
3039                        struct netlink_ext_ack *extack)
3040 {
3041         struct amt_dev *amt = netdev_priv(dev);
3042         int err = -EINVAL;
3043
3044         amt->net = net;
3045         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3046
3047         if (data[IFLA_AMT_MAX_TUNNELS] &&
3048             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3049                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3050         else
3051                 amt->max_tunnels = AMT_MAX_TUNNELS;
3052
3053         spin_lock_init(&amt->lock);
3054         amt->max_groups = AMT_MAX_GROUP;
3055         amt->max_sources = AMT_MAX_SOURCE;
3056         amt->hash_buckets = AMT_HSIZE;
3057         amt->nr_tunnels = 0;
3058         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3059         amt->stream_dev = dev_get_by_index(net,
3060                                            nla_get_u32(data[IFLA_AMT_LINK]));
3061         if (!amt->stream_dev) {
3062                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3063                                     "Can't find stream device");
3064                 return -ENODEV;
3065         }
3066
3067         if (amt->stream_dev->type != ARPHRD_ETHER) {
3068                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3069                                     "Invalid stream device type");
3070                 goto err;
3071         }
3072
3073         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3074         if (ipv4_is_loopback(amt->local_ip) ||
3075             ipv4_is_zeronet(amt->local_ip) ||
3076             ipv4_is_multicast(amt->local_ip)) {
3077                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3078                                     "Invalid Local address");
3079                 goto err;
3080         }
3081
3082         if (data[IFLA_AMT_RELAY_PORT])
3083                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3084         else
3085                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3086
3087         if (data[IFLA_AMT_GATEWAY_PORT])
3088                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3089         else
3090                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3091
3092         if (!amt->relay_port) {
3093                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3094                                     "relay port must not be 0");
3095                 goto err;
3096         }
3097         if (amt->mode == AMT_MODE_RELAY) {
3098                 amt->qrv = amt->net->ipv4.sysctl_igmp_qrv;
3099                 amt->qri = 10;
3100                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3101                                        AMT_RELAY_HLEN;
3102                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3103                 dev->max_mtu = dev->mtu;
3104                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3105         } else {
3106                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3107                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3108                                             "discovery must be set in gateway mode");
3109                         goto err;
3110                 }
3111                 if (!amt->gw_port) {
3112                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3113                                             "gateway port must not be 0");
3114                         goto err;
3115                 }
3116                 amt->remote_ip = 0;
3117                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3118                 if (ipv4_is_loopback(amt->discovery_ip) ||
3119                     ipv4_is_zeronet(amt->discovery_ip) ||
3120                     ipv4_is_multicast(amt->discovery_ip)) {
3121                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3122                                             "discovery must be unicast");
3123                         goto err;
3124                 }
3125
3126                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3127                                        AMT_GW_HLEN;
3128                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3129                 dev->max_mtu = dev->mtu;
3130                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3131         }
3132         amt->qi = AMT_INIT_QUERY_INTERVAL;
3133
3134         err = register_netdevice(dev);
3135         if (err < 0) {
3136                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3137                 goto err;
3138         }
3139
3140         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3141         if (err < 0) {
3142                 unregister_netdevice(dev);
3143                 goto err;
3144         }
3145
3146         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3147         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3148         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3149         INIT_LIST_HEAD(&amt->tunnel_list);
3150
3151         return 0;
3152 err:
3153         dev_put(amt->stream_dev);
3154         return err;
3155 }
3156
3157 static void amt_dellink(struct net_device *dev, struct list_head *head)
3158 {
3159         struct amt_dev *amt = netdev_priv(dev);
3160
3161         unregister_netdevice_queue(dev, head);
3162         netdev_upper_dev_unlink(amt->stream_dev, dev);
3163         dev_put(amt->stream_dev);
3164 }
3165
3166 static size_t amt_get_size(const struct net_device *dev)
3167 {
3168         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3169                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3170                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3171                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3172                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3173                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3174                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3175                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3176 }
3177
3178 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3179 {
3180         struct amt_dev *amt = netdev_priv(dev);
3181
3182         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3183                 goto nla_put_failure;
3184         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3185                 goto nla_put_failure;
3186         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3187                 goto nla_put_failure;
3188         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3189                 goto nla_put_failure;
3190         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3191                 goto nla_put_failure;
3192         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3193                 goto nla_put_failure;
3194         if (amt->remote_ip)
3195                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3196                         goto nla_put_failure;
3197         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3198                 goto nla_put_failure;
3199
3200         return 0;
3201
3202 nla_put_failure:
3203         return -EMSGSIZE;
3204 }
3205
3206 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3207         .kind           = "amt",
3208         .maxtype        = IFLA_AMT_MAX,
3209         .policy         = amt_policy,
3210         .priv_size      = sizeof(struct amt_dev),
3211         .setup          = amt_link_setup,
3212         .validate       = amt_validate,
3213         .newlink        = amt_newlink,
3214         .dellink        = amt_dellink,
3215         .get_size       = amt_get_size,
3216         .fill_info      = amt_fill_info,
3217 };
3218
3219 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3220 {
3221         struct net_device *upper_dev;
3222         struct amt_dev *amt;
3223
3224         for_each_netdev(dev_net(dev), upper_dev) {
3225                 if (netif_is_amt(upper_dev)) {
3226                         amt = netdev_priv(upper_dev);
3227                         if (amt->stream_dev == dev)
3228                                 return upper_dev;
3229                 }
3230         }
3231
3232         return NULL;
3233 }
3234
3235 static int amt_device_event(struct notifier_block *unused,
3236                             unsigned long event, void *ptr)
3237 {
3238         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3239         struct net_device *upper_dev;
3240         struct amt_dev *amt;
3241         LIST_HEAD(list);
3242         int new_mtu;
3243
3244         upper_dev = amt_lookup_upper_dev(dev);
3245         if (!upper_dev)
3246                 return NOTIFY_DONE;
3247         amt = netdev_priv(upper_dev);
3248
3249         switch (event) {
3250         case NETDEV_UNREGISTER:
3251                 amt_dellink(amt->dev, &list);
3252                 unregister_netdevice_many(&list);
3253                 break;
3254         case NETDEV_CHANGEMTU:
3255                 if (amt->mode == AMT_MODE_RELAY)
3256                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3257                 else
3258                         new_mtu = dev->mtu - AMT_GW_HLEN;
3259
3260                 dev_set_mtu(amt->dev, new_mtu);
3261                 break;
3262         }
3263
3264         return NOTIFY_DONE;
3265 }
3266
3267 static struct notifier_block amt_notifier_block __read_mostly = {
3268         .notifier_call = amt_device_event,
3269 };
3270
3271 static int __init amt_init(void)
3272 {
3273         int err;
3274
3275         err = register_netdevice_notifier(&amt_notifier_block);
3276         if (err < 0)
3277                 goto err;
3278
3279         err = rtnl_link_register(&amt_link_ops);
3280         if (err < 0)
3281                 goto unregister_notifier;
3282
3283         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 1);
3284         if (!amt_wq) {
3285                 err = -ENOMEM;
3286                 goto rtnl_unregister;
3287         }
3288
3289         spin_lock_init(&source_gc_lock);
3290         spin_lock_bh(&source_gc_lock);
3291         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3292         mod_delayed_work(amt_wq, &source_gc_wq,
3293                          msecs_to_jiffies(AMT_GC_INTERVAL));
3294         spin_unlock_bh(&source_gc_lock);
3295
3296         return 0;
3297
3298 rtnl_unregister:
3299         rtnl_link_unregister(&amt_link_ops);
3300 unregister_notifier:
3301         unregister_netdevice_notifier(&amt_notifier_block);
3302 err:
3303         pr_err("error loading AMT module loaded\n");
3304         return err;
3305 }
3306 late_initcall(amt_init);
3307
3308 static void __exit amt_fini(void)
3309 {
3310         rtnl_link_unregister(&amt_link_ops);
3311         unregister_netdevice_notifier(&amt_notifier_block);
3312         cancel_delayed_work_sync(&source_gc_wq);
3313         __amt_source_gc_work();
3314         destroy_workqueue(amt_wq);
3315 }
3316 module_exit(amt_fini);
3317
3318 MODULE_LICENSE("GPL");
3319 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3320 MODULE_ALIAS_RTNL_LINK("amt");