GNU Linux-libre 6.8.7-gnu
[releases.git] / drivers / net / vxlan / vxlan_mdb.c
1 // SPDX-License-Identifier: GPL-2.0-only
2
3 #include <linux/if_bridge.h>
4 #include <linux/in.h>
5 #include <linux/list.h>
6 #include <linux/netdevice.h>
7 #include <linux/netlink.h>
8 #include <linux/rhashtable.h>
9 #include <linux/rhashtable-types.h>
10 #include <linux/rtnetlink.h>
11 #include <linux/skbuff.h>
12 #include <linux/types.h>
13 #include <net/netlink.h>
14 #include <net/vxlan.h>
15
16 #include "vxlan_private.h"
17
18 struct vxlan_mdb_entry_key {
19         union vxlan_addr src;
20         union vxlan_addr dst;
21         __be32 vni;
22 };
23
24 struct vxlan_mdb_entry {
25         struct rhash_head rhnode;
26         struct list_head remotes;
27         struct vxlan_mdb_entry_key key;
28         struct hlist_node mdb_node;
29         struct rcu_head rcu;
30 };
31
32 #define VXLAN_MDB_REMOTE_F_BLOCKED      BIT(0)
33
34 struct vxlan_mdb_remote {
35         struct list_head list;
36         struct vxlan_rdst __rcu *rd;
37         u8 flags;
38         u8 filter_mode;
39         u8 rt_protocol;
40         struct hlist_head src_list;
41         struct rcu_head rcu;
42 };
43
44 #define VXLAN_SGRP_F_DELETE     BIT(0)
45
46 struct vxlan_mdb_src_entry {
47         struct hlist_node node;
48         union vxlan_addr addr;
49         u8 flags;
50 };
51
52 struct vxlan_mdb_dump_ctx {
53         long reserved;
54         long entry_idx;
55         long remote_idx;
56 };
57
58 struct vxlan_mdb_config_src_entry {
59         union vxlan_addr addr;
60         struct list_head node;
61 };
62
63 struct vxlan_mdb_config {
64         struct vxlan_dev *vxlan;
65         struct vxlan_mdb_entry_key group;
66         struct list_head src_list;
67         union vxlan_addr remote_ip;
68         u32 remote_ifindex;
69         __be32 remote_vni;
70         __be16 remote_port;
71         u16 nlflags;
72         u8 flags;
73         u8 filter_mode;
74         u8 rt_protocol;
75 };
76
77 struct vxlan_mdb_flush_desc {
78         union vxlan_addr remote_ip;
79         __be32 src_vni;
80         __be32 remote_vni;
81         __be16 remote_port;
82         u8 rt_protocol;
83 };
84
85 static const struct rhashtable_params vxlan_mdb_rht_params = {
86         .head_offset = offsetof(struct vxlan_mdb_entry, rhnode),
87         .key_offset = offsetof(struct vxlan_mdb_entry, key),
88         .key_len = sizeof(struct vxlan_mdb_entry_key),
89         .automatic_shrinking = true,
90 };
91
92 static int __vxlan_mdb_add(const struct vxlan_mdb_config *cfg,
93                            struct netlink_ext_ack *extack);
94 static int __vxlan_mdb_del(const struct vxlan_mdb_config *cfg,
95                            struct netlink_ext_ack *extack);
96
97 static void vxlan_br_mdb_entry_fill(const struct vxlan_dev *vxlan,
98                                     const struct vxlan_mdb_entry *mdb_entry,
99                                     const struct vxlan_mdb_remote *remote,
100                                     struct br_mdb_entry *e)
101 {
102         const union vxlan_addr *dst = &mdb_entry->key.dst;
103
104         memset(e, 0, sizeof(*e));
105         e->ifindex = vxlan->dev->ifindex;
106         e->state = MDB_PERMANENT;
107
108         if (remote->flags & VXLAN_MDB_REMOTE_F_BLOCKED)
109                 e->flags |= MDB_FLAGS_BLOCKED;
110
111         switch (dst->sa.sa_family) {
112         case AF_INET:
113                 e->addr.u.ip4 = dst->sin.sin_addr.s_addr;
114                 e->addr.proto = htons(ETH_P_IP);
115                 break;
116 #if IS_ENABLED(CONFIG_IPV6)
117         case AF_INET6:
118                 e->addr.u.ip6 = dst->sin6.sin6_addr;
119                 e->addr.proto = htons(ETH_P_IPV6);
120                 break;
121 #endif
122         }
123 }
124
125 static int vxlan_mdb_entry_info_fill_srcs(struct sk_buff *skb,
126                                           const struct vxlan_mdb_remote *remote)
127 {
128         struct vxlan_mdb_src_entry *ent;
129         struct nlattr *nest;
130
131         if (hlist_empty(&remote->src_list))
132                 return 0;
133
134         nest = nla_nest_start(skb, MDBA_MDB_EATTR_SRC_LIST);
135         if (!nest)
136                 return -EMSGSIZE;
137
138         hlist_for_each_entry(ent, &remote->src_list, node) {
139                 struct nlattr *nest_ent;
140
141                 nest_ent = nla_nest_start(skb, MDBA_MDB_SRCLIST_ENTRY);
142                 if (!nest_ent)
143                         goto out_cancel_err;
144
145                 if (vxlan_nla_put_addr(skb, MDBA_MDB_SRCATTR_ADDRESS,
146                                        &ent->addr) ||
147                     nla_put_u32(skb, MDBA_MDB_SRCATTR_TIMER, 0))
148                         goto out_cancel_err;
149
150                 nla_nest_end(skb, nest_ent);
151         }
152
153         nla_nest_end(skb, nest);
154
155         return 0;
156
157 out_cancel_err:
158         nla_nest_cancel(skb, nest);
159         return -EMSGSIZE;
160 }
161
162 static int vxlan_mdb_entry_info_fill(const struct vxlan_dev *vxlan,
163                                      struct sk_buff *skb,
164                                      const struct vxlan_mdb_entry *mdb_entry,
165                                      const struct vxlan_mdb_remote *remote)
166 {
167         struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
168         struct br_mdb_entry e;
169         struct nlattr *nest;
170
171         nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY_INFO);
172         if (!nest)
173                 return -EMSGSIZE;
174
175         vxlan_br_mdb_entry_fill(vxlan, mdb_entry, remote, &e);
176
177         if (nla_put_nohdr(skb, sizeof(e), &e) ||
178             nla_put_u32(skb, MDBA_MDB_EATTR_TIMER, 0))
179                 goto nest_err;
180
181         if (!vxlan_addr_any(&mdb_entry->key.src) &&
182             vxlan_nla_put_addr(skb, MDBA_MDB_EATTR_SOURCE, &mdb_entry->key.src))
183                 goto nest_err;
184
185         if (nla_put_u8(skb, MDBA_MDB_EATTR_RTPROT, remote->rt_protocol) ||
186             nla_put_u8(skb, MDBA_MDB_EATTR_GROUP_MODE, remote->filter_mode) ||
187             vxlan_mdb_entry_info_fill_srcs(skb, remote) ||
188             vxlan_nla_put_addr(skb, MDBA_MDB_EATTR_DST, &rd->remote_ip))
189                 goto nest_err;
190
191         if (rd->remote_port && rd->remote_port != vxlan->cfg.dst_port &&
192             nla_put_u16(skb, MDBA_MDB_EATTR_DST_PORT,
193                         be16_to_cpu(rd->remote_port)))
194                 goto nest_err;
195
196         if (rd->remote_vni != vxlan->default_dst.remote_vni &&
197             nla_put_u32(skb, MDBA_MDB_EATTR_VNI, be32_to_cpu(rd->remote_vni)))
198                 goto nest_err;
199
200         if (rd->remote_ifindex &&
201             nla_put_u32(skb, MDBA_MDB_EATTR_IFINDEX, rd->remote_ifindex))
202                 goto nest_err;
203
204         if ((vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA) &&
205             mdb_entry->key.vni && nla_put_u32(skb, MDBA_MDB_EATTR_SRC_VNI,
206                                               be32_to_cpu(mdb_entry->key.vni)))
207                 goto nest_err;
208
209         nla_nest_end(skb, nest);
210
211         return 0;
212
213 nest_err:
214         nla_nest_cancel(skb, nest);
215         return -EMSGSIZE;
216 }
217
218 static int vxlan_mdb_entry_fill(const struct vxlan_dev *vxlan,
219                                 struct sk_buff *skb,
220                                 struct vxlan_mdb_dump_ctx *ctx,
221                                 const struct vxlan_mdb_entry *mdb_entry)
222 {
223         int remote_idx = 0, s_remote_idx = ctx->remote_idx;
224         struct vxlan_mdb_remote *remote;
225         struct nlattr *nest;
226         int err = 0;
227
228         nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
229         if (!nest)
230                 return -EMSGSIZE;
231
232         list_for_each_entry(remote, &mdb_entry->remotes, list) {
233                 if (remote_idx < s_remote_idx)
234                         goto skip;
235
236                 err = vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote);
237                 if (err)
238                         break;
239 skip:
240                 remote_idx++;
241         }
242
243         ctx->remote_idx = err ? remote_idx : 0;
244         nla_nest_end(skb, nest);
245         return err;
246 }
247
248 static int vxlan_mdb_fill(const struct vxlan_dev *vxlan, struct sk_buff *skb,
249                           struct vxlan_mdb_dump_ctx *ctx)
250 {
251         int entry_idx = 0, s_entry_idx = ctx->entry_idx;
252         struct vxlan_mdb_entry *mdb_entry;
253         struct nlattr *nest;
254         int err = 0;
255
256         nest = nla_nest_start_noflag(skb, MDBA_MDB);
257         if (!nest)
258                 return -EMSGSIZE;
259
260         hlist_for_each_entry(mdb_entry, &vxlan->mdb_list, mdb_node) {
261                 if (entry_idx < s_entry_idx)
262                         goto skip;
263
264                 err = vxlan_mdb_entry_fill(vxlan, skb, ctx, mdb_entry);
265                 if (err)
266                         break;
267 skip:
268                 entry_idx++;
269         }
270
271         ctx->entry_idx = err ? entry_idx : 0;
272         nla_nest_end(skb, nest);
273         return err;
274 }
275
276 int vxlan_mdb_dump(struct net_device *dev, struct sk_buff *skb,
277                    struct netlink_callback *cb)
278 {
279         struct vxlan_mdb_dump_ctx *ctx = (void *)cb->ctx;
280         struct vxlan_dev *vxlan = netdev_priv(dev);
281         struct br_port_msg *bpm;
282         struct nlmsghdr *nlh;
283         int err;
284
285         ASSERT_RTNL();
286
287         NL_ASSERT_DUMP_CTX_FITS(struct vxlan_mdb_dump_ctx);
288
289         nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
290                         cb->nlh->nlmsg_seq, RTM_NEWMDB, sizeof(*bpm),
291                         NLM_F_MULTI);
292         if (!nlh)
293                 return -EMSGSIZE;
294
295         bpm = nlmsg_data(nlh);
296         memset(bpm, 0, sizeof(*bpm));
297         bpm->family = AF_BRIDGE;
298         bpm->ifindex = dev->ifindex;
299
300         err = vxlan_mdb_fill(vxlan, skb, ctx);
301
302         nlmsg_end(skb, nlh);
303
304         cb->seq = vxlan->mdb_seq;
305         nl_dump_check_consistent(cb, nlh);
306
307         return err;
308 }
309
310 static const struct nla_policy
311 vxlan_mdbe_src_list_entry_pol[MDBE_SRCATTR_MAX + 1] = {
312         [MDBE_SRCATTR_ADDRESS] = NLA_POLICY_RANGE(NLA_BINARY,
313                                                   sizeof(struct in_addr),
314                                                   sizeof(struct in6_addr)),
315 };
316
317 static const struct nla_policy
318 vxlan_mdbe_src_list_pol[MDBE_SRC_LIST_MAX + 1] = {
319         [MDBE_SRC_LIST_ENTRY] = NLA_POLICY_NESTED(vxlan_mdbe_src_list_entry_pol),
320 };
321
322 static const struct netlink_range_validation vni_range = {
323         .max = VXLAN_N_VID - 1,
324 };
325
326 static const struct nla_policy vxlan_mdbe_attrs_pol[MDBE_ATTR_MAX + 1] = {
327         [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY,
328                                               sizeof(struct in_addr),
329                                               sizeof(struct in6_addr)),
330         [MDBE_ATTR_GROUP_MODE] = NLA_POLICY_RANGE(NLA_U8, MCAST_EXCLUDE,
331                                                   MCAST_INCLUDE),
332         [MDBE_ATTR_SRC_LIST] = NLA_POLICY_NESTED(vxlan_mdbe_src_list_pol),
333         [MDBE_ATTR_RTPROT] = NLA_POLICY_MIN(NLA_U8, RTPROT_STATIC),
334         [MDBE_ATTR_DST] = NLA_POLICY_RANGE(NLA_BINARY,
335                                            sizeof(struct in_addr),
336                                            sizeof(struct in6_addr)),
337         [MDBE_ATTR_DST_PORT] = { .type = NLA_U16 },
338         [MDBE_ATTR_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
339         [MDBE_ATTR_IFINDEX] = NLA_POLICY_MIN(NLA_S32, 1),
340         [MDBE_ATTR_SRC_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
341 };
342
343 static bool vxlan_mdb_is_valid_source(const struct nlattr *attr, __be16 proto,
344                                       struct netlink_ext_ack *extack)
345 {
346         switch (proto) {
347         case htons(ETH_P_IP):
348                 if (nla_len(attr) != sizeof(struct in_addr)) {
349                         NL_SET_ERR_MSG_MOD(extack, "IPv4 invalid source address length");
350                         return false;
351                 }
352                 if (ipv4_is_multicast(nla_get_in_addr(attr))) {
353                         NL_SET_ERR_MSG_MOD(extack, "IPv4 multicast source address is not allowed");
354                         return false;
355                 }
356                 break;
357 #if IS_ENABLED(CONFIG_IPV6)
358         case htons(ETH_P_IPV6): {
359                 struct in6_addr src;
360
361                 if (nla_len(attr) != sizeof(struct in6_addr)) {
362                         NL_SET_ERR_MSG_MOD(extack, "IPv6 invalid source address length");
363                         return false;
364                 }
365                 src = nla_get_in6_addr(attr);
366                 if (ipv6_addr_is_multicast(&src)) {
367                         NL_SET_ERR_MSG_MOD(extack, "IPv6 multicast source address is not allowed");
368                         return false;
369                 }
370                 break;
371         }
372 #endif
373         default:
374                 NL_SET_ERR_MSG_MOD(extack, "Invalid protocol used with source address");
375                 return false;
376         }
377
378         return true;
379 }
380
381 static void vxlan_mdb_group_set(struct vxlan_mdb_entry_key *group,
382                                 const struct br_mdb_entry *entry,
383                                 const struct nlattr *source_attr)
384 {
385         switch (entry->addr.proto) {
386         case htons(ETH_P_IP):
387                 group->dst.sa.sa_family = AF_INET;
388                 group->dst.sin.sin_addr.s_addr = entry->addr.u.ip4;
389                 break;
390 #if IS_ENABLED(CONFIG_IPV6)
391         case htons(ETH_P_IPV6):
392                 group->dst.sa.sa_family = AF_INET6;
393                 group->dst.sin6.sin6_addr = entry->addr.u.ip6;
394                 break;
395 #endif
396         }
397
398         if (source_attr)
399                 vxlan_nla_get_addr(&group->src, source_attr);
400 }
401
402 static bool vxlan_mdb_is_star_g(const struct vxlan_mdb_entry_key *group)
403 {
404         return !vxlan_addr_any(&group->dst) && vxlan_addr_any(&group->src);
405 }
406
407 static bool vxlan_mdb_is_sg(const struct vxlan_mdb_entry_key *group)
408 {
409         return !vxlan_addr_any(&group->dst) && !vxlan_addr_any(&group->src);
410 }
411
412 static int vxlan_mdb_config_src_entry_init(struct vxlan_mdb_config *cfg,
413                                            __be16 proto,
414                                            const struct nlattr *src_entry,
415                                            struct netlink_ext_ack *extack)
416 {
417         struct nlattr *tb[MDBE_SRCATTR_MAX + 1];
418         struct vxlan_mdb_config_src_entry *src;
419         int err;
420
421         err = nla_parse_nested(tb, MDBE_SRCATTR_MAX, src_entry,
422                                vxlan_mdbe_src_list_entry_pol, extack);
423         if (err)
424                 return err;
425
426         if (NL_REQ_ATTR_CHECK(extack, src_entry, tb, MDBE_SRCATTR_ADDRESS))
427                 return -EINVAL;
428
429         if (!vxlan_mdb_is_valid_source(tb[MDBE_SRCATTR_ADDRESS], proto,
430                                        extack))
431                 return -EINVAL;
432
433         src = kzalloc(sizeof(*src), GFP_KERNEL);
434         if (!src)
435                 return -ENOMEM;
436
437         err = vxlan_nla_get_addr(&src->addr, tb[MDBE_SRCATTR_ADDRESS]);
438         if (err)
439                 goto err_free_src;
440
441         list_add_tail(&src->node, &cfg->src_list);
442
443         return 0;
444
445 err_free_src:
446         kfree(src);
447         return err;
448 }
449
450 static void
451 vxlan_mdb_config_src_entry_fini(struct vxlan_mdb_config_src_entry *src)
452 {
453         list_del(&src->node);
454         kfree(src);
455 }
456
457 static int vxlan_mdb_config_src_list_init(struct vxlan_mdb_config *cfg,
458                                           __be16 proto,
459                                           const struct nlattr *src_list,
460                                           struct netlink_ext_ack *extack)
461 {
462         struct vxlan_mdb_config_src_entry *src, *tmp;
463         struct nlattr *src_entry;
464         int rem, err;
465
466         nla_for_each_nested(src_entry, src_list, rem) {
467                 err = vxlan_mdb_config_src_entry_init(cfg, proto, src_entry,
468                                                       extack);
469                 if (err)
470                         goto err_src_entry_init;
471         }
472
473         return 0;
474
475 err_src_entry_init:
476         list_for_each_entry_safe_reverse(src, tmp, &cfg->src_list, node)
477                 vxlan_mdb_config_src_entry_fini(src);
478         return err;
479 }
480
481 static void vxlan_mdb_config_src_list_fini(struct vxlan_mdb_config *cfg)
482 {
483         struct vxlan_mdb_config_src_entry *src, *tmp;
484
485         list_for_each_entry_safe_reverse(src, tmp, &cfg->src_list, node)
486                 vxlan_mdb_config_src_entry_fini(src);
487 }
488
489 static int vxlan_mdb_config_attrs_init(struct vxlan_mdb_config *cfg,
490                                        const struct br_mdb_entry *entry,
491                                        const struct nlattr *set_attrs,
492                                        struct netlink_ext_ack *extack)
493 {
494         struct nlattr *mdbe_attrs[MDBE_ATTR_MAX + 1];
495         int err;
496
497         err = nla_parse_nested(mdbe_attrs, MDBE_ATTR_MAX, set_attrs,
498                                vxlan_mdbe_attrs_pol, extack);
499         if (err)
500                 return err;
501
502         if (NL_REQ_ATTR_CHECK(extack, set_attrs, mdbe_attrs, MDBE_ATTR_DST)) {
503                 NL_SET_ERR_MSG_MOD(extack, "Missing remote destination IP address");
504                 return -EINVAL;
505         }
506
507         if (mdbe_attrs[MDBE_ATTR_SOURCE] &&
508             !vxlan_mdb_is_valid_source(mdbe_attrs[MDBE_ATTR_SOURCE],
509                                        entry->addr.proto, extack))
510                 return -EINVAL;
511
512         vxlan_mdb_group_set(&cfg->group, entry, mdbe_attrs[MDBE_ATTR_SOURCE]);
513
514         /* rtnetlink code only validates that IPv4 group address is
515          * multicast.
516          */
517         if (!vxlan_addr_is_multicast(&cfg->group.dst) &&
518             !vxlan_addr_any(&cfg->group.dst)) {
519                 NL_SET_ERR_MSG_MOD(extack, "Group address is not multicast");
520                 return -EINVAL;
521         }
522
523         if (vxlan_addr_any(&cfg->group.dst) &&
524             mdbe_attrs[MDBE_ATTR_SOURCE]) {
525                 NL_SET_ERR_MSG_MOD(extack, "Source cannot be specified for the all-zeros entry");
526                 return -EINVAL;
527         }
528
529         if (vxlan_mdb_is_sg(&cfg->group))
530                 cfg->filter_mode = MCAST_INCLUDE;
531
532         if (mdbe_attrs[MDBE_ATTR_GROUP_MODE]) {
533                 if (!vxlan_mdb_is_star_g(&cfg->group)) {
534                         NL_SET_ERR_MSG_MOD(extack, "Filter mode can only be set for (*, G) entries");
535                         return -EINVAL;
536                 }
537                 cfg->filter_mode = nla_get_u8(mdbe_attrs[MDBE_ATTR_GROUP_MODE]);
538         }
539
540         if (mdbe_attrs[MDBE_ATTR_SRC_LIST]) {
541                 if (!vxlan_mdb_is_star_g(&cfg->group)) {
542                         NL_SET_ERR_MSG_MOD(extack, "Source list can only be set for (*, G) entries");
543                         return -EINVAL;
544                 }
545                 if (!mdbe_attrs[MDBE_ATTR_GROUP_MODE]) {
546                         NL_SET_ERR_MSG_MOD(extack, "Source list cannot be set without filter mode");
547                         return -EINVAL;
548                 }
549                 err = vxlan_mdb_config_src_list_init(cfg, entry->addr.proto,
550                                                      mdbe_attrs[MDBE_ATTR_SRC_LIST],
551                                                      extack);
552                 if (err)
553                         return err;
554         }
555
556         if (vxlan_mdb_is_star_g(&cfg->group) && list_empty(&cfg->src_list) &&
557             cfg->filter_mode == MCAST_INCLUDE) {
558                 NL_SET_ERR_MSG_MOD(extack, "Cannot add (*, G) INCLUDE with an empty source list");
559                 return -EINVAL;
560         }
561
562         if (mdbe_attrs[MDBE_ATTR_RTPROT])
563                 cfg->rt_protocol = nla_get_u8(mdbe_attrs[MDBE_ATTR_RTPROT]);
564
565         err = vxlan_nla_get_addr(&cfg->remote_ip, mdbe_attrs[MDBE_ATTR_DST]);
566         if (err) {
567                 NL_SET_ERR_MSG_MOD(extack, "Invalid remote destination address");
568                 goto err_src_list_fini;
569         }
570
571         if (mdbe_attrs[MDBE_ATTR_DST_PORT])
572                 cfg->remote_port =
573                         cpu_to_be16(nla_get_u16(mdbe_attrs[MDBE_ATTR_DST_PORT]));
574
575         if (mdbe_attrs[MDBE_ATTR_VNI])
576                 cfg->remote_vni =
577                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_VNI]));
578
579         if (mdbe_attrs[MDBE_ATTR_IFINDEX]) {
580                 cfg->remote_ifindex =
581                         nla_get_s32(mdbe_attrs[MDBE_ATTR_IFINDEX]);
582                 if (!__dev_get_by_index(cfg->vxlan->net, cfg->remote_ifindex)) {
583                         NL_SET_ERR_MSG_MOD(extack, "Outgoing interface not found");
584                         err = -EINVAL;
585                         goto err_src_list_fini;
586                 }
587         }
588
589         if (mdbe_attrs[MDBE_ATTR_SRC_VNI])
590                 cfg->group.vni =
591                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_SRC_VNI]));
592
593         return 0;
594
595 err_src_list_fini:
596         vxlan_mdb_config_src_list_fini(cfg);
597         return err;
598 }
599
600 static int vxlan_mdb_config_init(struct vxlan_mdb_config *cfg,
601                                  struct net_device *dev, struct nlattr *tb[],
602                                  u16 nlmsg_flags,
603                                  struct netlink_ext_ack *extack)
604 {
605         struct br_mdb_entry *entry = nla_data(tb[MDBA_SET_ENTRY]);
606         struct vxlan_dev *vxlan = netdev_priv(dev);
607
608         memset(cfg, 0, sizeof(*cfg));
609         cfg->vxlan = vxlan;
610         cfg->group.vni = vxlan->default_dst.remote_vni;
611         INIT_LIST_HEAD(&cfg->src_list);
612         cfg->nlflags = nlmsg_flags;
613         cfg->filter_mode = MCAST_EXCLUDE;
614         cfg->rt_protocol = RTPROT_STATIC;
615         cfg->remote_vni = vxlan->default_dst.remote_vni;
616         cfg->remote_port = vxlan->cfg.dst_port;
617
618         if (entry->ifindex != dev->ifindex) {
619                 NL_SET_ERR_MSG_MOD(extack, "Port net device must be the VXLAN net device");
620                 return -EINVAL;
621         }
622
623         /* State is not part of the entry key and can be ignored on deletion
624          * requests.
625          */
626         if ((nlmsg_flags & (NLM_F_CREATE | NLM_F_REPLACE)) &&
627             entry->state != MDB_PERMANENT) {
628                 NL_SET_ERR_MSG_MOD(extack, "MDB entry must be permanent");
629                 return -EINVAL;
630         }
631
632         if (entry->flags) {
633                 NL_SET_ERR_MSG_MOD(extack, "Invalid MDB entry flags");
634                 return -EINVAL;
635         }
636
637         if (entry->vid) {
638                 NL_SET_ERR_MSG_MOD(extack, "VID must not be specified");
639                 return -EINVAL;
640         }
641
642         if (entry->addr.proto != htons(ETH_P_IP) &&
643             entry->addr.proto != htons(ETH_P_IPV6)) {
644                 NL_SET_ERR_MSG_MOD(extack, "Group address must be an IPv4 / IPv6 address");
645                 return -EINVAL;
646         }
647
648         if (NL_REQ_ATTR_CHECK(extack, NULL, tb, MDBA_SET_ENTRY_ATTRS)) {
649                 NL_SET_ERR_MSG_MOD(extack, "Missing MDBA_SET_ENTRY_ATTRS attribute");
650                 return -EINVAL;
651         }
652
653         return vxlan_mdb_config_attrs_init(cfg, entry, tb[MDBA_SET_ENTRY_ATTRS],
654                                            extack);
655 }
656
657 static void vxlan_mdb_config_fini(struct vxlan_mdb_config *cfg)
658 {
659         vxlan_mdb_config_src_list_fini(cfg);
660 }
661
662 static struct vxlan_mdb_entry *
663 vxlan_mdb_entry_lookup(struct vxlan_dev *vxlan,
664                        const struct vxlan_mdb_entry_key *group)
665 {
666         return rhashtable_lookup_fast(&vxlan->mdb_tbl, group,
667                                       vxlan_mdb_rht_params);
668 }
669
670 static struct vxlan_mdb_remote *
671 vxlan_mdb_remote_lookup(const struct vxlan_mdb_entry *mdb_entry,
672                         const union vxlan_addr *addr)
673 {
674         struct vxlan_mdb_remote *remote;
675
676         list_for_each_entry(remote, &mdb_entry->remotes, list) {
677                 struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
678
679                 if (vxlan_addr_equal(addr, &rd->remote_ip))
680                         return remote;
681         }
682
683         return NULL;
684 }
685
686 static void vxlan_mdb_rdst_free(struct rcu_head *head)
687 {
688         struct vxlan_rdst *rd = container_of(head, struct vxlan_rdst, rcu);
689
690         dst_cache_destroy(&rd->dst_cache);
691         kfree(rd);
692 }
693
694 static int vxlan_mdb_remote_rdst_init(const struct vxlan_mdb_config *cfg,
695                                       struct vxlan_mdb_remote *remote)
696 {
697         struct vxlan_rdst *rd;
698         int err;
699
700         rd = kzalloc(sizeof(*rd), GFP_KERNEL);
701         if (!rd)
702                 return -ENOMEM;
703
704         err = dst_cache_init(&rd->dst_cache, GFP_KERNEL);
705         if (err)
706                 goto err_free_rdst;
707
708         rd->remote_ip = cfg->remote_ip;
709         rd->remote_port = cfg->remote_port;
710         rd->remote_vni = cfg->remote_vni;
711         rd->remote_ifindex = cfg->remote_ifindex;
712         rcu_assign_pointer(remote->rd, rd);
713
714         return 0;
715
716 err_free_rdst:
717         kfree(rd);
718         return err;
719 }
720
721 static void vxlan_mdb_remote_rdst_fini(struct vxlan_rdst *rd)
722 {
723         call_rcu(&rd->rcu, vxlan_mdb_rdst_free);
724 }
725
726 static int vxlan_mdb_remote_init(const struct vxlan_mdb_config *cfg,
727                                  struct vxlan_mdb_remote *remote)
728 {
729         int err;
730
731         err = vxlan_mdb_remote_rdst_init(cfg, remote);
732         if (err)
733                 return err;
734
735         remote->flags = cfg->flags;
736         remote->filter_mode = cfg->filter_mode;
737         remote->rt_protocol = cfg->rt_protocol;
738         INIT_HLIST_HEAD(&remote->src_list);
739
740         return 0;
741 }
742
743 static void vxlan_mdb_remote_fini(struct vxlan_dev *vxlan,
744                                   struct vxlan_mdb_remote *remote)
745 {
746         WARN_ON_ONCE(!hlist_empty(&remote->src_list));
747         vxlan_mdb_remote_rdst_fini(rtnl_dereference(remote->rd));
748 }
749
750 static struct vxlan_mdb_src_entry *
751 vxlan_mdb_remote_src_entry_lookup(const struct vxlan_mdb_remote *remote,
752                                   const union vxlan_addr *addr)
753 {
754         struct vxlan_mdb_src_entry *ent;
755
756         hlist_for_each_entry(ent, &remote->src_list, node) {
757                 if (vxlan_addr_equal(&ent->addr, addr))
758                         return ent;
759         }
760
761         return NULL;
762 }
763
764 static struct vxlan_mdb_src_entry *
765 vxlan_mdb_remote_src_entry_add(struct vxlan_mdb_remote *remote,
766                                const union vxlan_addr *addr)
767 {
768         struct vxlan_mdb_src_entry *ent;
769
770         ent = kzalloc(sizeof(*ent), GFP_KERNEL);
771         if (!ent)
772                 return NULL;
773
774         ent->addr = *addr;
775         hlist_add_head(&ent->node, &remote->src_list);
776
777         return ent;
778 }
779
780 static void
781 vxlan_mdb_remote_src_entry_del(struct vxlan_mdb_src_entry *ent)
782 {
783         hlist_del(&ent->node);
784         kfree(ent);
785 }
786
787 static int
788 vxlan_mdb_remote_src_fwd_add(const struct vxlan_mdb_config *cfg,
789                              const union vxlan_addr *addr,
790                              struct netlink_ext_ack *extack)
791 {
792         struct vxlan_mdb_config sg_cfg;
793
794         memset(&sg_cfg, 0, sizeof(sg_cfg));
795         sg_cfg.vxlan = cfg->vxlan;
796         sg_cfg.group.src = *addr;
797         sg_cfg.group.dst = cfg->group.dst;
798         sg_cfg.group.vni = cfg->group.vni;
799         INIT_LIST_HEAD(&sg_cfg.src_list);
800         sg_cfg.remote_ip = cfg->remote_ip;
801         sg_cfg.remote_ifindex = cfg->remote_ifindex;
802         sg_cfg.remote_vni = cfg->remote_vni;
803         sg_cfg.remote_port = cfg->remote_port;
804         sg_cfg.nlflags = cfg->nlflags;
805         sg_cfg.filter_mode = MCAST_INCLUDE;
806         if (cfg->filter_mode == MCAST_EXCLUDE)
807                 sg_cfg.flags = VXLAN_MDB_REMOTE_F_BLOCKED;
808         sg_cfg.rt_protocol = cfg->rt_protocol;
809
810         return __vxlan_mdb_add(&sg_cfg, extack);
811 }
812
813 static void
814 vxlan_mdb_remote_src_fwd_del(struct vxlan_dev *vxlan,
815                              const struct vxlan_mdb_entry_key *group,
816                              const struct vxlan_mdb_remote *remote,
817                              const union vxlan_addr *addr)
818 {
819         struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
820         struct vxlan_mdb_config sg_cfg;
821
822         memset(&sg_cfg, 0, sizeof(sg_cfg));
823         sg_cfg.vxlan = vxlan;
824         sg_cfg.group.src = *addr;
825         sg_cfg.group.dst = group->dst;
826         sg_cfg.group.vni = group->vni;
827         INIT_LIST_HEAD(&sg_cfg.src_list);
828         sg_cfg.remote_ip = rd->remote_ip;
829
830         __vxlan_mdb_del(&sg_cfg, NULL);
831 }
832
833 static int
834 vxlan_mdb_remote_src_add(const struct vxlan_mdb_config *cfg,
835                          struct vxlan_mdb_remote *remote,
836                          const struct vxlan_mdb_config_src_entry *src,
837                          struct netlink_ext_ack *extack)
838 {
839         struct vxlan_mdb_src_entry *ent;
840         int err;
841
842         ent = vxlan_mdb_remote_src_entry_lookup(remote, &src->addr);
843         if (!ent) {
844                 ent = vxlan_mdb_remote_src_entry_add(remote, &src->addr);
845                 if (!ent)
846                         return -ENOMEM;
847         } else if (!(cfg->nlflags & NLM_F_REPLACE)) {
848                 NL_SET_ERR_MSG_MOD(extack, "Source entry already exists");
849                 return -EEXIST;
850         }
851
852         err = vxlan_mdb_remote_src_fwd_add(cfg, &ent->addr, extack);
853         if (err)
854                 goto err_src_del;
855
856         /* Clear flags in case source entry was marked for deletion as part of
857          * replace flow.
858          */
859         ent->flags = 0;
860
861         return 0;
862
863 err_src_del:
864         vxlan_mdb_remote_src_entry_del(ent);
865         return err;
866 }
867
868 static void vxlan_mdb_remote_src_del(struct vxlan_dev *vxlan,
869                                      const struct vxlan_mdb_entry_key *group,
870                                      const struct vxlan_mdb_remote *remote,
871                                      struct vxlan_mdb_src_entry *ent)
872 {
873         vxlan_mdb_remote_src_fwd_del(vxlan, group, remote, &ent->addr);
874         vxlan_mdb_remote_src_entry_del(ent);
875 }
876
877 static int vxlan_mdb_remote_srcs_add(const struct vxlan_mdb_config *cfg,
878                                      struct vxlan_mdb_remote *remote,
879                                      struct netlink_ext_ack *extack)
880 {
881         struct vxlan_mdb_config_src_entry *src;
882         struct vxlan_mdb_src_entry *ent;
883         struct hlist_node *tmp;
884         int err;
885
886         list_for_each_entry(src, &cfg->src_list, node) {
887                 err = vxlan_mdb_remote_src_add(cfg, remote, src, extack);
888                 if (err)
889                         goto err_src_del;
890         }
891
892         return 0;
893
894 err_src_del:
895         hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node)
896                 vxlan_mdb_remote_src_del(cfg->vxlan, &cfg->group, remote, ent);
897         return err;
898 }
899
900 static void vxlan_mdb_remote_srcs_del(struct vxlan_dev *vxlan,
901                                       const struct vxlan_mdb_entry_key *group,
902                                       struct vxlan_mdb_remote *remote)
903 {
904         struct vxlan_mdb_src_entry *ent;
905         struct hlist_node *tmp;
906
907         hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node)
908                 vxlan_mdb_remote_src_del(vxlan, group, remote, ent);
909 }
910
911 static size_t
912 vxlan_mdb_nlmsg_src_list_size(const struct vxlan_mdb_entry_key *group,
913                               const struct vxlan_mdb_remote *remote)
914 {
915         struct vxlan_mdb_src_entry *ent;
916         size_t nlmsg_size;
917
918         if (hlist_empty(&remote->src_list))
919                 return 0;
920
921         /* MDBA_MDB_EATTR_SRC_LIST */
922         nlmsg_size = nla_total_size(0);
923
924         hlist_for_each_entry(ent, &remote->src_list, node) {
925                               /* MDBA_MDB_SRCLIST_ENTRY */
926                 nlmsg_size += nla_total_size(0) +
927                               /* MDBA_MDB_SRCATTR_ADDRESS */
928                               nla_total_size(vxlan_addr_size(&group->dst)) +
929                               /* MDBA_MDB_SRCATTR_TIMER */
930                               nla_total_size(sizeof(u8));
931         }
932
933         return nlmsg_size;
934 }
935
936 static size_t
937 vxlan_mdb_nlmsg_remote_size(const struct vxlan_dev *vxlan,
938                             const struct vxlan_mdb_entry *mdb_entry,
939                             const struct vxlan_mdb_remote *remote)
940 {
941         const struct vxlan_mdb_entry_key *group = &mdb_entry->key;
942         struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
943         size_t nlmsg_size;
944
945                      /* MDBA_MDB_ENTRY_INFO */
946         nlmsg_size = nla_total_size(sizeof(struct br_mdb_entry)) +
947                      /* MDBA_MDB_EATTR_TIMER */
948                      nla_total_size(sizeof(u32));
949
950         /* MDBA_MDB_EATTR_SOURCE */
951         if (vxlan_mdb_is_sg(group))
952                 nlmsg_size += nla_total_size(vxlan_addr_size(&group->dst));
953         /* MDBA_MDB_EATTR_RTPROT */
954         nlmsg_size += nla_total_size(sizeof(u8));
955         /* MDBA_MDB_EATTR_SRC_LIST */
956         nlmsg_size += vxlan_mdb_nlmsg_src_list_size(group, remote);
957         /* MDBA_MDB_EATTR_GROUP_MODE */
958         nlmsg_size += nla_total_size(sizeof(u8));
959         /* MDBA_MDB_EATTR_DST */
960         nlmsg_size += nla_total_size(vxlan_addr_size(&rd->remote_ip));
961         /* MDBA_MDB_EATTR_DST_PORT */
962         if (rd->remote_port && rd->remote_port != vxlan->cfg.dst_port)
963                 nlmsg_size += nla_total_size(sizeof(u16));
964         /* MDBA_MDB_EATTR_VNI */
965         if (rd->remote_vni != vxlan->default_dst.remote_vni)
966                 nlmsg_size += nla_total_size(sizeof(u32));
967         /* MDBA_MDB_EATTR_IFINDEX */
968         if (rd->remote_ifindex)
969                 nlmsg_size += nla_total_size(sizeof(u32));
970         /* MDBA_MDB_EATTR_SRC_VNI */
971         if ((vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA) && group->vni)
972                 nlmsg_size += nla_total_size(sizeof(u32));
973
974         return nlmsg_size;
975 }
976
977 static size_t vxlan_mdb_nlmsg_size(const struct vxlan_dev *vxlan,
978                                    const struct vxlan_mdb_entry *mdb_entry,
979                                    const struct vxlan_mdb_remote *remote)
980 {
981         return NLMSG_ALIGN(sizeof(struct br_port_msg)) +
982                /* MDBA_MDB */
983                nla_total_size(0) +
984                /* MDBA_MDB_ENTRY */
985                nla_total_size(0) +
986                /* Remote entry */
987                vxlan_mdb_nlmsg_remote_size(vxlan, mdb_entry, remote);
988 }
989
990 static int vxlan_mdb_nlmsg_fill(const struct vxlan_dev *vxlan,
991                                 struct sk_buff *skb,
992                                 const struct vxlan_mdb_entry *mdb_entry,
993                                 const struct vxlan_mdb_remote *remote,
994                                 int type)
995 {
996         struct nlattr *mdb_nest, *mdb_entry_nest;
997         struct br_port_msg *bpm;
998         struct nlmsghdr *nlh;
999
1000         nlh = nlmsg_put(skb, 0, 0, type, sizeof(*bpm), 0);
1001         if (!nlh)
1002                 return -EMSGSIZE;
1003
1004         bpm = nlmsg_data(nlh);
1005         memset(bpm, 0, sizeof(*bpm));
1006         bpm->family  = AF_BRIDGE;
1007         bpm->ifindex = vxlan->dev->ifindex;
1008
1009         mdb_nest = nla_nest_start_noflag(skb, MDBA_MDB);
1010         if (!mdb_nest)
1011                 goto cancel;
1012         mdb_entry_nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
1013         if (!mdb_entry_nest)
1014                 goto cancel;
1015
1016         if (vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote))
1017                 goto cancel;
1018
1019         nla_nest_end(skb, mdb_entry_nest);
1020         nla_nest_end(skb, mdb_nest);
1021         nlmsg_end(skb, nlh);
1022
1023         return 0;
1024
1025 cancel:
1026         nlmsg_cancel(skb, nlh);
1027         return -EMSGSIZE;
1028 }
1029
1030 static void vxlan_mdb_remote_notify(const struct vxlan_dev *vxlan,
1031                                     const struct vxlan_mdb_entry *mdb_entry,
1032                                     const struct vxlan_mdb_remote *remote,
1033                                     int type)
1034 {
1035         struct net *net = dev_net(vxlan->dev);
1036         struct sk_buff *skb;
1037         int err = -ENOBUFS;
1038
1039         skb = nlmsg_new(vxlan_mdb_nlmsg_size(vxlan, mdb_entry, remote),
1040                         GFP_KERNEL);
1041         if (!skb)
1042                 goto errout;
1043
1044         err = vxlan_mdb_nlmsg_fill(vxlan, skb, mdb_entry, remote, type);
1045         if (err) {
1046                 kfree_skb(skb);
1047                 goto errout;
1048         }
1049
1050         rtnl_notify(skb, net, 0, RTNLGRP_MDB, NULL, GFP_KERNEL);
1051         return;
1052 errout:
1053         rtnl_set_sk_err(net, RTNLGRP_MDB, err);
1054 }
1055
1056 static int
1057 vxlan_mdb_remote_srcs_replace(const struct vxlan_mdb_config *cfg,
1058                               const struct vxlan_mdb_entry *mdb_entry,
1059                               struct vxlan_mdb_remote *remote,
1060                               struct netlink_ext_ack *extack)
1061 {
1062         struct vxlan_dev *vxlan = cfg->vxlan;
1063         struct vxlan_mdb_src_entry *ent;
1064         struct hlist_node *tmp;
1065         int err;
1066
1067         hlist_for_each_entry(ent, &remote->src_list, node)
1068                 ent->flags |= VXLAN_SGRP_F_DELETE;
1069
1070         err = vxlan_mdb_remote_srcs_add(cfg, remote, extack);
1071         if (err)
1072                 goto err_clear_delete;
1073
1074         hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node) {
1075                 if (ent->flags & VXLAN_SGRP_F_DELETE)
1076                         vxlan_mdb_remote_src_del(vxlan, &mdb_entry->key, remote,
1077                                                  ent);
1078         }
1079
1080         return 0;
1081
1082 err_clear_delete:
1083         hlist_for_each_entry(ent, &remote->src_list, node)
1084                 ent->flags &= ~VXLAN_SGRP_F_DELETE;
1085         return err;
1086 }
1087
1088 static int vxlan_mdb_remote_replace(const struct vxlan_mdb_config *cfg,
1089                                     const struct vxlan_mdb_entry *mdb_entry,
1090                                     struct vxlan_mdb_remote *remote,
1091                                     struct netlink_ext_ack *extack)
1092 {
1093         struct vxlan_rdst *new_rd, *old_rd = rtnl_dereference(remote->rd);
1094         struct vxlan_dev *vxlan = cfg->vxlan;
1095         int err;
1096
1097         err = vxlan_mdb_remote_rdst_init(cfg, remote);
1098         if (err)
1099                 return err;
1100         new_rd = rtnl_dereference(remote->rd);
1101
1102         err = vxlan_mdb_remote_srcs_replace(cfg, mdb_entry, remote, extack);
1103         if (err)
1104                 goto err_rdst_reset;
1105
1106         WRITE_ONCE(remote->flags, cfg->flags);
1107         WRITE_ONCE(remote->filter_mode, cfg->filter_mode);
1108         remote->rt_protocol = cfg->rt_protocol;
1109         vxlan_mdb_remote_notify(vxlan, mdb_entry, remote, RTM_NEWMDB);
1110
1111         vxlan_mdb_remote_rdst_fini(old_rd);
1112
1113         return 0;
1114
1115 err_rdst_reset:
1116         rcu_assign_pointer(remote->rd, old_rd);
1117         vxlan_mdb_remote_rdst_fini(new_rd);
1118         return err;
1119 }
1120
1121 static int vxlan_mdb_remote_add(const struct vxlan_mdb_config *cfg,
1122                                 struct vxlan_mdb_entry *mdb_entry,
1123                                 struct netlink_ext_ack *extack)
1124 {
1125         struct vxlan_mdb_remote *remote;
1126         int err;
1127
1128         remote = vxlan_mdb_remote_lookup(mdb_entry, &cfg->remote_ip);
1129         if (remote) {
1130                 if (!(cfg->nlflags & NLM_F_REPLACE)) {
1131                         NL_SET_ERR_MSG_MOD(extack, "Replace not specified and MDB remote entry already exists");
1132                         return -EEXIST;
1133                 }
1134                 return vxlan_mdb_remote_replace(cfg, mdb_entry, remote, extack);
1135         }
1136
1137         if (!(cfg->nlflags & NLM_F_CREATE)) {
1138                 NL_SET_ERR_MSG_MOD(extack, "Create not specified and entry does not exist");
1139                 return -ENOENT;
1140         }
1141
1142         remote = kzalloc(sizeof(*remote), GFP_KERNEL);
1143         if (!remote)
1144                 return -ENOMEM;
1145
1146         err = vxlan_mdb_remote_init(cfg, remote);
1147         if (err) {
1148                 NL_SET_ERR_MSG_MOD(extack, "Failed to initialize remote MDB entry");
1149                 goto err_free_remote;
1150         }
1151
1152         err = vxlan_mdb_remote_srcs_add(cfg, remote, extack);
1153         if (err)
1154                 goto err_remote_fini;
1155
1156         list_add_rcu(&remote->list, &mdb_entry->remotes);
1157         vxlan_mdb_remote_notify(cfg->vxlan, mdb_entry, remote, RTM_NEWMDB);
1158
1159         return 0;
1160
1161 err_remote_fini:
1162         vxlan_mdb_remote_fini(cfg->vxlan, remote);
1163 err_free_remote:
1164         kfree(remote);
1165         return err;
1166 }
1167
1168 static void vxlan_mdb_remote_del(struct vxlan_dev *vxlan,
1169                                  struct vxlan_mdb_entry *mdb_entry,
1170                                  struct vxlan_mdb_remote *remote)
1171 {
1172         vxlan_mdb_remote_notify(vxlan, mdb_entry, remote, RTM_DELMDB);
1173         list_del_rcu(&remote->list);
1174         vxlan_mdb_remote_srcs_del(vxlan, &mdb_entry->key, remote);
1175         vxlan_mdb_remote_fini(vxlan, remote);
1176         kfree_rcu(remote, rcu);
1177 }
1178
1179 static struct vxlan_mdb_entry *
1180 vxlan_mdb_entry_get(struct vxlan_dev *vxlan,
1181                     const struct vxlan_mdb_entry_key *group)
1182 {
1183         struct vxlan_mdb_entry *mdb_entry;
1184         int err;
1185
1186         mdb_entry = vxlan_mdb_entry_lookup(vxlan, group);
1187         if (mdb_entry)
1188                 return mdb_entry;
1189
1190         mdb_entry = kzalloc(sizeof(*mdb_entry), GFP_KERNEL);
1191         if (!mdb_entry)
1192                 return ERR_PTR(-ENOMEM);
1193
1194         INIT_LIST_HEAD(&mdb_entry->remotes);
1195         memcpy(&mdb_entry->key, group, sizeof(mdb_entry->key));
1196         hlist_add_head(&mdb_entry->mdb_node, &vxlan->mdb_list);
1197
1198         err = rhashtable_lookup_insert_fast(&vxlan->mdb_tbl,
1199                                             &mdb_entry->rhnode,
1200                                             vxlan_mdb_rht_params);
1201         if (err)
1202                 goto err_free_entry;
1203
1204         if (hlist_is_singular_node(&mdb_entry->mdb_node, &vxlan->mdb_list))
1205                 vxlan->cfg.flags |= VXLAN_F_MDB;
1206
1207         return mdb_entry;
1208
1209 err_free_entry:
1210         hlist_del(&mdb_entry->mdb_node);
1211         kfree(mdb_entry);
1212         return ERR_PTR(err);
1213 }
1214
1215 static void vxlan_mdb_entry_put(struct vxlan_dev *vxlan,
1216                                 struct vxlan_mdb_entry *mdb_entry)
1217 {
1218         if (!list_empty(&mdb_entry->remotes))
1219                 return;
1220
1221         if (hlist_is_singular_node(&mdb_entry->mdb_node, &vxlan->mdb_list))
1222                 vxlan->cfg.flags &= ~VXLAN_F_MDB;
1223
1224         rhashtable_remove_fast(&vxlan->mdb_tbl, &mdb_entry->rhnode,
1225                                vxlan_mdb_rht_params);
1226         hlist_del(&mdb_entry->mdb_node);
1227         kfree_rcu(mdb_entry, rcu);
1228 }
1229
1230 static int __vxlan_mdb_add(const struct vxlan_mdb_config *cfg,
1231                            struct netlink_ext_ack *extack)
1232 {
1233         struct vxlan_dev *vxlan = cfg->vxlan;
1234         struct vxlan_mdb_entry *mdb_entry;
1235         int err;
1236
1237         mdb_entry = vxlan_mdb_entry_get(vxlan, &cfg->group);
1238         if (IS_ERR(mdb_entry))
1239                 return PTR_ERR(mdb_entry);
1240
1241         err = vxlan_mdb_remote_add(cfg, mdb_entry, extack);
1242         if (err)
1243                 goto err_entry_put;
1244
1245         vxlan->mdb_seq++;
1246
1247         return 0;
1248
1249 err_entry_put:
1250         vxlan_mdb_entry_put(vxlan, mdb_entry);
1251         return err;
1252 }
1253
1254 static int __vxlan_mdb_del(const struct vxlan_mdb_config *cfg,
1255                            struct netlink_ext_ack *extack)
1256 {
1257         struct vxlan_dev *vxlan = cfg->vxlan;
1258         struct vxlan_mdb_entry *mdb_entry;
1259         struct vxlan_mdb_remote *remote;
1260
1261         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &cfg->group);
1262         if (!mdb_entry) {
1263                 NL_SET_ERR_MSG_MOD(extack, "Did not find MDB entry");
1264                 return -ENOENT;
1265         }
1266
1267         remote = vxlan_mdb_remote_lookup(mdb_entry, &cfg->remote_ip);
1268         if (!remote) {
1269                 NL_SET_ERR_MSG_MOD(extack, "Did not find MDB remote entry");
1270                 return -ENOENT;
1271         }
1272
1273         vxlan_mdb_remote_del(vxlan, mdb_entry, remote);
1274         vxlan_mdb_entry_put(vxlan, mdb_entry);
1275
1276         vxlan->mdb_seq++;
1277
1278         return 0;
1279 }
1280
1281 int vxlan_mdb_add(struct net_device *dev, struct nlattr *tb[], u16 nlmsg_flags,
1282                   struct netlink_ext_ack *extack)
1283 {
1284         struct vxlan_mdb_config cfg;
1285         int err;
1286
1287         ASSERT_RTNL();
1288
1289         err = vxlan_mdb_config_init(&cfg, dev, tb, nlmsg_flags, extack);
1290         if (err)
1291                 return err;
1292
1293         err = __vxlan_mdb_add(&cfg, extack);
1294
1295         vxlan_mdb_config_fini(&cfg);
1296         return err;
1297 }
1298
1299 int vxlan_mdb_del(struct net_device *dev, struct nlattr *tb[],
1300                   struct netlink_ext_ack *extack)
1301 {
1302         struct vxlan_mdb_config cfg;
1303         int err;
1304
1305         ASSERT_RTNL();
1306
1307         err = vxlan_mdb_config_init(&cfg, dev, tb, 0, extack);
1308         if (err)
1309                 return err;
1310
1311         err = __vxlan_mdb_del(&cfg, extack);
1312
1313         vxlan_mdb_config_fini(&cfg);
1314         return err;
1315 }
1316
1317 static const struct nla_policy
1318 vxlan_mdbe_attrs_del_bulk_pol[MDBE_ATTR_MAX + 1] = {
1319         [MDBE_ATTR_RTPROT] = NLA_POLICY_MIN(NLA_U8, RTPROT_STATIC),
1320         [MDBE_ATTR_DST] = NLA_POLICY_RANGE(NLA_BINARY,
1321                                            sizeof(struct in_addr),
1322                                            sizeof(struct in6_addr)),
1323         [MDBE_ATTR_DST_PORT] = { .type = NLA_U16 },
1324         [MDBE_ATTR_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
1325         [MDBE_ATTR_SRC_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
1326         [MDBE_ATTR_STATE_MASK] = NLA_POLICY_MASK(NLA_U8, MDB_PERMANENT),
1327 };
1328
1329 static int vxlan_mdb_flush_desc_init(struct vxlan_dev *vxlan,
1330                                      struct vxlan_mdb_flush_desc *desc,
1331                                      struct nlattr *tb[],
1332                                      struct netlink_ext_ack *extack)
1333 {
1334         struct br_mdb_entry *entry = nla_data(tb[MDBA_SET_ENTRY]);
1335         struct nlattr *mdbe_attrs[MDBE_ATTR_MAX + 1];
1336         int err;
1337
1338         if (entry->ifindex && entry->ifindex != vxlan->dev->ifindex) {
1339                 NL_SET_ERR_MSG_MOD(extack, "Invalid port net device");
1340                 return -EINVAL;
1341         }
1342
1343         if (entry->vid) {
1344                 NL_SET_ERR_MSG_MOD(extack, "VID must not be specified");
1345                 return -EINVAL;
1346         }
1347
1348         if (!tb[MDBA_SET_ENTRY_ATTRS])
1349                 return 0;
1350
1351         err = nla_parse_nested(mdbe_attrs, MDBE_ATTR_MAX,
1352                                tb[MDBA_SET_ENTRY_ATTRS],
1353                                vxlan_mdbe_attrs_del_bulk_pol, extack);
1354         if (err)
1355                 return err;
1356
1357         if (mdbe_attrs[MDBE_ATTR_STATE_MASK]) {
1358                 u8 state_mask = nla_get_u8(mdbe_attrs[MDBE_ATTR_STATE_MASK]);
1359
1360                 if ((state_mask & MDB_PERMANENT) && !(entry->state & MDB_PERMANENT)) {
1361                         NL_SET_ERR_MSG_MOD(extack, "Only permanent MDB entries are supported");
1362                         return -EINVAL;
1363                 }
1364         }
1365
1366         if (mdbe_attrs[MDBE_ATTR_RTPROT])
1367                 desc->rt_protocol = nla_get_u8(mdbe_attrs[MDBE_ATTR_RTPROT]);
1368
1369         if (mdbe_attrs[MDBE_ATTR_DST])
1370                 vxlan_nla_get_addr(&desc->remote_ip, mdbe_attrs[MDBE_ATTR_DST]);
1371
1372         if (mdbe_attrs[MDBE_ATTR_DST_PORT])
1373                 desc->remote_port =
1374                         cpu_to_be16(nla_get_u16(mdbe_attrs[MDBE_ATTR_DST_PORT]));
1375
1376         if (mdbe_attrs[MDBE_ATTR_VNI])
1377                 desc->remote_vni =
1378                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_VNI]));
1379
1380         if (mdbe_attrs[MDBE_ATTR_SRC_VNI])
1381                 desc->src_vni =
1382                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_SRC_VNI]));
1383
1384         return 0;
1385 }
1386
1387 static void vxlan_mdb_remotes_flush(struct vxlan_dev *vxlan,
1388                                     struct vxlan_mdb_entry *mdb_entry,
1389                                     const struct vxlan_mdb_flush_desc *desc)
1390 {
1391         struct vxlan_mdb_remote *remote, *tmp;
1392
1393         list_for_each_entry_safe(remote, tmp, &mdb_entry->remotes, list) {
1394                 struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
1395                 __be32 remote_vni;
1396
1397                 if (desc->remote_ip.sa.sa_family &&
1398                     !vxlan_addr_equal(&desc->remote_ip, &rd->remote_ip))
1399                         continue;
1400
1401                 /* Encapsulation is performed with source VNI if remote VNI
1402                  * is not set.
1403                  */
1404                 remote_vni = rd->remote_vni ? : mdb_entry->key.vni;
1405                 if (desc->remote_vni && desc->remote_vni != remote_vni)
1406                         continue;
1407
1408                 if (desc->remote_port && desc->remote_port != rd->remote_port)
1409                         continue;
1410
1411                 if (desc->rt_protocol &&
1412                     desc->rt_protocol != remote->rt_protocol)
1413                         continue;
1414
1415                 vxlan_mdb_remote_del(vxlan, mdb_entry, remote);
1416         }
1417 }
1418
1419 static void vxlan_mdb_flush(struct vxlan_dev *vxlan,
1420                             const struct vxlan_mdb_flush_desc *desc)
1421 {
1422         struct vxlan_mdb_entry *mdb_entry;
1423         struct hlist_node *tmp;
1424
1425         /* The removal of an entry cannot trigger the removal of another entry
1426          * since entries are always added to the head of the list.
1427          */
1428         hlist_for_each_entry_safe(mdb_entry, tmp, &vxlan->mdb_list, mdb_node) {
1429                 if (desc->src_vni && desc->src_vni != mdb_entry->key.vni)
1430                         continue;
1431
1432                 vxlan_mdb_remotes_flush(vxlan, mdb_entry, desc);
1433                 /* Entry will only be removed if its remotes list is empty. */
1434                 vxlan_mdb_entry_put(vxlan, mdb_entry);
1435         }
1436 }
1437
1438 int vxlan_mdb_del_bulk(struct net_device *dev, struct nlattr *tb[],
1439                        struct netlink_ext_ack *extack)
1440 {
1441         struct vxlan_dev *vxlan = netdev_priv(dev);
1442         struct vxlan_mdb_flush_desc desc = {};
1443         int err;
1444
1445         ASSERT_RTNL();
1446
1447         err = vxlan_mdb_flush_desc_init(vxlan, &desc, tb, extack);
1448         if (err)
1449                 return err;
1450
1451         vxlan_mdb_flush(vxlan, &desc);
1452
1453         return 0;
1454 }
1455
1456 static const struct nla_policy vxlan_mdbe_attrs_get_pol[MDBE_ATTR_MAX + 1] = {
1457         [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY,
1458                                               sizeof(struct in_addr),
1459                                               sizeof(struct in6_addr)),
1460         [MDBE_ATTR_SRC_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
1461 };
1462
1463 static int vxlan_mdb_get_parse(struct net_device *dev, struct nlattr *tb[],
1464                                struct vxlan_mdb_entry_key *group,
1465                                struct netlink_ext_ack *extack)
1466 {
1467         struct br_mdb_entry *entry = nla_data(tb[MDBA_GET_ENTRY]);
1468         struct nlattr *mdbe_attrs[MDBE_ATTR_MAX + 1];
1469         struct vxlan_dev *vxlan = netdev_priv(dev);
1470         int err;
1471
1472         memset(group, 0, sizeof(*group));
1473         group->vni = vxlan->default_dst.remote_vni;
1474
1475         if (!tb[MDBA_GET_ENTRY_ATTRS]) {
1476                 vxlan_mdb_group_set(group, entry, NULL);
1477                 return 0;
1478         }
1479
1480         err = nla_parse_nested(mdbe_attrs, MDBE_ATTR_MAX,
1481                                tb[MDBA_GET_ENTRY_ATTRS],
1482                                vxlan_mdbe_attrs_get_pol, extack);
1483         if (err)
1484                 return err;
1485
1486         if (mdbe_attrs[MDBE_ATTR_SOURCE] &&
1487             !vxlan_mdb_is_valid_source(mdbe_attrs[MDBE_ATTR_SOURCE],
1488                                        entry->addr.proto, extack))
1489                 return -EINVAL;
1490
1491         vxlan_mdb_group_set(group, entry, mdbe_attrs[MDBE_ATTR_SOURCE]);
1492
1493         if (mdbe_attrs[MDBE_ATTR_SRC_VNI])
1494                 group->vni =
1495                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_SRC_VNI]));
1496
1497         return 0;
1498 }
1499
1500 static struct sk_buff *
1501 vxlan_mdb_get_reply_alloc(const struct vxlan_dev *vxlan,
1502                           const struct vxlan_mdb_entry *mdb_entry)
1503 {
1504         struct vxlan_mdb_remote *remote;
1505         size_t nlmsg_size;
1506
1507         nlmsg_size = NLMSG_ALIGN(sizeof(struct br_port_msg)) +
1508                      /* MDBA_MDB */
1509                      nla_total_size(0) +
1510                      /* MDBA_MDB_ENTRY */
1511                      nla_total_size(0);
1512
1513         list_for_each_entry(remote, &mdb_entry->remotes, list)
1514                 nlmsg_size += vxlan_mdb_nlmsg_remote_size(vxlan, mdb_entry,
1515                                                           remote);
1516
1517         return nlmsg_new(nlmsg_size, GFP_KERNEL);
1518 }
1519
1520 static int
1521 vxlan_mdb_get_reply_fill(const struct vxlan_dev *vxlan,
1522                          struct sk_buff *skb,
1523                          const struct vxlan_mdb_entry *mdb_entry,
1524                          u32 portid, u32 seq)
1525 {
1526         struct nlattr *mdb_nest, *mdb_entry_nest;
1527         struct vxlan_mdb_remote *remote;
1528         struct br_port_msg *bpm;
1529         struct nlmsghdr *nlh;
1530         int err;
1531
1532         nlh = nlmsg_put(skb, portid, seq, RTM_NEWMDB, sizeof(*bpm), 0);
1533         if (!nlh)
1534                 return -EMSGSIZE;
1535
1536         bpm = nlmsg_data(nlh);
1537         memset(bpm, 0, sizeof(*bpm));
1538         bpm->family  = AF_BRIDGE;
1539         bpm->ifindex = vxlan->dev->ifindex;
1540         mdb_nest = nla_nest_start_noflag(skb, MDBA_MDB);
1541         if (!mdb_nest) {
1542                 err = -EMSGSIZE;
1543                 goto cancel;
1544         }
1545         mdb_entry_nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
1546         if (!mdb_entry_nest) {
1547                 err = -EMSGSIZE;
1548                 goto cancel;
1549         }
1550
1551         list_for_each_entry(remote, &mdb_entry->remotes, list) {
1552                 err = vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote);
1553                 if (err)
1554                         goto cancel;
1555         }
1556
1557         nla_nest_end(skb, mdb_entry_nest);
1558         nla_nest_end(skb, mdb_nest);
1559         nlmsg_end(skb, nlh);
1560
1561         return 0;
1562
1563 cancel:
1564         nlmsg_cancel(skb, nlh);
1565         return err;
1566 }
1567
1568 int vxlan_mdb_get(struct net_device *dev, struct nlattr *tb[], u32 portid,
1569                   u32 seq, struct netlink_ext_ack *extack)
1570 {
1571         struct vxlan_dev *vxlan = netdev_priv(dev);
1572         struct vxlan_mdb_entry *mdb_entry;
1573         struct vxlan_mdb_entry_key group;
1574         struct sk_buff *skb;
1575         int err;
1576
1577         ASSERT_RTNL();
1578
1579         err = vxlan_mdb_get_parse(dev, tb, &group, extack);
1580         if (err)
1581                 return err;
1582
1583         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1584         if (!mdb_entry) {
1585                 NL_SET_ERR_MSG_MOD(extack, "MDB entry not found");
1586                 return -ENOENT;
1587         }
1588
1589         skb = vxlan_mdb_get_reply_alloc(vxlan, mdb_entry);
1590         if (!skb)
1591                 return -ENOMEM;
1592
1593         err = vxlan_mdb_get_reply_fill(vxlan, skb, mdb_entry, portid, seq);
1594         if (err) {
1595                 NL_SET_ERR_MSG_MOD(extack, "Failed to fill MDB get reply");
1596                 goto free;
1597         }
1598
1599         return rtnl_unicast(skb, dev_net(dev), portid);
1600
1601 free:
1602         kfree_skb(skb);
1603         return err;
1604 }
1605
1606 struct vxlan_mdb_entry *vxlan_mdb_entry_skb_get(struct vxlan_dev *vxlan,
1607                                                 struct sk_buff *skb,
1608                                                 __be32 src_vni)
1609 {
1610         struct vxlan_mdb_entry *mdb_entry;
1611         struct vxlan_mdb_entry_key group;
1612
1613         if (!is_multicast_ether_addr(eth_hdr(skb)->h_dest) ||
1614             is_broadcast_ether_addr(eth_hdr(skb)->h_dest))
1615                 return NULL;
1616
1617         /* When not in collect metadata mode, 'src_vni' is zero, but MDB
1618          * entries are stored with the VNI of the VXLAN device.
1619          */
1620         if (!(vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA))
1621                 src_vni = vxlan->default_dst.remote_vni;
1622
1623         memset(&group, 0, sizeof(group));
1624         group.vni = src_vni;
1625
1626         switch (skb->protocol) {
1627         case htons(ETH_P_IP):
1628                 if (!pskb_may_pull(skb, sizeof(struct iphdr)))
1629                         return NULL;
1630                 group.dst.sa.sa_family = AF_INET;
1631                 group.dst.sin.sin_addr.s_addr = ip_hdr(skb)->daddr;
1632                 group.src.sa.sa_family = AF_INET;
1633                 group.src.sin.sin_addr.s_addr = ip_hdr(skb)->saddr;
1634                 break;
1635 #if IS_ENABLED(CONFIG_IPV6)
1636         case htons(ETH_P_IPV6):
1637                 if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
1638                         return NULL;
1639                 group.dst.sa.sa_family = AF_INET6;
1640                 group.dst.sin6.sin6_addr = ipv6_hdr(skb)->daddr;
1641                 group.src.sa.sa_family = AF_INET6;
1642                 group.src.sin6.sin6_addr = ipv6_hdr(skb)->saddr;
1643                 break;
1644 #endif
1645         default:
1646                 return NULL;
1647         }
1648
1649         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1650         if (mdb_entry)
1651                 return mdb_entry;
1652
1653         memset(&group.src, 0, sizeof(group.src));
1654         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1655         if (mdb_entry)
1656                 return mdb_entry;
1657
1658         /* No (S, G) or (*, G) found. Look up the all-zeros entry, but only if
1659          * the destination IP address is not link-local multicast since we want
1660          * to transmit such traffic together with broadcast and unknown unicast
1661          * traffic.
1662          */
1663         switch (skb->protocol) {
1664         case htons(ETH_P_IP):
1665                 if (ipv4_is_local_multicast(group.dst.sin.sin_addr.s_addr))
1666                         return NULL;
1667                 group.dst.sin.sin_addr.s_addr = 0;
1668                 break;
1669 #if IS_ENABLED(CONFIG_IPV6)
1670         case htons(ETH_P_IPV6):
1671                 if (ipv6_addr_type(&group.dst.sin6.sin6_addr) &
1672                     IPV6_ADDR_LINKLOCAL)
1673                         return NULL;
1674                 memset(&group.dst.sin6.sin6_addr, 0,
1675                        sizeof(group.dst.sin6.sin6_addr));
1676                 break;
1677 #endif
1678         default:
1679                 return NULL;
1680         }
1681
1682         return vxlan_mdb_entry_lookup(vxlan, &group);
1683 }
1684
1685 netdev_tx_t vxlan_mdb_xmit(struct vxlan_dev *vxlan,
1686                            const struct vxlan_mdb_entry *mdb_entry,
1687                            struct sk_buff *skb)
1688 {
1689         struct vxlan_mdb_remote *remote, *fremote = NULL;
1690         __be32 src_vni = mdb_entry->key.vni;
1691
1692         list_for_each_entry_rcu(remote, &mdb_entry->remotes, list) {
1693                 struct sk_buff *skb1;
1694
1695                 if ((vxlan_mdb_is_star_g(&mdb_entry->key) &&
1696                      READ_ONCE(remote->filter_mode) == MCAST_INCLUDE) ||
1697                     (READ_ONCE(remote->flags) & VXLAN_MDB_REMOTE_F_BLOCKED))
1698                         continue;
1699
1700                 if (!fremote) {
1701                         fremote = remote;
1702                         continue;
1703                 }
1704
1705                 skb1 = skb_clone(skb, GFP_ATOMIC);
1706                 if (skb1)
1707                         vxlan_xmit_one(skb1, vxlan->dev, src_vni,
1708                                        rcu_dereference(remote->rd), false);
1709         }
1710
1711         if (fremote)
1712                 vxlan_xmit_one(skb, vxlan->dev, src_vni,
1713                                rcu_dereference(fremote->rd), false);
1714         else
1715                 kfree_skb(skb);
1716
1717         return NETDEV_TX_OK;
1718 }
1719
1720 static void vxlan_mdb_check_empty(void *ptr, void *arg)
1721 {
1722         WARN_ON_ONCE(1);
1723 }
1724
1725 int vxlan_mdb_init(struct vxlan_dev *vxlan)
1726 {
1727         int err;
1728
1729         err = rhashtable_init(&vxlan->mdb_tbl, &vxlan_mdb_rht_params);
1730         if (err)
1731                 return err;
1732
1733         INIT_HLIST_HEAD(&vxlan->mdb_list);
1734
1735         return 0;
1736 }
1737
1738 void vxlan_mdb_fini(struct vxlan_dev *vxlan)
1739 {
1740         struct vxlan_mdb_flush_desc desc = {};
1741
1742         vxlan_mdb_flush(vxlan, &desc);
1743         WARN_ON_ONCE(vxlan->cfg.flags & VXLAN_F_MDB);
1744         rhashtable_free_and_destroy(&vxlan->mdb_tbl, vxlan_mdb_check_empty,
1745                                     NULL);
1746 }