GNU Linux-libre 5.10.215-gnu1
[releases.git] / drivers / infiniband / sw / rxe / rxe_mr.c
1 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
2 /*
3  * Copyright (c) 2016 Mellanox Technologies Ltd. All rights reserved.
4  * Copyright (c) 2015 System Fabric Works, Inc. All rights reserved.
5  */
6
7 #include "rxe.h"
8 #include "rxe_loc.h"
9
10 /*
11  * lfsr (linear feedback shift register) with period 255
12  */
13 static u8 rxe_get_key(void)
14 {
15         static u32 key = 1;
16
17         key = key << 1;
18
19         key |= (0 != (key & 0x100)) ^ (0 != (key & 0x10))
20                 ^ (0 != (key & 0x80)) ^ (0 != (key & 0x40));
21
22         key &= 0xff;
23
24         return key;
25 }
26
27 int mem_check_range(struct rxe_mem *mem, u64 iova, size_t length)
28 {
29         switch (mem->type) {
30         case RXE_MEM_TYPE_DMA:
31                 return 0;
32
33         case RXE_MEM_TYPE_MR:
34         case RXE_MEM_TYPE_FMR:
35                 if (iova < mem->iova ||
36                     length > mem->length ||
37                     iova > mem->iova + mem->length - length)
38                         return -EFAULT;
39                 return 0;
40
41         default:
42                 return -EFAULT;
43         }
44 }
45
46 #define IB_ACCESS_REMOTE        (IB_ACCESS_REMOTE_READ          \
47                                 | IB_ACCESS_REMOTE_WRITE        \
48                                 | IB_ACCESS_REMOTE_ATOMIC)
49
50 static void rxe_mem_init(int access, struct rxe_mem *mem)
51 {
52         u32 lkey = mem->pelem.index << 8 | rxe_get_key();
53         u32 rkey = (access & IB_ACCESS_REMOTE) ? lkey : 0;
54
55         mem->ibmr.lkey          = lkey;
56         mem->ibmr.rkey          = rkey;
57         mem->state              = RXE_MEM_STATE_INVALID;
58         mem->type               = RXE_MEM_TYPE_NONE;
59         mem->map_shift          = ilog2(RXE_BUF_PER_MAP);
60 }
61
62 void rxe_mem_cleanup(struct rxe_pool_entry *arg)
63 {
64         struct rxe_mem *mem = container_of(arg, typeof(*mem), pelem);
65         int i;
66
67         ib_umem_release(mem->umem);
68
69         if (mem->map) {
70                 for (i = 0; i < mem->num_map; i++)
71                         kfree(mem->map[i]);
72
73                 kfree(mem->map);
74         }
75 }
76
77 static int rxe_mem_alloc(struct rxe_mem *mem, int num_buf)
78 {
79         int i;
80         int num_map;
81         struct rxe_map **map = mem->map;
82
83         num_map = (num_buf + RXE_BUF_PER_MAP - 1) / RXE_BUF_PER_MAP;
84
85         mem->map = kmalloc_array(num_map, sizeof(*map), GFP_KERNEL);
86         if (!mem->map)
87                 goto err1;
88
89         for (i = 0; i < num_map; i++) {
90                 mem->map[i] = kmalloc(sizeof(**map), GFP_KERNEL);
91                 if (!mem->map[i])
92                         goto err2;
93         }
94
95         BUILD_BUG_ON(!is_power_of_2(RXE_BUF_PER_MAP));
96
97         mem->map_shift  = ilog2(RXE_BUF_PER_MAP);
98         mem->map_mask   = RXE_BUF_PER_MAP - 1;
99
100         mem->num_buf = num_buf;
101         mem->num_map = num_map;
102         mem->max_buf = num_map * RXE_BUF_PER_MAP;
103
104         return 0;
105
106 err2:
107         for (i--; i >= 0; i--)
108                 kfree(mem->map[i]);
109
110         kfree(mem->map);
111 err1:
112         return -ENOMEM;
113 }
114
115 void rxe_mem_init_dma(struct rxe_pd *pd,
116                       int access, struct rxe_mem *mem)
117 {
118         rxe_mem_init(access, mem);
119
120         mem->ibmr.pd            = &pd->ibpd;
121         mem->access             = access;
122         mem->state              = RXE_MEM_STATE_VALID;
123         mem->type               = RXE_MEM_TYPE_DMA;
124 }
125
126 int rxe_mem_init_user(struct rxe_pd *pd, u64 start,
127                       u64 length, u64 iova, int access, struct ib_udata *udata,
128                       struct rxe_mem *mem)
129 {
130         struct rxe_map          **map;
131         struct rxe_phys_buf     *buf = NULL;
132         struct ib_umem          *umem;
133         struct sg_page_iter     sg_iter;
134         int                     num_buf;
135         void                    *vaddr;
136         int err;
137
138         umem = ib_umem_get(pd->ibpd.device, start, length, access);
139         if (IS_ERR(umem)) {
140                 pr_warn("err %d from rxe_umem_get\n",
141                         (int)PTR_ERR(umem));
142                 err = PTR_ERR(umem);
143                 goto err1;
144         }
145
146         mem->umem = umem;
147         num_buf = ib_umem_num_pages(umem);
148
149         rxe_mem_init(access, mem);
150
151         err = rxe_mem_alloc(mem, num_buf);
152         if (err) {
153                 pr_warn("err %d from rxe_mem_alloc\n", err);
154                 ib_umem_release(umem);
155                 goto err1;
156         }
157
158         mem->page_shift         = PAGE_SHIFT;
159         mem->page_mask = PAGE_SIZE - 1;
160
161         num_buf                 = 0;
162         map                     = mem->map;
163         if (length > 0) {
164                 buf = map[0]->buf;
165
166                 for_each_sg_page(umem->sg_head.sgl, &sg_iter, umem->nmap, 0) {
167                         if (num_buf >= RXE_BUF_PER_MAP) {
168                                 map++;
169                                 buf = map[0]->buf;
170                                 num_buf = 0;
171                         }
172
173                         vaddr = page_address(sg_page_iter_page(&sg_iter));
174                         if (!vaddr) {
175                                 pr_warn("null vaddr\n");
176                                 ib_umem_release(umem);
177                                 err = -ENOMEM;
178                                 goto err1;
179                         }
180
181                         buf->addr = (uintptr_t)vaddr;
182                         buf->size = PAGE_SIZE;
183                         num_buf++;
184                         buf++;
185
186                 }
187         }
188
189         mem->ibmr.pd            = &pd->ibpd;
190         mem->umem               = umem;
191         mem->access             = access;
192         mem->length             = length;
193         mem->iova               = iova;
194         mem->va                 = start;
195         mem->offset             = ib_umem_offset(umem);
196         mem->state              = RXE_MEM_STATE_VALID;
197         mem->type               = RXE_MEM_TYPE_MR;
198
199         return 0;
200
201 err1:
202         return err;
203 }
204
205 int rxe_mem_init_fast(struct rxe_pd *pd,
206                       int max_pages, struct rxe_mem *mem)
207 {
208         int err;
209
210         rxe_mem_init(0, mem);
211
212         /* In fastreg, we also set the rkey */
213         mem->ibmr.rkey = mem->ibmr.lkey;
214
215         err = rxe_mem_alloc(mem, max_pages);
216         if (err)
217                 goto err1;
218
219         mem->ibmr.pd            = &pd->ibpd;
220         mem->max_buf            = max_pages;
221         mem->state              = RXE_MEM_STATE_FREE;
222         mem->type               = RXE_MEM_TYPE_MR;
223
224         return 0;
225
226 err1:
227         return err;
228 }
229
230 static void lookup_iova(
231         struct rxe_mem  *mem,
232         u64                     iova,
233         int                     *m_out,
234         int                     *n_out,
235         size_t                  *offset_out)
236 {
237         size_t                  offset = iova - mem->iova + mem->offset;
238         int                     map_index;
239         int                     buf_index;
240         u64                     length;
241
242         if (likely(mem->page_shift)) {
243                 *offset_out = offset & mem->page_mask;
244                 offset >>= mem->page_shift;
245                 *n_out = offset & mem->map_mask;
246                 *m_out = offset >> mem->map_shift;
247         } else {
248                 map_index = 0;
249                 buf_index = 0;
250
251                 length = mem->map[map_index]->buf[buf_index].size;
252
253                 while (offset >= length) {
254                         offset -= length;
255                         buf_index++;
256
257                         if (buf_index == RXE_BUF_PER_MAP) {
258                                 map_index++;
259                                 buf_index = 0;
260                         }
261                         length = mem->map[map_index]->buf[buf_index].size;
262                 }
263
264                 *m_out = map_index;
265                 *n_out = buf_index;
266                 *offset_out = offset;
267         }
268 }
269
270 void *iova_to_vaddr(struct rxe_mem *mem, u64 iova, int length)
271 {
272         size_t offset;
273         int m, n;
274         void *addr;
275
276         if (mem->state != RXE_MEM_STATE_VALID) {
277                 pr_warn("mem not in valid state\n");
278                 addr = NULL;
279                 goto out;
280         }
281
282         if (!mem->map) {
283                 addr = (void *)(uintptr_t)iova;
284                 goto out;
285         }
286
287         if (mem_check_range(mem, iova, length)) {
288                 pr_warn("range violation\n");
289                 addr = NULL;
290                 goto out;
291         }
292
293         lookup_iova(mem, iova, &m, &n, &offset);
294
295         if (offset + length > mem->map[m]->buf[n].size) {
296                 pr_warn("crosses page boundary\n");
297                 addr = NULL;
298                 goto out;
299         }
300
301         addr = (void *)(uintptr_t)mem->map[m]->buf[n].addr + offset;
302
303 out:
304         return addr;
305 }
306
307 /* copy data from a range (vaddr, vaddr+length-1) to or from
308  * a mem object starting at iova. Compute incremental value of
309  * crc32 if crcp is not zero. caller must hold a reference to mem
310  */
311 int rxe_mem_copy(struct rxe_mem *mem, u64 iova, void *addr, int length,
312                  enum copy_direction dir, u32 *crcp)
313 {
314         int                     err;
315         int                     bytes;
316         u8                      *va;
317         struct rxe_map          **map;
318         struct rxe_phys_buf     *buf;
319         int                     m;
320         int                     i;
321         size_t                  offset;
322         u32                     crc = crcp ? (*crcp) : 0;
323
324         if (length == 0)
325                 return 0;
326
327         if (mem->type == RXE_MEM_TYPE_DMA) {
328                 u8 *src, *dest;
329
330                 src  = (dir == to_mem_obj) ?
331                         addr : ((void *)(uintptr_t)iova);
332
333                 dest = (dir == to_mem_obj) ?
334                         ((void *)(uintptr_t)iova) : addr;
335
336                 memcpy(dest, src, length);
337
338                 if (crcp)
339                         *crcp = rxe_crc32(to_rdev(mem->ibmr.device),
340                                         *crcp, dest, length);
341
342                 return 0;
343         }
344
345         WARN_ON_ONCE(!mem->map);
346
347         err = mem_check_range(mem, iova, length);
348         if (err) {
349                 err = -EFAULT;
350                 goto err1;
351         }
352
353         lookup_iova(mem, iova, &m, &i, &offset);
354
355         map     = mem->map + m;
356         buf     = map[0]->buf + i;
357
358         while (length > 0) {
359                 u8 *src, *dest;
360
361                 va      = (u8 *)(uintptr_t)buf->addr + offset;
362                 src  = (dir == to_mem_obj) ? addr : va;
363                 dest = (dir == to_mem_obj) ? va : addr;
364
365                 bytes   = buf->size - offset;
366
367                 if (bytes > length)
368                         bytes = length;
369
370                 memcpy(dest, src, bytes);
371
372                 if (crcp)
373                         crc = rxe_crc32(to_rdev(mem->ibmr.device),
374                                         crc, dest, bytes);
375
376                 length  -= bytes;
377                 addr    += bytes;
378
379                 offset  = 0;
380                 buf++;
381                 i++;
382
383                 if (i == RXE_BUF_PER_MAP) {
384                         i = 0;
385                         map++;
386                         buf = map[0]->buf;
387                 }
388         }
389
390         if (crcp)
391                 *crcp = crc;
392
393         return 0;
394
395 err1:
396         return err;
397 }
398
399 /* copy data in or out of a wqe, i.e. sg list
400  * under the control of a dma descriptor
401  */
402 int copy_data(
403         struct rxe_pd           *pd,
404         int                     access,
405         struct rxe_dma_info     *dma,
406         void                    *addr,
407         int                     length,
408         enum copy_direction     dir,
409         u32                     *crcp)
410 {
411         int                     bytes;
412         struct rxe_sge          *sge    = &dma->sge[dma->cur_sge];
413         int                     offset  = dma->sge_offset;
414         int                     resid   = dma->resid;
415         struct rxe_mem          *mem    = NULL;
416         u64                     iova;
417         int                     err;
418
419         if (length == 0)
420                 return 0;
421
422         if (length > resid) {
423                 err = -EINVAL;
424                 goto err2;
425         }
426
427         if (sge->length && (offset < sge->length)) {
428                 mem = lookup_mem(pd, access, sge->lkey, lookup_local);
429                 if (!mem) {
430                         err = -EINVAL;
431                         goto err1;
432                 }
433         }
434
435         while (length > 0) {
436                 bytes = length;
437
438                 if (offset >= sge->length) {
439                         if (mem) {
440                                 rxe_drop_ref(mem);
441                                 mem = NULL;
442                         }
443                         sge++;
444                         dma->cur_sge++;
445                         offset = 0;
446
447                         if (dma->cur_sge >= dma->num_sge) {
448                                 err = -ENOSPC;
449                                 goto err2;
450                         }
451
452                         if (sge->length) {
453                                 mem = lookup_mem(pd, access, sge->lkey,
454                                                  lookup_local);
455                                 if (!mem) {
456                                         err = -EINVAL;
457                                         goto err1;
458                                 }
459                         } else {
460                                 continue;
461                         }
462                 }
463
464                 if (bytes > sge->length - offset)
465                         bytes = sge->length - offset;
466
467                 if (bytes > 0) {
468                         iova = sge->addr + offset;
469
470                         err = rxe_mem_copy(mem, iova, addr, bytes, dir, crcp);
471                         if (err)
472                                 goto err2;
473
474                         offset  += bytes;
475                         resid   -= bytes;
476                         length  -= bytes;
477                         addr    += bytes;
478                 }
479         }
480
481         dma->sge_offset = offset;
482         dma->resid      = resid;
483
484         if (mem)
485                 rxe_drop_ref(mem);
486
487         return 0;
488
489 err2:
490         if (mem)
491                 rxe_drop_ref(mem);
492 err1:
493         return err;
494 }
495
496 int advance_dma_data(struct rxe_dma_info *dma, unsigned int length)
497 {
498         struct rxe_sge          *sge    = &dma->sge[dma->cur_sge];
499         int                     offset  = dma->sge_offset;
500         int                     resid   = dma->resid;
501
502         while (length) {
503                 unsigned int bytes;
504
505                 if (offset >= sge->length) {
506                         sge++;
507                         dma->cur_sge++;
508                         offset = 0;
509                         if (dma->cur_sge >= dma->num_sge)
510                                 return -ENOSPC;
511                 }
512
513                 bytes = length;
514
515                 if (bytes > sge->length - offset)
516                         bytes = sge->length - offset;
517
518                 offset  += bytes;
519                 resid   -= bytes;
520                 length  -= bytes;
521         }
522
523         dma->sge_offset = offset;
524         dma->resid      = resid;
525
526         return 0;
527 }
528
529 /* (1) find the mem (mr or mw) corresponding to lkey/rkey
530  *     depending on lookup_type
531  * (2) verify that the (qp) pd matches the mem pd
532  * (3) verify that the mem can support the requested access
533  * (4) verify that mem state is valid
534  */
535 struct rxe_mem *lookup_mem(struct rxe_pd *pd, int access, u32 key,
536                            enum lookup_type type)
537 {
538         struct rxe_mem *mem;
539         struct rxe_dev *rxe = to_rdev(pd->ibpd.device);
540         int index = key >> 8;
541
542         mem = rxe_pool_get_index(&rxe->mr_pool, index);
543         if (!mem)
544                 return NULL;
545
546         if (unlikely((type == lookup_local && mr_lkey(mem) != key) ||
547                      (type == lookup_remote && mr_rkey(mem) != key) ||
548                      mr_pd(mem) != pd ||
549                      (access && !(access & mem->access)) ||
550                      mem->state != RXE_MEM_STATE_VALID)) {
551                 rxe_drop_ref(mem);
552                 mem = NULL;
553         }
554
555         return mem;
556 }