GNU Linux-libre 5.19-rc6-gnu
[releases.git] / drivers / infiniband / sw / rdmavt / mcast.c
1 // SPDX-License-Identifier: GPL-2.0 or BSD-3-Clause
2 /*
3  * Copyright(c) 2016 Intel Corporation.
4  */
5
6 #include <linux/slab.h>
7 #include <linux/sched.h>
8 #include <linux/rculist.h>
9 #include <rdma/rdma_vt.h>
10 #include <rdma/rdmavt_qp.h>
11
12 #include "mcast.h"
13
14 /**
15  * rvt_driver_mcast_init - init resources for multicast
16  * @rdi: rvt dev struct
17  *
18  * This is per device that registers with rdmavt
19  */
20 void rvt_driver_mcast_init(struct rvt_dev_info *rdi)
21 {
22         /*
23          * Anything that needs setup for multicast on a per driver or per rdi
24          * basis should be done in here.
25          */
26         spin_lock_init(&rdi->n_mcast_grps_lock);
27 }
28
29 /**
30  * rvt_mcast_qp_alloc - alloc a struct to link a QP to mcast GID struct
31  * @qp: the QP to link
32  */
33 static struct rvt_mcast_qp *rvt_mcast_qp_alloc(struct rvt_qp *qp)
34 {
35         struct rvt_mcast_qp *mqp;
36
37         mqp = kmalloc(sizeof(*mqp), GFP_KERNEL);
38         if (!mqp)
39                 goto bail;
40
41         mqp->qp = qp;
42         rvt_get_qp(qp);
43
44 bail:
45         return mqp;
46 }
47
48 static void rvt_mcast_qp_free(struct rvt_mcast_qp *mqp)
49 {
50         struct rvt_qp *qp = mqp->qp;
51
52         /* Notify hfi1_destroy_qp() if it is waiting. */
53         rvt_put_qp(qp);
54
55         kfree(mqp);
56 }
57
58 /**
59  * rvt_mcast_alloc - allocate the multicast GID structure
60  * @mgid: the multicast GID
61  * @lid: the muilticast LID (host order)
62  *
63  * A list of QPs will be attached to this structure.
64  */
65 static struct rvt_mcast *rvt_mcast_alloc(union ib_gid *mgid, u16 lid)
66 {
67         struct rvt_mcast *mcast;
68
69         mcast = kzalloc(sizeof(*mcast), GFP_KERNEL);
70         if (!mcast)
71                 goto bail;
72
73         mcast->mcast_addr.mgid = *mgid;
74         mcast->mcast_addr.lid = lid;
75
76         INIT_LIST_HEAD(&mcast->qp_list);
77         init_waitqueue_head(&mcast->wait);
78         atomic_set(&mcast->refcount, 0);
79
80 bail:
81         return mcast;
82 }
83
84 static void rvt_mcast_free(struct rvt_mcast *mcast)
85 {
86         struct rvt_mcast_qp *p, *tmp;
87
88         list_for_each_entry_safe(p, tmp, &mcast->qp_list, list)
89                 rvt_mcast_qp_free(p);
90
91         kfree(mcast);
92 }
93
94 /**
95  * rvt_mcast_find - search the global table for the given multicast GID/LID
96  * NOTE: It is valid to have 1 MLID with multiple MGIDs.  It is not valid
97  * to have 1 MGID with multiple MLIDs.
98  * @ibp: the IB port structure
99  * @mgid: the multicast GID to search for
100  * @lid: the multicast LID portion of the multicast address (host order)
101  *
102  * The caller is responsible for decrementing the reference count if found.
103  *
104  * Return: NULL if not found.
105  */
106 struct rvt_mcast *rvt_mcast_find(struct rvt_ibport *ibp, union ib_gid *mgid,
107                                  u16 lid)
108 {
109         struct rb_node *n;
110         unsigned long flags;
111         struct rvt_mcast *found = NULL;
112
113         spin_lock_irqsave(&ibp->lock, flags);
114         n = ibp->mcast_tree.rb_node;
115         while (n) {
116                 int ret;
117                 struct rvt_mcast *mcast;
118
119                 mcast = rb_entry(n, struct rvt_mcast, rb_node);
120
121                 ret = memcmp(mgid->raw, mcast->mcast_addr.mgid.raw,
122                              sizeof(*mgid));
123                 if (ret < 0) {
124                         n = n->rb_left;
125                 } else if (ret > 0) {
126                         n = n->rb_right;
127                 } else {
128                         /* MGID/MLID must match */
129                         if (mcast->mcast_addr.lid == lid) {
130                                 atomic_inc(&mcast->refcount);
131                                 found = mcast;
132                         }
133                         break;
134                 }
135         }
136         spin_unlock_irqrestore(&ibp->lock, flags);
137         return found;
138 }
139 EXPORT_SYMBOL(rvt_mcast_find);
140
141 /*
142  * rvt_mcast_add - insert mcast GID into table and attach QP struct
143  * @mcast: the mcast GID table
144  * @mqp: the QP to attach
145  *
146  * Return: zero if both were added.  Return EEXIST if the GID was already in
147  * the table but the QP was added.  Return ESRCH if the QP was already
148  * attached and neither structure was added. Return EINVAL if the MGID was
149  * found, but the MLID did NOT match.
150  */
151 static int rvt_mcast_add(struct rvt_dev_info *rdi, struct rvt_ibport *ibp,
152                          struct rvt_mcast *mcast, struct rvt_mcast_qp *mqp)
153 {
154         struct rb_node **n = &ibp->mcast_tree.rb_node;
155         struct rb_node *pn = NULL;
156         int ret;
157
158         spin_lock_irq(&ibp->lock);
159
160         while (*n) {
161                 struct rvt_mcast *tmcast;
162                 struct rvt_mcast_qp *p;
163
164                 pn = *n;
165                 tmcast = rb_entry(pn, struct rvt_mcast, rb_node);
166
167                 ret = memcmp(mcast->mcast_addr.mgid.raw,
168                              tmcast->mcast_addr.mgid.raw,
169                              sizeof(mcast->mcast_addr.mgid));
170                 if (ret < 0) {
171                         n = &pn->rb_left;
172                         continue;
173                 }
174                 if (ret > 0) {
175                         n = &pn->rb_right;
176                         continue;
177                 }
178
179                 if (tmcast->mcast_addr.lid != mcast->mcast_addr.lid) {
180                         ret = EINVAL;
181                         goto bail;
182                 }
183
184                 /* Search the QP list to see if this is already there. */
185                 list_for_each_entry_rcu(p, &tmcast->qp_list, list) {
186                         if (p->qp == mqp->qp) {
187                                 ret = ESRCH;
188                                 goto bail;
189                         }
190                 }
191                 if (tmcast->n_attached ==
192                     rdi->dparms.props.max_mcast_qp_attach) {
193                         ret = ENOMEM;
194                         goto bail;
195                 }
196
197                 tmcast->n_attached++;
198
199                 list_add_tail_rcu(&mqp->list, &tmcast->qp_list);
200                 ret = EEXIST;
201                 goto bail;
202         }
203
204         spin_lock(&rdi->n_mcast_grps_lock);
205         if (rdi->n_mcast_grps_allocated == rdi->dparms.props.max_mcast_grp) {
206                 spin_unlock(&rdi->n_mcast_grps_lock);
207                 ret = ENOMEM;
208                 goto bail;
209         }
210
211         rdi->n_mcast_grps_allocated++;
212         spin_unlock(&rdi->n_mcast_grps_lock);
213
214         mcast->n_attached++;
215
216         list_add_tail_rcu(&mqp->list, &mcast->qp_list);
217
218         atomic_inc(&mcast->refcount);
219         rb_link_node(&mcast->rb_node, pn, n);
220         rb_insert_color(&mcast->rb_node, &ibp->mcast_tree);
221
222         ret = 0;
223
224 bail:
225         spin_unlock_irq(&ibp->lock);
226
227         return ret;
228 }
229
230 /**
231  * rvt_attach_mcast - attach a qp to a multicast group
232  * @ibqp: Infiniband qp
233  * @gid: multicast guid
234  * @lid: multicast lid
235  *
236  * Return: 0 on success
237  */
238 int rvt_attach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
239 {
240         struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
241         struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
242         struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
243         struct rvt_mcast *mcast;
244         struct rvt_mcast_qp *mqp;
245         int ret = -ENOMEM;
246
247         if (ibqp->qp_num <= 1 || qp->state == IB_QPS_RESET)
248                 return -EINVAL;
249
250         /*
251          * Allocate data structures since its better to do this outside of
252          * spin locks and it will most likely be needed.
253          */
254         mcast = rvt_mcast_alloc(gid, lid);
255         if (!mcast)
256                 return -ENOMEM;
257
258         mqp = rvt_mcast_qp_alloc(qp);
259         if (!mqp)
260                 goto bail_mcast;
261
262         switch (rvt_mcast_add(rdi, ibp, mcast, mqp)) {
263         case ESRCH:
264                 /* Neither was used: OK to attach the same QP twice. */
265                 ret = 0;
266                 goto bail_mqp;
267         case EEXIST: /* The mcast wasn't used */
268                 ret = 0;
269                 goto bail_mcast;
270         case ENOMEM:
271                 /* Exceeded the maximum number of mcast groups. */
272                 ret = -ENOMEM;
273                 goto bail_mqp;
274         case EINVAL:
275                 /* Invalid MGID/MLID pair */
276                 ret = -EINVAL;
277                 goto bail_mqp;
278         default:
279                 break;
280         }
281
282         return 0;
283
284 bail_mqp:
285         rvt_mcast_qp_free(mqp);
286
287 bail_mcast:
288         rvt_mcast_free(mcast);
289
290         return ret;
291 }
292
293 /**
294  * rvt_detach_mcast - remove a qp from a multicast group
295  * @ibqp: Infiniband qp
296  * @gid: multicast guid
297  * @lid: multicast lid
298  *
299  * Return: 0 on success
300  */
301 int rvt_detach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
302 {
303         struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
304         struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
305         struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
306         struct rvt_mcast *mcast = NULL;
307         struct rvt_mcast_qp *p, *tmp, *delp = NULL;
308         struct rb_node *n;
309         int last = 0;
310         int ret = 0;
311
312         if (ibqp->qp_num <= 1)
313                 return -EINVAL;
314
315         spin_lock_irq(&ibp->lock);
316
317         /* Find the GID in the mcast table. */
318         n = ibp->mcast_tree.rb_node;
319         while (1) {
320                 if (!n) {
321                         spin_unlock_irq(&ibp->lock);
322                         return -EINVAL;
323                 }
324
325                 mcast = rb_entry(n, struct rvt_mcast, rb_node);
326                 ret = memcmp(gid->raw, mcast->mcast_addr.mgid.raw,
327                              sizeof(*gid));
328                 if (ret < 0) {
329                         n = n->rb_left;
330                 } else if (ret > 0) {
331                         n = n->rb_right;
332                 } else {
333                         /* MGID/MLID must match */
334                         if (mcast->mcast_addr.lid != lid) {
335                                 spin_unlock_irq(&ibp->lock);
336                                 return -EINVAL;
337                         }
338                         break;
339                 }
340         }
341
342         /* Search the QP list. */
343         list_for_each_entry_safe(p, tmp, &mcast->qp_list, list) {
344                 if (p->qp != qp)
345                         continue;
346                 /*
347                  * We found it, so remove it, but don't poison the forward
348                  * link until we are sure there are no list walkers.
349                  */
350                 list_del_rcu(&p->list);
351                 mcast->n_attached--;
352                 delp = p;
353
354                 /* If this was the last attached QP, remove the GID too. */
355                 if (list_empty(&mcast->qp_list)) {
356                         rb_erase(&mcast->rb_node, &ibp->mcast_tree);
357                         last = 1;
358                 }
359                 break;
360         }
361
362         spin_unlock_irq(&ibp->lock);
363         /* QP not attached */
364         if (!delp)
365                 return -EINVAL;
366
367         /*
368          * Wait for any list walkers to finish before freeing the
369          * list element.
370          */
371         wait_event(mcast->wait, atomic_read(&mcast->refcount) <= 1);
372         rvt_mcast_qp_free(delp);
373
374         if (last) {
375                 atomic_dec(&mcast->refcount);
376                 wait_event(mcast->wait, !atomic_read(&mcast->refcount));
377                 rvt_mcast_free(mcast);
378                 spin_lock_irq(&rdi->n_mcast_grps_lock);
379                 rdi->n_mcast_grps_allocated--;
380                 spin_unlock_irq(&rdi->n_mcast_grps_lock);
381         }
382
383         return 0;
384 }
385
386 /**
387  * rvt_mcast_tree_empty - determine if any qps are attached to any mcast group
388  * @rdi: rvt dev struct
389  *
390  * Return: in use count
391  */
392 int rvt_mcast_tree_empty(struct rvt_dev_info *rdi)
393 {
394         int i;
395         int in_use = 0;
396
397         for (i = 0; i < rdi->dparms.nports; i++)
398                 if (rdi->ports[i]->mcast_tree.rb_node)
399                         in_use++;
400         return in_use;
401 }