GNU Linux-libre 6.8.7-gnu
[releases.git] / drivers / net / vxlan / vxlan_vnifilter.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  *      Vxlan vni filter for collect metadata mode
4  *
5  *      Authors: Roopa Prabhu <roopa@nvidia.com>
6  *
7  */
8
9 #include <linux/kernel.h>
10 #include <linux/slab.h>
11 #include <linux/etherdevice.h>
12 #include <linux/rhashtable.h>
13 #include <net/rtnetlink.h>
14 #include <net/net_namespace.h>
15 #include <net/sock.h>
16 #include <net/vxlan.h>
17
18 #include "vxlan_private.h"
19
20 static inline int vxlan_vni_cmp(struct rhashtable_compare_arg *arg,
21                                 const void *ptr)
22 {
23         const struct vxlan_vni_node *vnode = ptr;
24         __be32 vni = *(__be32 *)arg->key;
25
26         return vnode->vni != vni;
27 }
28
29 const struct rhashtable_params vxlan_vni_rht_params = {
30         .head_offset = offsetof(struct vxlan_vni_node, vnode),
31         .key_offset = offsetof(struct vxlan_vni_node, vni),
32         .key_len = sizeof(__be32),
33         .nelem_hint = 3,
34         .max_size = VXLAN_N_VID,
35         .obj_cmpfn = vxlan_vni_cmp,
36         .automatic_shrinking = true,
37 };
38
39 static void vxlan_vs_add_del_vninode(struct vxlan_dev *vxlan,
40                                      struct vxlan_vni_node *v,
41                                      bool del)
42 {
43         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
44         struct vxlan_dev_node *node;
45         struct vxlan_sock *vs;
46
47         spin_lock(&vn->sock_lock);
48         if (del) {
49                 if (!hlist_unhashed(&v->hlist4.hlist))
50                         hlist_del_init_rcu(&v->hlist4.hlist);
51 #if IS_ENABLED(CONFIG_IPV6)
52                 if (!hlist_unhashed(&v->hlist6.hlist))
53                         hlist_del_init_rcu(&v->hlist6.hlist);
54 #endif
55                 goto out;
56         }
57
58 #if IS_ENABLED(CONFIG_IPV6)
59         vs = rtnl_dereference(vxlan->vn6_sock);
60         if (vs && v) {
61                 node = &v->hlist6;
62                 hlist_add_head_rcu(&node->hlist, vni_head(vs, v->vni));
63         }
64 #endif
65         vs = rtnl_dereference(vxlan->vn4_sock);
66         if (vs && v) {
67                 node = &v->hlist4;
68                 hlist_add_head_rcu(&node->hlist, vni_head(vs, v->vni));
69         }
70 out:
71         spin_unlock(&vn->sock_lock);
72 }
73
74 void vxlan_vs_add_vnigrp(struct vxlan_dev *vxlan,
75                          struct vxlan_sock *vs,
76                          bool ipv6)
77 {
78         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
79         struct vxlan_vni_group *vg = rtnl_dereference(vxlan->vnigrp);
80         struct vxlan_vni_node *v, *tmp;
81         struct vxlan_dev_node *node;
82
83         if (!vg)
84                 return;
85
86         spin_lock(&vn->sock_lock);
87         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
88 #if IS_ENABLED(CONFIG_IPV6)
89                 if (ipv6)
90                         node = &v->hlist6;
91                 else
92 #endif
93                         node = &v->hlist4;
94                 node->vxlan = vxlan;
95                 hlist_add_head_rcu(&node->hlist, vni_head(vs, v->vni));
96         }
97         spin_unlock(&vn->sock_lock);
98 }
99
100 void vxlan_vs_del_vnigrp(struct vxlan_dev *vxlan)
101 {
102         struct vxlan_vni_group *vg = rtnl_dereference(vxlan->vnigrp);
103         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
104         struct vxlan_vni_node *v, *tmp;
105
106         if (!vg)
107                 return;
108
109         spin_lock(&vn->sock_lock);
110         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
111                 hlist_del_init_rcu(&v->hlist4.hlist);
112 #if IS_ENABLED(CONFIG_IPV6)
113                 hlist_del_init_rcu(&v->hlist6.hlist);
114 #endif
115         }
116         spin_unlock(&vn->sock_lock);
117 }
118
119 static void vxlan_vnifilter_stats_get(const struct vxlan_vni_node *vninode,
120                                       struct vxlan_vni_stats *dest)
121 {
122         int i;
123
124         memset(dest, 0, sizeof(*dest));
125         for_each_possible_cpu(i) {
126                 struct vxlan_vni_stats_pcpu *pstats;
127                 struct vxlan_vni_stats temp;
128                 unsigned int start;
129
130                 pstats = per_cpu_ptr(vninode->stats, i);
131                 do {
132                         start = u64_stats_fetch_begin(&pstats->syncp);
133                         memcpy(&temp, &pstats->stats, sizeof(temp));
134                 } while (u64_stats_fetch_retry(&pstats->syncp, start));
135
136                 dest->rx_packets += temp.rx_packets;
137                 dest->rx_bytes += temp.rx_bytes;
138                 dest->rx_drops += temp.rx_drops;
139                 dest->rx_errors += temp.rx_errors;
140                 dest->tx_packets += temp.tx_packets;
141                 dest->tx_bytes += temp.tx_bytes;
142                 dest->tx_drops += temp.tx_drops;
143                 dest->tx_errors += temp.tx_errors;
144         }
145 }
146
147 static void vxlan_vnifilter_stats_add(struct vxlan_vni_node *vninode,
148                                       int type, unsigned int len)
149 {
150         struct vxlan_vni_stats_pcpu *pstats = this_cpu_ptr(vninode->stats);
151
152         u64_stats_update_begin(&pstats->syncp);
153         switch (type) {
154         case VXLAN_VNI_STATS_RX:
155                 pstats->stats.rx_bytes += len;
156                 pstats->stats.rx_packets++;
157                 break;
158         case VXLAN_VNI_STATS_RX_DROPS:
159                 pstats->stats.rx_drops++;
160                 break;
161         case VXLAN_VNI_STATS_RX_ERRORS:
162                 pstats->stats.rx_errors++;
163                 break;
164         case VXLAN_VNI_STATS_TX:
165                 pstats->stats.tx_bytes += len;
166                 pstats->stats.tx_packets++;
167                 break;
168         case VXLAN_VNI_STATS_TX_DROPS:
169                 pstats->stats.tx_drops++;
170                 break;
171         case VXLAN_VNI_STATS_TX_ERRORS:
172                 pstats->stats.tx_errors++;
173                 break;
174         }
175         u64_stats_update_end(&pstats->syncp);
176 }
177
178 void vxlan_vnifilter_count(struct vxlan_dev *vxlan, __be32 vni,
179                            struct vxlan_vni_node *vninode,
180                            int type, unsigned int len)
181 {
182         struct vxlan_vni_node *vnode;
183
184         if (!(vxlan->cfg.flags & VXLAN_F_VNIFILTER))
185                 return;
186
187         if (vninode) {
188                 vnode = vninode;
189         } else {
190                 vnode = vxlan_vnifilter_lookup(vxlan, vni);
191                 if (!vnode)
192                         return;
193         }
194
195         vxlan_vnifilter_stats_add(vnode, type, len);
196 }
197
198 static u32 vnirange(struct vxlan_vni_node *vbegin,
199                     struct vxlan_vni_node *vend)
200 {
201         return (be32_to_cpu(vend->vni) - be32_to_cpu(vbegin->vni));
202 }
203
204 static size_t vxlan_vnifilter_entry_nlmsg_size(void)
205 {
206         return NLMSG_ALIGN(sizeof(struct tunnel_msg))
207                 + nla_total_size(0) /* VXLAN_VNIFILTER_ENTRY */
208                 + nla_total_size(sizeof(u32)) /* VXLAN_VNIFILTER_ENTRY_START */
209                 + nla_total_size(sizeof(u32)) /* VXLAN_VNIFILTER_ENTRY_END */
210                 + nla_total_size(sizeof(struct in6_addr));/* VXLAN_VNIFILTER_ENTRY_GROUP{6} */
211 }
212
213 static int __vnifilter_entry_fill_stats(struct sk_buff *skb,
214                                         const struct vxlan_vni_node *vbegin)
215 {
216         struct vxlan_vni_stats vstats;
217         struct nlattr *vstats_attr;
218
219         vstats_attr = nla_nest_start(skb, VXLAN_VNIFILTER_ENTRY_STATS);
220         if (!vstats_attr)
221                 goto out_stats_err;
222
223         vxlan_vnifilter_stats_get(vbegin, &vstats);
224         if (nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_RX_BYTES,
225                               vstats.rx_bytes, VNIFILTER_ENTRY_STATS_PAD) ||
226             nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_RX_PKTS,
227                               vstats.rx_packets, VNIFILTER_ENTRY_STATS_PAD) ||
228             nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_RX_DROPS,
229                               vstats.rx_drops, VNIFILTER_ENTRY_STATS_PAD) ||
230             nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_RX_ERRORS,
231                               vstats.rx_errors, VNIFILTER_ENTRY_STATS_PAD) ||
232             nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_TX_BYTES,
233                               vstats.tx_bytes, VNIFILTER_ENTRY_STATS_PAD) ||
234             nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_TX_PKTS,
235                               vstats.tx_packets, VNIFILTER_ENTRY_STATS_PAD) ||
236             nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_TX_DROPS,
237                               vstats.tx_drops, VNIFILTER_ENTRY_STATS_PAD) ||
238             nla_put_u64_64bit(skb, VNIFILTER_ENTRY_STATS_TX_ERRORS,
239                               vstats.tx_errors, VNIFILTER_ENTRY_STATS_PAD))
240                 goto out_stats_err;
241
242         nla_nest_end(skb, vstats_attr);
243
244         return 0;
245
246 out_stats_err:
247         nla_nest_cancel(skb, vstats_attr);
248         return -EMSGSIZE;
249 }
250
251 static bool vxlan_fill_vni_filter_entry(struct sk_buff *skb,
252                                         struct vxlan_vni_node *vbegin,
253                                         struct vxlan_vni_node *vend,
254                                         bool fill_stats)
255 {
256         struct nlattr *ventry;
257         u32 vs = be32_to_cpu(vbegin->vni);
258         u32 ve = 0;
259
260         if (vbegin != vend)
261                 ve = be32_to_cpu(vend->vni);
262
263         ventry = nla_nest_start(skb, VXLAN_VNIFILTER_ENTRY);
264         if (!ventry)
265                 return false;
266
267         if (nla_put_u32(skb, VXLAN_VNIFILTER_ENTRY_START, vs))
268                 goto out_err;
269
270         if (ve && nla_put_u32(skb, VXLAN_VNIFILTER_ENTRY_END, ve))
271                 goto out_err;
272
273         if (!vxlan_addr_any(&vbegin->remote_ip)) {
274                 if (vbegin->remote_ip.sa.sa_family == AF_INET) {
275                         if (nla_put_in_addr(skb, VXLAN_VNIFILTER_ENTRY_GROUP,
276                                             vbegin->remote_ip.sin.sin_addr.s_addr))
277                                 goto out_err;
278 #if IS_ENABLED(CONFIG_IPV6)
279                 } else {
280                         if (nla_put_in6_addr(skb, VXLAN_VNIFILTER_ENTRY_GROUP6,
281                                              &vbegin->remote_ip.sin6.sin6_addr))
282                                 goto out_err;
283 #endif
284                 }
285         }
286
287         if (fill_stats && __vnifilter_entry_fill_stats(skb, vbegin))
288                 goto out_err;
289
290         nla_nest_end(skb, ventry);
291
292         return true;
293
294 out_err:
295         nla_nest_cancel(skb, ventry);
296
297         return false;
298 }
299
300 static void vxlan_vnifilter_notify(const struct vxlan_dev *vxlan,
301                                    struct vxlan_vni_node *vninode, int cmd)
302 {
303         struct tunnel_msg *tmsg;
304         struct sk_buff *skb;
305         struct nlmsghdr *nlh;
306         struct net *net = dev_net(vxlan->dev);
307         int err = -ENOBUFS;
308
309         skb = nlmsg_new(vxlan_vnifilter_entry_nlmsg_size(), GFP_KERNEL);
310         if (!skb)
311                 goto out_err;
312
313         err = -EMSGSIZE;
314         nlh = nlmsg_put(skb, 0, 0, cmd, sizeof(*tmsg), 0);
315         if (!nlh)
316                 goto out_err;
317         tmsg = nlmsg_data(nlh);
318         memset(tmsg, 0, sizeof(*tmsg));
319         tmsg->family = AF_BRIDGE;
320         tmsg->ifindex = vxlan->dev->ifindex;
321
322         if (!vxlan_fill_vni_filter_entry(skb, vninode, vninode, false))
323                 goto out_err;
324
325         nlmsg_end(skb, nlh);
326         rtnl_notify(skb, net, 0, RTNLGRP_TUNNEL, NULL, GFP_KERNEL);
327
328         return;
329
330 out_err:
331         rtnl_set_sk_err(net, RTNLGRP_TUNNEL, err);
332
333         kfree_skb(skb);
334 }
335
336 static int vxlan_vnifilter_dump_dev(const struct net_device *dev,
337                                     struct sk_buff *skb,
338                                     struct netlink_callback *cb)
339 {
340         struct vxlan_vni_node *tmp, *v, *vbegin = NULL, *vend = NULL;
341         struct vxlan_dev *vxlan = netdev_priv(dev);
342         struct tunnel_msg *new_tmsg, *tmsg;
343         int idx = 0, s_idx = cb->args[1];
344         struct vxlan_vni_group *vg;
345         struct nlmsghdr *nlh;
346         bool dump_stats;
347         int err = 0;
348
349         if (!(vxlan->cfg.flags & VXLAN_F_VNIFILTER))
350                 return -EINVAL;
351
352         /* RCU needed because of the vni locking rules (rcu || rtnl) */
353         vg = rcu_dereference(vxlan->vnigrp);
354         if (!vg || !vg->num_vnis)
355                 return 0;
356
357         tmsg = nlmsg_data(cb->nlh);
358         dump_stats = !!(tmsg->flags & TUNNEL_MSG_FLAG_STATS);
359
360         nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
361                         RTM_NEWTUNNEL, sizeof(*new_tmsg), NLM_F_MULTI);
362         if (!nlh)
363                 return -EMSGSIZE;
364         new_tmsg = nlmsg_data(nlh);
365         memset(new_tmsg, 0, sizeof(*new_tmsg));
366         new_tmsg->family = PF_BRIDGE;
367         new_tmsg->ifindex = dev->ifindex;
368
369         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
370                 if (idx < s_idx) {
371                         idx++;
372                         continue;
373                 }
374                 if (!vbegin) {
375                         vbegin = v;
376                         vend = v;
377                         continue;
378                 }
379                 if (!dump_stats && vnirange(vend, v) == 1 &&
380                     vxlan_addr_equal(&v->remote_ip, &vend->remote_ip)) {
381                         goto update_end;
382                 } else {
383                         if (!vxlan_fill_vni_filter_entry(skb, vbegin, vend,
384                                                          dump_stats)) {
385                                 err = -EMSGSIZE;
386                                 break;
387                         }
388                         idx += vnirange(vbegin, vend) + 1;
389                         vbegin = v;
390                 }
391 update_end:
392                 vend = v;
393         }
394
395         if (!err && vbegin) {
396                 if (!vxlan_fill_vni_filter_entry(skb, vbegin, vend, dump_stats))
397                         err = -EMSGSIZE;
398         }
399
400         cb->args[1] = err ? idx : 0;
401
402         nlmsg_end(skb, nlh);
403
404         return err;
405 }
406
407 static int vxlan_vnifilter_dump(struct sk_buff *skb, struct netlink_callback *cb)
408 {
409         int idx = 0, err = 0, s_idx = cb->args[0];
410         struct net *net = sock_net(skb->sk);
411         struct tunnel_msg *tmsg;
412         struct net_device *dev;
413
414         tmsg = nlmsg_data(cb->nlh);
415
416         if (tmsg->flags & ~TUNNEL_MSG_VALID_USER_FLAGS) {
417                 NL_SET_ERR_MSG(cb->extack, "Invalid tunnelmsg flags in ancillary header");
418                 return -EINVAL;
419         }
420
421         rcu_read_lock();
422         if (tmsg->ifindex) {
423                 dev = dev_get_by_index_rcu(net, tmsg->ifindex);
424                 if (!dev) {
425                         err = -ENODEV;
426                         goto out_err;
427                 }
428                 if (!netif_is_vxlan(dev)) {
429                         NL_SET_ERR_MSG(cb->extack,
430                                        "The device is not a vxlan device");
431                         err = -EINVAL;
432                         goto out_err;
433                 }
434                 err = vxlan_vnifilter_dump_dev(dev, skb, cb);
435                 /* if the dump completed without an error we return 0 here */
436                 if (err != -EMSGSIZE)
437                         goto out_err;
438         } else {
439                 for_each_netdev_rcu(net, dev) {
440                         if (!netif_is_vxlan(dev))
441                                 continue;
442                         if (idx < s_idx)
443                                 goto skip;
444                         err = vxlan_vnifilter_dump_dev(dev, skb, cb);
445                         if (err == -EMSGSIZE)
446                                 break;
447 skip:
448                         idx++;
449                 }
450         }
451         cb->args[0] = idx;
452         rcu_read_unlock();
453
454         return skb->len;
455
456 out_err:
457         rcu_read_unlock();
458
459         return err;
460 }
461
462 static const struct nla_policy vni_filter_entry_policy[VXLAN_VNIFILTER_ENTRY_MAX + 1] = {
463         [VXLAN_VNIFILTER_ENTRY_START] = { .type = NLA_U32 },
464         [VXLAN_VNIFILTER_ENTRY_END] = { .type = NLA_U32 },
465         [VXLAN_VNIFILTER_ENTRY_GROUP]   = { .type = NLA_BINARY,
466                                             .len = sizeof_field(struct iphdr, daddr) },
467         [VXLAN_VNIFILTER_ENTRY_GROUP6]  = { .type = NLA_BINARY,
468                                             .len = sizeof(struct in6_addr) },
469 };
470
471 static const struct nla_policy vni_filter_policy[VXLAN_VNIFILTER_MAX + 1] = {
472         [VXLAN_VNIFILTER_ENTRY] = { .type = NLA_NESTED },
473 };
474
475 static int vxlan_update_default_fdb_entry(struct vxlan_dev *vxlan, __be32 vni,
476                                           union vxlan_addr *old_remote_ip,
477                                           union vxlan_addr *remote_ip,
478                                           struct netlink_ext_ack *extack)
479 {
480         struct vxlan_rdst *dst = &vxlan->default_dst;
481         u32 hash_index;
482         int err = 0;
483
484         hash_index = fdb_head_index(vxlan, all_zeros_mac, vni);
485         spin_lock_bh(&vxlan->hash_lock[hash_index]);
486         if (remote_ip && !vxlan_addr_any(remote_ip)) {
487                 err = vxlan_fdb_update(vxlan, all_zeros_mac,
488                                        remote_ip,
489                                        NUD_REACHABLE | NUD_PERMANENT,
490                                        NLM_F_APPEND | NLM_F_CREATE,
491                                        vxlan->cfg.dst_port,
492                                        vni,
493                                        vni,
494                                        dst->remote_ifindex,
495                                        NTF_SELF, 0, true, extack);
496                 if (err) {
497                         spin_unlock_bh(&vxlan->hash_lock[hash_index]);
498                         return err;
499                 }
500         }
501
502         if (old_remote_ip && !vxlan_addr_any(old_remote_ip)) {
503                 __vxlan_fdb_delete(vxlan, all_zeros_mac,
504                                    *old_remote_ip,
505                                    vxlan->cfg.dst_port,
506                                    vni, vni,
507                                    dst->remote_ifindex,
508                                    true);
509         }
510         spin_unlock_bh(&vxlan->hash_lock[hash_index]);
511
512         return err;
513 }
514
515 static int vxlan_vni_update_group(struct vxlan_dev *vxlan,
516                                   struct vxlan_vni_node *vninode,
517                                   union vxlan_addr *group,
518                                   bool create, bool *changed,
519                                   struct netlink_ext_ack *extack)
520 {
521         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
522         struct vxlan_rdst *dst = &vxlan->default_dst;
523         union vxlan_addr *newrip = NULL, *oldrip = NULL;
524         union vxlan_addr old_remote_ip;
525         int ret = 0;
526
527         memcpy(&old_remote_ip, &vninode->remote_ip, sizeof(old_remote_ip));
528
529         /* if per vni remote ip is not present use vxlan dev
530          * default dst remote ip for fdb entry
531          */
532         if (group && !vxlan_addr_any(group)) {
533                 newrip = group;
534         } else {
535                 if (!vxlan_addr_any(&dst->remote_ip))
536                         newrip = &dst->remote_ip;
537         }
538
539         /* if old rip exists, and no newrip,
540          * explicitly delete old rip
541          */
542         if (!newrip && !vxlan_addr_any(&old_remote_ip))
543                 oldrip = &old_remote_ip;
544
545         if (!newrip && !oldrip)
546                 return 0;
547
548         if (!create && oldrip && newrip && vxlan_addr_equal(oldrip, newrip))
549                 return 0;
550
551         ret = vxlan_update_default_fdb_entry(vxlan, vninode->vni,
552                                              oldrip, newrip,
553                                              extack);
554         if (ret)
555                 goto out;
556
557         if (group)
558                 memcpy(&vninode->remote_ip, group, sizeof(vninode->remote_ip));
559
560         if (vxlan->dev->flags & IFF_UP) {
561                 if (vxlan_addr_multicast(&old_remote_ip) &&
562                     !vxlan_group_used(vn, vxlan, vninode->vni,
563                                       &old_remote_ip,
564                                       vxlan->default_dst.remote_ifindex)) {
565                         ret = vxlan_igmp_leave(vxlan, &old_remote_ip,
566                                                0);
567                         if (ret)
568                                 goto out;
569                 }
570
571                 if (vxlan_addr_multicast(&vninode->remote_ip)) {
572                         ret = vxlan_igmp_join(vxlan, &vninode->remote_ip, 0);
573                         if (ret == -EADDRINUSE)
574                                 ret = 0;
575                         if (ret)
576                                 goto out;
577                 }
578         }
579
580         *changed = true;
581
582         return 0;
583 out:
584         return ret;
585 }
586
587 int vxlan_vnilist_update_group(struct vxlan_dev *vxlan,
588                                union vxlan_addr *old_remote_ip,
589                                union vxlan_addr *new_remote_ip,
590                                struct netlink_ext_ack *extack)
591 {
592         struct list_head *headp, *hpos;
593         struct vxlan_vni_group *vg;
594         struct vxlan_vni_node *vent;
595         int ret;
596
597         vg = rtnl_dereference(vxlan->vnigrp);
598
599         headp = &vg->vni_list;
600         list_for_each_prev(hpos, headp) {
601                 vent = list_entry(hpos, struct vxlan_vni_node, vlist);
602                 if (vxlan_addr_any(&vent->remote_ip)) {
603                         ret = vxlan_update_default_fdb_entry(vxlan, vent->vni,
604                                                              old_remote_ip,
605                                                              new_remote_ip,
606                                                              extack);
607                         if (ret)
608                                 return ret;
609                 }
610         }
611
612         return 0;
613 }
614
615 static void vxlan_vni_delete_group(struct vxlan_dev *vxlan,
616                                    struct vxlan_vni_node *vninode)
617 {
618         struct vxlan_net *vn = net_generic(vxlan->net, vxlan_net_id);
619         struct vxlan_rdst *dst = &vxlan->default_dst;
620
621         /* if per vni remote_ip not present, delete the
622          * default dst remote_ip previously added for this vni
623          */
624         if (!vxlan_addr_any(&vninode->remote_ip) ||
625             !vxlan_addr_any(&dst->remote_ip))
626                 __vxlan_fdb_delete(vxlan, all_zeros_mac,
627                                    (vxlan_addr_any(&vninode->remote_ip) ?
628                                    dst->remote_ip : vninode->remote_ip),
629                                    vxlan->cfg.dst_port,
630                                    vninode->vni, vninode->vni,
631                                    dst->remote_ifindex,
632                                    true);
633
634         if (vxlan->dev->flags & IFF_UP) {
635                 if (vxlan_addr_multicast(&vninode->remote_ip) &&
636                     !vxlan_group_used(vn, vxlan, vninode->vni,
637                                       &vninode->remote_ip,
638                                       dst->remote_ifindex)) {
639                         vxlan_igmp_leave(vxlan, &vninode->remote_ip, 0);
640                 }
641         }
642 }
643
644 static int vxlan_vni_update(struct vxlan_dev *vxlan,
645                             struct vxlan_vni_group *vg,
646                             __be32 vni, union vxlan_addr *group,
647                             bool *changed,
648                             struct netlink_ext_ack *extack)
649 {
650         struct vxlan_vni_node *vninode;
651         int ret;
652
653         vninode = rhashtable_lookup_fast(&vg->vni_hash, &vni,
654                                          vxlan_vni_rht_params);
655         if (!vninode)
656                 return 0;
657
658         ret = vxlan_vni_update_group(vxlan, vninode, group, false, changed,
659                                      extack);
660         if (ret)
661                 return ret;
662
663         if (changed)
664                 vxlan_vnifilter_notify(vxlan, vninode, RTM_NEWTUNNEL);
665
666         return 0;
667 }
668
669 static void __vxlan_vni_add_list(struct vxlan_vni_group *vg,
670                                  struct vxlan_vni_node *v)
671 {
672         struct list_head *headp, *hpos;
673         struct vxlan_vni_node *vent;
674
675         headp = &vg->vni_list;
676         list_for_each_prev(hpos, headp) {
677                 vent = list_entry(hpos, struct vxlan_vni_node, vlist);
678                 if (be32_to_cpu(v->vni) < be32_to_cpu(vent->vni))
679                         continue;
680                 else
681                         break;
682         }
683         list_add_rcu(&v->vlist, hpos);
684         vg->num_vnis++;
685 }
686
687 static void __vxlan_vni_del_list(struct vxlan_vni_group *vg,
688                                  struct vxlan_vni_node *v)
689 {
690         list_del_rcu(&v->vlist);
691         vg->num_vnis--;
692 }
693
694 static struct vxlan_vni_node *vxlan_vni_alloc(struct vxlan_dev *vxlan,
695                                               __be32 vni)
696 {
697         struct vxlan_vni_node *vninode;
698
699         vninode = kzalloc(sizeof(*vninode), GFP_KERNEL);
700         if (!vninode)
701                 return NULL;
702         vninode->stats = netdev_alloc_pcpu_stats(struct vxlan_vni_stats_pcpu);
703         if (!vninode->stats) {
704                 kfree(vninode);
705                 return NULL;
706         }
707         vninode->vni = vni;
708         vninode->hlist4.vxlan = vxlan;
709 #if IS_ENABLED(CONFIG_IPV6)
710         vninode->hlist6.vxlan = vxlan;
711 #endif
712
713         return vninode;
714 }
715
716 static void vxlan_vni_free(struct vxlan_vni_node *vninode)
717 {
718         free_percpu(vninode->stats);
719         kfree(vninode);
720 }
721
722 static int vxlan_vni_add(struct vxlan_dev *vxlan,
723                          struct vxlan_vni_group *vg,
724                          u32 vni, union vxlan_addr *group,
725                          struct netlink_ext_ack *extack)
726 {
727         struct vxlan_vni_node *vninode;
728         __be32 v = cpu_to_be32(vni);
729         bool changed = false;
730         int err = 0;
731
732         if (vxlan_vnifilter_lookup(vxlan, v))
733                 return vxlan_vni_update(vxlan, vg, v, group, &changed, extack);
734
735         err = vxlan_vni_in_use(vxlan->net, vxlan, &vxlan->cfg, v);
736         if (err) {
737                 NL_SET_ERR_MSG(extack, "VNI in use");
738                 return err;
739         }
740
741         vninode = vxlan_vni_alloc(vxlan, v);
742         if (!vninode)
743                 return -ENOMEM;
744
745         err = rhashtable_lookup_insert_fast(&vg->vni_hash,
746                                             &vninode->vnode,
747                                             vxlan_vni_rht_params);
748         if (err) {
749                 vxlan_vni_free(vninode);
750                 return err;
751         }
752
753         __vxlan_vni_add_list(vg, vninode);
754
755         if (vxlan->dev->flags & IFF_UP)
756                 vxlan_vs_add_del_vninode(vxlan, vninode, false);
757
758         err = vxlan_vni_update_group(vxlan, vninode, group, true, &changed,
759                                      extack);
760
761         if (changed)
762                 vxlan_vnifilter_notify(vxlan, vninode, RTM_NEWTUNNEL);
763
764         return err;
765 }
766
767 static void vxlan_vni_node_rcu_free(struct rcu_head *rcu)
768 {
769         struct vxlan_vni_node *v;
770
771         v = container_of(rcu, struct vxlan_vni_node, rcu);
772         vxlan_vni_free(v);
773 }
774
775 static int vxlan_vni_del(struct vxlan_dev *vxlan,
776                          struct vxlan_vni_group *vg,
777                          u32 vni, struct netlink_ext_ack *extack)
778 {
779         struct vxlan_vni_node *vninode;
780         __be32 v = cpu_to_be32(vni);
781         int err = 0;
782
783         vg = rtnl_dereference(vxlan->vnigrp);
784
785         vninode = rhashtable_lookup_fast(&vg->vni_hash, &v,
786                                          vxlan_vni_rht_params);
787         if (!vninode) {
788                 err = -ENOENT;
789                 goto out;
790         }
791
792         vxlan_vni_delete_group(vxlan, vninode);
793
794         err = rhashtable_remove_fast(&vg->vni_hash,
795                                      &vninode->vnode,
796                                      vxlan_vni_rht_params);
797         if (err)
798                 goto out;
799
800         __vxlan_vni_del_list(vg, vninode);
801
802         vxlan_vnifilter_notify(vxlan, vninode, RTM_DELTUNNEL);
803
804         if (vxlan->dev->flags & IFF_UP)
805                 vxlan_vs_add_del_vninode(vxlan, vninode, true);
806
807         call_rcu(&vninode->rcu, vxlan_vni_node_rcu_free);
808
809         return 0;
810 out:
811         return err;
812 }
813
814 static int vxlan_vni_add_del(struct vxlan_dev *vxlan, __u32 start_vni,
815                              __u32 end_vni, union vxlan_addr *group,
816                              int cmd, struct netlink_ext_ack *extack)
817 {
818         struct vxlan_vni_group *vg;
819         int v, err = 0;
820
821         vg = rtnl_dereference(vxlan->vnigrp);
822
823         for (v = start_vni; v <= end_vni; v++) {
824                 switch (cmd) {
825                 case RTM_NEWTUNNEL:
826                         err = vxlan_vni_add(vxlan, vg, v, group, extack);
827                         break;
828                 case RTM_DELTUNNEL:
829                         err = vxlan_vni_del(vxlan, vg, v, extack);
830                         break;
831                 default:
832                         err = -EOPNOTSUPP;
833                         break;
834                 }
835                 if (err)
836                         goto out;
837         }
838
839         return 0;
840 out:
841         return err;
842 }
843
844 static int vxlan_process_vni_filter(struct vxlan_dev *vxlan,
845                                     struct nlattr *nlvnifilter,
846                                     int cmd, struct netlink_ext_ack *extack)
847 {
848         struct nlattr *vattrs[VXLAN_VNIFILTER_ENTRY_MAX + 1];
849         u32 vni_start = 0, vni_end = 0;
850         union vxlan_addr group;
851         int err;
852
853         err = nla_parse_nested(vattrs,
854                                VXLAN_VNIFILTER_ENTRY_MAX,
855                                nlvnifilter, vni_filter_entry_policy,
856                                extack);
857         if (err)
858                 return err;
859
860         if (vattrs[VXLAN_VNIFILTER_ENTRY_START]) {
861                 vni_start = nla_get_u32(vattrs[VXLAN_VNIFILTER_ENTRY_START]);
862                 vni_end = vni_start;
863         }
864
865         if (vattrs[VXLAN_VNIFILTER_ENTRY_END])
866                 vni_end = nla_get_u32(vattrs[VXLAN_VNIFILTER_ENTRY_END]);
867
868         if (!vni_start && !vni_end) {
869                 NL_SET_ERR_MSG_ATTR(extack, nlvnifilter,
870                                     "vni start nor end found in vni entry");
871                 return -EINVAL;
872         }
873
874         if (vattrs[VXLAN_VNIFILTER_ENTRY_GROUP]) {
875                 group.sin.sin_addr.s_addr =
876                         nla_get_in_addr(vattrs[VXLAN_VNIFILTER_ENTRY_GROUP]);
877                 group.sa.sa_family = AF_INET;
878         } else if (vattrs[VXLAN_VNIFILTER_ENTRY_GROUP6]) {
879                 group.sin6.sin6_addr =
880                         nla_get_in6_addr(vattrs[VXLAN_VNIFILTER_ENTRY_GROUP6]);
881                 group.sa.sa_family = AF_INET6;
882         } else {
883                 memset(&group, 0, sizeof(group));
884         }
885
886         if (vxlan_addr_multicast(&group) && !vxlan->default_dst.remote_ifindex) {
887                 NL_SET_ERR_MSG(extack,
888                                "Local interface required for multicast remote group");
889
890                 return -EINVAL;
891         }
892
893         err = vxlan_vni_add_del(vxlan, vni_start, vni_end, &group, cmd,
894                                 extack);
895         if (err)
896                 return err;
897
898         return 0;
899 }
900
901 void vxlan_vnigroup_uninit(struct vxlan_dev *vxlan)
902 {
903         struct vxlan_vni_node *v, *tmp;
904         struct vxlan_vni_group *vg;
905
906         vg = rtnl_dereference(vxlan->vnigrp);
907         list_for_each_entry_safe(v, tmp, &vg->vni_list, vlist) {
908                 rhashtable_remove_fast(&vg->vni_hash, &v->vnode,
909                                        vxlan_vni_rht_params);
910                 hlist_del_init_rcu(&v->hlist4.hlist);
911 #if IS_ENABLED(CONFIG_IPV6)
912                 hlist_del_init_rcu(&v->hlist6.hlist);
913 #endif
914                 __vxlan_vni_del_list(vg, v);
915                 vxlan_vnifilter_notify(vxlan, v, RTM_DELTUNNEL);
916                 call_rcu(&v->rcu, vxlan_vni_node_rcu_free);
917         }
918         rhashtable_destroy(&vg->vni_hash);
919         kfree(vg);
920 }
921
922 int vxlan_vnigroup_init(struct vxlan_dev *vxlan)
923 {
924         struct vxlan_vni_group *vg;
925         int ret;
926
927         vg = kzalloc(sizeof(*vg), GFP_KERNEL);
928         if (!vg)
929                 return -ENOMEM;
930         ret = rhashtable_init(&vg->vni_hash, &vxlan_vni_rht_params);
931         if (ret) {
932                 kfree(vg);
933                 return ret;
934         }
935         INIT_LIST_HEAD(&vg->vni_list);
936         rcu_assign_pointer(vxlan->vnigrp, vg);
937
938         return 0;
939 }
940
941 static int vxlan_vnifilter_process(struct sk_buff *skb, struct nlmsghdr *nlh,
942                                    struct netlink_ext_ack *extack)
943 {
944         struct net *net = sock_net(skb->sk);
945         struct tunnel_msg *tmsg;
946         struct vxlan_dev *vxlan;
947         struct net_device *dev;
948         struct nlattr *attr;
949         int err, vnis = 0;
950         int rem;
951
952         /* this should validate the header and check for remaining bytes */
953         err = nlmsg_parse(nlh, sizeof(*tmsg), NULL, VXLAN_VNIFILTER_MAX,
954                           vni_filter_policy, extack);
955         if (err < 0)
956                 return err;
957
958         tmsg = nlmsg_data(nlh);
959         dev = __dev_get_by_index(net, tmsg->ifindex);
960         if (!dev)
961                 return -ENODEV;
962
963         if (!netif_is_vxlan(dev)) {
964                 NL_SET_ERR_MSG_MOD(extack, "The device is not a vxlan device");
965                 return -EINVAL;
966         }
967
968         vxlan = netdev_priv(dev);
969
970         if (!(vxlan->cfg.flags & VXLAN_F_VNIFILTER))
971                 return -EOPNOTSUPP;
972
973         nlmsg_for_each_attr(attr, nlh, sizeof(*tmsg), rem) {
974                 switch (nla_type(attr)) {
975                 case VXLAN_VNIFILTER_ENTRY:
976                         err = vxlan_process_vni_filter(vxlan, attr,
977                                                        nlh->nlmsg_type, extack);
978                         break;
979                 default:
980                         continue;
981                 }
982                 vnis++;
983                 if (err)
984                         break;
985         }
986
987         if (!vnis) {
988                 NL_SET_ERR_MSG_MOD(extack, "No vnis found to process");
989                 err = -EINVAL;
990         }
991
992         return err;
993 }
994
995 void vxlan_vnifilter_init(void)
996 {
997         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_GETTUNNEL, NULL,
998                              vxlan_vnifilter_dump, 0);
999         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_NEWTUNNEL,
1000                              vxlan_vnifilter_process, NULL, 0);
1001         rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_DELTUNNEL,
1002                              vxlan_vnifilter_process, NULL, 0);
1003 }
1004
1005 void vxlan_vnifilter_uninit(void)
1006 {
1007         rtnl_unregister(PF_BRIDGE, RTM_GETTUNNEL);
1008         rtnl_unregister(PF_BRIDGE, RTM_NEWTUNNEL);
1009         rtnl_unregister(PF_BRIDGE, RTM_DELTUNNEL);
1010 }