GNU Linux-libre 6.1.24-gnu
[releases.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userspace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/highmem.h>
28 #include <linux/iommu.h>
29 #include <linux/module.h>
30 #include <linux/mm.h>
31 #include <linux/kthread.h>
32 #include <linux/rbtree.h>
33 #include <linux/sched/signal.h>
34 #include <linux/sched/mm.h>
35 #include <linux/slab.h>
36 #include <linux/uaccess.h>
37 #include <linux/vfio.h>
38 #include <linux/workqueue.h>
39 #include <linux/notifier.h>
40 #include <linux/irqdomain.h>
41 #include "vfio.h"
42
43 #define DRIVER_VERSION  "0.2"
44 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
45 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
46
47 static bool allow_unsafe_interrupts;
48 module_param_named(allow_unsafe_interrupts,
49                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
50 MODULE_PARM_DESC(allow_unsafe_interrupts,
51                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
52
53 static bool disable_hugepages;
54 module_param_named(disable_hugepages,
55                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
56 MODULE_PARM_DESC(disable_hugepages,
57                  "Disable VFIO IOMMU support for IOMMU hugepages.");
58
59 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
60 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
61 MODULE_PARM_DESC(dma_entry_limit,
62                  "Maximum number of user DMA mappings per container (65535).");
63
64 struct vfio_iommu {
65         struct list_head        domain_list;
66         struct list_head        iova_list;
67         struct mutex            lock;
68         struct rb_root          dma_list;
69         struct list_head        device_list;
70         struct mutex            device_list_lock;
71         unsigned int            dma_avail;
72         unsigned int            vaddr_invalid_count;
73         uint64_t                pgsize_bitmap;
74         uint64_t                num_non_pinned_groups;
75         wait_queue_head_t       vaddr_wait;
76         bool                    v2;
77         bool                    nesting;
78         bool                    dirty_page_tracking;
79         bool                    container_open;
80         struct list_head        emulated_iommu_groups;
81 };
82
83 struct vfio_domain {
84         struct iommu_domain     *domain;
85         struct list_head        next;
86         struct list_head        group_list;
87         bool                    fgsp : 1;       /* Fine-grained super pages */
88         bool                    enforce_cache_coherency : 1;
89 };
90
91 struct vfio_dma {
92         struct rb_node          node;
93         dma_addr_t              iova;           /* Device address */
94         unsigned long           vaddr;          /* Process virtual addr */
95         size_t                  size;           /* Map size (bytes) */
96         int                     prot;           /* IOMMU_READ/WRITE */
97         bool                    iommu_mapped;
98         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
99         bool                    vaddr_invalid;
100         struct task_struct      *task;
101         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
102         unsigned long           *bitmap;
103         struct mm_struct        *mm;
104         size_t                  locked_vm;
105 };
106
107 struct vfio_batch {
108         struct page             **pages;        /* for pin_user_pages_remote */
109         struct page             *fallback_page; /* if pages alloc fails */
110         int                     capacity;       /* length of pages array */
111         int                     size;           /* of batch currently */
112         int                     offset;         /* of next entry in pages */
113 };
114
115 struct vfio_iommu_group {
116         struct iommu_group      *iommu_group;
117         struct list_head        next;
118         bool                    pinned_page_dirty_scope;
119 };
120
121 struct vfio_iova {
122         struct list_head        list;
123         dma_addr_t              start;
124         dma_addr_t              end;
125 };
126
127 /*
128  * Guest RAM pinning working set or DMA target
129  */
130 struct vfio_pfn {
131         struct rb_node          node;
132         dma_addr_t              iova;           /* Device address */
133         unsigned long           pfn;            /* Host pfn */
134         unsigned int            ref_count;
135 };
136
137 struct vfio_regions {
138         struct list_head list;
139         dma_addr_t iova;
140         phys_addr_t phys;
141         size_t len;
142 };
143
144 #define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
145
146 /*
147  * Input argument of number of bits to bitmap_set() is unsigned integer, which
148  * further casts to signed integer for unaligned multi-bit operation,
149  * __bitmap_set().
150  * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
151  * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
152  * system.
153  */
154 #define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
155 #define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
156
157 #define WAITED 1
158
159 static int put_pfn(unsigned long pfn, int prot);
160
161 static struct vfio_iommu_group*
162 vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
163                             struct iommu_group *iommu_group);
164
165 /*
166  * This code handles mapping and unmapping of user data buffers
167  * into DMA'ble space using the IOMMU
168  */
169
170 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
171                                       dma_addr_t start, size_t size)
172 {
173         struct rb_node *node = iommu->dma_list.rb_node;
174
175         while (node) {
176                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
177
178                 if (start + size <= dma->iova)
179                         node = node->rb_left;
180                 else if (start >= dma->iova + dma->size)
181                         node = node->rb_right;
182                 else
183                         return dma;
184         }
185
186         return NULL;
187 }
188
189 static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
190                                                 dma_addr_t start, u64 size)
191 {
192         struct rb_node *res = NULL;
193         struct rb_node *node = iommu->dma_list.rb_node;
194         struct vfio_dma *dma_res = NULL;
195
196         while (node) {
197                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
198
199                 if (start < dma->iova + dma->size) {
200                         res = node;
201                         dma_res = dma;
202                         if (start >= dma->iova)
203                                 break;
204                         node = node->rb_left;
205                 } else {
206                         node = node->rb_right;
207                 }
208         }
209         if (res && size && dma_res->iova >= start + size)
210                 res = NULL;
211         return res;
212 }
213
214 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
215 {
216         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
217         struct vfio_dma *dma;
218
219         while (*link) {
220                 parent = *link;
221                 dma = rb_entry(parent, struct vfio_dma, node);
222
223                 if (new->iova + new->size <= dma->iova)
224                         link = &(*link)->rb_left;
225                 else
226                         link = &(*link)->rb_right;
227         }
228
229         rb_link_node(&new->node, parent, link);
230         rb_insert_color(&new->node, &iommu->dma_list);
231 }
232
233 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
234 {
235         rb_erase(&old->node, &iommu->dma_list);
236 }
237
238
239 static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
240 {
241         uint64_t npages = dma->size / pgsize;
242
243         if (npages > DIRTY_BITMAP_PAGES_MAX)
244                 return -EINVAL;
245
246         /*
247          * Allocate extra 64 bits that are used to calculate shift required for
248          * bitmap_shift_left() to manipulate and club unaligned number of pages
249          * in adjacent vfio_dma ranges.
250          */
251         dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
252                                GFP_KERNEL);
253         if (!dma->bitmap)
254                 return -ENOMEM;
255
256         return 0;
257 }
258
259 static void vfio_dma_bitmap_free(struct vfio_dma *dma)
260 {
261         kvfree(dma->bitmap);
262         dma->bitmap = NULL;
263 }
264
265 static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
266 {
267         struct rb_node *p;
268         unsigned long pgshift = __ffs(pgsize);
269
270         for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
271                 struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
272
273                 bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
274         }
275 }
276
277 static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
278 {
279         struct rb_node *n;
280         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
281
282         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
283                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
284
285                 bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
286         }
287 }
288
289 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
290 {
291         struct rb_node *n;
292
293         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
294                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
295                 int ret;
296
297                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
298                 if (ret) {
299                         struct rb_node *p;
300
301                         for (p = rb_prev(n); p; p = rb_prev(p)) {
302                                 struct vfio_dma *dma = rb_entry(n,
303                                                         struct vfio_dma, node);
304
305                                 vfio_dma_bitmap_free(dma);
306                         }
307                         return ret;
308                 }
309                 vfio_dma_populate_bitmap(dma, pgsize);
310         }
311         return 0;
312 }
313
314 static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
315 {
316         struct rb_node *n;
317
318         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
319                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
320
321                 vfio_dma_bitmap_free(dma);
322         }
323 }
324
325 /*
326  * Helper Functions for host iova-pfn list
327  */
328 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
329 {
330         struct vfio_pfn *vpfn;
331         struct rb_node *node = dma->pfn_list.rb_node;
332
333         while (node) {
334                 vpfn = rb_entry(node, struct vfio_pfn, node);
335
336                 if (iova < vpfn->iova)
337                         node = node->rb_left;
338                 else if (iova > vpfn->iova)
339                         node = node->rb_right;
340                 else
341                         return vpfn;
342         }
343         return NULL;
344 }
345
346 static void vfio_link_pfn(struct vfio_dma *dma,
347                           struct vfio_pfn *new)
348 {
349         struct rb_node **link, *parent = NULL;
350         struct vfio_pfn *vpfn;
351
352         link = &dma->pfn_list.rb_node;
353         while (*link) {
354                 parent = *link;
355                 vpfn = rb_entry(parent, struct vfio_pfn, node);
356
357                 if (new->iova < vpfn->iova)
358                         link = &(*link)->rb_left;
359                 else
360                         link = &(*link)->rb_right;
361         }
362
363         rb_link_node(&new->node, parent, link);
364         rb_insert_color(&new->node, &dma->pfn_list);
365 }
366
367 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
368 {
369         rb_erase(&old->node, &dma->pfn_list);
370 }
371
372 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
373                                 unsigned long pfn)
374 {
375         struct vfio_pfn *vpfn;
376
377         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
378         if (!vpfn)
379                 return -ENOMEM;
380
381         vpfn->iova = iova;
382         vpfn->pfn = pfn;
383         vpfn->ref_count = 1;
384         vfio_link_pfn(dma, vpfn);
385         return 0;
386 }
387
388 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
389                                       struct vfio_pfn *vpfn)
390 {
391         vfio_unlink_pfn(dma, vpfn);
392         kfree(vpfn);
393 }
394
395 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
396                                                unsigned long iova)
397 {
398         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
399
400         if (vpfn)
401                 vpfn->ref_count++;
402         return vpfn;
403 }
404
405 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
406 {
407         int ret = 0;
408
409         vpfn->ref_count--;
410         if (!vpfn->ref_count) {
411                 ret = put_pfn(vpfn->pfn, dma->prot);
412                 vfio_remove_from_pfn_list(dma, vpfn);
413         }
414         return ret;
415 }
416
417 static int mm_lock_acct(struct task_struct *task, struct mm_struct *mm,
418                         bool lock_cap, long npage)
419 {
420         int ret = mmap_write_lock_killable(mm);
421
422         if (ret)
423                 return ret;
424
425         ret = __account_locked_vm(mm, abs(npage), npage > 0, task, lock_cap);
426         mmap_write_unlock(mm);
427         return ret;
428 }
429
430 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
431 {
432         struct mm_struct *mm;
433         int ret;
434
435         if (!npage)
436                 return 0;
437
438         mm = dma->mm;
439         if (async && !mmget_not_zero(mm))
440                 return -ESRCH; /* process exited */
441
442         ret = mm_lock_acct(dma->task, mm, dma->lock_cap, npage);
443         if (!ret)
444                 dma->locked_vm += npage;
445
446         if (async)
447                 mmput(mm);
448
449         return ret;
450 }
451
452 /*
453  * Some mappings aren't backed by a struct page, for example an mmap'd
454  * MMIO range for our own or another device.  These use a different
455  * pfn conversion and shouldn't be tracked as locked pages.
456  * For compound pages, any driver that sets the reserved bit in head
457  * page needs to set the reserved bit in all subpages to be safe.
458  */
459 static bool is_invalid_reserved_pfn(unsigned long pfn)
460 {
461         if (pfn_valid(pfn))
462                 return PageReserved(pfn_to_page(pfn));
463
464         return true;
465 }
466
467 static int put_pfn(unsigned long pfn, int prot)
468 {
469         if (!is_invalid_reserved_pfn(pfn)) {
470                 struct page *page = pfn_to_page(pfn);
471
472                 unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
473                 return 1;
474         }
475         return 0;
476 }
477
478 #define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
479
480 static void vfio_batch_init(struct vfio_batch *batch)
481 {
482         batch->size = 0;
483         batch->offset = 0;
484
485         if (unlikely(disable_hugepages))
486                 goto fallback;
487
488         batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
489         if (!batch->pages)
490                 goto fallback;
491
492         batch->capacity = VFIO_BATCH_MAX_CAPACITY;
493         return;
494
495 fallback:
496         batch->pages = &batch->fallback_page;
497         batch->capacity = 1;
498 }
499
500 static void vfio_batch_unpin(struct vfio_batch *batch, struct vfio_dma *dma)
501 {
502         while (batch->size) {
503                 unsigned long pfn = page_to_pfn(batch->pages[batch->offset]);
504
505                 put_pfn(pfn, dma->prot);
506                 batch->offset++;
507                 batch->size--;
508         }
509 }
510
511 static void vfio_batch_fini(struct vfio_batch *batch)
512 {
513         if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
514                 free_page((unsigned long)batch->pages);
515 }
516
517 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
518                             unsigned long vaddr, unsigned long *pfn,
519                             bool write_fault)
520 {
521         pte_t *ptep;
522         spinlock_t *ptl;
523         int ret;
524
525         ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
526         if (ret) {
527                 bool unlocked = false;
528
529                 ret = fixup_user_fault(mm, vaddr,
530                                        FAULT_FLAG_REMOTE |
531                                        (write_fault ?  FAULT_FLAG_WRITE : 0),
532                                        &unlocked);
533                 if (unlocked)
534                         return -EAGAIN;
535
536                 if (ret)
537                         return ret;
538
539                 ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
540                 if (ret)
541                         return ret;
542         }
543
544         if (write_fault && !pte_write(*ptep))
545                 ret = -EFAULT;
546         else
547                 *pfn = pte_pfn(*ptep);
548
549         pte_unmap_unlock(ptep, ptl);
550         return ret;
551 }
552
553 /*
554  * Returns the positive number of pfns successfully obtained or a negative
555  * error code.
556  */
557 static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
558                           long npages, int prot, unsigned long *pfn,
559                           struct page **pages)
560 {
561         struct vm_area_struct *vma;
562         unsigned int flags = 0;
563         int ret;
564
565         if (prot & IOMMU_WRITE)
566                 flags |= FOLL_WRITE;
567
568         mmap_read_lock(mm);
569         ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
570                                     pages, NULL, NULL);
571         if (ret > 0) {
572                 int i;
573
574                 /*
575                  * The zero page is always resident, we don't need to pin it
576                  * and it falls into our invalid/reserved test so we don't
577                  * unpin in put_pfn().  Unpin all zero pages in the batch here.
578                  */
579                 for (i = 0 ; i < ret; i++) {
580                         if (unlikely(is_zero_pfn(page_to_pfn(pages[i]))))
581                                 unpin_user_page(pages[i]);
582                 }
583
584                 *pfn = page_to_pfn(pages[0]);
585                 goto done;
586         }
587
588         vaddr = untagged_addr(vaddr);
589
590 retry:
591         vma = vma_lookup(mm, vaddr);
592
593         if (vma && vma->vm_flags & VM_PFNMAP) {
594                 ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
595                 if (ret == -EAGAIN)
596                         goto retry;
597
598                 if (!ret) {
599                         if (is_invalid_reserved_pfn(*pfn))
600                                 ret = 1;
601                         else
602                                 ret = -EFAULT;
603                 }
604         }
605 done:
606         mmap_read_unlock(mm);
607         return ret;
608 }
609
610 static int vfio_wait(struct vfio_iommu *iommu)
611 {
612         DEFINE_WAIT(wait);
613
614         prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
615         mutex_unlock(&iommu->lock);
616         schedule();
617         mutex_lock(&iommu->lock);
618         finish_wait(&iommu->vaddr_wait, &wait);
619         if (kthread_should_stop() || !iommu->container_open ||
620             fatal_signal_pending(current)) {
621                 return -EFAULT;
622         }
623         return WAITED;
624 }
625
626 /*
627  * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
628  * if the task waits, but is re-locked on return.  Return result in *dma_p.
629  * Return 0 on success with no waiting, WAITED on success if waited, and -errno
630  * on error.
631  */
632 static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
633                                size_t size, struct vfio_dma **dma_p)
634 {
635         int ret = 0;
636
637         do {
638                 *dma_p = vfio_find_dma(iommu, start, size);
639                 if (!*dma_p)
640                         return -EINVAL;
641                 else if (!(*dma_p)->vaddr_invalid)
642                         return ret;
643                 else
644                         ret = vfio_wait(iommu);
645         } while (ret == WAITED);
646
647         return ret;
648 }
649
650 /*
651  * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
652  * if the task waits, but is re-locked on return.  Return 0 on success with no
653  * waiting, WAITED on success if waited, and -errno on error.
654  */
655 static int vfio_wait_all_valid(struct vfio_iommu *iommu)
656 {
657         int ret = 0;
658
659         while (iommu->vaddr_invalid_count && ret >= 0)
660                 ret = vfio_wait(iommu);
661
662         return ret;
663 }
664
665 /*
666  * Attempt to pin pages.  We really don't want to track all the pfns and
667  * the iommu can only map chunks of consecutive pfns anyway, so get the
668  * first page and all consecutive pages with the same locking.
669  */
670 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
671                                   long npage, unsigned long *pfn_base,
672                                   unsigned long limit, struct vfio_batch *batch)
673 {
674         unsigned long pfn;
675         struct mm_struct *mm = current->mm;
676         long ret, pinned = 0, lock_acct = 0;
677         bool rsvd;
678         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
679
680         /* This code path is only user initiated */
681         if (!mm)
682                 return -ENODEV;
683
684         if (batch->size) {
685                 /* Leftover pages in batch from an earlier call. */
686                 *pfn_base = page_to_pfn(batch->pages[batch->offset]);
687                 pfn = *pfn_base;
688                 rsvd = is_invalid_reserved_pfn(*pfn_base);
689         } else {
690                 *pfn_base = 0;
691         }
692
693         while (npage) {
694                 if (!batch->size) {
695                         /* Empty batch, so refill it. */
696                         long req_pages = min_t(long, npage, batch->capacity);
697
698                         ret = vaddr_get_pfns(mm, vaddr, req_pages, dma->prot,
699                                              &pfn, batch->pages);
700                         if (ret < 0)
701                                 goto unpin_out;
702
703                         batch->size = ret;
704                         batch->offset = 0;
705
706                         if (!*pfn_base) {
707                                 *pfn_base = pfn;
708                                 rsvd = is_invalid_reserved_pfn(*pfn_base);
709                         }
710                 }
711
712                 /*
713                  * pfn is preset for the first iteration of this inner loop and
714                  * updated at the end to handle a VM_PFNMAP pfn.  In that case,
715                  * batch->pages isn't valid (there's no struct page), so allow
716                  * batch->pages to be touched only when there's more than one
717                  * pfn to check, which guarantees the pfns are from a
718                  * !VM_PFNMAP vma.
719                  */
720                 while (true) {
721                         if (pfn != *pfn_base + pinned ||
722                             rsvd != is_invalid_reserved_pfn(pfn))
723                                 goto out;
724
725                         /*
726                          * Reserved pages aren't counted against the user,
727                          * externally pinned pages are already counted against
728                          * the user.
729                          */
730                         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
731                                 if (!dma->lock_cap &&
732                                     mm->locked_vm + lock_acct + 1 > limit) {
733                                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
734                                                 __func__, limit << PAGE_SHIFT);
735                                         ret = -ENOMEM;
736                                         goto unpin_out;
737                                 }
738                                 lock_acct++;
739                         }
740
741                         pinned++;
742                         npage--;
743                         vaddr += PAGE_SIZE;
744                         iova += PAGE_SIZE;
745                         batch->offset++;
746                         batch->size--;
747
748                         if (!batch->size)
749                                 break;
750
751                         pfn = page_to_pfn(batch->pages[batch->offset]);
752                 }
753
754                 if (unlikely(disable_hugepages))
755                         break;
756         }
757
758 out:
759         ret = vfio_lock_acct(dma, lock_acct, false);
760
761 unpin_out:
762         if (batch->size == 1 && !batch->offset) {
763                 /* May be a VM_PFNMAP pfn, which the batch can't remember. */
764                 put_pfn(pfn, dma->prot);
765                 batch->size = 0;
766         }
767
768         if (ret < 0) {
769                 if (pinned && !rsvd) {
770                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
771                                 put_pfn(pfn, dma->prot);
772                 }
773                 vfio_batch_unpin(batch, dma);
774
775                 return ret;
776         }
777
778         return pinned;
779 }
780
781 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
782                                     unsigned long pfn, long npage,
783                                     bool do_accounting)
784 {
785         long unlocked = 0, locked = 0;
786         long i;
787
788         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
789                 if (put_pfn(pfn++, dma->prot)) {
790                         unlocked++;
791                         if (vfio_find_vpfn(dma, iova))
792                                 locked++;
793                 }
794         }
795
796         if (do_accounting)
797                 vfio_lock_acct(dma, locked - unlocked, true);
798
799         return unlocked;
800 }
801
802 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
803                                   unsigned long *pfn_base, bool do_accounting)
804 {
805         struct page *pages[1];
806         struct mm_struct *mm;
807         int ret;
808
809         mm = dma->mm;
810         if (!mmget_not_zero(mm))
811                 return -ENODEV;
812
813         ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
814         if (ret != 1)
815                 goto out;
816
817         ret = 0;
818
819         if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
820                 ret = vfio_lock_acct(dma, 1, false);
821                 if (ret) {
822                         put_pfn(*pfn_base, dma->prot);
823                         if (ret == -ENOMEM)
824                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
825                                         "(%ld) exceeded\n", __func__,
826                                         dma->task->comm, task_pid_nr(dma->task),
827                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
828                 }
829         }
830
831 out:
832         mmput(mm);
833         return ret;
834 }
835
836 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
837                                     bool do_accounting)
838 {
839         int unlocked;
840         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
841
842         if (!vpfn)
843                 return 0;
844
845         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
846
847         if (do_accounting)
848                 vfio_lock_acct(dma, -unlocked, true);
849
850         return unlocked;
851 }
852
853 static int vfio_iommu_type1_pin_pages(void *iommu_data,
854                                       struct iommu_group *iommu_group,
855                                       dma_addr_t user_iova,
856                                       int npage, int prot,
857                                       struct page **pages)
858 {
859         struct vfio_iommu *iommu = iommu_data;
860         struct vfio_iommu_group *group;
861         int i, j, ret;
862         unsigned long remote_vaddr;
863         struct vfio_dma *dma;
864         bool do_accounting;
865         dma_addr_t iova;
866
867         if (!iommu || !pages)
868                 return -EINVAL;
869
870         /* Supported for v2 version only */
871         if (!iommu->v2)
872                 return -EACCES;
873
874         mutex_lock(&iommu->lock);
875
876         if (WARN_ONCE(iommu->vaddr_invalid_count,
877                       "vfio_pin_pages not allowed with VFIO_UPDATE_VADDR\n")) {
878                 ret = -EBUSY;
879                 goto pin_done;
880         }
881
882         /*
883          * Wait for all necessary vaddr's to be valid so they can be used in
884          * the main loop without dropping the lock, to avoid racing vs unmap.
885          */
886 again:
887         if (iommu->vaddr_invalid_count) {
888                 for (i = 0; i < npage; i++) {
889                         iova = user_iova + PAGE_SIZE * i;
890                         ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
891                         if (ret < 0)
892                                 goto pin_done;
893                         if (ret == WAITED)
894                                 goto again;
895                 }
896         }
897
898         /* Fail if no dma_umap notifier is registered */
899         if (list_empty(&iommu->device_list)) {
900                 ret = -EINVAL;
901                 goto pin_done;
902         }
903
904         /*
905          * If iommu capable domain exist in the container then all pages are
906          * already pinned and accounted. Accounting should be done if there is no
907          * iommu capable domain in the container.
908          */
909         do_accounting = list_empty(&iommu->domain_list);
910
911         for (i = 0; i < npage; i++) {
912                 unsigned long phys_pfn;
913                 struct vfio_pfn *vpfn;
914
915                 iova = user_iova + PAGE_SIZE * i;
916                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
917                 if (!dma) {
918                         ret = -EINVAL;
919                         goto pin_unwind;
920                 }
921
922                 if ((dma->prot & prot) != prot) {
923                         ret = -EPERM;
924                         goto pin_unwind;
925                 }
926
927                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
928                 if (vpfn) {
929                         pages[i] = pfn_to_page(vpfn->pfn);
930                         continue;
931                 }
932
933                 remote_vaddr = dma->vaddr + (iova - dma->iova);
934                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn,
935                                              do_accounting);
936                 if (ret)
937                         goto pin_unwind;
938
939                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn);
940                 if (ret) {
941                         if (put_pfn(phys_pfn, dma->prot) && do_accounting)
942                                 vfio_lock_acct(dma, -1, true);
943                         goto pin_unwind;
944                 }
945
946                 pages[i] = pfn_to_page(phys_pfn);
947
948                 if (iommu->dirty_page_tracking) {
949                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
950
951                         /*
952                          * Bitmap populated with the smallest supported page
953                          * size
954                          */
955                         bitmap_set(dma->bitmap,
956                                    (iova - dma->iova) >> pgshift, 1);
957                 }
958         }
959         ret = i;
960
961         group = vfio_iommu_find_iommu_group(iommu, iommu_group);
962         if (!group->pinned_page_dirty_scope) {
963                 group->pinned_page_dirty_scope = true;
964                 iommu->num_non_pinned_groups--;
965         }
966
967         goto pin_done;
968
969 pin_unwind:
970         pages[i] = NULL;
971         for (j = 0; j < i; j++) {
972                 dma_addr_t iova;
973
974                 iova = user_iova + PAGE_SIZE * j;
975                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
976                 vfio_unpin_page_external(dma, iova, do_accounting);
977                 pages[j] = NULL;
978         }
979 pin_done:
980         mutex_unlock(&iommu->lock);
981         return ret;
982 }
983
984 static void vfio_iommu_type1_unpin_pages(void *iommu_data,
985                                          dma_addr_t user_iova, int npage)
986 {
987         struct vfio_iommu *iommu = iommu_data;
988         bool do_accounting;
989         int i;
990
991         /* Supported for v2 version only */
992         if (WARN_ON(!iommu->v2))
993                 return;
994
995         mutex_lock(&iommu->lock);
996
997         do_accounting = list_empty(&iommu->domain_list);
998         for (i = 0; i < npage; i++) {
999                 dma_addr_t iova = user_iova + PAGE_SIZE * i;
1000                 struct vfio_dma *dma;
1001
1002                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
1003                 if (!dma)
1004                         break;
1005
1006                 vfio_unpin_page_external(dma, iova, do_accounting);
1007         }
1008
1009         mutex_unlock(&iommu->lock);
1010
1011         WARN_ON(i != npage);
1012 }
1013
1014 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
1015                             struct list_head *regions,
1016                             struct iommu_iotlb_gather *iotlb_gather)
1017 {
1018         long unlocked = 0;
1019         struct vfio_regions *entry, *next;
1020
1021         iommu_iotlb_sync(domain->domain, iotlb_gather);
1022
1023         list_for_each_entry_safe(entry, next, regions, list) {
1024                 unlocked += vfio_unpin_pages_remote(dma,
1025                                                     entry->iova,
1026                                                     entry->phys >> PAGE_SHIFT,
1027                                                     entry->len >> PAGE_SHIFT,
1028                                                     false);
1029                 list_del(&entry->list);
1030                 kfree(entry);
1031         }
1032
1033         cond_resched();
1034
1035         return unlocked;
1036 }
1037
1038 /*
1039  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
1040  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
1041  * of these regions (currently using a list).
1042  *
1043  * This value specifies maximum number of regions for each IOTLB flush sync.
1044  */
1045 #define VFIO_IOMMU_TLB_SYNC_MAX         512
1046
1047 static size_t unmap_unpin_fast(struct vfio_domain *domain,
1048                                struct vfio_dma *dma, dma_addr_t *iova,
1049                                size_t len, phys_addr_t phys, long *unlocked,
1050                                struct list_head *unmapped_list,
1051                                int *unmapped_cnt,
1052                                struct iommu_iotlb_gather *iotlb_gather)
1053 {
1054         size_t unmapped = 0;
1055         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
1056
1057         if (entry) {
1058                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
1059                                             iotlb_gather);
1060
1061                 if (!unmapped) {
1062                         kfree(entry);
1063                 } else {
1064                         entry->iova = *iova;
1065                         entry->phys = phys;
1066                         entry->len  = unmapped;
1067                         list_add_tail(&entry->list, unmapped_list);
1068
1069                         *iova += unmapped;
1070                         (*unmapped_cnt)++;
1071                 }
1072         }
1073
1074         /*
1075          * Sync if the number of fast-unmap regions hits the limit
1076          * or in case of errors.
1077          */
1078         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
1079                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
1080                                              iotlb_gather);
1081                 *unmapped_cnt = 0;
1082         }
1083
1084         return unmapped;
1085 }
1086
1087 static size_t unmap_unpin_slow(struct vfio_domain *domain,
1088                                struct vfio_dma *dma, dma_addr_t *iova,
1089                                size_t len, phys_addr_t phys,
1090                                long *unlocked)
1091 {
1092         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
1093
1094         if (unmapped) {
1095                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
1096                                                      phys >> PAGE_SHIFT,
1097                                                      unmapped >> PAGE_SHIFT,
1098                                                      false);
1099                 *iova += unmapped;
1100                 cond_resched();
1101         }
1102         return unmapped;
1103 }
1104
1105 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
1106                              bool do_accounting)
1107 {
1108         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
1109         struct vfio_domain *domain, *d;
1110         LIST_HEAD(unmapped_region_list);
1111         struct iommu_iotlb_gather iotlb_gather;
1112         int unmapped_region_cnt = 0;
1113         long unlocked = 0;
1114
1115         if (!dma->size)
1116                 return 0;
1117
1118         if (list_empty(&iommu->domain_list))
1119                 return 0;
1120
1121         /*
1122          * We use the IOMMU to track the physical addresses, otherwise we'd
1123          * need a much more complicated tracking system.  Unfortunately that
1124          * means we need to use one of the iommu domains to figure out the
1125          * pfns to unpin.  The rest need to be unmapped in advance so we have
1126          * no iommu translations remaining when the pages are unpinned.
1127          */
1128         domain = d = list_first_entry(&iommu->domain_list,
1129                                       struct vfio_domain, next);
1130
1131         list_for_each_entry_continue(d, &iommu->domain_list, next) {
1132                 iommu_unmap(d->domain, dma->iova, dma->size);
1133                 cond_resched();
1134         }
1135
1136         iommu_iotlb_gather_init(&iotlb_gather);
1137         while (iova < end) {
1138                 size_t unmapped, len;
1139                 phys_addr_t phys, next;
1140
1141                 phys = iommu_iova_to_phys(domain->domain, iova);
1142                 if (WARN_ON(!phys)) {
1143                         iova += PAGE_SIZE;
1144                         continue;
1145                 }
1146
1147                 /*
1148                  * To optimize for fewer iommu_unmap() calls, each of which
1149                  * may require hardware cache flushing, try to find the
1150                  * largest contiguous physical memory chunk to unmap.
1151                  */
1152                 for (len = PAGE_SIZE;
1153                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
1154                         next = iommu_iova_to_phys(domain->domain, iova + len);
1155                         if (next != phys + len)
1156                                 break;
1157                 }
1158
1159                 /*
1160                  * First, try to use fast unmap/unpin. In case of failure,
1161                  * switch to slow unmap/unpin path.
1162                  */
1163                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
1164                                             &unlocked, &unmapped_region_list,
1165                                             &unmapped_region_cnt,
1166                                             &iotlb_gather);
1167                 if (!unmapped) {
1168                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
1169                                                     phys, &unlocked);
1170                         if (WARN_ON(!unmapped))
1171                                 break;
1172                 }
1173         }
1174
1175         dma->iommu_mapped = false;
1176
1177         if (unmapped_region_cnt) {
1178                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
1179                                             &iotlb_gather);
1180         }
1181
1182         if (do_accounting) {
1183                 vfio_lock_acct(dma, -unlocked, true);
1184                 return 0;
1185         }
1186         return unlocked;
1187 }
1188
1189 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
1190 {
1191         WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
1192         vfio_unmap_unpin(iommu, dma, true);
1193         vfio_unlink_dma(iommu, dma);
1194         put_task_struct(dma->task);
1195         mmdrop(dma->mm);
1196         vfio_dma_bitmap_free(dma);
1197         if (dma->vaddr_invalid) {
1198                 iommu->vaddr_invalid_count--;
1199                 wake_up_all(&iommu->vaddr_wait);
1200         }
1201         kfree(dma);
1202         iommu->dma_avail++;
1203 }
1204
1205 static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
1206 {
1207         struct vfio_domain *domain;
1208
1209         iommu->pgsize_bitmap = ULONG_MAX;
1210
1211         list_for_each_entry(domain, &iommu->domain_list, next)
1212                 iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
1213
1214         /*
1215          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
1216          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
1217          * That way the user will be able to map/unmap buffers whose size/
1218          * start address is aligned with PAGE_SIZE. Pinning code uses that
1219          * granularity while iommu driver can use the sub-PAGE_SIZE size
1220          * to map the buffer.
1221          */
1222         if (iommu->pgsize_bitmap & ~PAGE_MASK) {
1223                 iommu->pgsize_bitmap &= PAGE_MASK;
1224                 iommu->pgsize_bitmap |= PAGE_SIZE;
1225         }
1226 }
1227
1228 static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1229                               struct vfio_dma *dma, dma_addr_t base_iova,
1230                               size_t pgsize)
1231 {
1232         unsigned long pgshift = __ffs(pgsize);
1233         unsigned long nbits = dma->size >> pgshift;
1234         unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
1235         unsigned long copy_offset = bit_offset / BITS_PER_LONG;
1236         unsigned long shift = bit_offset % BITS_PER_LONG;
1237         unsigned long leftover;
1238
1239         /*
1240          * mark all pages dirty if any IOMMU capable device is not able
1241          * to report dirty pages and all pages are pinned and mapped.
1242          */
1243         if (iommu->num_non_pinned_groups && dma->iommu_mapped)
1244                 bitmap_set(dma->bitmap, 0, nbits);
1245
1246         if (shift) {
1247                 bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1248                                   nbits + shift);
1249
1250                 if (copy_from_user(&leftover,
1251                                    (void __user *)(bitmap + copy_offset),
1252                                    sizeof(leftover)))
1253                         return -EFAULT;
1254
1255                 bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1256         }
1257
1258         if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1259                          DIRTY_BITMAP_BYTES(nbits + shift)))
1260                 return -EFAULT;
1261
1262         return 0;
1263 }
1264
1265 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1266                                   dma_addr_t iova, size_t size, size_t pgsize)
1267 {
1268         struct vfio_dma *dma;
1269         struct rb_node *n;
1270         unsigned long pgshift = __ffs(pgsize);
1271         int ret;
1272
1273         /*
1274          * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1275          * vfio_dma mappings may be clubbed by specifying large ranges, but
1276          * there must not be any previous mappings bisected by the range.
1277          * An error will be returned if these conditions are not met.
1278          */
1279         dma = vfio_find_dma(iommu, iova, 1);
1280         if (dma && dma->iova != iova)
1281                 return -EINVAL;
1282
1283         dma = vfio_find_dma(iommu, iova + size - 1, 0);
1284         if (dma && dma->iova + dma->size != iova + size)
1285                 return -EINVAL;
1286
1287         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1288                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1289
1290                 if (dma->iova < iova)
1291                         continue;
1292
1293                 if (dma->iova > iova + size - 1)
1294                         break;
1295
1296                 ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1297                 if (ret)
1298                         return ret;
1299
1300                 /*
1301                  * Re-populate bitmap to include all pinned pages which are
1302                  * considered as dirty but exclude pages which are unpinned and
1303                  * pages which are marked dirty by vfio_dma_rw()
1304                  */
1305                 bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1306                 vfio_dma_populate_bitmap(dma, pgsize);
1307         }
1308         return 0;
1309 }
1310
1311 static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1312 {
1313         if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1314             (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1315                 return -EINVAL;
1316
1317         return 0;
1318 }
1319
1320 /*
1321  * Notify VFIO drivers using vfio_register_emulated_iommu_dev() to invalidate
1322  * and unmap iovas within the range we're about to unmap. Drivers MUST unpin
1323  * pages in response to an invalidation.
1324  */
1325 static void vfio_notify_dma_unmap(struct vfio_iommu *iommu,
1326                                   struct vfio_dma *dma)
1327 {
1328         struct vfio_device *device;
1329
1330         if (list_empty(&iommu->device_list))
1331                 return;
1332
1333         /*
1334          * The device is expected to call vfio_unpin_pages() for any IOVA it has
1335          * pinned within the range. Since vfio_unpin_pages() will eventually
1336          * call back down to this code and try to obtain the iommu->lock we must
1337          * drop it.
1338          */
1339         mutex_lock(&iommu->device_list_lock);
1340         mutex_unlock(&iommu->lock);
1341
1342         list_for_each_entry(device, &iommu->device_list, iommu_entry)
1343                 device->ops->dma_unmap(device, dma->iova, dma->size);
1344
1345         mutex_unlock(&iommu->device_list_lock);
1346         mutex_lock(&iommu->lock);
1347 }
1348
1349 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1350                              struct vfio_iommu_type1_dma_unmap *unmap,
1351                              struct vfio_bitmap *bitmap)
1352 {
1353         struct vfio_dma *dma, *dma_last = NULL;
1354         size_t unmapped = 0, pgsize;
1355         int ret = -EINVAL, retries = 0;
1356         unsigned long pgshift;
1357         dma_addr_t iova = unmap->iova;
1358         u64 size = unmap->size;
1359         bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
1360         bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
1361         struct rb_node *n, *first_n;
1362
1363         mutex_lock(&iommu->lock);
1364
1365         /* Cannot update vaddr if mdev is present. */
1366         if (invalidate_vaddr && !list_empty(&iommu->emulated_iommu_groups)) {
1367                 ret = -EBUSY;
1368                 goto unlock;
1369         }
1370
1371         pgshift = __ffs(iommu->pgsize_bitmap);
1372         pgsize = (size_t)1 << pgshift;
1373
1374         if (iova & (pgsize - 1))
1375                 goto unlock;
1376
1377         if (unmap_all) {
1378                 if (iova || size)
1379                         goto unlock;
1380                 size = U64_MAX;
1381         } else if (!size || size & (pgsize - 1) ||
1382                    iova + size - 1 < iova || size > SIZE_MAX) {
1383                 goto unlock;
1384         }
1385
1386         /* When dirty tracking is enabled, allow only min supported pgsize */
1387         if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1388             (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1389                 goto unlock;
1390         }
1391
1392         WARN_ON((pgsize - 1) & PAGE_MASK);
1393 again:
1394         /*
1395          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1396          * avoid tracking individual mappings.  This means that the granularity
1397          * of the original mapping was lost and the user was allowed to attempt
1398          * to unmap any range.  Depending on the contiguousness of physical
1399          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1400          * or may not have worked.  We only guaranteed unmap granularity
1401          * matching the original mapping; even though it was untracked here,
1402          * the original mappings are reflected in IOMMU mappings.  This
1403          * resulted in a couple unusual behaviors.  First, if a range is not
1404          * able to be unmapped, ex. a set of 4k pages that was mapped as a
1405          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1406          * a zero sized unmap.  Also, if an unmap request overlaps the first
1407          * address of a hugepage, the IOMMU will unmap the entire hugepage.
1408          * This also returns success and the returned unmap size reflects the
1409          * actual size unmapped.
1410          *
1411          * We attempt to maintain compatibility with this "v1" interface, but
1412          * we take control out of the hands of the IOMMU.  Therefore, an unmap
1413          * request offset from the beginning of the original mapping will
1414          * return success with zero sized unmap.  And an unmap request covering
1415          * the first iova of mapping will unmap the entire range.
1416          *
1417          * The v2 version of this interface intends to be more deterministic.
1418          * Unmap requests must fully cover previous mappings.  Multiple
1419          * mappings may still be unmaped by specifying large ranges, but there
1420          * must not be any previous mappings bisected by the range.  An error
1421          * will be returned if these conditions are not met.  The v2 interface
1422          * will only return success and a size of zero if there were no
1423          * mappings within the range.
1424          */
1425         if (iommu->v2 && !unmap_all) {
1426                 dma = vfio_find_dma(iommu, iova, 1);
1427                 if (dma && dma->iova != iova)
1428                         goto unlock;
1429
1430                 dma = vfio_find_dma(iommu, iova + size - 1, 0);
1431                 if (dma && dma->iova + dma->size != iova + size)
1432                         goto unlock;
1433         }
1434
1435         ret = 0;
1436         n = first_n = vfio_find_dma_first_node(iommu, iova, size);
1437
1438         while (n) {
1439                 dma = rb_entry(n, struct vfio_dma, node);
1440                 if (dma->iova >= iova + size)
1441                         break;
1442
1443                 if (!iommu->v2 && iova > dma->iova)
1444                         break;
1445
1446                 if (invalidate_vaddr) {
1447                         if (dma->vaddr_invalid) {
1448                                 struct rb_node *last_n = n;
1449
1450                                 for (n = first_n; n != last_n; n = rb_next(n)) {
1451                                         dma = rb_entry(n,
1452                                                        struct vfio_dma, node);
1453                                         dma->vaddr_invalid = false;
1454                                         iommu->vaddr_invalid_count--;
1455                                 }
1456                                 ret = -EINVAL;
1457                                 unmapped = 0;
1458                                 break;
1459                         }
1460                         dma->vaddr_invalid = true;
1461                         iommu->vaddr_invalid_count++;
1462                         unmapped += dma->size;
1463                         n = rb_next(n);
1464                         continue;
1465                 }
1466
1467                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1468                         if (dma_last == dma) {
1469                                 BUG_ON(++retries > 10);
1470                         } else {
1471                                 dma_last = dma;
1472                                 retries = 0;
1473                         }
1474
1475                         vfio_notify_dma_unmap(iommu, dma);
1476                         goto again;
1477                 }
1478
1479                 if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1480                         ret = update_user_bitmap(bitmap->data, iommu, dma,
1481                                                  iova, pgsize);
1482                         if (ret)
1483                                 break;
1484                 }
1485
1486                 unmapped += dma->size;
1487                 n = rb_next(n);
1488                 vfio_remove_dma(iommu, dma);
1489         }
1490
1491 unlock:
1492         mutex_unlock(&iommu->lock);
1493
1494         /* Report how much was unmapped */
1495         unmap->size = unmapped;
1496
1497         return ret;
1498 }
1499
1500 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1501                           unsigned long pfn, long npage, int prot)
1502 {
1503         struct vfio_domain *d;
1504         int ret;
1505
1506         list_for_each_entry(d, &iommu->domain_list, next) {
1507                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1508                                 npage << PAGE_SHIFT, prot | IOMMU_CACHE);
1509                 if (ret)
1510                         goto unwind;
1511
1512                 cond_resched();
1513         }
1514
1515         return 0;
1516
1517 unwind:
1518         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1519                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1520                 cond_resched();
1521         }
1522
1523         return ret;
1524 }
1525
1526 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1527                             size_t map_size)
1528 {
1529         dma_addr_t iova = dma->iova;
1530         unsigned long vaddr = dma->vaddr;
1531         struct vfio_batch batch;
1532         size_t size = map_size;
1533         long npage;
1534         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1535         int ret = 0;
1536
1537         vfio_batch_init(&batch);
1538
1539         while (size) {
1540                 /* Pin a contiguous chunk of memory */
1541                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1542                                               size >> PAGE_SHIFT, &pfn, limit,
1543                                               &batch);
1544                 if (npage <= 0) {
1545                         WARN_ON(!npage);
1546                         ret = (int)npage;
1547                         break;
1548                 }
1549
1550                 /* Map it! */
1551                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1552                                      dma->prot);
1553                 if (ret) {
1554                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1555                                                 npage, true);
1556                         vfio_batch_unpin(&batch, dma);
1557                         break;
1558                 }
1559
1560                 size -= npage << PAGE_SHIFT;
1561                 dma->size += npage << PAGE_SHIFT;
1562         }
1563
1564         vfio_batch_fini(&batch);
1565         dma->iommu_mapped = true;
1566
1567         if (ret)
1568                 vfio_remove_dma(iommu, dma);
1569
1570         return ret;
1571 }
1572
1573 /*
1574  * Check dma map request is within a valid iova range
1575  */
1576 static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1577                                       dma_addr_t start, dma_addr_t end)
1578 {
1579         struct list_head *iova = &iommu->iova_list;
1580         struct vfio_iova *node;
1581
1582         list_for_each_entry(node, iova, list) {
1583                 if (start >= node->start && end <= node->end)
1584                         return true;
1585         }
1586
1587         /*
1588          * Check for list_empty() as well since a container with
1589          * a single mdev device will have an empty list.
1590          */
1591         return list_empty(iova);
1592 }
1593
1594 static int vfio_change_dma_owner(struct vfio_dma *dma)
1595 {
1596         struct task_struct *task = current->group_leader;
1597         struct mm_struct *mm = current->mm;
1598         long npage = dma->locked_vm;
1599         bool lock_cap;
1600         int ret;
1601
1602         if (mm == dma->mm)
1603                 return 0;
1604
1605         lock_cap = capable(CAP_IPC_LOCK);
1606         ret = mm_lock_acct(task, mm, lock_cap, npage);
1607         if (ret)
1608                 return ret;
1609
1610         if (mmget_not_zero(dma->mm)) {
1611                 mm_lock_acct(dma->task, dma->mm, dma->lock_cap, -npage);
1612                 mmput(dma->mm);
1613         }
1614
1615         if (dma->task != task) {
1616                 put_task_struct(dma->task);
1617                 dma->task = get_task_struct(task);
1618         }
1619         mmdrop(dma->mm);
1620         dma->mm = mm;
1621         mmgrab(dma->mm);
1622         dma->lock_cap = lock_cap;
1623         return 0;
1624 }
1625
1626 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1627                            struct vfio_iommu_type1_dma_map *map)
1628 {
1629         bool set_vaddr = map->flags & VFIO_DMA_MAP_FLAG_VADDR;
1630         dma_addr_t iova = map->iova;
1631         unsigned long vaddr = map->vaddr;
1632         size_t size = map->size;
1633         int ret = 0, prot = 0;
1634         size_t pgsize;
1635         struct vfio_dma *dma;
1636
1637         /* Verify that none of our __u64 fields overflow */
1638         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1639                 return -EINVAL;
1640
1641         /* READ/WRITE from device perspective */
1642         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1643                 prot |= IOMMU_WRITE;
1644         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1645                 prot |= IOMMU_READ;
1646
1647         if ((prot && set_vaddr) || (!prot && !set_vaddr))
1648                 return -EINVAL;
1649
1650         mutex_lock(&iommu->lock);
1651
1652         pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1653
1654         WARN_ON((pgsize - 1) & PAGE_MASK);
1655
1656         if (!size || (size | iova | vaddr) & (pgsize - 1)) {
1657                 ret = -EINVAL;
1658                 goto out_unlock;
1659         }
1660
1661         /* Don't allow IOVA or virtual address wrap */
1662         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1663                 ret = -EINVAL;
1664                 goto out_unlock;
1665         }
1666
1667         dma = vfio_find_dma(iommu, iova, size);
1668         if (set_vaddr) {
1669                 if (!dma) {
1670                         ret = -ENOENT;
1671                 } else if (!dma->vaddr_invalid || dma->iova != iova ||
1672                            dma->size != size) {
1673                         ret = -EINVAL;
1674                 } else {
1675                         ret = vfio_change_dma_owner(dma);
1676                         if (ret)
1677                                 goto out_unlock;
1678                         dma->vaddr = vaddr;
1679                         dma->vaddr_invalid = false;
1680                         iommu->vaddr_invalid_count--;
1681                         wake_up_all(&iommu->vaddr_wait);
1682                 }
1683                 goto out_unlock;
1684         } else if (dma) {
1685                 ret = -EEXIST;
1686                 goto out_unlock;
1687         }
1688
1689         if (!iommu->dma_avail) {
1690                 ret = -ENOSPC;
1691                 goto out_unlock;
1692         }
1693
1694         if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1695                 ret = -EINVAL;
1696                 goto out_unlock;
1697         }
1698
1699         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1700         if (!dma) {
1701                 ret = -ENOMEM;
1702                 goto out_unlock;
1703         }
1704
1705         iommu->dma_avail--;
1706         dma->iova = iova;
1707         dma->vaddr = vaddr;
1708         dma->prot = prot;
1709
1710         /*
1711          * We need to be able to both add to a task's locked memory and test
1712          * against the locked memory limit and we need to be able to do both
1713          * outside of this call path as pinning can be asynchronous via the
1714          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1715          * task_struct. Save the group_leader so that all DMA tracking uses
1716          * the same task, to make debugging easier.  VM locked pages requires
1717          * an mm_struct, so grab the mm in case the task dies.
1718          */
1719         get_task_struct(current->group_leader);
1720         dma->task = current->group_leader;
1721         dma->lock_cap = capable(CAP_IPC_LOCK);
1722         dma->mm = current->mm;
1723         mmgrab(dma->mm);
1724
1725         dma->pfn_list = RB_ROOT;
1726
1727         /* Insert zero-sized and grow as we map chunks of it */
1728         vfio_link_dma(iommu, dma);
1729
1730         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1731         if (list_empty(&iommu->domain_list))
1732                 dma->size = size;
1733         else
1734                 ret = vfio_pin_map_dma(iommu, dma, size);
1735
1736         if (!ret && iommu->dirty_page_tracking) {
1737                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
1738                 if (ret)
1739                         vfio_remove_dma(iommu, dma);
1740         }
1741
1742 out_unlock:
1743         mutex_unlock(&iommu->lock);
1744         return ret;
1745 }
1746
1747 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1748                              struct vfio_domain *domain)
1749 {
1750         struct vfio_batch batch;
1751         struct vfio_domain *d = NULL;
1752         struct rb_node *n;
1753         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1754         int ret;
1755
1756         ret = vfio_wait_all_valid(iommu);
1757         if (ret < 0)
1758                 return ret;
1759
1760         /* Arbitrarily pick the first domain in the list for lookups */
1761         if (!list_empty(&iommu->domain_list))
1762                 d = list_first_entry(&iommu->domain_list,
1763                                      struct vfio_domain, next);
1764
1765         vfio_batch_init(&batch);
1766
1767         n = rb_first(&iommu->dma_list);
1768
1769         for (; n; n = rb_next(n)) {
1770                 struct vfio_dma *dma;
1771                 dma_addr_t iova;
1772
1773                 dma = rb_entry(n, struct vfio_dma, node);
1774                 iova = dma->iova;
1775
1776                 while (iova < dma->iova + dma->size) {
1777                         phys_addr_t phys;
1778                         size_t size;
1779
1780                         if (dma->iommu_mapped) {
1781                                 phys_addr_t p;
1782                                 dma_addr_t i;
1783
1784                                 if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1785                                         ret = -EINVAL;
1786                                         goto unwind;
1787                                 }
1788
1789                                 phys = iommu_iova_to_phys(d->domain, iova);
1790
1791                                 if (WARN_ON(!phys)) {
1792                                         iova += PAGE_SIZE;
1793                                         continue;
1794                                 }
1795
1796                                 size = PAGE_SIZE;
1797                                 p = phys + size;
1798                                 i = iova + size;
1799                                 while (i < dma->iova + dma->size &&
1800                                        p == iommu_iova_to_phys(d->domain, i)) {
1801                                         size += PAGE_SIZE;
1802                                         p += PAGE_SIZE;
1803                                         i += PAGE_SIZE;
1804                                 }
1805                         } else {
1806                                 unsigned long pfn;
1807                                 unsigned long vaddr = dma->vaddr +
1808                                                      (iova - dma->iova);
1809                                 size_t n = dma->iova + dma->size - iova;
1810                                 long npage;
1811
1812                                 npage = vfio_pin_pages_remote(dma, vaddr,
1813                                                               n >> PAGE_SHIFT,
1814                                                               &pfn, limit,
1815                                                               &batch);
1816                                 if (npage <= 0) {
1817                                         WARN_ON(!npage);
1818                                         ret = (int)npage;
1819                                         goto unwind;
1820                                 }
1821
1822                                 phys = pfn << PAGE_SHIFT;
1823                                 size = npage << PAGE_SHIFT;
1824                         }
1825
1826                         ret = iommu_map(domain->domain, iova, phys,
1827                                         size, dma->prot | IOMMU_CACHE);
1828                         if (ret) {
1829                                 if (!dma->iommu_mapped) {
1830                                         vfio_unpin_pages_remote(dma, iova,
1831                                                         phys >> PAGE_SHIFT,
1832                                                         size >> PAGE_SHIFT,
1833                                                         true);
1834                                         vfio_batch_unpin(&batch, dma);
1835                                 }
1836                                 goto unwind;
1837                         }
1838
1839                         iova += size;
1840                 }
1841         }
1842
1843         /* All dmas are now mapped, defer to second tree walk for unwind */
1844         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1845                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1846
1847                 dma->iommu_mapped = true;
1848         }
1849
1850         vfio_batch_fini(&batch);
1851         return 0;
1852
1853 unwind:
1854         for (; n; n = rb_prev(n)) {
1855                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1856                 dma_addr_t iova;
1857
1858                 if (dma->iommu_mapped) {
1859                         iommu_unmap(domain->domain, dma->iova, dma->size);
1860                         continue;
1861                 }
1862
1863                 iova = dma->iova;
1864                 while (iova < dma->iova + dma->size) {
1865                         phys_addr_t phys, p;
1866                         size_t size;
1867                         dma_addr_t i;
1868
1869                         phys = iommu_iova_to_phys(domain->domain, iova);
1870                         if (!phys) {
1871                                 iova += PAGE_SIZE;
1872                                 continue;
1873                         }
1874
1875                         size = PAGE_SIZE;
1876                         p = phys + size;
1877                         i = iova + size;
1878                         while (i < dma->iova + dma->size &&
1879                                p == iommu_iova_to_phys(domain->domain, i)) {
1880                                 size += PAGE_SIZE;
1881                                 p += PAGE_SIZE;
1882                                 i += PAGE_SIZE;
1883                         }
1884
1885                         iommu_unmap(domain->domain, iova, size);
1886                         vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1887                                                 size >> PAGE_SHIFT, true);
1888                 }
1889         }
1890
1891         vfio_batch_fini(&batch);
1892         return ret;
1893 }
1894
1895 /*
1896  * We change our unmap behavior slightly depending on whether the IOMMU
1897  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1898  * for practically any contiguous power-of-two mapping we give it.  This means
1899  * we don't need to look for contiguous chunks ourselves to make unmapping
1900  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1901  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1902  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1903  * hugetlbfs is in use.
1904  */
1905 static void vfio_test_domain_fgsp(struct vfio_domain *domain, struct list_head *regions)
1906 {
1907         int ret, order = get_order(PAGE_SIZE * 2);
1908         struct vfio_iova *region;
1909         struct page *pages;
1910         dma_addr_t start;
1911
1912         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1913         if (!pages)
1914                 return;
1915
1916         list_for_each_entry(region, regions, list) {
1917                 start = ALIGN(region->start, PAGE_SIZE * 2);
1918                 if (start >= region->end || (region->end - start < PAGE_SIZE * 2))
1919                         continue;
1920
1921                 ret = iommu_map(domain->domain, start, page_to_phys(pages), PAGE_SIZE * 2,
1922                                 IOMMU_READ | IOMMU_WRITE | IOMMU_CACHE);
1923                 if (!ret) {
1924                         size_t unmapped = iommu_unmap(domain->domain, start, PAGE_SIZE);
1925
1926                         if (unmapped == PAGE_SIZE)
1927                                 iommu_unmap(domain->domain, start + PAGE_SIZE, PAGE_SIZE);
1928                         else
1929                                 domain->fgsp = true;
1930                 }
1931                 break;
1932         }
1933
1934         __free_pages(pages, order);
1935 }
1936
1937 static struct vfio_iommu_group *find_iommu_group(struct vfio_domain *domain,
1938                                                  struct iommu_group *iommu_group)
1939 {
1940         struct vfio_iommu_group *g;
1941
1942         list_for_each_entry(g, &domain->group_list, next) {
1943                 if (g->iommu_group == iommu_group)
1944                         return g;
1945         }
1946
1947         return NULL;
1948 }
1949
1950 static struct vfio_iommu_group*
1951 vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1952                             struct iommu_group *iommu_group)
1953 {
1954         struct vfio_iommu_group *group;
1955         struct vfio_domain *domain;
1956
1957         list_for_each_entry(domain, &iommu->domain_list, next) {
1958                 group = find_iommu_group(domain, iommu_group);
1959                 if (group)
1960                         return group;
1961         }
1962
1963         list_for_each_entry(group, &iommu->emulated_iommu_groups, next)
1964                 if (group->iommu_group == iommu_group)
1965                         return group;
1966         return NULL;
1967 }
1968
1969 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1970                                   phys_addr_t *base)
1971 {
1972         struct iommu_resv_region *region;
1973         bool ret = false;
1974
1975         list_for_each_entry(region, group_resv_regions, list) {
1976                 /*
1977                  * The presence of any 'real' MSI regions should take
1978                  * precedence over the software-managed one if the
1979                  * IOMMU driver happens to advertise both types.
1980                  */
1981                 if (region->type == IOMMU_RESV_MSI) {
1982                         ret = false;
1983                         break;
1984                 }
1985
1986                 if (region->type == IOMMU_RESV_SW_MSI) {
1987                         *base = region->start;
1988                         ret = true;
1989                 }
1990         }
1991
1992         return ret;
1993 }
1994
1995 /*
1996  * This is a helper function to insert an address range to iova list.
1997  * The list is initially created with a single entry corresponding to
1998  * the IOMMU domain geometry to which the device group is attached.
1999  * The list aperture gets modified when a new domain is added to the
2000  * container if the new aperture doesn't conflict with the current one
2001  * or with any existing dma mappings. The list is also modified to
2002  * exclude any reserved regions associated with the device group.
2003  */
2004 static int vfio_iommu_iova_insert(struct list_head *head,
2005                                   dma_addr_t start, dma_addr_t end)
2006 {
2007         struct vfio_iova *region;
2008
2009         region = kmalloc(sizeof(*region), GFP_KERNEL);
2010         if (!region)
2011                 return -ENOMEM;
2012
2013         INIT_LIST_HEAD(&region->list);
2014         region->start = start;
2015         region->end = end;
2016
2017         list_add_tail(&region->list, head);
2018         return 0;
2019 }
2020
2021 /*
2022  * Check the new iommu aperture conflicts with existing aper or with any
2023  * existing dma mappings.
2024  */
2025 static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
2026                                      dma_addr_t start, dma_addr_t end)
2027 {
2028         struct vfio_iova *first, *last;
2029         struct list_head *iova = &iommu->iova_list;
2030
2031         if (list_empty(iova))
2032                 return false;
2033
2034         /* Disjoint sets, return conflict */
2035         first = list_first_entry(iova, struct vfio_iova, list);
2036         last = list_last_entry(iova, struct vfio_iova, list);
2037         if (start > last->end || end < first->start)
2038                 return true;
2039
2040         /* Check for any existing dma mappings below the new start */
2041         if (start > first->start) {
2042                 if (vfio_find_dma(iommu, first->start, start - first->start))
2043                         return true;
2044         }
2045
2046         /* Check for any existing dma mappings beyond the new end */
2047         if (end < last->end) {
2048                 if (vfio_find_dma(iommu, end + 1, last->end - end))
2049                         return true;
2050         }
2051
2052         return false;
2053 }
2054
2055 /*
2056  * Resize iommu iova aperture window. This is called only if the new
2057  * aperture has no conflict with existing aperture and dma mappings.
2058  */
2059 static int vfio_iommu_aper_resize(struct list_head *iova,
2060                                   dma_addr_t start, dma_addr_t end)
2061 {
2062         struct vfio_iova *node, *next;
2063
2064         if (list_empty(iova))
2065                 return vfio_iommu_iova_insert(iova, start, end);
2066
2067         /* Adjust iova list start */
2068         list_for_each_entry_safe(node, next, iova, list) {
2069                 if (start < node->start)
2070                         break;
2071                 if (start >= node->start && start < node->end) {
2072                         node->start = start;
2073                         break;
2074                 }
2075                 /* Delete nodes before new start */
2076                 list_del(&node->list);
2077                 kfree(node);
2078         }
2079
2080         /* Adjust iova list end */
2081         list_for_each_entry_safe(node, next, iova, list) {
2082                 if (end > node->end)
2083                         continue;
2084                 if (end > node->start && end <= node->end) {
2085                         node->end = end;
2086                         continue;
2087                 }
2088                 /* Delete nodes after new end */
2089                 list_del(&node->list);
2090                 kfree(node);
2091         }
2092
2093         return 0;
2094 }
2095
2096 /*
2097  * Check reserved region conflicts with existing dma mappings
2098  */
2099 static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
2100                                      struct list_head *resv_regions)
2101 {
2102         struct iommu_resv_region *region;
2103
2104         /* Check for conflict with existing dma mappings */
2105         list_for_each_entry(region, resv_regions, list) {
2106                 if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
2107                         continue;
2108
2109                 if (vfio_find_dma(iommu, region->start, region->length))
2110                         return true;
2111         }
2112
2113         return false;
2114 }
2115
2116 /*
2117  * Check iova region overlap with  reserved regions and
2118  * exclude them from the iommu iova range
2119  */
2120 static int vfio_iommu_resv_exclude(struct list_head *iova,
2121                                    struct list_head *resv_regions)
2122 {
2123         struct iommu_resv_region *resv;
2124         struct vfio_iova *n, *next;
2125
2126         list_for_each_entry(resv, resv_regions, list) {
2127                 phys_addr_t start, end;
2128
2129                 if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
2130                         continue;
2131
2132                 start = resv->start;
2133                 end = resv->start + resv->length - 1;
2134
2135                 list_for_each_entry_safe(n, next, iova, list) {
2136                         int ret = 0;
2137
2138                         /* No overlap */
2139                         if (start > n->end || end < n->start)
2140                                 continue;
2141                         /*
2142                          * Insert a new node if current node overlaps with the
2143                          * reserve region to exclude that from valid iova range.
2144                          * Note that, new node is inserted before the current
2145                          * node and finally the current node is deleted keeping
2146                          * the list updated and sorted.
2147                          */
2148                         if (start > n->start)
2149                                 ret = vfio_iommu_iova_insert(&n->list, n->start,
2150                                                              start - 1);
2151                         if (!ret && end < n->end)
2152                                 ret = vfio_iommu_iova_insert(&n->list, end + 1,
2153                                                              n->end);
2154                         if (ret)
2155                                 return ret;
2156
2157                         list_del(&n->list);
2158                         kfree(n);
2159                 }
2160         }
2161
2162         if (list_empty(iova))
2163                 return -EINVAL;
2164
2165         return 0;
2166 }
2167
2168 static void vfio_iommu_resv_free(struct list_head *resv_regions)
2169 {
2170         struct iommu_resv_region *n, *next;
2171
2172         list_for_each_entry_safe(n, next, resv_regions, list) {
2173                 list_del(&n->list);
2174                 kfree(n);
2175         }
2176 }
2177
2178 static void vfio_iommu_iova_free(struct list_head *iova)
2179 {
2180         struct vfio_iova *n, *next;
2181
2182         list_for_each_entry_safe(n, next, iova, list) {
2183                 list_del(&n->list);
2184                 kfree(n);
2185         }
2186 }
2187
2188 static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
2189                                     struct list_head *iova_copy)
2190 {
2191         struct list_head *iova = &iommu->iova_list;
2192         struct vfio_iova *n;
2193         int ret;
2194
2195         list_for_each_entry(n, iova, list) {
2196                 ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
2197                 if (ret)
2198                         goto out_free;
2199         }
2200
2201         return 0;
2202
2203 out_free:
2204         vfio_iommu_iova_free(iova_copy);
2205         return ret;
2206 }
2207
2208 static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2209                                         struct list_head *iova_copy)
2210 {
2211         struct list_head *iova = &iommu->iova_list;
2212
2213         vfio_iommu_iova_free(iova);
2214
2215         list_splice_tail(iova_copy, iova);
2216 }
2217
2218 /* Redundantly walks non-present capabilities to simplify caller */
2219 static int vfio_iommu_device_capable(struct device *dev, void *data)
2220 {
2221         return device_iommu_capable(dev, (enum iommu_cap)data);
2222 }
2223
2224 static int vfio_iommu_domain_alloc(struct device *dev, void *data)
2225 {
2226         struct iommu_domain **domain = data;
2227
2228         *domain = iommu_domain_alloc(dev->bus);
2229         return 1; /* Don't iterate */
2230 }
2231
2232 static int vfio_iommu_type1_attach_group(void *iommu_data,
2233                 struct iommu_group *iommu_group, enum vfio_group_type type)
2234 {
2235         struct vfio_iommu *iommu = iommu_data;
2236         struct vfio_iommu_group *group;
2237         struct vfio_domain *domain, *d;
2238         bool resv_msi, msi_remap;
2239         phys_addr_t resv_msi_base = 0;
2240         struct iommu_domain_geometry *geo;
2241         LIST_HEAD(iova_copy);
2242         LIST_HEAD(group_resv_regions);
2243         int ret = -EBUSY;
2244
2245         mutex_lock(&iommu->lock);
2246
2247         /* Attach could require pinning, so disallow while vaddr is invalid. */
2248         if (iommu->vaddr_invalid_count)
2249                 goto out_unlock;
2250
2251         /* Check for duplicates */
2252         ret = -EINVAL;
2253         if (vfio_iommu_find_iommu_group(iommu, iommu_group))
2254                 goto out_unlock;
2255
2256         ret = -ENOMEM;
2257         group = kzalloc(sizeof(*group), GFP_KERNEL);
2258         if (!group)
2259                 goto out_unlock;
2260         group->iommu_group = iommu_group;
2261
2262         if (type == VFIO_EMULATED_IOMMU) {
2263                 list_add(&group->next, &iommu->emulated_iommu_groups);
2264                 /*
2265                  * An emulated IOMMU group cannot dirty memory directly, it can
2266                  * only use interfaces that provide dirty tracking.
2267                  * The iommu scope can only be promoted with the addition of a
2268                  * dirty tracking group.
2269                  */
2270                 group->pinned_page_dirty_scope = true;
2271                 ret = 0;
2272                 goto out_unlock;
2273         }
2274
2275         ret = -ENOMEM;
2276         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2277         if (!domain)
2278                 goto out_free_group;
2279
2280         /*
2281          * Going via the iommu_group iterator avoids races, and trivially gives
2282          * us a representative device for the IOMMU API call. We don't actually
2283          * want to iterate beyond the first device (if any).
2284          */
2285         ret = -EIO;
2286         iommu_group_for_each_dev(iommu_group, &domain->domain,
2287                                  vfio_iommu_domain_alloc);
2288         if (!domain->domain)
2289                 goto out_free_domain;
2290
2291         if (iommu->nesting) {
2292                 ret = iommu_enable_nesting(domain->domain);
2293                 if (ret)
2294                         goto out_domain;
2295         }
2296
2297         ret = iommu_attach_group(domain->domain, group->iommu_group);
2298         if (ret)
2299                 goto out_domain;
2300
2301         /* Get aperture info */
2302         geo = &domain->domain->geometry;
2303         if (vfio_iommu_aper_conflict(iommu, geo->aperture_start,
2304                                      geo->aperture_end)) {
2305                 ret = -EINVAL;
2306                 goto out_detach;
2307         }
2308
2309         ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2310         if (ret)
2311                 goto out_detach;
2312
2313         if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2314                 ret = -EINVAL;
2315                 goto out_detach;
2316         }
2317
2318         /*
2319          * We don't want to work on the original iova list as the list
2320          * gets modified and in case of failure we have to retain the
2321          * original list. Get a copy here.
2322          */
2323         ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2324         if (ret)
2325                 goto out_detach;
2326
2327         ret = vfio_iommu_aper_resize(&iova_copy, geo->aperture_start,
2328                                      geo->aperture_end);
2329         if (ret)
2330                 goto out_detach;
2331
2332         ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2333         if (ret)
2334                 goto out_detach;
2335
2336         resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2337
2338         INIT_LIST_HEAD(&domain->group_list);
2339         list_add(&group->next, &domain->group_list);
2340
2341         msi_remap = irq_domain_check_msi_remap() ||
2342                     iommu_group_for_each_dev(iommu_group, (void *)IOMMU_CAP_INTR_REMAP,
2343                                              vfio_iommu_device_capable);
2344
2345         if (!allow_unsafe_interrupts && !msi_remap) {
2346                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2347                        __func__);
2348                 ret = -EPERM;
2349                 goto out_detach;
2350         }
2351
2352         /*
2353          * If the IOMMU can block non-coherent operations (ie PCIe TLPs with
2354          * no-snoop set) then VFIO always turns this feature on because on Intel
2355          * platforms it optimizes KVM to disable wbinvd emulation.
2356          */
2357         if (domain->domain->ops->enforce_cache_coherency)
2358                 domain->enforce_cache_coherency =
2359                         domain->domain->ops->enforce_cache_coherency(
2360                                 domain->domain);
2361
2362         /*
2363          * Try to match an existing compatible domain.  We don't want to
2364          * preclude an IOMMU driver supporting multiple bus_types and being
2365          * able to include different bus_types in the same IOMMU domain, so
2366          * we test whether the domains use the same iommu_ops rather than
2367          * testing if they're on the same bus_type.
2368          */
2369         list_for_each_entry(d, &iommu->domain_list, next) {
2370                 if (d->domain->ops == domain->domain->ops &&
2371                     d->enforce_cache_coherency ==
2372                             domain->enforce_cache_coherency) {
2373                         iommu_detach_group(domain->domain, group->iommu_group);
2374                         if (!iommu_attach_group(d->domain,
2375                                                 group->iommu_group)) {
2376                                 list_add(&group->next, &d->group_list);
2377                                 iommu_domain_free(domain->domain);
2378                                 kfree(domain);
2379                                 goto done;
2380                         }
2381
2382                         ret = iommu_attach_group(domain->domain,
2383                                                  group->iommu_group);
2384                         if (ret)
2385                                 goto out_domain;
2386                 }
2387         }
2388
2389         vfio_test_domain_fgsp(domain, &iova_copy);
2390
2391         /* replay mappings on new domains */
2392         ret = vfio_iommu_replay(iommu, domain);
2393         if (ret)
2394                 goto out_detach;
2395
2396         if (resv_msi) {
2397                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2398                 if (ret && ret != -ENODEV)
2399                         goto out_detach;
2400         }
2401
2402         list_add(&domain->next, &iommu->domain_list);
2403         vfio_update_pgsize_bitmap(iommu);
2404 done:
2405         /* Delete the old one and insert new iova list */
2406         vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2407
2408         /*
2409          * An iommu backed group can dirty memory directly and therefore
2410          * demotes the iommu scope until it declares itself dirty tracking
2411          * capable via the page pinning interface.
2412          */
2413         iommu->num_non_pinned_groups++;
2414         mutex_unlock(&iommu->lock);
2415         vfio_iommu_resv_free(&group_resv_regions);
2416
2417         return 0;
2418
2419 out_detach:
2420         iommu_detach_group(domain->domain, group->iommu_group);
2421 out_domain:
2422         iommu_domain_free(domain->domain);
2423         vfio_iommu_iova_free(&iova_copy);
2424         vfio_iommu_resv_free(&group_resv_regions);
2425 out_free_domain:
2426         kfree(domain);
2427 out_free_group:
2428         kfree(group);
2429 out_unlock:
2430         mutex_unlock(&iommu->lock);
2431         return ret;
2432 }
2433
2434 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2435 {
2436         struct rb_node *node;
2437
2438         while ((node = rb_first(&iommu->dma_list)))
2439                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2440 }
2441
2442 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2443 {
2444         struct rb_node *n, *p;
2445
2446         n = rb_first(&iommu->dma_list);
2447         for (; n; n = rb_next(n)) {
2448                 struct vfio_dma *dma;
2449                 long locked = 0, unlocked = 0;
2450
2451                 dma = rb_entry(n, struct vfio_dma, node);
2452                 unlocked += vfio_unmap_unpin(iommu, dma, false);
2453                 p = rb_first(&dma->pfn_list);
2454                 for (; p; p = rb_next(p)) {
2455                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2456                                                          node);
2457
2458                         if (!is_invalid_reserved_pfn(vpfn->pfn))
2459                                 locked++;
2460                 }
2461                 vfio_lock_acct(dma, locked - unlocked, true);
2462         }
2463 }
2464
2465 /*
2466  * Called when a domain is removed in detach. It is possible that
2467  * the removed domain decided the iova aperture window. Modify the
2468  * iova aperture with the smallest window among existing domains.
2469  */
2470 static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2471                                    struct list_head *iova_copy)
2472 {
2473         struct vfio_domain *domain;
2474         struct vfio_iova *node;
2475         dma_addr_t start = 0;
2476         dma_addr_t end = (dma_addr_t)~0;
2477
2478         if (list_empty(iova_copy))
2479                 return;
2480
2481         list_for_each_entry(domain, &iommu->domain_list, next) {
2482                 struct iommu_domain_geometry *geo = &domain->domain->geometry;
2483
2484                 if (geo->aperture_start > start)
2485                         start = geo->aperture_start;
2486                 if (geo->aperture_end < end)
2487                         end = geo->aperture_end;
2488         }
2489
2490         /* Modify aperture limits. The new aper is either same or bigger */
2491         node = list_first_entry(iova_copy, struct vfio_iova, list);
2492         node->start = start;
2493         node = list_last_entry(iova_copy, struct vfio_iova, list);
2494         node->end = end;
2495 }
2496
2497 /*
2498  * Called when a group is detached. The reserved regions for that
2499  * group can be part of valid iova now. But since reserved regions
2500  * may be duplicated among groups, populate the iova valid regions
2501  * list again.
2502  */
2503 static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2504                                    struct list_head *iova_copy)
2505 {
2506         struct vfio_domain *d;
2507         struct vfio_iommu_group *g;
2508         struct vfio_iova *node;
2509         dma_addr_t start, end;
2510         LIST_HEAD(resv_regions);
2511         int ret;
2512
2513         if (list_empty(iova_copy))
2514                 return -EINVAL;
2515
2516         list_for_each_entry(d, &iommu->domain_list, next) {
2517                 list_for_each_entry(g, &d->group_list, next) {
2518                         ret = iommu_get_group_resv_regions(g->iommu_group,
2519                                                            &resv_regions);
2520                         if (ret)
2521                                 goto done;
2522                 }
2523         }
2524
2525         node = list_first_entry(iova_copy, struct vfio_iova, list);
2526         start = node->start;
2527         node = list_last_entry(iova_copy, struct vfio_iova, list);
2528         end = node->end;
2529
2530         /* purge the iova list and create new one */
2531         vfio_iommu_iova_free(iova_copy);
2532
2533         ret = vfio_iommu_aper_resize(iova_copy, start, end);
2534         if (ret)
2535                 goto done;
2536
2537         /* Exclude current reserved regions from iova ranges */
2538         ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2539 done:
2540         vfio_iommu_resv_free(&resv_regions);
2541         return ret;
2542 }
2543
2544 static void vfio_iommu_type1_detach_group(void *iommu_data,
2545                                           struct iommu_group *iommu_group)
2546 {
2547         struct vfio_iommu *iommu = iommu_data;
2548         struct vfio_domain *domain;
2549         struct vfio_iommu_group *group;
2550         bool update_dirty_scope = false;
2551         LIST_HEAD(iova_copy);
2552
2553         mutex_lock(&iommu->lock);
2554         list_for_each_entry(group, &iommu->emulated_iommu_groups, next) {
2555                 if (group->iommu_group != iommu_group)
2556                         continue;
2557                 update_dirty_scope = !group->pinned_page_dirty_scope;
2558                 list_del(&group->next);
2559                 kfree(group);
2560
2561                 if (list_empty(&iommu->emulated_iommu_groups) &&
2562                     list_empty(&iommu->domain_list)) {
2563                         WARN_ON(!list_empty(&iommu->device_list));
2564                         vfio_iommu_unmap_unpin_all(iommu);
2565                 }
2566                 goto detach_group_done;
2567         }
2568
2569         /*
2570          * Get a copy of iova list. This will be used to update
2571          * and to replace the current one later. Please note that
2572          * we will leave the original list as it is if update fails.
2573          */
2574         vfio_iommu_iova_get_copy(iommu, &iova_copy);
2575
2576         list_for_each_entry(domain, &iommu->domain_list, next) {
2577                 group = find_iommu_group(domain, iommu_group);
2578                 if (!group)
2579                         continue;
2580
2581                 iommu_detach_group(domain->domain, group->iommu_group);
2582                 update_dirty_scope = !group->pinned_page_dirty_scope;
2583                 list_del(&group->next);
2584                 kfree(group);
2585                 /*
2586                  * Group ownership provides privilege, if the group list is
2587                  * empty, the domain goes away. If it's the last domain with
2588                  * iommu and external domain doesn't exist, then all the
2589                  * mappings go away too. If it's the last domain with iommu and
2590                  * external domain exist, update accounting
2591                  */
2592                 if (list_empty(&domain->group_list)) {
2593                         if (list_is_singular(&iommu->domain_list)) {
2594                                 if (list_empty(&iommu->emulated_iommu_groups)) {
2595                                         WARN_ON(!list_empty(
2596                                                 &iommu->device_list));
2597                                         vfio_iommu_unmap_unpin_all(iommu);
2598                                 } else {
2599                                         vfio_iommu_unmap_unpin_reaccount(iommu);
2600                                 }
2601                         }
2602                         iommu_domain_free(domain->domain);
2603                         list_del(&domain->next);
2604                         kfree(domain);
2605                         vfio_iommu_aper_expand(iommu, &iova_copy);
2606                         vfio_update_pgsize_bitmap(iommu);
2607                 }
2608                 break;
2609         }
2610
2611         if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2612                 vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2613         else
2614                 vfio_iommu_iova_free(&iova_copy);
2615
2616 detach_group_done:
2617         /*
2618          * Removal of a group without dirty tracking may allow the iommu scope
2619          * to be promoted.
2620          */
2621         if (update_dirty_scope) {
2622                 iommu->num_non_pinned_groups--;
2623                 if (iommu->dirty_page_tracking)
2624                         vfio_iommu_populate_bitmap_full(iommu);
2625         }
2626         mutex_unlock(&iommu->lock);
2627 }
2628
2629 static void *vfio_iommu_type1_open(unsigned long arg)
2630 {
2631         struct vfio_iommu *iommu;
2632
2633         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2634         if (!iommu)
2635                 return ERR_PTR(-ENOMEM);
2636
2637         switch (arg) {
2638         case VFIO_TYPE1_IOMMU:
2639                 break;
2640         case VFIO_TYPE1_NESTING_IOMMU:
2641                 iommu->nesting = true;
2642                 fallthrough;
2643         case VFIO_TYPE1v2_IOMMU:
2644                 iommu->v2 = true;
2645                 break;
2646         default:
2647                 kfree(iommu);
2648                 return ERR_PTR(-EINVAL);
2649         }
2650
2651         INIT_LIST_HEAD(&iommu->domain_list);
2652         INIT_LIST_HEAD(&iommu->iova_list);
2653         iommu->dma_list = RB_ROOT;
2654         iommu->dma_avail = dma_entry_limit;
2655         iommu->container_open = true;
2656         mutex_init(&iommu->lock);
2657         mutex_init(&iommu->device_list_lock);
2658         INIT_LIST_HEAD(&iommu->device_list);
2659         init_waitqueue_head(&iommu->vaddr_wait);
2660         iommu->pgsize_bitmap = PAGE_MASK;
2661         INIT_LIST_HEAD(&iommu->emulated_iommu_groups);
2662
2663         return iommu;
2664 }
2665
2666 static void vfio_release_domain(struct vfio_domain *domain)
2667 {
2668         struct vfio_iommu_group *group, *group_tmp;
2669
2670         list_for_each_entry_safe(group, group_tmp,
2671                                  &domain->group_list, next) {
2672                 iommu_detach_group(domain->domain, group->iommu_group);
2673                 list_del(&group->next);
2674                 kfree(group);
2675         }
2676
2677         iommu_domain_free(domain->domain);
2678 }
2679
2680 static void vfio_iommu_type1_release(void *iommu_data)
2681 {
2682         struct vfio_iommu *iommu = iommu_data;
2683         struct vfio_domain *domain, *domain_tmp;
2684         struct vfio_iommu_group *group, *next_group;
2685
2686         list_for_each_entry_safe(group, next_group,
2687                         &iommu->emulated_iommu_groups, next) {
2688                 list_del(&group->next);
2689                 kfree(group);
2690         }
2691
2692         vfio_iommu_unmap_unpin_all(iommu);
2693
2694         list_for_each_entry_safe(domain, domain_tmp,
2695                                  &iommu->domain_list, next) {
2696                 vfio_release_domain(domain);
2697                 list_del(&domain->next);
2698                 kfree(domain);
2699         }
2700
2701         vfio_iommu_iova_free(&iommu->iova_list);
2702
2703         kfree(iommu);
2704 }
2705
2706 static int vfio_domains_have_enforce_cache_coherency(struct vfio_iommu *iommu)
2707 {
2708         struct vfio_domain *domain;
2709         int ret = 1;
2710
2711         mutex_lock(&iommu->lock);
2712         list_for_each_entry(domain, &iommu->domain_list, next) {
2713                 if (!(domain->enforce_cache_coherency)) {
2714                         ret = 0;
2715                         break;
2716                 }
2717         }
2718         mutex_unlock(&iommu->lock);
2719
2720         return ret;
2721 }
2722
2723 static bool vfio_iommu_has_emulated(struct vfio_iommu *iommu)
2724 {
2725         bool ret;
2726
2727         mutex_lock(&iommu->lock);
2728         ret = !list_empty(&iommu->emulated_iommu_groups);
2729         mutex_unlock(&iommu->lock);
2730         return ret;
2731 }
2732
2733 static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2734                                             unsigned long arg)
2735 {
2736         switch (arg) {
2737         case VFIO_TYPE1_IOMMU:
2738         case VFIO_TYPE1v2_IOMMU:
2739         case VFIO_TYPE1_NESTING_IOMMU:
2740         case VFIO_UNMAP_ALL:
2741                 return 1;
2742         case VFIO_UPDATE_VADDR:
2743                 /*
2744                  * Disable this feature if mdevs are present.  They cannot
2745                  * safely pin/unpin/rw while vaddrs are being updated.
2746                  */
2747                 return iommu && !vfio_iommu_has_emulated(iommu);
2748         case VFIO_DMA_CC_IOMMU:
2749                 if (!iommu)
2750                         return 0;
2751                 return vfio_domains_have_enforce_cache_coherency(iommu);
2752         default:
2753                 return 0;
2754         }
2755 }
2756
2757 static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2758                  struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2759                  size_t size)
2760 {
2761         struct vfio_info_cap_header *header;
2762         struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2763
2764         header = vfio_info_cap_add(caps, size,
2765                                    VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2766         if (IS_ERR(header))
2767                 return PTR_ERR(header);
2768
2769         iova_cap = container_of(header,
2770                                 struct vfio_iommu_type1_info_cap_iova_range,
2771                                 header);
2772         iova_cap->nr_iovas = cap_iovas->nr_iovas;
2773         memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2774                cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2775         return 0;
2776 }
2777
2778 static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2779                                       struct vfio_info_cap *caps)
2780 {
2781         struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2782         struct vfio_iova *iova;
2783         size_t size;
2784         int iovas = 0, i = 0, ret;
2785
2786         list_for_each_entry(iova, &iommu->iova_list, list)
2787                 iovas++;
2788
2789         if (!iovas) {
2790                 /*
2791                  * Return 0 as a container with a single mdev device
2792                  * will have an empty list
2793                  */
2794                 return 0;
2795         }
2796
2797         size = struct_size(cap_iovas, iova_ranges, iovas);
2798
2799         cap_iovas = kzalloc(size, GFP_KERNEL);
2800         if (!cap_iovas)
2801                 return -ENOMEM;
2802
2803         cap_iovas->nr_iovas = iovas;
2804
2805         list_for_each_entry(iova, &iommu->iova_list, list) {
2806                 cap_iovas->iova_ranges[i].start = iova->start;
2807                 cap_iovas->iova_ranges[i].end = iova->end;
2808                 i++;
2809         }
2810
2811         ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2812
2813         kfree(cap_iovas);
2814         return ret;
2815 }
2816
2817 static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2818                                            struct vfio_info_cap *caps)
2819 {
2820         struct vfio_iommu_type1_info_cap_migration cap_mig;
2821
2822         cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2823         cap_mig.header.version = 1;
2824
2825         cap_mig.flags = 0;
2826         /* support minimum pgsize */
2827         cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2828         cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2829
2830         return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2831 }
2832
2833 static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2834                                            struct vfio_info_cap *caps)
2835 {
2836         struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2837
2838         cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2839         cap_dma_avail.header.version = 1;
2840
2841         cap_dma_avail.avail = iommu->dma_avail;
2842
2843         return vfio_info_add_capability(caps, &cap_dma_avail.header,
2844                                         sizeof(cap_dma_avail));
2845 }
2846
2847 static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2848                                      unsigned long arg)
2849 {
2850         struct vfio_iommu_type1_info info;
2851         unsigned long minsz;
2852         struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2853         unsigned long capsz;
2854         int ret;
2855
2856         minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2857
2858         /* For backward compatibility, cannot require this */
2859         capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2860
2861         if (copy_from_user(&info, (void __user *)arg, minsz))
2862                 return -EFAULT;
2863
2864         if (info.argsz < minsz)
2865                 return -EINVAL;
2866
2867         if (info.argsz >= capsz) {
2868                 minsz = capsz;
2869                 info.cap_offset = 0; /* output, no-recopy necessary */
2870         }
2871
2872         mutex_lock(&iommu->lock);
2873         info.flags = VFIO_IOMMU_INFO_PGSIZES;
2874
2875         info.iova_pgsizes = iommu->pgsize_bitmap;
2876
2877         ret = vfio_iommu_migration_build_caps(iommu, &caps);
2878
2879         if (!ret)
2880                 ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2881
2882         if (!ret)
2883                 ret = vfio_iommu_iova_build_caps(iommu, &caps);
2884
2885         mutex_unlock(&iommu->lock);
2886
2887         if (ret)
2888                 return ret;
2889
2890         if (caps.size) {
2891                 info.flags |= VFIO_IOMMU_INFO_CAPS;
2892
2893                 if (info.argsz < sizeof(info) + caps.size) {
2894                         info.argsz = sizeof(info) + caps.size;
2895                 } else {
2896                         vfio_info_cap_shift(&caps, sizeof(info));
2897                         if (copy_to_user((void __user *)arg +
2898                                         sizeof(info), caps.buf,
2899                                         caps.size)) {
2900                                 kfree(caps.buf);
2901                                 return -EFAULT;
2902                         }
2903                         info.cap_offset = sizeof(info);
2904                 }
2905
2906                 kfree(caps.buf);
2907         }
2908
2909         return copy_to_user((void __user *)arg, &info, minsz) ?
2910                         -EFAULT : 0;
2911 }
2912
2913 static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2914                                     unsigned long arg)
2915 {
2916         struct vfio_iommu_type1_dma_map map;
2917         unsigned long minsz;
2918         uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE |
2919                         VFIO_DMA_MAP_FLAG_VADDR;
2920
2921         minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2922
2923         if (copy_from_user(&map, (void __user *)arg, minsz))
2924                 return -EFAULT;
2925
2926         if (map.argsz < minsz || map.flags & ~mask)
2927                 return -EINVAL;
2928
2929         return vfio_dma_do_map(iommu, &map);
2930 }
2931
2932 static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2933                                       unsigned long arg)
2934 {
2935         struct vfio_iommu_type1_dma_unmap unmap;
2936         struct vfio_bitmap bitmap = { 0 };
2937         uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
2938                         VFIO_DMA_UNMAP_FLAG_VADDR |
2939                         VFIO_DMA_UNMAP_FLAG_ALL;
2940         unsigned long minsz;
2941         int ret;
2942
2943         minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2944
2945         if (copy_from_user(&unmap, (void __user *)arg, minsz))
2946                 return -EFAULT;
2947
2948         if (unmap.argsz < minsz || unmap.flags & ~mask)
2949                 return -EINVAL;
2950
2951         if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
2952             (unmap.flags & (VFIO_DMA_UNMAP_FLAG_ALL |
2953                             VFIO_DMA_UNMAP_FLAG_VADDR)))
2954                 return -EINVAL;
2955
2956         if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2957                 unsigned long pgshift;
2958
2959                 if (unmap.argsz < (minsz + sizeof(bitmap)))
2960                         return -EINVAL;
2961
2962                 if (copy_from_user(&bitmap,
2963                                    (void __user *)(arg + minsz),
2964                                    sizeof(bitmap)))
2965                         return -EFAULT;
2966
2967                 if (!access_ok((void __user *)bitmap.data, bitmap.size))
2968                         return -EINVAL;
2969
2970                 pgshift = __ffs(bitmap.pgsize);
2971                 ret = verify_bitmap_size(unmap.size >> pgshift,
2972                                          bitmap.size);
2973                 if (ret)
2974                         return ret;
2975         }
2976
2977         ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2978         if (ret)
2979                 return ret;
2980
2981         return copy_to_user((void __user *)arg, &unmap, minsz) ?
2982                         -EFAULT : 0;
2983 }
2984
2985 static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
2986                                         unsigned long arg)
2987 {
2988         struct vfio_iommu_type1_dirty_bitmap dirty;
2989         uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2990                         VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2991                         VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2992         unsigned long minsz;
2993         int ret = 0;
2994
2995         if (!iommu->v2)
2996                 return -EACCES;
2997
2998         minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
2999
3000         if (copy_from_user(&dirty, (void __user *)arg, minsz))
3001                 return -EFAULT;
3002
3003         if (dirty.argsz < minsz || dirty.flags & ~mask)
3004                 return -EINVAL;
3005
3006         /* only one flag should be set at a time */
3007         if (__ffs(dirty.flags) != __fls(dirty.flags))
3008                 return -EINVAL;
3009
3010         if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
3011                 size_t pgsize;
3012
3013                 mutex_lock(&iommu->lock);
3014                 pgsize = 1 << __ffs(iommu->pgsize_bitmap);
3015                 if (!iommu->dirty_page_tracking) {
3016                         ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
3017                         if (!ret)
3018                                 iommu->dirty_page_tracking = true;
3019                 }
3020                 mutex_unlock(&iommu->lock);
3021                 return ret;
3022         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
3023                 mutex_lock(&iommu->lock);
3024                 if (iommu->dirty_page_tracking) {
3025                         iommu->dirty_page_tracking = false;
3026                         vfio_dma_bitmap_free_all(iommu);
3027                 }
3028                 mutex_unlock(&iommu->lock);
3029                 return 0;
3030         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
3031                 struct vfio_iommu_type1_dirty_bitmap_get range;
3032                 unsigned long pgshift;
3033                 size_t data_size = dirty.argsz - minsz;
3034                 size_t iommu_pgsize;
3035
3036                 if (!data_size || data_size < sizeof(range))
3037                         return -EINVAL;
3038
3039                 if (copy_from_user(&range, (void __user *)(arg + minsz),
3040                                    sizeof(range)))
3041                         return -EFAULT;
3042
3043                 if (range.iova + range.size < range.iova)
3044                         return -EINVAL;
3045                 if (!access_ok((void __user *)range.bitmap.data,
3046                                range.bitmap.size))
3047                         return -EINVAL;
3048
3049                 pgshift = __ffs(range.bitmap.pgsize);
3050                 ret = verify_bitmap_size(range.size >> pgshift,
3051                                          range.bitmap.size);
3052                 if (ret)
3053                         return ret;
3054
3055                 mutex_lock(&iommu->lock);
3056
3057                 iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
3058
3059                 /* allow only smallest supported pgsize */
3060                 if (range.bitmap.pgsize != iommu_pgsize) {
3061                         ret = -EINVAL;
3062                         goto out_unlock;
3063                 }
3064                 if (range.iova & (iommu_pgsize - 1)) {
3065                         ret = -EINVAL;
3066                         goto out_unlock;
3067                 }
3068                 if (!range.size || range.size & (iommu_pgsize - 1)) {
3069                         ret = -EINVAL;
3070                         goto out_unlock;
3071                 }
3072
3073                 if (iommu->dirty_page_tracking)
3074                         ret = vfio_iova_dirty_bitmap(range.bitmap.data,
3075                                                      iommu, range.iova,
3076                                                      range.size,
3077                                                      range.bitmap.pgsize);
3078                 else
3079                         ret = -EINVAL;
3080 out_unlock:
3081                 mutex_unlock(&iommu->lock);
3082
3083                 return ret;
3084         }
3085
3086         return -EINVAL;
3087 }
3088
3089 static long vfio_iommu_type1_ioctl(void *iommu_data,
3090                                    unsigned int cmd, unsigned long arg)
3091 {
3092         struct vfio_iommu *iommu = iommu_data;
3093
3094         switch (cmd) {
3095         case VFIO_CHECK_EXTENSION:
3096                 return vfio_iommu_type1_check_extension(iommu, arg);
3097         case VFIO_IOMMU_GET_INFO:
3098                 return vfio_iommu_type1_get_info(iommu, arg);
3099         case VFIO_IOMMU_MAP_DMA:
3100                 return vfio_iommu_type1_map_dma(iommu, arg);
3101         case VFIO_IOMMU_UNMAP_DMA:
3102                 return vfio_iommu_type1_unmap_dma(iommu, arg);
3103         case VFIO_IOMMU_DIRTY_PAGES:
3104                 return vfio_iommu_type1_dirty_pages(iommu, arg);
3105         default:
3106                 return -ENOTTY;
3107         }
3108 }
3109
3110 static void vfio_iommu_type1_register_device(void *iommu_data,
3111                                              struct vfio_device *vdev)
3112 {
3113         struct vfio_iommu *iommu = iommu_data;
3114
3115         if (!vdev->ops->dma_unmap)
3116                 return;
3117
3118         /*
3119          * list_empty(&iommu->device_list) is tested under the iommu->lock while
3120          * iteration for dma_unmap must be done under the device_list_lock.
3121          * Holding both locks here allows avoiding the device_list_lock in
3122          * several fast paths. See vfio_notify_dma_unmap()
3123          */
3124         mutex_lock(&iommu->lock);
3125         mutex_lock(&iommu->device_list_lock);
3126         list_add(&vdev->iommu_entry, &iommu->device_list);
3127         mutex_unlock(&iommu->device_list_lock);
3128         mutex_unlock(&iommu->lock);
3129 }
3130
3131 static void vfio_iommu_type1_unregister_device(void *iommu_data,
3132                                                struct vfio_device *vdev)
3133 {
3134         struct vfio_iommu *iommu = iommu_data;
3135
3136         if (!vdev->ops->dma_unmap)
3137                 return;
3138
3139         mutex_lock(&iommu->lock);
3140         mutex_lock(&iommu->device_list_lock);
3141         list_del(&vdev->iommu_entry);
3142         mutex_unlock(&iommu->device_list_lock);
3143         mutex_unlock(&iommu->lock);
3144 }
3145
3146 static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
3147                                          dma_addr_t user_iova, void *data,
3148                                          size_t count, bool write,
3149                                          size_t *copied)
3150 {
3151         struct mm_struct *mm;
3152         unsigned long vaddr;
3153         struct vfio_dma *dma;
3154         bool kthread = current->mm == NULL;
3155         size_t offset;
3156         int ret;
3157
3158         *copied = 0;
3159
3160         ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
3161         if (ret < 0)
3162                 return ret;
3163
3164         if ((write && !(dma->prot & IOMMU_WRITE)) ||
3165                         !(dma->prot & IOMMU_READ))
3166                 return -EPERM;
3167
3168         mm = dma->mm;
3169         if (!mmget_not_zero(mm))
3170                 return -EPERM;
3171
3172         if (kthread)
3173                 kthread_use_mm(mm);
3174         else if (current->mm != mm)
3175                 goto out;
3176
3177         offset = user_iova - dma->iova;
3178
3179         if (count > dma->size - offset)
3180                 count = dma->size - offset;
3181
3182         vaddr = dma->vaddr + offset;
3183
3184         if (write) {
3185                 *copied = copy_to_user((void __user *)vaddr, data,
3186                                          count) ? 0 : count;
3187                 if (*copied && iommu->dirty_page_tracking) {
3188                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
3189                         /*
3190                          * Bitmap populated with the smallest supported page
3191                          * size
3192                          */
3193                         bitmap_set(dma->bitmap, offset >> pgshift,
3194                                    ((offset + *copied - 1) >> pgshift) -
3195                                    (offset >> pgshift) + 1);
3196                 }
3197         } else
3198                 *copied = copy_from_user(data, (void __user *)vaddr,
3199                                            count) ? 0 : count;
3200         if (kthread)
3201                 kthread_unuse_mm(mm);
3202 out:
3203         mmput(mm);
3204         return *copied ? 0 : -EFAULT;
3205 }
3206
3207 static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
3208                                    void *data, size_t count, bool write)
3209 {
3210         struct vfio_iommu *iommu = iommu_data;
3211         int ret = 0;
3212         size_t done;
3213
3214         mutex_lock(&iommu->lock);
3215
3216         if (WARN_ONCE(iommu->vaddr_invalid_count,
3217                       "vfio_dma_rw not allowed with VFIO_UPDATE_VADDR\n")) {
3218                 ret = -EBUSY;
3219                 goto out;
3220         }
3221
3222         while (count > 0) {
3223                 ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
3224                                                     count, write, &done);
3225                 if (ret)
3226                         break;
3227
3228                 count -= done;
3229                 data += done;
3230                 user_iova += done;
3231         }
3232
3233 out:
3234         mutex_unlock(&iommu->lock);
3235         return ret;
3236 }
3237
3238 static struct iommu_domain *
3239 vfio_iommu_type1_group_iommu_domain(void *iommu_data,
3240                                     struct iommu_group *iommu_group)
3241 {
3242         struct iommu_domain *domain = ERR_PTR(-ENODEV);
3243         struct vfio_iommu *iommu = iommu_data;
3244         struct vfio_domain *d;
3245
3246         if (!iommu || !iommu_group)
3247                 return ERR_PTR(-EINVAL);
3248
3249         mutex_lock(&iommu->lock);
3250         list_for_each_entry(d, &iommu->domain_list, next) {
3251                 if (find_iommu_group(d, iommu_group)) {
3252                         domain = d->domain;
3253                         break;
3254                 }
3255         }
3256         mutex_unlock(&iommu->lock);
3257
3258         return domain;
3259 }
3260
3261 static void vfio_iommu_type1_notify(void *iommu_data,
3262                                     enum vfio_iommu_notify_type event)
3263 {
3264         struct vfio_iommu *iommu = iommu_data;
3265
3266         if (event != VFIO_IOMMU_CONTAINER_CLOSE)
3267                 return;
3268         mutex_lock(&iommu->lock);
3269         iommu->container_open = false;
3270         mutex_unlock(&iommu->lock);
3271         wake_up_all(&iommu->vaddr_wait);
3272 }
3273
3274 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3275         .name                   = "vfio-iommu-type1",
3276         .owner                  = THIS_MODULE,
3277         .open                   = vfio_iommu_type1_open,
3278         .release                = vfio_iommu_type1_release,
3279         .ioctl                  = vfio_iommu_type1_ioctl,
3280         .attach_group           = vfio_iommu_type1_attach_group,
3281         .detach_group           = vfio_iommu_type1_detach_group,
3282         .pin_pages              = vfio_iommu_type1_pin_pages,
3283         .unpin_pages            = vfio_iommu_type1_unpin_pages,
3284         .register_device        = vfio_iommu_type1_register_device,
3285         .unregister_device      = vfio_iommu_type1_unregister_device,
3286         .dma_rw                 = vfio_iommu_type1_dma_rw,
3287         .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
3288         .notify                 = vfio_iommu_type1_notify,
3289 };
3290
3291 static int __init vfio_iommu_type1_init(void)
3292 {
3293         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3294 }
3295
3296 static void __exit vfio_iommu_type1_cleanup(void)
3297 {
3298         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3299 }
3300
3301 module_init(vfio_iommu_type1_init);
3302 module_exit(vfio_iommu_type1_cleanup);
3303
3304 MODULE_VERSION(DRIVER_VERSION);
3305 MODULE_LICENSE("GPL v2");
3306 MODULE_AUTHOR(DRIVER_AUTHOR);
3307 MODULE_DESCRIPTION(DRIVER_DESC);