GNU Linux-libre 5.10.219-gnu1
[releases.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userpsace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/highmem.h>
28 #include <linux/iommu.h>
29 #include <linux/module.h>
30 #include <linux/mm.h>
31 #include <linux/kthread.h>
32 #include <linux/rbtree.h>
33 #include <linux/sched/signal.h>
34 #include <linux/sched/mm.h>
35 #include <linux/slab.h>
36 #include <linux/uaccess.h>
37 #include <linux/vfio.h>
38 #include <linux/workqueue.h>
39 #include <linux/mdev.h>
40 #include <linux/notifier.h>
41 #include <linux/dma-iommu.h>
42 #include <linux/irqdomain.h>
43
44 #define DRIVER_VERSION  "0.2"
45 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
46 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
47
48 static bool allow_unsafe_interrupts;
49 module_param_named(allow_unsafe_interrupts,
50                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
51 MODULE_PARM_DESC(allow_unsafe_interrupts,
52                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
53
54 static bool disable_hugepages;
55 module_param_named(disable_hugepages,
56                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
57 MODULE_PARM_DESC(disable_hugepages,
58                  "Disable VFIO IOMMU support for IOMMU hugepages.");
59
60 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
61 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
62 MODULE_PARM_DESC(dma_entry_limit,
63                  "Maximum number of user DMA mappings per container (65535).");
64
65 struct vfio_iommu {
66         struct list_head        domain_list;
67         struct list_head        iova_list;
68         struct vfio_domain      *external_domain; /* domain for external user */
69         struct mutex            lock;
70         struct rb_root          dma_list;
71         struct blocking_notifier_head notifier;
72         unsigned int            dma_avail;
73         uint64_t                pgsize_bitmap;
74         bool                    v2;
75         bool                    nesting;
76         bool                    dirty_page_tracking;
77         bool                    pinned_page_dirty_scope;
78 };
79
80 struct vfio_domain {
81         struct iommu_domain     *domain;
82         struct list_head        next;
83         struct list_head        group_list;
84         int                     prot;           /* IOMMU_CACHE */
85         bool                    fgsp;           /* Fine-grained super pages */
86 };
87
88 struct vfio_dma {
89         struct rb_node          node;
90         dma_addr_t              iova;           /* Device address */
91         unsigned long           vaddr;          /* Process virtual addr */
92         size_t                  size;           /* Map size (bytes) */
93         int                     prot;           /* IOMMU_READ/WRITE */
94         bool                    iommu_mapped;
95         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
96         struct task_struct      *task;
97         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
98         unsigned long           *bitmap;
99         struct mm_struct        *mm;
100 };
101
102 struct vfio_batch {
103         struct page             **pages;        /* for pin_user_pages_remote */
104         struct page             *fallback_page; /* if pages alloc fails */
105         int                     capacity;       /* length of pages array */
106 };
107
108 struct vfio_group {
109         struct iommu_group      *iommu_group;
110         struct list_head        next;
111         bool                    mdev_group;     /* An mdev group */
112         bool                    pinned_page_dirty_scope;
113 };
114
115 struct vfio_iova {
116         struct list_head        list;
117         dma_addr_t              start;
118         dma_addr_t              end;
119 };
120
121 /*
122  * Guest RAM pinning working set or DMA target
123  */
124 struct vfio_pfn {
125         struct rb_node          node;
126         dma_addr_t              iova;           /* Device address */
127         unsigned long           pfn;            /* Host pfn */
128         unsigned int            ref_count;
129 };
130
131 struct vfio_regions {
132         struct list_head list;
133         dma_addr_t iova;
134         phys_addr_t phys;
135         size_t len;
136 };
137
138 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
139                                         (!list_empty(&iommu->domain_list))
140
141 #define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
142
143 /*
144  * Input argument of number of bits to bitmap_set() is unsigned integer, which
145  * further casts to signed integer for unaligned multi-bit operation,
146  * __bitmap_set().
147  * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
148  * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
149  * system.
150  */
151 #define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
152 #define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
153
154 static int put_pfn(unsigned long pfn, int prot);
155
156 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
157                                                struct iommu_group *iommu_group);
158
159 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu);
160 /*
161  * This code handles mapping and unmapping of user data buffers
162  * into DMA'ble space using the IOMMU
163  */
164
165 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
166                                       dma_addr_t start, size_t size)
167 {
168         struct rb_node *node = iommu->dma_list.rb_node;
169
170         while (node) {
171                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
172
173                 if (start + size <= dma->iova)
174                         node = node->rb_left;
175                 else if (start >= dma->iova + dma->size)
176                         node = node->rb_right;
177                 else
178                         return dma;
179         }
180
181         return NULL;
182 }
183
184 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
185 {
186         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
187         struct vfio_dma *dma;
188
189         while (*link) {
190                 parent = *link;
191                 dma = rb_entry(parent, struct vfio_dma, node);
192
193                 if (new->iova + new->size <= dma->iova)
194                         link = &(*link)->rb_left;
195                 else
196                         link = &(*link)->rb_right;
197         }
198
199         rb_link_node(&new->node, parent, link);
200         rb_insert_color(&new->node, &iommu->dma_list);
201 }
202
203 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
204 {
205         rb_erase(&old->node, &iommu->dma_list);
206 }
207
208
209 static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
210 {
211         uint64_t npages = dma->size / pgsize;
212
213         if (npages > DIRTY_BITMAP_PAGES_MAX)
214                 return -EINVAL;
215
216         /*
217          * Allocate extra 64 bits that are used to calculate shift required for
218          * bitmap_shift_left() to manipulate and club unaligned number of pages
219          * in adjacent vfio_dma ranges.
220          */
221         dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
222                                GFP_KERNEL);
223         if (!dma->bitmap)
224                 return -ENOMEM;
225
226         return 0;
227 }
228
229 static void vfio_dma_bitmap_free(struct vfio_dma *dma)
230 {
231         kfree(dma->bitmap);
232         dma->bitmap = NULL;
233 }
234
235 static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
236 {
237         struct rb_node *p;
238         unsigned long pgshift = __ffs(pgsize);
239
240         for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
241                 struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
242
243                 bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
244         }
245 }
246
247 static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
248 {
249         struct rb_node *n;
250         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
251
252         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
253                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
254
255                 bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
256         }
257 }
258
259 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
260 {
261         struct rb_node *n;
262
263         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
264                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
265                 int ret;
266
267                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
268                 if (ret) {
269                         struct rb_node *p;
270
271                         for (p = rb_prev(n); p; p = rb_prev(p)) {
272                                 struct vfio_dma *dma = rb_entry(n,
273                                                         struct vfio_dma, node);
274
275                                 vfio_dma_bitmap_free(dma);
276                         }
277                         return ret;
278                 }
279                 vfio_dma_populate_bitmap(dma, pgsize);
280         }
281         return 0;
282 }
283
284 static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
285 {
286         struct rb_node *n;
287
288         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
289                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
290
291                 vfio_dma_bitmap_free(dma);
292         }
293 }
294
295 /*
296  * Helper Functions for host iova-pfn list
297  */
298 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
299 {
300         struct vfio_pfn *vpfn;
301         struct rb_node *node = dma->pfn_list.rb_node;
302
303         while (node) {
304                 vpfn = rb_entry(node, struct vfio_pfn, node);
305
306                 if (iova < vpfn->iova)
307                         node = node->rb_left;
308                 else if (iova > vpfn->iova)
309                         node = node->rb_right;
310                 else
311                         return vpfn;
312         }
313         return NULL;
314 }
315
316 static void vfio_link_pfn(struct vfio_dma *dma,
317                           struct vfio_pfn *new)
318 {
319         struct rb_node **link, *parent = NULL;
320         struct vfio_pfn *vpfn;
321
322         link = &dma->pfn_list.rb_node;
323         while (*link) {
324                 parent = *link;
325                 vpfn = rb_entry(parent, struct vfio_pfn, node);
326
327                 if (new->iova < vpfn->iova)
328                         link = &(*link)->rb_left;
329                 else
330                         link = &(*link)->rb_right;
331         }
332
333         rb_link_node(&new->node, parent, link);
334         rb_insert_color(&new->node, &dma->pfn_list);
335 }
336
337 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
338 {
339         rb_erase(&old->node, &dma->pfn_list);
340 }
341
342 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
343                                 unsigned long pfn)
344 {
345         struct vfio_pfn *vpfn;
346
347         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
348         if (!vpfn)
349                 return -ENOMEM;
350
351         vpfn->iova = iova;
352         vpfn->pfn = pfn;
353         vpfn->ref_count = 1;
354         vfio_link_pfn(dma, vpfn);
355         return 0;
356 }
357
358 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
359                                       struct vfio_pfn *vpfn)
360 {
361         vfio_unlink_pfn(dma, vpfn);
362         kfree(vpfn);
363 }
364
365 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
366                                                unsigned long iova)
367 {
368         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
369
370         if (vpfn)
371                 vpfn->ref_count++;
372         return vpfn;
373 }
374
375 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
376 {
377         int ret = 0;
378
379         vpfn->ref_count--;
380         if (!vpfn->ref_count) {
381                 ret = put_pfn(vpfn->pfn, dma->prot);
382                 vfio_remove_from_pfn_list(dma, vpfn);
383         }
384         return ret;
385 }
386
387 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
388 {
389         struct mm_struct *mm;
390         int ret;
391
392         if (!npage)
393                 return 0;
394
395         mm = dma->mm;
396         if (async && !mmget_not_zero(mm))
397                 return -ESRCH; /* process exited */
398
399         ret = mmap_write_lock_killable(mm);
400         if (!ret) {
401                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
402                                           dma->lock_cap);
403                 mmap_write_unlock(mm);
404         }
405
406         if (async)
407                 mmput(mm);
408
409         return ret;
410 }
411
412 /*
413  * Some mappings aren't backed by a struct page, for example an mmap'd
414  * MMIO range for our own or another device.  These use a different
415  * pfn conversion and shouldn't be tracked as locked pages.
416  * For compound pages, any driver that sets the reserved bit in head
417  * page needs to set the reserved bit in all subpages to be safe.
418  */
419 static bool is_invalid_reserved_pfn(unsigned long pfn)
420 {
421         if (pfn_valid(pfn))
422                 return PageReserved(pfn_to_page(pfn));
423
424         return true;
425 }
426
427 static int put_pfn(unsigned long pfn, int prot)
428 {
429         if (!is_invalid_reserved_pfn(pfn)) {
430                 struct page *page = pfn_to_page(pfn);
431
432                 unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
433                 return 1;
434         }
435         return 0;
436 }
437
438 #define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
439
440 static void vfio_batch_init(struct vfio_batch *batch)
441 {
442         if (unlikely(disable_hugepages))
443                 goto fallback;
444
445         batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
446         if (!batch->pages)
447                 goto fallback;
448
449         batch->capacity = VFIO_BATCH_MAX_CAPACITY;
450         return;
451
452 fallback:
453         batch->pages = &batch->fallback_page;
454         batch->capacity = 1;
455 }
456
457 static void vfio_batch_fini(struct vfio_batch *batch)
458 {
459         if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
460                 free_page((unsigned long)batch->pages);
461 }
462
463 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
464                             unsigned long vaddr, unsigned long *pfn,
465                             bool write_fault)
466 {
467         pte_t *ptep;
468         spinlock_t *ptl;
469         int ret;
470
471         ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
472         if (ret) {
473                 bool unlocked = false;
474
475                 ret = fixup_user_fault(mm, vaddr,
476                                        FAULT_FLAG_REMOTE |
477                                        (write_fault ?  FAULT_FLAG_WRITE : 0),
478                                        &unlocked);
479                 if (unlocked)
480                         return -EAGAIN;
481
482                 if (ret)
483                         return ret;
484
485                 ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
486                 if (ret)
487                         return ret;
488         }
489
490         if (write_fault && !pte_write(*ptep))
491                 ret = -EFAULT;
492         else
493                 *pfn = pte_pfn(*ptep);
494
495         pte_unmap_unlock(ptep, ptl);
496         return ret;
497 }
498
499 /*
500  * Returns the positive number of pfns successfully obtained or a negative
501  * error code.
502  */
503 static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
504                           long npages, int prot, unsigned long *pfn,
505                           struct page **pages)
506 {
507         struct vm_area_struct *vma;
508         unsigned int flags = 0;
509         int ret;
510
511         if (prot & IOMMU_WRITE)
512                 flags |= FOLL_WRITE;
513
514         mmap_read_lock(mm);
515         ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
516                                     pages, NULL, NULL);
517         if (ret > 0) {
518                 int i;
519
520                 /*
521                  * The zero page is always resident, we don't need to pin it
522                  * and it falls into our invalid/reserved test so we don't
523                  * unpin in put_pfn().  Unpin all zero pages in the batch here.
524                  */
525                 for (i = 0 ; i < ret; i++) {
526                         if (unlikely(is_zero_pfn(page_to_pfn(pages[i]))))
527                                 unpin_user_page(pages[i]);
528                 }
529
530                 *pfn = page_to_pfn(pages[0]);
531                 goto done;
532         }
533
534         vaddr = untagged_addr(vaddr);
535
536 retry:
537         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
538
539         if (vma && vma->vm_flags & VM_PFNMAP) {
540                 ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
541                 if (ret == -EAGAIN)
542                         goto retry;
543
544                 if (!ret) {
545                         if (is_invalid_reserved_pfn(*pfn))
546                                 ret = 1;
547                         else
548                                 ret = -EFAULT;
549                 }
550         }
551 done:
552         mmap_read_unlock(mm);
553         return ret;
554 }
555
556 /*
557  * Attempt to pin pages.  We really don't want to track all the pfns and
558  * the iommu can only map chunks of consecutive pfns anyway, so get the
559  * first page and all consecutive pages with the same locking.
560  */
561 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
562                                   long npage, unsigned long *pfn_base,
563                                   unsigned long limit, struct vfio_batch *batch)
564 {
565         unsigned long pfn = 0;
566         long ret, pinned = 0, lock_acct = 0;
567         bool rsvd;
568         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
569
570         /* This code path is only user initiated */
571         if (!current->mm)
572                 return -ENODEV;
573
574         ret = vaddr_get_pfns(current->mm, vaddr, 1, dma->prot, pfn_base,
575                              batch->pages);
576         if (ret < 0)
577                 return ret;
578
579         pinned++;
580         rsvd = is_invalid_reserved_pfn(*pfn_base);
581
582         /*
583          * Reserved pages aren't counted against the user, externally pinned
584          * pages are already counted against the user.
585          */
586         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
587                 if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
588                         put_pfn(*pfn_base, dma->prot);
589                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
590                                         limit << PAGE_SHIFT);
591                         return -ENOMEM;
592                 }
593                 lock_acct++;
594         }
595
596         if (unlikely(disable_hugepages))
597                 goto out;
598
599         /* Lock all the consecutive pages from pfn_base */
600         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
601              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
602                 ret = vaddr_get_pfns(current->mm, vaddr, 1, dma->prot, &pfn,
603                                      batch->pages);
604                 if (ret < 0)
605                         break;
606
607                 if (pfn != *pfn_base + pinned ||
608                     rsvd != is_invalid_reserved_pfn(pfn)) {
609                         put_pfn(pfn, dma->prot);
610                         break;
611                 }
612
613                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
614                         if (!dma->lock_cap &&
615                             current->mm->locked_vm + lock_acct + 1 > limit) {
616                                 put_pfn(pfn, dma->prot);
617                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
618                                         __func__, limit << PAGE_SHIFT);
619                                 ret = -ENOMEM;
620                                 goto unpin_out;
621                         }
622                         lock_acct++;
623                 }
624         }
625
626 out:
627         ret = vfio_lock_acct(dma, lock_acct, false);
628
629 unpin_out:
630         if (ret < 0) {
631                 if (!rsvd) {
632                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
633                                 put_pfn(pfn, dma->prot);
634                 }
635
636                 return ret;
637         }
638
639         return pinned;
640 }
641
642 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
643                                     unsigned long pfn, long npage,
644                                     bool do_accounting)
645 {
646         long unlocked = 0, locked = 0;
647         long i;
648
649         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
650                 if (put_pfn(pfn++, dma->prot)) {
651                         unlocked++;
652                         if (vfio_find_vpfn(dma, iova))
653                                 locked++;
654                 }
655         }
656
657         if (do_accounting)
658                 vfio_lock_acct(dma, locked - unlocked, true);
659
660         return unlocked;
661 }
662
663 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
664                                   unsigned long *pfn_base, bool do_accounting)
665 {
666         struct page *pages[1];
667         struct mm_struct *mm;
668         int ret;
669
670         mm = dma->mm;
671         if (!mmget_not_zero(mm))
672                 return -ENODEV;
673
674         ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
675         if (ret != 1)
676                 goto out;
677
678         ret = 0;
679
680         if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
681                 ret = vfio_lock_acct(dma, 1, false);
682                 if (ret) {
683                         put_pfn(*pfn_base, dma->prot);
684                         if (ret == -ENOMEM)
685                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
686                                         "(%ld) exceeded\n", __func__,
687                                         dma->task->comm, task_pid_nr(dma->task),
688                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
689                 }
690         }
691
692 out:
693         mmput(mm);
694         return ret;
695 }
696
697 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
698                                     bool do_accounting)
699 {
700         int unlocked;
701         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
702
703         if (!vpfn)
704                 return 0;
705
706         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
707
708         if (do_accounting)
709                 vfio_lock_acct(dma, -unlocked, true);
710
711         return unlocked;
712 }
713
714 static int vfio_iommu_type1_pin_pages(void *iommu_data,
715                                       struct iommu_group *iommu_group,
716                                       unsigned long *user_pfn,
717                                       int npage, int prot,
718                                       unsigned long *phys_pfn)
719 {
720         struct vfio_iommu *iommu = iommu_data;
721         struct vfio_group *group;
722         int i, j, ret;
723         unsigned long remote_vaddr;
724         struct vfio_dma *dma;
725         bool do_accounting;
726
727         if (!iommu || !user_pfn || !phys_pfn)
728                 return -EINVAL;
729
730         /* Supported for v2 version only */
731         if (!iommu->v2)
732                 return -EACCES;
733
734         mutex_lock(&iommu->lock);
735
736         /* Fail if notifier list is empty */
737         if (!iommu->notifier.head) {
738                 ret = -EINVAL;
739                 goto pin_done;
740         }
741
742         /*
743          * If iommu capable domain exist in the container then all pages are
744          * already pinned and accounted. Accouting should be done if there is no
745          * iommu capable domain in the container.
746          */
747         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
748
749         for (i = 0; i < npage; i++) {
750                 dma_addr_t iova;
751                 struct vfio_pfn *vpfn;
752
753                 iova = user_pfn[i] << PAGE_SHIFT;
754                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
755                 if (!dma) {
756                         ret = -EINVAL;
757                         goto pin_unwind;
758                 }
759
760                 if ((dma->prot & prot) != prot) {
761                         ret = -EPERM;
762                         goto pin_unwind;
763                 }
764
765                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
766                 if (vpfn) {
767                         phys_pfn[i] = vpfn->pfn;
768                         continue;
769                 }
770
771                 remote_vaddr = dma->vaddr + (iova - dma->iova);
772                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
773                                              do_accounting);
774                 if (ret)
775                         goto pin_unwind;
776
777                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
778                 if (ret) {
779                         if (put_pfn(phys_pfn[i], dma->prot) && do_accounting)
780                                 vfio_lock_acct(dma, -1, true);
781                         goto pin_unwind;
782                 }
783
784                 if (iommu->dirty_page_tracking) {
785                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
786
787                         /*
788                          * Bitmap populated with the smallest supported page
789                          * size
790                          */
791                         bitmap_set(dma->bitmap,
792                                    (iova - dma->iova) >> pgshift, 1);
793                 }
794         }
795         ret = i;
796
797         group = vfio_iommu_find_iommu_group(iommu, iommu_group);
798         if (!group->pinned_page_dirty_scope) {
799                 group->pinned_page_dirty_scope = true;
800                 update_pinned_page_dirty_scope(iommu);
801         }
802
803         goto pin_done;
804
805 pin_unwind:
806         phys_pfn[i] = 0;
807         for (j = 0; j < i; j++) {
808                 dma_addr_t iova;
809
810                 iova = user_pfn[j] << PAGE_SHIFT;
811                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
812                 vfio_unpin_page_external(dma, iova, do_accounting);
813                 phys_pfn[j] = 0;
814         }
815 pin_done:
816         mutex_unlock(&iommu->lock);
817         return ret;
818 }
819
820 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
821                                         unsigned long *user_pfn,
822                                         int npage)
823 {
824         struct vfio_iommu *iommu = iommu_data;
825         bool do_accounting;
826         int i;
827
828         if (!iommu || !user_pfn)
829                 return -EINVAL;
830
831         /* Supported for v2 version only */
832         if (!iommu->v2)
833                 return -EACCES;
834
835         mutex_lock(&iommu->lock);
836
837         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
838         for (i = 0; i < npage; i++) {
839                 struct vfio_dma *dma;
840                 dma_addr_t iova;
841
842                 iova = user_pfn[i] << PAGE_SHIFT;
843                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
844                 if (!dma)
845                         goto unpin_exit;
846                 vfio_unpin_page_external(dma, iova, do_accounting);
847         }
848
849 unpin_exit:
850         mutex_unlock(&iommu->lock);
851         return i > npage ? npage : (i > 0 ? i : -EINVAL);
852 }
853
854 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
855                             struct list_head *regions,
856                             struct iommu_iotlb_gather *iotlb_gather)
857 {
858         long unlocked = 0;
859         struct vfio_regions *entry, *next;
860
861         iommu_iotlb_sync(domain->domain, iotlb_gather);
862
863         list_for_each_entry_safe(entry, next, regions, list) {
864                 unlocked += vfio_unpin_pages_remote(dma,
865                                                     entry->iova,
866                                                     entry->phys >> PAGE_SHIFT,
867                                                     entry->len >> PAGE_SHIFT,
868                                                     false);
869                 list_del(&entry->list);
870                 kfree(entry);
871         }
872
873         cond_resched();
874
875         return unlocked;
876 }
877
878 /*
879  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
880  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
881  * of these regions (currently using a list).
882  *
883  * This value specifies maximum number of regions for each IOTLB flush sync.
884  */
885 #define VFIO_IOMMU_TLB_SYNC_MAX         512
886
887 static size_t unmap_unpin_fast(struct vfio_domain *domain,
888                                struct vfio_dma *dma, dma_addr_t *iova,
889                                size_t len, phys_addr_t phys, long *unlocked,
890                                struct list_head *unmapped_list,
891                                int *unmapped_cnt,
892                                struct iommu_iotlb_gather *iotlb_gather)
893 {
894         size_t unmapped = 0;
895         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
896
897         if (entry) {
898                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
899                                             iotlb_gather);
900
901                 if (!unmapped) {
902                         kfree(entry);
903                 } else {
904                         entry->iova = *iova;
905                         entry->phys = phys;
906                         entry->len  = unmapped;
907                         list_add_tail(&entry->list, unmapped_list);
908
909                         *iova += unmapped;
910                         (*unmapped_cnt)++;
911                 }
912         }
913
914         /*
915          * Sync if the number of fast-unmap regions hits the limit
916          * or in case of errors.
917          */
918         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
919                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
920                                              iotlb_gather);
921                 *unmapped_cnt = 0;
922         }
923
924         return unmapped;
925 }
926
927 static size_t unmap_unpin_slow(struct vfio_domain *domain,
928                                struct vfio_dma *dma, dma_addr_t *iova,
929                                size_t len, phys_addr_t phys,
930                                long *unlocked)
931 {
932         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
933
934         if (unmapped) {
935                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
936                                                      phys >> PAGE_SHIFT,
937                                                      unmapped >> PAGE_SHIFT,
938                                                      false);
939                 *iova += unmapped;
940                 cond_resched();
941         }
942         return unmapped;
943 }
944
945 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
946                              bool do_accounting)
947 {
948         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
949         struct vfio_domain *domain, *d;
950         LIST_HEAD(unmapped_region_list);
951         struct iommu_iotlb_gather iotlb_gather;
952         int unmapped_region_cnt = 0;
953         long unlocked = 0;
954
955         if (!dma->size)
956                 return 0;
957
958         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
959                 return 0;
960
961         /*
962          * We use the IOMMU to track the physical addresses, otherwise we'd
963          * need a much more complicated tracking system.  Unfortunately that
964          * means we need to use one of the iommu domains to figure out the
965          * pfns to unpin.  The rest need to be unmapped in advance so we have
966          * no iommu translations remaining when the pages are unpinned.
967          */
968         domain = d = list_first_entry(&iommu->domain_list,
969                                       struct vfio_domain, next);
970
971         list_for_each_entry_continue(d, &iommu->domain_list, next) {
972                 iommu_unmap(d->domain, dma->iova, dma->size);
973                 cond_resched();
974         }
975
976         iommu_iotlb_gather_init(&iotlb_gather);
977         while (iova < end) {
978                 size_t unmapped, len;
979                 phys_addr_t phys, next;
980
981                 phys = iommu_iova_to_phys(domain->domain, iova);
982                 if (WARN_ON(!phys)) {
983                         iova += PAGE_SIZE;
984                         continue;
985                 }
986
987                 /*
988                  * To optimize for fewer iommu_unmap() calls, each of which
989                  * may require hardware cache flushing, try to find the
990                  * largest contiguous physical memory chunk to unmap.
991                  */
992                 for (len = PAGE_SIZE;
993                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
994                         next = iommu_iova_to_phys(domain->domain, iova + len);
995                         if (next != phys + len)
996                                 break;
997                 }
998
999                 /*
1000                  * First, try to use fast unmap/unpin. In case of failure,
1001                  * switch to slow unmap/unpin path.
1002                  */
1003                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
1004                                             &unlocked, &unmapped_region_list,
1005                                             &unmapped_region_cnt,
1006                                             &iotlb_gather);
1007                 if (!unmapped) {
1008                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
1009                                                     phys, &unlocked);
1010                         if (WARN_ON(!unmapped))
1011                                 break;
1012                 }
1013         }
1014
1015         dma->iommu_mapped = false;
1016
1017         if (unmapped_region_cnt) {
1018                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
1019                                             &iotlb_gather);
1020         }
1021
1022         if (do_accounting) {
1023                 vfio_lock_acct(dma, -unlocked, true);
1024                 return 0;
1025         }
1026         return unlocked;
1027 }
1028
1029 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
1030 {
1031         WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
1032         vfio_unmap_unpin(iommu, dma, true);
1033         vfio_unlink_dma(iommu, dma);
1034         put_task_struct(dma->task);
1035         mmdrop(dma->mm);
1036         vfio_dma_bitmap_free(dma);
1037         kfree(dma);
1038         iommu->dma_avail++;
1039 }
1040
1041 static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
1042 {
1043         struct vfio_domain *domain;
1044
1045         iommu->pgsize_bitmap = ULONG_MAX;
1046
1047         list_for_each_entry(domain, &iommu->domain_list, next)
1048                 iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
1049
1050         /*
1051          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
1052          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
1053          * That way the user will be able to map/unmap buffers whose size/
1054          * start address is aligned with PAGE_SIZE. Pinning code uses that
1055          * granularity while iommu driver can use the sub-PAGE_SIZE size
1056          * to map the buffer.
1057          */
1058         if (iommu->pgsize_bitmap & ~PAGE_MASK) {
1059                 iommu->pgsize_bitmap &= PAGE_MASK;
1060                 iommu->pgsize_bitmap |= PAGE_SIZE;
1061         }
1062 }
1063
1064 static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1065                               struct vfio_dma *dma, dma_addr_t base_iova,
1066                               size_t pgsize)
1067 {
1068         unsigned long pgshift = __ffs(pgsize);
1069         unsigned long nbits = dma->size >> pgshift;
1070         unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
1071         unsigned long copy_offset = bit_offset / BITS_PER_LONG;
1072         unsigned long shift = bit_offset % BITS_PER_LONG;
1073         unsigned long leftover;
1074
1075         /*
1076          * mark all pages dirty if any IOMMU capable device is not able
1077          * to report dirty pages and all pages are pinned and mapped.
1078          */
1079         if (!iommu->pinned_page_dirty_scope && dma->iommu_mapped)
1080                 bitmap_set(dma->bitmap, 0, nbits);
1081
1082         if (shift) {
1083                 bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1084                                   nbits + shift);
1085
1086                 if (copy_from_user(&leftover,
1087                                    (void __user *)(bitmap + copy_offset),
1088                                    sizeof(leftover)))
1089                         return -EFAULT;
1090
1091                 bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1092         }
1093
1094         if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1095                          DIRTY_BITMAP_BYTES(nbits + shift)))
1096                 return -EFAULT;
1097
1098         return 0;
1099 }
1100
1101 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1102                                   dma_addr_t iova, size_t size, size_t pgsize)
1103 {
1104         struct vfio_dma *dma;
1105         struct rb_node *n;
1106         unsigned long pgshift = __ffs(pgsize);
1107         int ret;
1108
1109         /*
1110          * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1111          * vfio_dma mappings may be clubbed by specifying large ranges, but
1112          * there must not be any previous mappings bisected by the range.
1113          * An error will be returned if these conditions are not met.
1114          */
1115         dma = vfio_find_dma(iommu, iova, 1);
1116         if (dma && dma->iova != iova)
1117                 return -EINVAL;
1118
1119         dma = vfio_find_dma(iommu, iova + size - 1, 0);
1120         if (dma && dma->iova + dma->size != iova + size)
1121                 return -EINVAL;
1122
1123         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1124                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1125
1126                 if (dma->iova < iova)
1127                         continue;
1128
1129                 if (dma->iova > iova + size - 1)
1130                         break;
1131
1132                 ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1133                 if (ret)
1134                         return ret;
1135
1136                 /*
1137                  * Re-populate bitmap to include all pinned pages which are
1138                  * considered as dirty but exclude pages which are unpinned and
1139                  * pages which are marked dirty by vfio_dma_rw()
1140                  */
1141                 bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1142                 vfio_dma_populate_bitmap(dma, pgsize);
1143         }
1144         return 0;
1145 }
1146
1147 static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1148 {
1149         if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1150             (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1151                 return -EINVAL;
1152
1153         return 0;
1154 }
1155
1156 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1157                              struct vfio_iommu_type1_dma_unmap *unmap,
1158                              struct vfio_bitmap *bitmap)
1159 {
1160         struct vfio_dma *dma, *dma_last = NULL;
1161         size_t unmapped = 0, pgsize;
1162         int ret = 0, retries = 0;
1163         unsigned long pgshift;
1164
1165         mutex_lock(&iommu->lock);
1166
1167         pgshift = __ffs(iommu->pgsize_bitmap);
1168         pgsize = (size_t)1 << pgshift;
1169
1170         if (unmap->iova & (pgsize - 1)) {
1171                 ret = -EINVAL;
1172                 goto unlock;
1173         }
1174
1175         if (!unmap->size || unmap->size & (pgsize - 1)) {
1176                 ret = -EINVAL;
1177                 goto unlock;
1178         }
1179
1180         if (unmap->iova + unmap->size - 1 < unmap->iova ||
1181             unmap->size > SIZE_MAX) {
1182                 ret = -EINVAL;
1183                 goto unlock;
1184         }
1185
1186         /* When dirty tracking is enabled, allow only min supported pgsize */
1187         if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1188             (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1189                 ret = -EINVAL;
1190                 goto unlock;
1191         }
1192
1193         WARN_ON((pgsize - 1) & PAGE_MASK);
1194 again:
1195         /*
1196          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1197          * avoid tracking individual mappings.  This means that the granularity
1198          * of the original mapping was lost and the user was allowed to attempt
1199          * to unmap any range.  Depending on the contiguousness of physical
1200          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1201          * or may not have worked.  We only guaranteed unmap granularity
1202          * matching the original mapping; even though it was untracked here,
1203          * the original mappings are reflected in IOMMU mappings.  This
1204          * resulted in a couple unusual behaviors.  First, if a range is not
1205          * able to be unmapped, ex. a set of 4k pages that was mapped as a
1206          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1207          * a zero sized unmap.  Also, if an unmap request overlaps the first
1208          * address of a hugepage, the IOMMU will unmap the entire hugepage.
1209          * This also returns success and the returned unmap size reflects the
1210          * actual size unmapped.
1211          *
1212          * We attempt to maintain compatibility with this "v1" interface, but
1213          * we take control out of the hands of the IOMMU.  Therefore, an unmap
1214          * request offset from the beginning of the original mapping will
1215          * return success with zero sized unmap.  And an unmap request covering
1216          * the first iova of mapping will unmap the entire range.
1217          *
1218          * The v2 version of this interface intends to be more deterministic.
1219          * Unmap requests must fully cover previous mappings.  Multiple
1220          * mappings may still be unmaped by specifying large ranges, but there
1221          * must not be any previous mappings bisected by the range.  An error
1222          * will be returned if these conditions are not met.  The v2 interface
1223          * will only return success and a size of zero if there were no
1224          * mappings within the range.
1225          */
1226         if (iommu->v2) {
1227                 dma = vfio_find_dma(iommu, unmap->iova, 1);
1228                 if (dma && dma->iova != unmap->iova) {
1229                         ret = -EINVAL;
1230                         goto unlock;
1231                 }
1232                 dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
1233                 if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
1234                         ret = -EINVAL;
1235                         goto unlock;
1236                 }
1237         }
1238
1239         while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
1240                 if (!iommu->v2 && unmap->iova > dma->iova)
1241                         break;
1242                 /*
1243                  * Task with same address space who mapped this iova range is
1244                  * allowed to unmap the iova range.
1245                  */
1246                 if (dma->task->mm != current->mm)
1247                         break;
1248
1249                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1250                         struct vfio_iommu_type1_dma_unmap nb_unmap;
1251
1252                         if (dma_last == dma) {
1253                                 BUG_ON(++retries > 10);
1254                         } else {
1255                                 dma_last = dma;
1256                                 retries = 0;
1257                         }
1258
1259                         nb_unmap.iova = dma->iova;
1260                         nb_unmap.size = dma->size;
1261
1262                         /*
1263                          * Notify anyone (mdev vendor drivers) to invalidate and
1264                          * unmap iovas within the range we're about to unmap.
1265                          * Vendor drivers MUST unpin pages in response to an
1266                          * invalidation.
1267                          */
1268                         mutex_unlock(&iommu->lock);
1269                         blocking_notifier_call_chain(&iommu->notifier,
1270                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
1271                                                     &nb_unmap);
1272                         mutex_lock(&iommu->lock);
1273                         goto again;
1274                 }
1275
1276                 if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1277                         ret = update_user_bitmap(bitmap->data, iommu, dma,
1278                                                  unmap->iova, pgsize);
1279                         if (ret)
1280                                 break;
1281                 }
1282
1283                 unmapped += dma->size;
1284                 vfio_remove_dma(iommu, dma);
1285         }
1286
1287 unlock:
1288         mutex_unlock(&iommu->lock);
1289
1290         /* Report how much was unmapped */
1291         unmap->size = unmapped;
1292
1293         return ret;
1294 }
1295
1296 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1297                           unsigned long pfn, long npage, int prot)
1298 {
1299         struct vfio_domain *d;
1300         int ret;
1301
1302         list_for_each_entry(d, &iommu->domain_list, next) {
1303                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1304                                 npage << PAGE_SHIFT, prot | d->prot);
1305                 if (ret)
1306                         goto unwind;
1307
1308                 cond_resched();
1309         }
1310
1311         return 0;
1312
1313 unwind:
1314         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1315                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1316                 cond_resched();
1317         }
1318
1319         return ret;
1320 }
1321
1322 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1323                             size_t map_size)
1324 {
1325         dma_addr_t iova = dma->iova;
1326         unsigned long vaddr = dma->vaddr;
1327         struct vfio_batch batch;
1328         size_t size = map_size;
1329         long npage;
1330         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1331         int ret = 0;
1332
1333         vfio_batch_init(&batch);
1334
1335         while (size) {
1336                 /* Pin a contiguous chunk of memory */
1337                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1338                                               size >> PAGE_SHIFT, &pfn, limit,
1339                                               &batch);
1340                 if (npage <= 0) {
1341                         WARN_ON(!npage);
1342                         ret = (int)npage;
1343                         break;
1344                 }
1345
1346                 /* Map it! */
1347                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1348                                      dma->prot);
1349                 if (ret) {
1350                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1351                                                 npage, true);
1352                         break;
1353                 }
1354
1355                 size -= npage << PAGE_SHIFT;
1356                 dma->size += npage << PAGE_SHIFT;
1357         }
1358
1359         vfio_batch_fini(&batch);
1360         dma->iommu_mapped = true;
1361
1362         if (ret)
1363                 vfio_remove_dma(iommu, dma);
1364
1365         return ret;
1366 }
1367
1368 /*
1369  * Check dma map request is within a valid iova range
1370  */
1371 static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1372                                       dma_addr_t start, dma_addr_t end)
1373 {
1374         struct list_head *iova = &iommu->iova_list;
1375         struct vfio_iova *node;
1376
1377         list_for_each_entry(node, iova, list) {
1378                 if (start >= node->start && end <= node->end)
1379                         return true;
1380         }
1381
1382         /*
1383          * Check for list_empty() as well since a container with
1384          * a single mdev device will have an empty list.
1385          */
1386         return list_empty(iova);
1387 }
1388
1389 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1390                            struct vfio_iommu_type1_dma_map *map)
1391 {
1392         dma_addr_t iova = map->iova;
1393         unsigned long vaddr = map->vaddr;
1394         size_t size = map->size;
1395         int ret = 0, prot = 0;
1396         size_t pgsize;
1397         struct vfio_dma *dma;
1398
1399         /* Verify that none of our __u64 fields overflow */
1400         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1401                 return -EINVAL;
1402
1403         /* READ/WRITE from device perspective */
1404         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1405                 prot |= IOMMU_WRITE;
1406         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1407                 prot |= IOMMU_READ;
1408
1409         mutex_lock(&iommu->lock);
1410
1411         pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1412
1413         WARN_ON((pgsize - 1) & PAGE_MASK);
1414
1415         if (!prot || !size || (size | iova | vaddr) & (pgsize - 1)) {
1416                 ret = -EINVAL;
1417                 goto out_unlock;
1418         }
1419
1420         /* Don't allow IOVA or virtual address wrap */
1421         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1422                 ret = -EINVAL;
1423                 goto out_unlock;
1424         }
1425
1426         if (vfio_find_dma(iommu, iova, size)) {
1427                 ret = -EEXIST;
1428                 goto out_unlock;
1429         }
1430
1431         if (!iommu->dma_avail) {
1432                 ret = -ENOSPC;
1433                 goto out_unlock;
1434         }
1435
1436         if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1437                 ret = -EINVAL;
1438                 goto out_unlock;
1439         }
1440
1441         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1442         if (!dma) {
1443                 ret = -ENOMEM;
1444                 goto out_unlock;
1445         }
1446
1447         iommu->dma_avail--;
1448         dma->iova = iova;
1449         dma->vaddr = vaddr;
1450         dma->prot = prot;
1451
1452         /*
1453          * We need to be able to both add to a task's locked memory and test
1454          * against the locked memory limit and we need to be able to do both
1455          * outside of this call path as pinning can be asynchronous via the
1456          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1457          * task_struct. Save the group_leader so that all DMA tracking uses
1458          * the same task, to make debugging easier.  VM locked pages requires
1459          * an mm_struct, so grab the mm in case the task dies.
1460          */
1461         get_task_struct(current->group_leader);
1462         dma->task = current->group_leader;
1463         dma->lock_cap = capable(CAP_IPC_LOCK);
1464         dma->mm = current->mm;
1465         mmgrab(dma->mm);
1466
1467         dma->pfn_list = RB_ROOT;
1468
1469         /* Insert zero-sized and grow as we map chunks of it */
1470         vfio_link_dma(iommu, dma);
1471
1472         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1473         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1474                 dma->size = size;
1475         else
1476                 ret = vfio_pin_map_dma(iommu, dma, size);
1477
1478         if (!ret && iommu->dirty_page_tracking) {
1479                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
1480                 if (ret)
1481                         vfio_remove_dma(iommu, dma);
1482         }
1483
1484 out_unlock:
1485         mutex_unlock(&iommu->lock);
1486         return ret;
1487 }
1488
1489 static int vfio_bus_type(struct device *dev, void *data)
1490 {
1491         struct bus_type **bus = data;
1492
1493         if (*bus && *bus != dev->bus)
1494                 return -EINVAL;
1495
1496         *bus = dev->bus;
1497
1498         return 0;
1499 }
1500
1501 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1502                              struct vfio_domain *domain)
1503 {
1504         struct vfio_batch batch;
1505         struct vfio_domain *d = NULL;
1506         struct rb_node *n;
1507         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1508         int ret;
1509
1510         /* Arbitrarily pick the first domain in the list for lookups */
1511         if (!list_empty(&iommu->domain_list))
1512                 d = list_first_entry(&iommu->domain_list,
1513                                      struct vfio_domain, next);
1514
1515         vfio_batch_init(&batch);
1516
1517         n = rb_first(&iommu->dma_list);
1518
1519         for (; n; n = rb_next(n)) {
1520                 struct vfio_dma *dma;
1521                 dma_addr_t iova;
1522
1523                 dma = rb_entry(n, struct vfio_dma, node);
1524                 iova = dma->iova;
1525
1526                 while (iova < dma->iova + dma->size) {
1527                         phys_addr_t phys;
1528                         size_t size;
1529
1530                         if (dma->iommu_mapped) {
1531                                 phys_addr_t p;
1532                                 dma_addr_t i;
1533
1534                                 if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1535                                         ret = -EINVAL;
1536                                         goto unwind;
1537                                 }
1538
1539                                 phys = iommu_iova_to_phys(d->domain, iova);
1540
1541                                 if (WARN_ON(!phys)) {
1542                                         iova += PAGE_SIZE;
1543                                         continue;
1544                                 }
1545
1546                                 size = PAGE_SIZE;
1547                                 p = phys + size;
1548                                 i = iova + size;
1549                                 while (i < dma->iova + dma->size &&
1550                                        p == iommu_iova_to_phys(d->domain, i)) {
1551                                         size += PAGE_SIZE;
1552                                         p += PAGE_SIZE;
1553                                         i += PAGE_SIZE;
1554                                 }
1555                         } else {
1556                                 unsigned long pfn;
1557                                 unsigned long vaddr = dma->vaddr +
1558                                                      (iova - dma->iova);
1559                                 size_t n = dma->iova + dma->size - iova;
1560                                 long npage;
1561
1562                                 npage = vfio_pin_pages_remote(dma, vaddr,
1563                                                               n >> PAGE_SHIFT,
1564                                                               &pfn, limit,
1565                                                               &batch);
1566                                 if (npage <= 0) {
1567                                         WARN_ON(!npage);
1568                                         ret = (int)npage;
1569                                         goto unwind;
1570                                 }
1571
1572                                 phys = pfn << PAGE_SHIFT;
1573                                 size = npage << PAGE_SHIFT;
1574                         }
1575
1576                         ret = iommu_map(domain->domain, iova, phys,
1577                                         size, dma->prot | domain->prot);
1578                         if (ret) {
1579                                 if (!dma->iommu_mapped)
1580                                         vfio_unpin_pages_remote(dma, iova,
1581                                                         phys >> PAGE_SHIFT,
1582                                                         size >> PAGE_SHIFT,
1583                                                         true);
1584                                 goto unwind;
1585                         }
1586
1587                         iova += size;
1588                 }
1589         }
1590
1591         /* All dmas are now mapped, defer to second tree walk for unwind */
1592         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1593                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1594
1595                 dma->iommu_mapped = true;
1596         }
1597
1598         vfio_batch_fini(&batch);
1599         return 0;
1600
1601 unwind:
1602         for (; n; n = rb_prev(n)) {
1603                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1604                 dma_addr_t iova;
1605
1606                 if (dma->iommu_mapped) {
1607                         iommu_unmap(domain->domain, dma->iova, dma->size);
1608                         continue;
1609                 }
1610
1611                 iova = dma->iova;
1612                 while (iova < dma->iova + dma->size) {
1613                         phys_addr_t phys, p;
1614                         size_t size;
1615                         dma_addr_t i;
1616
1617                         phys = iommu_iova_to_phys(domain->domain, iova);
1618                         if (!phys) {
1619                                 iova += PAGE_SIZE;
1620                                 continue;
1621                         }
1622
1623                         size = PAGE_SIZE;
1624                         p = phys + size;
1625                         i = iova + size;
1626                         while (i < dma->iova + dma->size &&
1627                                p == iommu_iova_to_phys(domain->domain, i)) {
1628                                 size += PAGE_SIZE;
1629                                 p += PAGE_SIZE;
1630                                 i += PAGE_SIZE;
1631                         }
1632
1633                         iommu_unmap(domain->domain, iova, size);
1634                         vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1635                                                 size >> PAGE_SHIFT, true);
1636                 }
1637         }
1638
1639         vfio_batch_fini(&batch);
1640         return ret;
1641 }
1642
1643 /*
1644  * We change our unmap behavior slightly depending on whether the IOMMU
1645  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1646  * for practically any contiguous power-of-two mapping we give it.  This means
1647  * we don't need to look for contiguous chunks ourselves to make unmapping
1648  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1649  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1650  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1651  * hugetlbfs is in use.
1652  */
1653 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1654 {
1655         struct page *pages;
1656         int ret, order = get_order(PAGE_SIZE * 2);
1657
1658         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1659         if (!pages)
1660                 return;
1661
1662         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1663                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1664         if (!ret) {
1665                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1666
1667                 if (unmapped == PAGE_SIZE)
1668                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1669                 else
1670                         domain->fgsp = true;
1671         }
1672
1673         __free_pages(pages, order);
1674 }
1675
1676 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1677                                            struct iommu_group *iommu_group)
1678 {
1679         struct vfio_group *g;
1680
1681         list_for_each_entry(g, &domain->group_list, next) {
1682                 if (g->iommu_group == iommu_group)
1683                         return g;
1684         }
1685
1686         return NULL;
1687 }
1688
1689 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1690                                                struct iommu_group *iommu_group)
1691 {
1692         struct vfio_domain *domain;
1693         struct vfio_group *group = NULL;
1694
1695         list_for_each_entry(domain, &iommu->domain_list, next) {
1696                 group = find_iommu_group(domain, iommu_group);
1697                 if (group)
1698                         return group;
1699         }
1700
1701         if (iommu->external_domain)
1702                 group = find_iommu_group(iommu->external_domain, iommu_group);
1703
1704         return group;
1705 }
1706
1707 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu)
1708 {
1709         struct vfio_domain *domain;
1710         struct vfio_group *group;
1711
1712         list_for_each_entry(domain, &iommu->domain_list, next) {
1713                 list_for_each_entry(group, &domain->group_list, next) {
1714                         if (!group->pinned_page_dirty_scope) {
1715                                 iommu->pinned_page_dirty_scope = false;
1716                                 return;
1717                         }
1718                 }
1719         }
1720
1721         if (iommu->external_domain) {
1722                 domain = iommu->external_domain;
1723                 list_for_each_entry(group, &domain->group_list, next) {
1724                         if (!group->pinned_page_dirty_scope) {
1725                                 iommu->pinned_page_dirty_scope = false;
1726                                 return;
1727                         }
1728                 }
1729         }
1730
1731         iommu->pinned_page_dirty_scope = true;
1732 }
1733
1734 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1735                                   phys_addr_t *base)
1736 {
1737         struct iommu_resv_region *region;
1738         bool ret = false;
1739
1740         list_for_each_entry(region, group_resv_regions, list) {
1741                 /*
1742                  * The presence of any 'real' MSI regions should take
1743                  * precedence over the software-managed one if the
1744                  * IOMMU driver happens to advertise both types.
1745                  */
1746                 if (region->type == IOMMU_RESV_MSI) {
1747                         ret = false;
1748                         break;
1749                 }
1750
1751                 if (region->type == IOMMU_RESV_SW_MSI) {
1752                         *base = region->start;
1753                         ret = true;
1754                 }
1755         }
1756
1757         return ret;
1758 }
1759
1760 static struct device *vfio_mdev_get_iommu_device(struct device *dev)
1761 {
1762         struct device *(*fn)(struct device *dev);
1763         struct device *iommu_device;
1764
1765         fn = symbol_get(mdev_get_iommu_device);
1766         if (fn) {
1767                 iommu_device = fn(dev);
1768                 symbol_put(mdev_get_iommu_device);
1769
1770                 return iommu_device;
1771         }
1772
1773         return NULL;
1774 }
1775
1776 static int vfio_mdev_attach_domain(struct device *dev, void *data)
1777 {
1778         struct iommu_domain *domain = data;
1779         struct device *iommu_device;
1780
1781         iommu_device = vfio_mdev_get_iommu_device(dev);
1782         if (iommu_device) {
1783                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1784                         return iommu_aux_attach_device(domain, iommu_device);
1785                 else
1786                         return iommu_attach_device(domain, iommu_device);
1787         }
1788
1789         return -EINVAL;
1790 }
1791
1792 static int vfio_mdev_detach_domain(struct device *dev, void *data)
1793 {
1794         struct iommu_domain *domain = data;
1795         struct device *iommu_device;
1796
1797         iommu_device = vfio_mdev_get_iommu_device(dev);
1798         if (iommu_device) {
1799                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1800                         iommu_aux_detach_device(domain, iommu_device);
1801                 else
1802                         iommu_detach_device(domain, iommu_device);
1803         }
1804
1805         return 0;
1806 }
1807
1808 static int vfio_iommu_attach_group(struct vfio_domain *domain,
1809                                    struct vfio_group *group)
1810 {
1811         if (group->mdev_group)
1812                 return iommu_group_for_each_dev(group->iommu_group,
1813                                                 domain->domain,
1814                                                 vfio_mdev_attach_domain);
1815         else
1816                 return iommu_attach_group(domain->domain, group->iommu_group);
1817 }
1818
1819 static void vfio_iommu_detach_group(struct vfio_domain *domain,
1820                                     struct vfio_group *group)
1821 {
1822         if (group->mdev_group)
1823                 iommu_group_for_each_dev(group->iommu_group, domain->domain,
1824                                          vfio_mdev_detach_domain);
1825         else
1826                 iommu_detach_group(domain->domain, group->iommu_group);
1827 }
1828
1829 static bool vfio_bus_is_mdev(struct bus_type *bus)
1830 {
1831         struct bus_type *mdev_bus;
1832         bool ret = false;
1833
1834         mdev_bus = symbol_get(mdev_bus_type);
1835         if (mdev_bus) {
1836                 ret = (bus == mdev_bus);
1837                 symbol_put(mdev_bus_type);
1838         }
1839
1840         return ret;
1841 }
1842
1843 static int vfio_mdev_iommu_device(struct device *dev, void *data)
1844 {
1845         struct device **old = data, *new;
1846
1847         new = vfio_mdev_get_iommu_device(dev);
1848         if (!new || (*old && *old != new))
1849                 return -EINVAL;
1850
1851         *old = new;
1852
1853         return 0;
1854 }
1855
1856 /*
1857  * This is a helper function to insert an address range to iova list.
1858  * The list is initially created with a single entry corresponding to
1859  * the IOMMU domain geometry to which the device group is attached.
1860  * The list aperture gets modified when a new domain is added to the
1861  * container if the new aperture doesn't conflict with the current one
1862  * or with any existing dma mappings. The list is also modified to
1863  * exclude any reserved regions associated with the device group.
1864  */
1865 static int vfio_iommu_iova_insert(struct list_head *head,
1866                                   dma_addr_t start, dma_addr_t end)
1867 {
1868         struct vfio_iova *region;
1869
1870         region = kmalloc(sizeof(*region), GFP_KERNEL);
1871         if (!region)
1872                 return -ENOMEM;
1873
1874         INIT_LIST_HEAD(&region->list);
1875         region->start = start;
1876         region->end = end;
1877
1878         list_add_tail(&region->list, head);
1879         return 0;
1880 }
1881
1882 /*
1883  * Check the new iommu aperture conflicts with existing aper or with any
1884  * existing dma mappings.
1885  */
1886 static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
1887                                      dma_addr_t start, dma_addr_t end)
1888 {
1889         struct vfio_iova *first, *last;
1890         struct list_head *iova = &iommu->iova_list;
1891
1892         if (list_empty(iova))
1893                 return false;
1894
1895         /* Disjoint sets, return conflict */
1896         first = list_first_entry(iova, struct vfio_iova, list);
1897         last = list_last_entry(iova, struct vfio_iova, list);
1898         if (start > last->end || end < first->start)
1899                 return true;
1900
1901         /* Check for any existing dma mappings below the new start */
1902         if (start > first->start) {
1903                 if (vfio_find_dma(iommu, first->start, start - first->start))
1904                         return true;
1905         }
1906
1907         /* Check for any existing dma mappings beyond the new end */
1908         if (end < last->end) {
1909                 if (vfio_find_dma(iommu, end + 1, last->end - end))
1910                         return true;
1911         }
1912
1913         return false;
1914 }
1915
1916 /*
1917  * Resize iommu iova aperture window. This is called only if the new
1918  * aperture has no conflict with existing aperture and dma mappings.
1919  */
1920 static int vfio_iommu_aper_resize(struct list_head *iova,
1921                                   dma_addr_t start, dma_addr_t end)
1922 {
1923         struct vfio_iova *node, *next;
1924
1925         if (list_empty(iova))
1926                 return vfio_iommu_iova_insert(iova, start, end);
1927
1928         /* Adjust iova list start */
1929         list_for_each_entry_safe(node, next, iova, list) {
1930                 if (start < node->start)
1931                         break;
1932                 if (start >= node->start && start < node->end) {
1933                         node->start = start;
1934                         break;
1935                 }
1936                 /* Delete nodes before new start */
1937                 list_del(&node->list);
1938                 kfree(node);
1939         }
1940
1941         /* Adjust iova list end */
1942         list_for_each_entry_safe(node, next, iova, list) {
1943                 if (end > node->end)
1944                         continue;
1945                 if (end > node->start && end <= node->end) {
1946                         node->end = end;
1947                         continue;
1948                 }
1949                 /* Delete nodes after new end */
1950                 list_del(&node->list);
1951                 kfree(node);
1952         }
1953
1954         return 0;
1955 }
1956
1957 /*
1958  * Check reserved region conflicts with existing dma mappings
1959  */
1960 static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
1961                                      struct list_head *resv_regions)
1962 {
1963         struct iommu_resv_region *region;
1964
1965         /* Check for conflict with existing dma mappings */
1966         list_for_each_entry(region, resv_regions, list) {
1967                 if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
1968                         continue;
1969
1970                 if (vfio_find_dma(iommu, region->start, region->length))
1971                         return true;
1972         }
1973
1974         return false;
1975 }
1976
1977 /*
1978  * Check iova region overlap with  reserved regions and
1979  * exclude them from the iommu iova range
1980  */
1981 static int vfio_iommu_resv_exclude(struct list_head *iova,
1982                                    struct list_head *resv_regions)
1983 {
1984         struct iommu_resv_region *resv;
1985         struct vfio_iova *n, *next;
1986
1987         list_for_each_entry(resv, resv_regions, list) {
1988                 phys_addr_t start, end;
1989
1990                 if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
1991                         continue;
1992
1993                 start = resv->start;
1994                 end = resv->start + resv->length - 1;
1995
1996                 list_for_each_entry_safe(n, next, iova, list) {
1997                         int ret = 0;
1998
1999                         /* No overlap */
2000                         if (start > n->end || end < n->start)
2001                                 continue;
2002                         /*
2003                          * Insert a new node if current node overlaps with the
2004                          * reserve region to exlude that from valid iova range.
2005                          * Note that, new node is inserted before the current
2006                          * node and finally the current node is deleted keeping
2007                          * the list updated and sorted.
2008                          */
2009                         if (start > n->start)
2010                                 ret = vfio_iommu_iova_insert(&n->list, n->start,
2011                                                              start - 1);
2012                         if (!ret && end < n->end)
2013                                 ret = vfio_iommu_iova_insert(&n->list, end + 1,
2014                                                              n->end);
2015                         if (ret)
2016                                 return ret;
2017
2018                         list_del(&n->list);
2019                         kfree(n);
2020                 }
2021         }
2022
2023         if (list_empty(iova))
2024                 return -EINVAL;
2025
2026         return 0;
2027 }
2028
2029 static void vfio_iommu_resv_free(struct list_head *resv_regions)
2030 {
2031         struct iommu_resv_region *n, *next;
2032
2033         list_for_each_entry_safe(n, next, resv_regions, list) {
2034                 list_del(&n->list);
2035                 kfree(n);
2036         }
2037 }
2038
2039 static void vfio_iommu_iova_free(struct list_head *iova)
2040 {
2041         struct vfio_iova *n, *next;
2042
2043         list_for_each_entry_safe(n, next, iova, list) {
2044                 list_del(&n->list);
2045                 kfree(n);
2046         }
2047 }
2048
2049 static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
2050                                     struct list_head *iova_copy)
2051 {
2052         struct list_head *iova = &iommu->iova_list;
2053         struct vfio_iova *n;
2054         int ret;
2055
2056         list_for_each_entry(n, iova, list) {
2057                 ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
2058                 if (ret)
2059                         goto out_free;
2060         }
2061
2062         return 0;
2063
2064 out_free:
2065         vfio_iommu_iova_free(iova_copy);
2066         return ret;
2067 }
2068
2069 static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2070                                         struct list_head *iova_copy)
2071 {
2072         struct list_head *iova = &iommu->iova_list;
2073
2074         vfio_iommu_iova_free(iova);
2075
2076         list_splice_tail(iova_copy, iova);
2077 }
2078
2079 static int vfio_iommu_type1_attach_group(void *iommu_data,
2080                                          struct iommu_group *iommu_group)
2081 {
2082         struct vfio_iommu *iommu = iommu_data;
2083         struct vfio_group *group;
2084         struct vfio_domain *domain, *d;
2085         struct bus_type *bus = NULL;
2086         int ret;
2087         bool resv_msi, msi_remap;
2088         phys_addr_t resv_msi_base = 0;
2089         struct iommu_domain_geometry geo;
2090         LIST_HEAD(iova_copy);
2091         LIST_HEAD(group_resv_regions);
2092
2093         mutex_lock(&iommu->lock);
2094
2095         /* Check for duplicates */
2096         if (vfio_iommu_find_iommu_group(iommu, iommu_group)) {
2097                 mutex_unlock(&iommu->lock);
2098                 return -EINVAL;
2099         }
2100
2101         group = kzalloc(sizeof(*group), GFP_KERNEL);
2102         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2103         if (!group || !domain) {
2104                 ret = -ENOMEM;
2105                 goto out_free;
2106         }
2107
2108         group->iommu_group = iommu_group;
2109
2110         /* Determine bus_type in order to allocate a domain */
2111         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
2112         if (ret)
2113                 goto out_free;
2114
2115         if (vfio_bus_is_mdev(bus)) {
2116                 struct device *iommu_device = NULL;
2117
2118                 group->mdev_group = true;
2119
2120                 /* Determine the isolation type */
2121                 ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
2122                                                vfio_mdev_iommu_device);
2123                 if (ret || !iommu_device) {
2124                         if (!iommu->external_domain) {
2125                                 INIT_LIST_HEAD(&domain->group_list);
2126                                 iommu->external_domain = domain;
2127                                 vfio_update_pgsize_bitmap(iommu);
2128                         } else {
2129                                 kfree(domain);
2130                         }
2131
2132                         list_add(&group->next,
2133                                  &iommu->external_domain->group_list);
2134                         /*
2135                          * Non-iommu backed group cannot dirty memory directly,
2136                          * it can only use interfaces that provide dirty
2137                          * tracking.
2138                          * The iommu scope can only be promoted with the
2139                          * addition of a dirty tracking group.
2140                          */
2141                         group->pinned_page_dirty_scope = true;
2142                         if (!iommu->pinned_page_dirty_scope)
2143                                 update_pinned_page_dirty_scope(iommu);
2144                         mutex_unlock(&iommu->lock);
2145
2146                         return 0;
2147                 }
2148
2149                 bus = iommu_device->bus;
2150         }
2151
2152         domain->domain = iommu_domain_alloc(bus);
2153         if (!domain->domain) {
2154                 ret = -EIO;
2155                 goto out_free;
2156         }
2157
2158         if (iommu->nesting) {
2159                 int attr = 1;
2160
2161                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
2162                                             &attr);
2163                 if (ret)
2164                         goto out_domain;
2165         }
2166
2167         ret = vfio_iommu_attach_group(domain, group);
2168         if (ret)
2169                 goto out_domain;
2170
2171         /* Get aperture info */
2172         iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY, &geo);
2173
2174         if (vfio_iommu_aper_conflict(iommu, geo.aperture_start,
2175                                      geo.aperture_end)) {
2176                 ret = -EINVAL;
2177                 goto out_detach;
2178         }
2179
2180         ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2181         if (ret)
2182                 goto out_detach;
2183
2184         if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2185                 ret = -EINVAL;
2186                 goto out_detach;
2187         }
2188
2189         /*
2190          * We don't want to work on the original iova list as the list
2191          * gets modified and in case of failure we have to retain the
2192          * original list. Get a copy here.
2193          */
2194         ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2195         if (ret)
2196                 goto out_detach;
2197
2198         ret = vfio_iommu_aper_resize(&iova_copy, geo.aperture_start,
2199                                      geo.aperture_end);
2200         if (ret)
2201                 goto out_detach;
2202
2203         ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2204         if (ret)
2205                 goto out_detach;
2206
2207         resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2208
2209         INIT_LIST_HEAD(&domain->group_list);
2210         list_add(&group->next, &domain->group_list);
2211
2212         msi_remap = irq_domain_check_msi_remap() ||
2213                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
2214
2215         if (!allow_unsafe_interrupts && !msi_remap) {
2216                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2217                        __func__);
2218                 ret = -EPERM;
2219                 goto out_detach;
2220         }
2221
2222         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
2223                 domain->prot |= IOMMU_CACHE;
2224
2225         /*
2226          * Try to match an existing compatible domain.  We don't want to
2227          * preclude an IOMMU driver supporting multiple bus_types and being
2228          * able to include different bus_types in the same IOMMU domain, so
2229          * we test whether the domains use the same iommu_ops rather than
2230          * testing if they're on the same bus_type.
2231          */
2232         list_for_each_entry(d, &iommu->domain_list, next) {
2233                 if (d->domain->ops == domain->domain->ops &&
2234                     d->prot == domain->prot) {
2235                         vfio_iommu_detach_group(domain, group);
2236                         if (!vfio_iommu_attach_group(d, group)) {
2237                                 list_add(&group->next, &d->group_list);
2238                                 iommu_domain_free(domain->domain);
2239                                 kfree(domain);
2240                                 goto done;
2241                         }
2242
2243                         ret = vfio_iommu_attach_group(domain, group);
2244                         if (ret)
2245                                 goto out_domain;
2246                 }
2247         }
2248
2249         vfio_test_domain_fgsp(domain);
2250
2251         /* replay mappings on new domains */
2252         ret = vfio_iommu_replay(iommu, domain);
2253         if (ret)
2254                 goto out_detach;
2255
2256         if (resv_msi) {
2257                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2258                 if (ret && ret != -ENODEV)
2259                         goto out_detach;
2260         }
2261
2262         list_add(&domain->next, &iommu->domain_list);
2263         vfio_update_pgsize_bitmap(iommu);
2264 done:
2265         /* Delete the old one and insert new iova list */
2266         vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2267
2268         /*
2269          * An iommu backed group can dirty memory directly and therefore
2270          * demotes the iommu scope until it declares itself dirty tracking
2271          * capable via the page pinning interface.
2272          */
2273         iommu->pinned_page_dirty_scope = false;
2274         mutex_unlock(&iommu->lock);
2275         vfio_iommu_resv_free(&group_resv_regions);
2276
2277         return 0;
2278
2279 out_detach:
2280         vfio_iommu_detach_group(domain, group);
2281 out_domain:
2282         iommu_domain_free(domain->domain);
2283         vfio_iommu_iova_free(&iova_copy);
2284         vfio_iommu_resv_free(&group_resv_regions);
2285 out_free:
2286         kfree(domain);
2287         kfree(group);
2288         mutex_unlock(&iommu->lock);
2289         return ret;
2290 }
2291
2292 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2293 {
2294         struct rb_node *node;
2295
2296         while ((node = rb_first(&iommu->dma_list)))
2297                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2298 }
2299
2300 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2301 {
2302         struct rb_node *n, *p;
2303
2304         n = rb_first(&iommu->dma_list);
2305         for (; n; n = rb_next(n)) {
2306                 struct vfio_dma *dma;
2307                 long locked = 0, unlocked = 0;
2308
2309                 dma = rb_entry(n, struct vfio_dma, node);
2310                 unlocked += vfio_unmap_unpin(iommu, dma, false);
2311                 p = rb_first(&dma->pfn_list);
2312                 for (; p; p = rb_next(p)) {
2313                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2314                                                          node);
2315
2316                         if (!is_invalid_reserved_pfn(vpfn->pfn))
2317                                 locked++;
2318                 }
2319                 vfio_lock_acct(dma, locked - unlocked, true);
2320         }
2321 }
2322
2323 /*
2324  * Called when a domain is removed in detach. It is possible that
2325  * the removed domain decided the iova aperture window. Modify the
2326  * iova aperture with the smallest window among existing domains.
2327  */
2328 static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2329                                    struct list_head *iova_copy)
2330 {
2331         struct vfio_domain *domain;
2332         struct iommu_domain_geometry geo;
2333         struct vfio_iova *node;
2334         dma_addr_t start = 0;
2335         dma_addr_t end = (dma_addr_t)~0;
2336
2337         if (list_empty(iova_copy))
2338                 return;
2339
2340         list_for_each_entry(domain, &iommu->domain_list, next) {
2341                 iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY,
2342                                       &geo);
2343                 if (geo.aperture_start > start)
2344                         start = geo.aperture_start;
2345                 if (geo.aperture_end < end)
2346                         end = geo.aperture_end;
2347         }
2348
2349         /* Modify aperture limits. The new aper is either same or bigger */
2350         node = list_first_entry(iova_copy, struct vfio_iova, list);
2351         node->start = start;
2352         node = list_last_entry(iova_copy, struct vfio_iova, list);
2353         node->end = end;
2354 }
2355
2356 /*
2357  * Called when a group is detached. The reserved regions for that
2358  * group can be part of valid iova now. But since reserved regions
2359  * may be duplicated among groups, populate the iova valid regions
2360  * list again.
2361  */
2362 static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2363                                    struct list_head *iova_copy)
2364 {
2365         struct vfio_domain *d;
2366         struct vfio_group *g;
2367         struct vfio_iova *node;
2368         dma_addr_t start, end;
2369         LIST_HEAD(resv_regions);
2370         int ret;
2371
2372         if (list_empty(iova_copy))
2373                 return -EINVAL;
2374
2375         list_for_each_entry(d, &iommu->domain_list, next) {
2376                 list_for_each_entry(g, &d->group_list, next) {
2377                         ret = iommu_get_group_resv_regions(g->iommu_group,
2378                                                            &resv_regions);
2379                         if (ret)
2380                                 goto done;
2381                 }
2382         }
2383
2384         node = list_first_entry(iova_copy, struct vfio_iova, list);
2385         start = node->start;
2386         node = list_last_entry(iova_copy, struct vfio_iova, list);
2387         end = node->end;
2388
2389         /* purge the iova list and create new one */
2390         vfio_iommu_iova_free(iova_copy);
2391
2392         ret = vfio_iommu_aper_resize(iova_copy, start, end);
2393         if (ret)
2394                 goto done;
2395
2396         /* Exclude current reserved regions from iova ranges */
2397         ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2398 done:
2399         vfio_iommu_resv_free(&resv_regions);
2400         return ret;
2401 }
2402
2403 static void vfio_iommu_type1_detach_group(void *iommu_data,
2404                                           struct iommu_group *iommu_group)
2405 {
2406         struct vfio_iommu *iommu = iommu_data;
2407         struct vfio_domain *domain;
2408         struct vfio_group *group;
2409         bool update_dirty_scope = false;
2410         LIST_HEAD(iova_copy);
2411
2412         mutex_lock(&iommu->lock);
2413
2414         if (iommu->external_domain) {
2415                 group = find_iommu_group(iommu->external_domain, iommu_group);
2416                 if (group) {
2417                         update_dirty_scope = !group->pinned_page_dirty_scope;
2418                         list_del(&group->next);
2419                         kfree(group);
2420
2421                         if (list_empty(&iommu->external_domain->group_list)) {
2422                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu)) {
2423                                         WARN_ON(iommu->notifier.head);
2424                                         vfio_iommu_unmap_unpin_all(iommu);
2425                                 }
2426
2427                                 kfree(iommu->external_domain);
2428                                 iommu->external_domain = NULL;
2429                         }
2430                         goto detach_group_done;
2431                 }
2432         }
2433
2434         /*
2435          * Get a copy of iova list. This will be used to update
2436          * and to replace the current one later. Please note that
2437          * we will leave the original list as it is if update fails.
2438          */
2439         vfio_iommu_iova_get_copy(iommu, &iova_copy);
2440
2441         list_for_each_entry(domain, &iommu->domain_list, next) {
2442                 group = find_iommu_group(domain, iommu_group);
2443                 if (!group)
2444                         continue;
2445
2446                 vfio_iommu_detach_group(domain, group);
2447                 update_dirty_scope = !group->pinned_page_dirty_scope;
2448                 list_del(&group->next);
2449                 kfree(group);
2450                 /*
2451                  * Group ownership provides privilege, if the group list is
2452                  * empty, the domain goes away. If it's the last domain with
2453                  * iommu and external domain doesn't exist, then all the
2454                  * mappings go away too. If it's the last domain with iommu and
2455                  * external domain exist, update accounting
2456                  */
2457                 if (list_empty(&domain->group_list)) {
2458                         if (list_is_singular(&iommu->domain_list)) {
2459                                 if (!iommu->external_domain) {
2460                                         WARN_ON(iommu->notifier.head);
2461                                         vfio_iommu_unmap_unpin_all(iommu);
2462                                 } else {
2463                                         vfio_iommu_unmap_unpin_reaccount(iommu);
2464                                 }
2465                         }
2466                         iommu_domain_free(domain->domain);
2467                         list_del(&domain->next);
2468                         kfree(domain);
2469                         vfio_iommu_aper_expand(iommu, &iova_copy);
2470                         vfio_update_pgsize_bitmap(iommu);
2471                 }
2472                 break;
2473         }
2474
2475         if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2476                 vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2477         else
2478                 vfio_iommu_iova_free(&iova_copy);
2479
2480 detach_group_done:
2481         /*
2482          * Removal of a group without dirty tracking may allow the iommu scope
2483          * to be promoted.
2484          */
2485         if (update_dirty_scope) {
2486                 update_pinned_page_dirty_scope(iommu);
2487                 if (iommu->dirty_page_tracking)
2488                         vfio_iommu_populate_bitmap_full(iommu);
2489         }
2490         mutex_unlock(&iommu->lock);
2491 }
2492
2493 static void *vfio_iommu_type1_open(unsigned long arg)
2494 {
2495         struct vfio_iommu *iommu;
2496
2497         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2498         if (!iommu)
2499                 return ERR_PTR(-ENOMEM);
2500
2501         switch (arg) {
2502         case VFIO_TYPE1_IOMMU:
2503                 break;
2504         case VFIO_TYPE1_NESTING_IOMMU:
2505                 iommu->nesting = true;
2506                 fallthrough;
2507         case VFIO_TYPE1v2_IOMMU:
2508                 iommu->v2 = true;
2509                 break;
2510         default:
2511                 kfree(iommu);
2512                 return ERR_PTR(-EINVAL);
2513         }
2514
2515         INIT_LIST_HEAD(&iommu->domain_list);
2516         INIT_LIST_HEAD(&iommu->iova_list);
2517         iommu->dma_list = RB_ROOT;
2518         iommu->dma_avail = dma_entry_limit;
2519         mutex_init(&iommu->lock);
2520         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
2521
2522         return iommu;
2523 }
2524
2525 static void vfio_release_domain(struct vfio_domain *domain, bool external)
2526 {
2527         struct vfio_group *group, *group_tmp;
2528
2529         list_for_each_entry_safe(group, group_tmp,
2530                                  &domain->group_list, next) {
2531                 if (!external)
2532                         vfio_iommu_detach_group(domain, group);
2533                 list_del(&group->next);
2534                 kfree(group);
2535         }
2536
2537         if (!external)
2538                 iommu_domain_free(domain->domain);
2539 }
2540
2541 static void vfio_iommu_type1_release(void *iommu_data)
2542 {
2543         struct vfio_iommu *iommu = iommu_data;
2544         struct vfio_domain *domain, *domain_tmp;
2545
2546         if (iommu->external_domain) {
2547                 vfio_release_domain(iommu->external_domain, true);
2548                 kfree(iommu->external_domain);
2549         }
2550
2551         vfio_iommu_unmap_unpin_all(iommu);
2552
2553         list_for_each_entry_safe(domain, domain_tmp,
2554                                  &iommu->domain_list, next) {
2555                 vfio_release_domain(domain, false);
2556                 list_del(&domain->next);
2557                 kfree(domain);
2558         }
2559
2560         vfio_iommu_iova_free(&iommu->iova_list);
2561
2562         kfree(iommu);
2563 }
2564
2565 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
2566 {
2567         struct vfio_domain *domain;
2568         int ret = 1;
2569
2570         mutex_lock(&iommu->lock);
2571         list_for_each_entry(domain, &iommu->domain_list, next) {
2572                 if (!(domain->prot & IOMMU_CACHE)) {
2573                         ret = 0;
2574                         break;
2575                 }
2576         }
2577         mutex_unlock(&iommu->lock);
2578
2579         return ret;
2580 }
2581
2582 static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2583                                             unsigned long arg)
2584 {
2585         switch (arg) {
2586         case VFIO_TYPE1_IOMMU:
2587         case VFIO_TYPE1v2_IOMMU:
2588         case VFIO_TYPE1_NESTING_IOMMU:
2589                 return 1;
2590         case VFIO_DMA_CC_IOMMU:
2591                 if (!iommu)
2592                         return 0;
2593                 return vfio_domains_have_iommu_cache(iommu);
2594         default:
2595                 return 0;
2596         }
2597 }
2598
2599 static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2600                  struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2601                  size_t size)
2602 {
2603         struct vfio_info_cap_header *header;
2604         struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2605
2606         header = vfio_info_cap_add(caps, size,
2607                                    VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2608         if (IS_ERR(header))
2609                 return PTR_ERR(header);
2610
2611         iova_cap = container_of(header,
2612                                 struct vfio_iommu_type1_info_cap_iova_range,
2613                                 header);
2614         iova_cap->nr_iovas = cap_iovas->nr_iovas;
2615         memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2616                cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2617         return 0;
2618 }
2619
2620 static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2621                                       struct vfio_info_cap *caps)
2622 {
2623         struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2624         struct vfio_iova *iova;
2625         size_t size;
2626         int iovas = 0, i = 0, ret;
2627
2628         list_for_each_entry(iova, &iommu->iova_list, list)
2629                 iovas++;
2630
2631         if (!iovas) {
2632                 /*
2633                  * Return 0 as a container with a single mdev device
2634                  * will have an empty list
2635                  */
2636                 return 0;
2637         }
2638
2639         size = sizeof(*cap_iovas) + (iovas * sizeof(*cap_iovas->iova_ranges));
2640
2641         cap_iovas = kzalloc(size, GFP_KERNEL);
2642         if (!cap_iovas)
2643                 return -ENOMEM;
2644
2645         cap_iovas->nr_iovas = iovas;
2646
2647         list_for_each_entry(iova, &iommu->iova_list, list) {
2648                 cap_iovas->iova_ranges[i].start = iova->start;
2649                 cap_iovas->iova_ranges[i].end = iova->end;
2650                 i++;
2651         }
2652
2653         ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2654
2655         kfree(cap_iovas);
2656         return ret;
2657 }
2658
2659 static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2660                                            struct vfio_info_cap *caps)
2661 {
2662         struct vfio_iommu_type1_info_cap_migration cap_mig = {};
2663
2664         cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2665         cap_mig.header.version = 1;
2666
2667         cap_mig.flags = 0;
2668         /* support minimum pgsize */
2669         cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2670         cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2671
2672         return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2673 }
2674
2675 static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2676                                            struct vfio_info_cap *caps)
2677 {
2678         struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2679
2680         cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2681         cap_dma_avail.header.version = 1;
2682
2683         cap_dma_avail.avail = iommu->dma_avail;
2684
2685         return vfio_info_add_capability(caps, &cap_dma_avail.header,
2686                                         sizeof(cap_dma_avail));
2687 }
2688
2689 static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2690                                      unsigned long arg)
2691 {
2692         struct vfio_iommu_type1_info info;
2693         unsigned long minsz;
2694         struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2695         unsigned long capsz;
2696         int ret;
2697
2698         minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2699
2700         /* For backward compatibility, cannot require this */
2701         capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2702
2703         if (copy_from_user(&info, (void __user *)arg, minsz))
2704                 return -EFAULT;
2705
2706         if (info.argsz < minsz)
2707                 return -EINVAL;
2708
2709         if (info.argsz >= capsz) {
2710                 minsz = capsz;
2711                 info.cap_offset = 0; /* output, no-recopy necessary */
2712         }
2713
2714         mutex_lock(&iommu->lock);
2715         info.flags = VFIO_IOMMU_INFO_PGSIZES;
2716
2717         info.iova_pgsizes = iommu->pgsize_bitmap;
2718
2719         ret = vfio_iommu_migration_build_caps(iommu, &caps);
2720
2721         if (!ret)
2722                 ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2723
2724         if (!ret)
2725                 ret = vfio_iommu_iova_build_caps(iommu, &caps);
2726
2727         mutex_unlock(&iommu->lock);
2728
2729         if (ret)
2730                 return ret;
2731
2732         if (caps.size) {
2733                 info.flags |= VFIO_IOMMU_INFO_CAPS;
2734
2735                 if (info.argsz < sizeof(info) + caps.size) {
2736                         info.argsz = sizeof(info) + caps.size;
2737                 } else {
2738                         vfio_info_cap_shift(&caps, sizeof(info));
2739                         if (copy_to_user((void __user *)arg +
2740                                         sizeof(info), caps.buf,
2741                                         caps.size)) {
2742                                 kfree(caps.buf);
2743                                 return -EFAULT;
2744                         }
2745                         info.cap_offset = sizeof(info);
2746                 }
2747
2748                 kfree(caps.buf);
2749         }
2750
2751         return copy_to_user((void __user *)arg, &info, minsz) ?
2752                         -EFAULT : 0;
2753 }
2754
2755 static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2756                                     unsigned long arg)
2757 {
2758         struct vfio_iommu_type1_dma_map map;
2759         unsigned long minsz;
2760         uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE;
2761
2762         minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2763
2764         if (copy_from_user(&map, (void __user *)arg, minsz))
2765                 return -EFAULT;
2766
2767         if (map.argsz < minsz || map.flags & ~mask)
2768                 return -EINVAL;
2769
2770         return vfio_dma_do_map(iommu, &map);
2771 }
2772
2773 static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2774                                       unsigned long arg)
2775 {
2776         struct vfio_iommu_type1_dma_unmap unmap;
2777         struct vfio_bitmap bitmap = { 0 };
2778         unsigned long minsz;
2779         int ret;
2780
2781         minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2782
2783         if (copy_from_user(&unmap, (void __user *)arg, minsz))
2784                 return -EFAULT;
2785
2786         if (unmap.argsz < minsz ||
2787             unmap.flags & ~VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP)
2788                 return -EINVAL;
2789
2790         if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2791                 unsigned long pgshift;
2792
2793                 if (unmap.argsz < (minsz + sizeof(bitmap)))
2794                         return -EINVAL;
2795
2796                 if (copy_from_user(&bitmap,
2797                                    (void __user *)(arg + minsz),
2798                                    sizeof(bitmap)))
2799                         return -EFAULT;
2800
2801                 if (!access_ok((void __user *)bitmap.data, bitmap.size))
2802                         return -EINVAL;
2803
2804                 pgshift = __ffs(bitmap.pgsize);
2805                 ret = verify_bitmap_size(unmap.size >> pgshift,
2806                                          bitmap.size);
2807                 if (ret)
2808                         return ret;
2809         }
2810
2811         ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2812         if (ret)
2813                 return ret;
2814
2815         return copy_to_user((void __user *)arg, &unmap, minsz) ?
2816                         -EFAULT : 0;
2817 }
2818
2819 static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
2820                                         unsigned long arg)
2821 {
2822         struct vfio_iommu_type1_dirty_bitmap dirty;
2823         uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2824                         VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2825                         VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2826         unsigned long minsz;
2827         int ret = 0;
2828
2829         if (!iommu->v2)
2830                 return -EACCES;
2831
2832         minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
2833
2834         if (copy_from_user(&dirty, (void __user *)arg, minsz))
2835                 return -EFAULT;
2836
2837         if (dirty.argsz < minsz || dirty.flags & ~mask)
2838                 return -EINVAL;
2839
2840         /* only one flag should be set at a time */
2841         if (__ffs(dirty.flags) != __fls(dirty.flags))
2842                 return -EINVAL;
2843
2844         if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
2845                 size_t pgsize;
2846
2847                 mutex_lock(&iommu->lock);
2848                 pgsize = 1 << __ffs(iommu->pgsize_bitmap);
2849                 if (!iommu->dirty_page_tracking) {
2850                         ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
2851                         if (!ret)
2852                                 iommu->dirty_page_tracking = true;
2853                 }
2854                 mutex_unlock(&iommu->lock);
2855                 return ret;
2856         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
2857                 mutex_lock(&iommu->lock);
2858                 if (iommu->dirty_page_tracking) {
2859                         iommu->dirty_page_tracking = false;
2860                         vfio_dma_bitmap_free_all(iommu);
2861                 }
2862                 mutex_unlock(&iommu->lock);
2863                 return 0;
2864         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
2865                 struct vfio_iommu_type1_dirty_bitmap_get range;
2866                 unsigned long pgshift;
2867                 size_t data_size = dirty.argsz - minsz;
2868                 size_t iommu_pgsize;
2869
2870                 if (!data_size || data_size < sizeof(range))
2871                         return -EINVAL;
2872
2873                 if (copy_from_user(&range, (void __user *)(arg + minsz),
2874                                    sizeof(range)))
2875                         return -EFAULT;
2876
2877                 if (range.iova + range.size < range.iova)
2878                         return -EINVAL;
2879                 if (!access_ok((void __user *)range.bitmap.data,
2880                                range.bitmap.size))
2881                         return -EINVAL;
2882
2883                 pgshift = __ffs(range.bitmap.pgsize);
2884                 ret = verify_bitmap_size(range.size >> pgshift,
2885                                          range.bitmap.size);
2886                 if (ret)
2887                         return ret;
2888
2889                 mutex_lock(&iommu->lock);
2890
2891                 iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2892
2893                 /* allow only smallest supported pgsize */
2894                 if (range.bitmap.pgsize != iommu_pgsize) {
2895                         ret = -EINVAL;
2896                         goto out_unlock;
2897                 }
2898                 if (range.iova & (iommu_pgsize - 1)) {
2899                         ret = -EINVAL;
2900                         goto out_unlock;
2901                 }
2902                 if (!range.size || range.size & (iommu_pgsize - 1)) {
2903                         ret = -EINVAL;
2904                         goto out_unlock;
2905                 }
2906
2907                 if (iommu->dirty_page_tracking)
2908                         ret = vfio_iova_dirty_bitmap(range.bitmap.data,
2909                                                      iommu, range.iova,
2910                                                      range.size,
2911                                                      range.bitmap.pgsize);
2912                 else
2913                         ret = -EINVAL;
2914 out_unlock:
2915                 mutex_unlock(&iommu->lock);
2916
2917                 return ret;
2918         }
2919
2920         return -EINVAL;
2921 }
2922
2923 static long vfio_iommu_type1_ioctl(void *iommu_data,
2924                                    unsigned int cmd, unsigned long arg)
2925 {
2926         struct vfio_iommu *iommu = iommu_data;
2927
2928         switch (cmd) {
2929         case VFIO_CHECK_EXTENSION:
2930                 return vfio_iommu_type1_check_extension(iommu, arg);
2931         case VFIO_IOMMU_GET_INFO:
2932                 return vfio_iommu_type1_get_info(iommu, arg);
2933         case VFIO_IOMMU_MAP_DMA:
2934                 return vfio_iommu_type1_map_dma(iommu, arg);
2935         case VFIO_IOMMU_UNMAP_DMA:
2936                 return vfio_iommu_type1_unmap_dma(iommu, arg);
2937         case VFIO_IOMMU_DIRTY_PAGES:
2938                 return vfio_iommu_type1_dirty_pages(iommu, arg);
2939         default:
2940                 return -ENOTTY;
2941         }
2942 }
2943
2944 static int vfio_iommu_type1_register_notifier(void *iommu_data,
2945                                               unsigned long *events,
2946                                               struct notifier_block *nb)
2947 {
2948         struct vfio_iommu *iommu = iommu_data;
2949
2950         /* clear known events */
2951         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
2952
2953         /* refuse to register if still events remaining */
2954         if (*events)
2955                 return -EINVAL;
2956
2957         return blocking_notifier_chain_register(&iommu->notifier, nb);
2958 }
2959
2960 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
2961                                                 struct notifier_block *nb)
2962 {
2963         struct vfio_iommu *iommu = iommu_data;
2964
2965         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
2966 }
2967
2968 static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
2969                                          dma_addr_t user_iova, void *data,
2970                                          size_t count, bool write,
2971                                          size_t *copied)
2972 {
2973         struct mm_struct *mm;
2974         unsigned long vaddr;
2975         struct vfio_dma *dma;
2976         bool kthread = current->mm == NULL;
2977         size_t offset;
2978
2979         *copied = 0;
2980
2981         dma = vfio_find_dma(iommu, user_iova, 1);
2982         if (!dma)
2983                 return -EINVAL;
2984
2985         if ((write && !(dma->prot & IOMMU_WRITE)) ||
2986                         !(dma->prot & IOMMU_READ))
2987                 return -EPERM;
2988
2989         mm = dma->mm;
2990         if (!mmget_not_zero(mm))
2991                 return -EPERM;
2992
2993         if (kthread)
2994                 kthread_use_mm(mm);
2995         else if (current->mm != mm)
2996                 goto out;
2997
2998         offset = user_iova - dma->iova;
2999
3000         if (count > dma->size - offset)
3001                 count = dma->size - offset;
3002
3003         vaddr = dma->vaddr + offset;
3004
3005         if (write) {
3006                 *copied = copy_to_user((void __user *)vaddr, data,
3007                                          count) ? 0 : count;
3008                 if (*copied && iommu->dirty_page_tracking) {
3009                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
3010                         /*
3011                          * Bitmap populated with the smallest supported page
3012                          * size
3013                          */
3014                         bitmap_set(dma->bitmap, offset >> pgshift,
3015                                    ((offset + *copied - 1) >> pgshift) -
3016                                    (offset >> pgshift) + 1);
3017                 }
3018         } else
3019                 *copied = copy_from_user(data, (void __user *)vaddr,
3020                                            count) ? 0 : count;
3021         if (kthread)
3022                 kthread_unuse_mm(mm);
3023 out:
3024         mmput(mm);
3025         return *copied ? 0 : -EFAULT;
3026 }
3027
3028 static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
3029                                    void *data, size_t count, bool write)
3030 {
3031         struct vfio_iommu *iommu = iommu_data;
3032         int ret = 0;
3033         size_t done;
3034
3035         mutex_lock(&iommu->lock);
3036         while (count > 0) {
3037                 ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
3038                                                     count, write, &done);
3039                 if (ret)
3040                         break;
3041
3042                 count -= done;
3043                 data += done;
3044                 user_iova += done;
3045         }
3046
3047         mutex_unlock(&iommu->lock);
3048         return ret;
3049 }
3050
3051 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3052         .name                   = "vfio-iommu-type1",
3053         .owner                  = THIS_MODULE,
3054         .open                   = vfio_iommu_type1_open,
3055         .release                = vfio_iommu_type1_release,
3056         .ioctl                  = vfio_iommu_type1_ioctl,
3057         .attach_group           = vfio_iommu_type1_attach_group,
3058         .detach_group           = vfio_iommu_type1_detach_group,
3059         .pin_pages              = vfio_iommu_type1_pin_pages,
3060         .unpin_pages            = vfio_iommu_type1_unpin_pages,
3061         .register_notifier      = vfio_iommu_type1_register_notifier,
3062         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
3063         .dma_rw                 = vfio_iommu_type1_dma_rw,
3064 };
3065
3066 static int __init vfio_iommu_type1_init(void)
3067 {
3068         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3069 }
3070
3071 static void __exit vfio_iommu_type1_cleanup(void)
3072 {
3073         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3074 }
3075
3076 module_init(vfio_iommu_type1_init);
3077 module_exit(vfio_iommu_type1_cleanup);
3078
3079 MODULE_VERSION(DRIVER_VERSION);
3080 MODULE_LICENSE("GPL v2");
3081 MODULE_AUTHOR(DRIVER_AUTHOR);
3082 MODULE_DESCRIPTION(DRIVER_DESC);