GNU Linux-libre 5.19-rc6-gnu
[releases.git] / drivers / iommu / virtio-iommu.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Virtio driver for the paravirtualized IOMMU
4  *
5  * Copyright (C) 2019 Arm Limited
6  */
7
8 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
9
10 #include <linux/amba/bus.h>
11 #include <linux/delay.h>
12 #include <linux/dma-iommu.h>
13 #include <linux/dma-map-ops.h>
14 #include <linux/freezer.h>
15 #include <linux/interval_tree.h>
16 #include <linux/iommu.h>
17 #include <linux/module.h>
18 #include <linux/of_platform.h>
19 #include <linux/pci.h>
20 #include <linux/platform_device.h>
21 #include <linux/virtio.h>
22 #include <linux/virtio_config.h>
23 #include <linux/virtio_ids.h>
24 #include <linux/wait.h>
25
26 #include <uapi/linux/virtio_iommu.h>
27
28 #define MSI_IOVA_BASE                   0x8000000
29 #define MSI_IOVA_LENGTH                 0x100000
30
31 #define VIOMMU_REQUEST_VQ               0
32 #define VIOMMU_EVENT_VQ                 1
33 #define VIOMMU_NR_VQS                   2
34
35 struct viommu_dev {
36         struct iommu_device             iommu;
37         struct device                   *dev;
38         struct virtio_device            *vdev;
39
40         struct ida                      domain_ids;
41
42         struct virtqueue                *vqs[VIOMMU_NR_VQS];
43         spinlock_t                      request_lock;
44         struct list_head                requests;
45         void                            *evts;
46
47         /* Device configuration */
48         struct iommu_domain_geometry    geometry;
49         u64                             pgsize_bitmap;
50         u32                             first_domain;
51         u32                             last_domain;
52         /* Supported MAP flags */
53         u32                             map_flags;
54         u32                             probe_size;
55 };
56
57 struct viommu_mapping {
58         phys_addr_t                     paddr;
59         struct interval_tree_node       iova;
60         u32                             flags;
61 };
62
63 struct viommu_domain {
64         struct iommu_domain             domain;
65         struct viommu_dev               *viommu;
66         struct mutex                    mutex; /* protects viommu pointer */
67         unsigned int                    id;
68         u32                             map_flags;
69
70         spinlock_t                      mappings_lock;
71         struct rb_root_cached           mappings;
72
73         unsigned long                   nr_endpoints;
74         bool                            bypass;
75 };
76
77 struct viommu_endpoint {
78         struct device                   *dev;
79         struct viommu_dev               *viommu;
80         struct viommu_domain            *vdomain;
81         struct list_head                resv_regions;
82 };
83
84 struct viommu_request {
85         struct list_head                list;
86         void                            *writeback;
87         unsigned int                    write_offset;
88         unsigned int                    len;
89         char                            buf[];
90 };
91
92 #define VIOMMU_FAULT_RESV_MASK          0xffffff00
93
94 struct viommu_event {
95         union {
96                 u32                     head;
97                 struct virtio_iommu_fault fault;
98         };
99 };
100
101 #define to_viommu_domain(domain)        \
102         container_of(domain, struct viommu_domain, domain)
103
104 static int viommu_get_req_errno(void *buf, size_t len)
105 {
106         struct virtio_iommu_req_tail *tail = buf + len - sizeof(*tail);
107
108         switch (tail->status) {
109         case VIRTIO_IOMMU_S_OK:
110                 return 0;
111         case VIRTIO_IOMMU_S_UNSUPP:
112                 return -ENOSYS;
113         case VIRTIO_IOMMU_S_INVAL:
114                 return -EINVAL;
115         case VIRTIO_IOMMU_S_RANGE:
116                 return -ERANGE;
117         case VIRTIO_IOMMU_S_NOENT:
118                 return -ENOENT;
119         case VIRTIO_IOMMU_S_FAULT:
120                 return -EFAULT;
121         case VIRTIO_IOMMU_S_NOMEM:
122                 return -ENOMEM;
123         case VIRTIO_IOMMU_S_IOERR:
124         case VIRTIO_IOMMU_S_DEVERR:
125         default:
126                 return -EIO;
127         }
128 }
129
130 static void viommu_set_req_status(void *buf, size_t len, int status)
131 {
132         struct virtio_iommu_req_tail *tail = buf + len - sizeof(*tail);
133
134         tail->status = status;
135 }
136
137 static off_t viommu_get_write_desc_offset(struct viommu_dev *viommu,
138                                           struct virtio_iommu_req_head *req,
139                                           size_t len)
140 {
141         size_t tail_size = sizeof(struct virtio_iommu_req_tail);
142
143         if (req->type == VIRTIO_IOMMU_T_PROBE)
144                 return len - viommu->probe_size - tail_size;
145
146         return len - tail_size;
147 }
148
149 /*
150  * __viommu_sync_req - Complete all in-flight requests
151  *
152  * Wait for all added requests to complete. When this function returns, all
153  * requests that were in-flight at the time of the call have completed.
154  */
155 static int __viommu_sync_req(struct viommu_dev *viommu)
156 {
157         unsigned int len;
158         size_t write_len;
159         struct viommu_request *req;
160         struct virtqueue *vq = viommu->vqs[VIOMMU_REQUEST_VQ];
161
162         assert_spin_locked(&viommu->request_lock);
163
164         virtqueue_kick(vq);
165
166         while (!list_empty(&viommu->requests)) {
167                 len = 0;
168                 req = virtqueue_get_buf(vq, &len);
169                 if (!req)
170                         continue;
171
172                 if (!len)
173                         viommu_set_req_status(req->buf, req->len,
174                                               VIRTIO_IOMMU_S_IOERR);
175
176                 write_len = req->len - req->write_offset;
177                 if (req->writeback && len == write_len)
178                         memcpy(req->writeback, req->buf + req->write_offset,
179                                write_len);
180
181                 list_del(&req->list);
182                 kfree(req);
183         }
184
185         return 0;
186 }
187
188 static int viommu_sync_req(struct viommu_dev *viommu)
189 {
190         int ret;
191         unsigned long flags;
192
193         spin_lock_irqsave(&viommu->request_lock, flags);
194         ret = __viommu_sync_req(viommu);
195         if (ret)
196                 dev_dbg(viommu->dev, "could not sync requests (%d)\n", ret);
197         spin_unlock_irqrestore(&viommu->request_lock, flags);
198
199         return ret;
200 }
201
202 /*
203  * __viommu_add_request - Add one request to the queue
204  * @buf: pointer to the request buffer
205  * @len: length of the request buffer
206  * @writeback: copy data back to the buffer when the request completes.
207  *
208  * Add a request to the queue. Only synchronize the queue if it's already full.
209  * Otherwise don't kick the queue nor wait for requests to complete.
210  *
211  * When @writeback is true, data written by the device, including the request
212  * status, is copied into @buf after the request completes. This is unsafe if
213  * the caller allocates @buf on stack and drops the lock between add_req() and
214  * sync_req().
215  *
216  * Return 0 if the request was successfully added to the queue.
217  */
218 static int __viommu_add_req(struct viommu_dev *viommu, void *buf, size_t len,
219                             bool writeback)
220 {
221         int ret;
222         off_t write_offset;
223         struct viommu_request *req;
224         struct scatterlist top_sg, bottom_sg;
225         struct scatterlist *sg[2] = { &top_sg, &bottom_sg };
226         struct virtqueue *vq = viommu->vqs[VIOMMU_REQUEST_VQ];
227
228         assert_spin_locked(&viommu->request_lock);
229
230         write_offset = viommu_get_write_desc_offset(viommu, buf, len);
231         if (write_offset <= 0)
232                 return -EINVAL;
233
234         req = kzalloc(sizeof(*req) + len, GFP_ATOMIC);
235         if (!req)
236                 return -ENOMEM;
237
238         req->len = len;
239         if (writeback) {
240                 req->writeback = buf + write_offset;
241                 req->write_offset = write_offset;
242         }
243         memcpy(&req->buf, buf, write_offset);
244
245         sg_init_one(&top_sg, req->buf, write_offset);
246         sg_init_one(&bottom_sg, req->buf + write_offset, len - write_offset);
247
248         ret = virtqueue_add_sgs(vq, sg, 1, 1, req, GFP_ATOMIC);
249         if (ret == -ENOSPC) {
250                 /* If the queue is full, sync and retry */
251                 if (!__viommu_sync_req(viommu))
252                         ret = virtqueue_add_sgs(vq, sg, 1, 1, req, GFP_ATOMIC);
253         }
254         if (ret)
255                 goto err_free;
256
257         list_add_tail(&req->list, &viommu->requests);
258         return 0;
259
260 err_free:
261         kfree(req);
262         return ret;
263 }
264
265 static int viommu_add_req(struct viommu_dev *viommu, void *buf, size_t len)
266 {
267         int ret;
268         unsigned long flags;
269
270         spin_lock_irqsave(&viommu->request_lock, flags);
271         ret = __viommu_add_req(viommu, buf, len, false);
272         if (ret)
273                 dev_dbg(viommu->dev, "could not add request: %d\n", ret);
274         spin_unlock_irqrestore(&viommu->request_lock, flags);
275
276         return ret;
277 }
278
279 /*
280  * Send a request and wait for it to complete. Return the request status (as an
281  * errno)
282  */
283 static int viommu_send_req_sync(struct viommu_dev *viommu, void *buf,
284                                 size_t len)
285 {
286         int ret;
287         unsigned long flags;
288
289         spin_lock_irqsave(&viommu->request_lock, flags);
290
291         ret = __viommu_add_req(viommu, buf, len, true);
292         if (ret) {
293                 dev_dbg(viommu->dev, "could not add request (%d)\n", ret);
294                 goto out_unlock;
295         }
296
297         ret = __viommu_sync_req(viommu);
298         if (ret) {
299                 dev_dbg(viommu->dev, "could not sync requests (%d)\n", ret);
300                 /* Fall-through (get the actual request status) */
301         }
302
303         ret = viommu_get_req_errno(buf, len);
304 out_unlock:
305         spin_unlock_irqrestore(&viommu->request_lock, flags);
306         return ret;
307 }
308
309 /*
310  * viommu_add_mapping - add a mapping to the internal tree
311  *
312  * On success, return the new mapping. Otherwise return NULL.
313  */
314 static int viommu_add_mapping(struct viommu_domain *vdomain, u64 iova, u64 end,
315                               phys_addr_t paddr, u32 flags)
316 {
317         unsigned long irqflags;
318         struct viommu_mapping *mapping;
319
320         mapping = kzalloc(sizeof(*mapping), GFP_ATOMIC);
321         if (!mapping)
322                 return -ENOMEM;
323
324         mapping->paddr          = paddr;
325         mapping->iova.start     = iova;
326         mapping->iova.last      = end;
327         mapping->flags          = flags;
328
329         spin_lock_irqsave(&vdomain->mappings_lock, irqflags);
330         interval_tree_insert(&mapping->iova, &vdomain->mappings);
331         spin_unlock_irqrestore(&vdomain->mappings_lock, irqflags);
332
333         return 0;
334 }
335
336 /*
337  * viommu_del_mappings - remove mappings from the internal tree
338  *
339  * @vdomain: the domain
340  * @iova: start of the range
341  * @end: end of the range
342  *
343  * On success, returns the number of unmapped bytes
344  */
345 static size_t viommu_del_mappings(struct viommu_domain *vdomain,
346                                   u64 iova, u64 end)
347 {
348         size_t unmapped = 0;
349         unsigned long flags;
350         struct viommu_mapping *mapping = NULL;
351         struct interval_tree_node *node, *next;
352
353         spin_lock_irqsave(&vdomain->mappings_lock, flags);
354         next = interval_tree_iter_first(&vdomain->mappings, iova, end);
355         while (next) {
356                 node = next;
357                 mapping = container_of(node, struct viommu_mapping, iova);
358                 next = interval_tree_iter_next(node, iova, end);
359
360                 /* Trying to split a mapping? */
361                 if (mapping->iova.start < iova)
362                         break;
363
364                 /*
365                  * Virtio-iommu doesn't allow UNMAP to split a mapping created
366                  * with a single MAP request, so remove the full mapping.
367                  */
368                 unmapped += mapping->iova.last - mapping->iova.start + 1;
369
370                 interval_tree_remove(node, &vdomain->mappings);
371                 kfree(mapping);
372         }
373         spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
374
375         return unmapped;
376 }
377
378 /*
379  * Fill the domain with identity mappings, skipping the device's reserved
380  * regions.
381  */
382 static int viommu_domain_map_identity(struct viommu_endpoint *vdev,
383                                       struct viommu_domain *vdomain)
384 {
385         int ret;
386         struct iommu_resv_region *resv;
387         u64 iova = vdomain->domain.geometry.aperture_start;
388         u64 limit = vdomain->domain.geometry.aperture_end;
389         u32 flags = VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE;
390         unsigned long granule = 1UL << __ffs(vdomain->domain.pgsize_bitmap);
391
392         iova = ALIGN(iova, granule);
393         limit = ALIGN_DOWN(limit + 1, granule) - 1;
394
395         list_for_each_entry(resv, &vdev->resv_regions, list) {
396                 u64 resv_start = ALIGN_DOWN(resv->start, granule);
397                 u64 resv_end = ALIGN(resv->start + resv->length, granule) - 1;
398
399                 if (resv_end < iova || resv_start > limit)
400                         /* No overlap */
401                         continue;
402
403                 if (resv_start > iova) {
404                         ret = viommu_add_mapping(vdomain, iova, resv_start - 1,
405                                                  (phys_addr_t)iova, flags);
406                         if (ret)
407                                 goto err_unmap;
408                 }
409
410                 if (resv_end >= limit)
411                         return 0;
412
413                 iova = resv_end + 1;
414         }
415
416         ret = viommu_add_mapping(vdomain, iova, limit, (phys_addr_t)iova,
417                                  flags);
418         if (ret)
419                 goto err_unmap;
420         return 0;
421
422 err_unmap:
423         viommu_del_mappings(vdomain, 0, iova);
424         return ret;
425 }
426
427 /*
428  * viommu_replay_mappings - re-send MAP requests
429  *
430  * When reattaching a domain that was previously detached from all endpoints,
431  * mappings were deleted from the device. Re-create the mappings available in
432  * the internal tree.
433  */
434 static int viommu_replay_mappings(struct viommu_domain *vdomain)
435 {
436         int ret = 0;
437         unsigned long flags;
438         struct viommu_mapping *mapping;
439         struct interval_tree_node *node;
440         struct virtio_iommu_req_map map;
441
442         spin_lock_irqsave(&vdomain->mappings_lock, flags);
443         node = interval_tree_iter_first(&vdomain->mappings, 0, -1UL);
444         while (node) {
445                 mapping = container_of(node, struct viommu_mapping, iova);
446                 map = (struct virtio_iommu_req_map) {
447                         .head.type      = VIRTIO_IOMMU_T_MAP,
448                         .domain         = cpu_to_le32(vdomain->id),
449                         .virt_start     = cpu_to_le64(mapping->iova.start),
450                         .virt_end       = cpu_to_le64(mapping->iova.last),
451                         .phys_start     = cpu_to_le64(mapping->paddr),
452                         .flags          = cpu_to_le32(mapping->flags),
453                 };
454
455                 ret = viommu_send_req_sync(vdomain->viommu, &map, sizeof(map));
456                 if (ret)
457                         break;
458
459                 node = interval_tree_iter_next(node, 0, -1UL);
460         }
461         spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
462
463         return ret;
464 }
465
466 static int viommu_add_resv_mem(struct viommu_endpoint *vdev,
467                                struct virtio_iommu_probe_resv_mem *mem,
468                                size_t len)
469 {
470         size_t size;
471         u64 start64, end64;
472         phys_addr_t start, end;
473         struct iommu_resv_region *region = NULL, *next;
474         unsigned long prot = IOMMU_WRITE | IOMMU_NOEXEC | IOMMU_MMIO;
475
476         start = start64 = le64_to_cpu(mem->start);
477         end = end64 = le64_to_cpu(mem->end);
478         size = end64 - start64 + 1;
479
480         /* Catch any overflow, including the unlikely end64 - start64 + 1 = 0 */
481         if (start != start64 || end != end64 || size < end64 - start64)
482                 return -EOVERFLOW;
483
484         if (len < sizeof(*mem))
485                 return -EINVAL;
486
487         switch (mem->subtype) {
488         default:
489                 dev_warn(vdev->dev, "unknown resv mem subtype 0x%x\n",
490                          mem->subtype);
491                 fallthrough;
492         case VIRTIO_IOMMU_RESV_MEM_T_RESERVED:
493                 region = iommu_alloc_resv_region(start, size, 0,
494                                                  IOMMU_RESV_RESERVED);
495                 break;
496         case VIRTIO_IOMMU_RESV_MEM_T_MSI:
497                 region = iommu_alloc_resv_region(start, size, prot,
498                                                  IOMMU_RESV_MSI);
499                 break;
500         }
501         if (!region)
502                 return -ENOMEM;
503
504         /* Keep the list sorted */
505         list_for_each_entry(next, &vdev->resv_regions, list) {
506                 if (next->start > region->start)
507                         break;
508         }
509         list_add_tail(&region->list, &next->list);
510         return 0;
511 }
512
513 static int viommu_probe_endpoint(struct viommu_dev *viommu, struct device *dev)
514 {
515         int ret;
516         u16 type, len;
517         size_t cur = 0;
518         size_t probe_len;
519         struct virtio_iommu_req_probe *probe;
520         struct virtio_iommu_probe_property *prop;
521         struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
522         struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
523
524         if (!fwspec->num_ids)
525                 return -EINVAL;
526
527         probe_len = sizeof(*probe) + viommu->probe_size +
528                     sizeof(struct virtio_iommu_req_tail);
529         probe = kzalloc(probe_len, GFP_KERNEL);
530         if (!probe)
531                 return -ENOMEM;
532
533         probe->head.type = VIRTIO_IOMMU_T_PROBE;
534         /*
535          * For now, assume that properties of an endpoint that outputs multiple
536          * IDs are consistent. Only probe the first one.
537          */
538         probe->endpoint = cpu_to_le32(fwspec->ids[0]);
539
540         ret = viommu_send_req_sync(viommu, probe, probe_len);
541         if (ret)
542                 goto out_free;
543
544         prop = (void *)probe->properties;
545         type = le16_to_cpu(prop->type) & VIRTIO_IOMMU_PROBE_T_MASK;
546
547         while (type != VIRTIO_IOMMU_PROBE_T_NONE &&
548                cur < viommu->probe_size) {
549                 len = le16_to_cpu(prop->length) + sizeof(*prop);
550
551                 switch (type) {
552                 case VIRTIO_IOMMU_PROBE_T_RESV_MEM:
553                         ret = viommu_add_resv_mem(vdev, (void *)prop, len);
554                         break;
555                 default:
556                         dev_err(dev, "unknown viommu prop 0x%x\n", type);
557                 }
558
559                 if (ret)
560                         dev_err(dev, "failed to parse viommu prop 0x%x\n", type);
561
562                 cur += len;
563                 if (cur >= viommu->probe_size)
564                         break;
565
566                 prop = (void *)probe->properties + cur;
567                 type = le16_to_cpu(prop->type) & VIRTIO_IOMMU_PROBE_T_MASK;
568         }
569
570 out_free:
571         kfree(probe);
572         return ret;
573 }
574
575 static int viommu_fault_handler(struct viommu_dev *viommu,
576                                 struct virtio_iommu_fault *fault)
577 {
578         char *reason_str;
579
580         u8 reason       = fault->reason;
581         u32 flags       = le32_to_cpu(fault->flags);
582         u32 endpoint    = le32_to_cpu(fault->endpoint);
583         u64 address     = le64_to_cpu(fault->address);
584
585         switch (reason) {
586         case VIRTIO_IOMMU_FAULT_R_DOMAIN:
587                 reason_str = "domain";
588                 break;
589         case VIRTIO_IOMMU_FAULT_R_MAPPING:
590                 reason_str = "page";
591                 break;
592         case VIRTIO_IOMMU_FAULT_R_UNKNOWN:
593         default:
594                 reason_str = "unknown";
595                 break;
596         }
597
598         /* TODO: find EP by ID and report_iommu_fault */
599         if (flags & VIRTIO_IOMMU_FAULT_F_ADDRESS)
600                 dev_err_ratelimited(viommu->dev, "%s fault from EP %u at %#llx [%s%s%s]\n",
601                                     reason_str, endpoint, address,
602                                     flags & VIRTIO_IOMMU_FAULT_F_READ ? "R" : "",
603                                     flags & VIRTIO_IOMMU_FAULT_F_WRITE ? "W" : "",
604                                     flags & VIRTIO_IOMMU_FAULT_F_EXEC ? "X" : "");
605         else
606                 dev_err_ratelimited(viommu->dev, "%s fault from EP %u\n",
607                                     reason_str, endpoint);
608         return 0;
609 }
610
611 static void viommu_event_handler(struct virtqueue *vq)
612 {
613         int ret;
614         unsigned int len;
615         struct scatterlist sg[1];
616         struct viommu_event *evt;
617         struct viommu_dev *viommu = vq->vdev->priv;
618
619         while ((evt = virtqueue_get_buf(vq, &len)) != NULL) {
620                 if (len > sizeof(*evt)) {
621                         dev_err(viommu->dev,
622                                 "invalid event buffer (len %u != %zu)\n",
623                                 len, sizeof(*evt));
624                 } else if (!(evt->head & VIOMMU_FAULT_RESV_MASK)) {
625                         viommu_fault_handler(viommu, &evt->fault);
626                 }
627
628                 sg_init_one(sg, evt, sizeof(*evt));
629                 ret = virtqueue_add_inbuf(vq, sg, 1, evt, GFP_ATOMIC);
630                 if (ret)
631                         dev_err(viommu->dev, "could not add event buffer\n");
632         }
633
634         virtqueue_kick(vq);
635 }
636
637 /* IOMMU API */
638
639 static struct iommu_domain *viommu_domain_alloc(unsigned type)
640 {
641         struct viommu_domain *vdomain;
642
643         if (type != IOMMU_DOMAIN_UNMANAGED &&
644             type != IOMMU_DOMAIN_DMA &&
645             type != IOMMU_DOMAIN_IDENTITY)
646                 return NULL;
647
648         vdomain = kzalloc(sizeof(*vdomain), GFP_KERNEL);
649         if (!vdomain)
650                 return NULL;
651
652         mutex_init(&vdomain->mutex);
653         spin_lock_init(&vdomain->mappings_lock);
654         vdomain->mappings = RB_ROOT_CACHED;
655
656         return &vdomain->domain;
657 }
658
659 static int viommu_domain_finalise(struct viommu_endpoint *vdev,
660                                   struct iommu_domain *domain)
661 {
662         int ret;
663         unsigned long viommu_page_size;
664         struct viommu_dev *viommu = vdev->viommu;
665         struct viommu_domain *vdomain = to_viommu_domain(domain);
666
667         viommu_page_size = 1UL << __ffs(viommu->pgsize_bitmap);
668         if (viommu_page_size > PAGE_SIZE) {
669                 dev_err(vdev->dev,
670                         "granule 0x%lx larger than system page size 0x%lx\n",
671                         viommu_page_size, PAGE_SIZE);
672                 return -EINVAL;
673         }
674
675         ret = ida_alloc_range(&viommu->domain_ids, viommu->first_domain,
676                               viommu->last_domain, GFP_KERNEL);
677         if (ret < 0)
678                 return ret;
679
680         vdomain->id             = (unsigned int)ret;
681
682         domain->pgsize_bitmap   = viommu->pgsize_bitmap;
683         domain->geometry        = viommu->geometry;
684
685         vdomain->map_flags      = viommu->map_flags;
686         vdomain->viommu         = viommu;
687
688         if (domain->type == IOMMU_DOMAIN_IDENTITY) {
689                 if (virtio_has_feature(viommu->vdev,
690                                        VIRTIO_IOMMU_F_BYPASS_CONFIG)) {
691                         vdomain->bypass = true;
692                         return 0;
693                 }
694
695                 ret = viommu_domain_map_identity(vdev, vdomain);
696                 if (ret) {
697                         ida_free(&viommu->domain_ids, vdomain->id);
698                         vdomain->viommu = NULL;
699                         return -EOPNOTSUPP;
700                 }
701         }
702
703         return 0;
704 }
705
706 static void viommu_domain_free(struct iommu_domain *domain)
707 {
708         struct viommu_domain *vdomain = to_viommu_domain(domain);
709
710         /* Free all remaining mappings */
711         viommu_del_mappings(vdomain, 0, ULLONG_MAX);
712
713         if (vdomain->viommu)
714                 ida_free(&vdomain->viommu->domain_ids, vdomain->id);
715
716         kfree(vdomain);
717 }
718
719 static int viommu_attach_dev(struct iommu_domain *domain, struct device *dev)
720 {
721         int i;
722         int ret = 0;
723         struct virtio_iommu_req_attach req;
724         struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
725         struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
726         struct viommu_domain *vdomain = to_viommu_domain(domain);
727
728         mutex_lock(&vdomain->mutex);
729         if (!vdomain->viommu) {
730                 /*
731                  * Properly initialize the domain now that we know which viommu
732                  * owns it.
733                  */
734                 ret = viommu_domain_finalise(vdev, domain);
735         } else if (vdomain->viommu != vdev->viommu) {
736                 dev_err(dev, "cannot attach to foreign vIOMMU\n");
737                 ret = -EXDEV;
738         }
739         mutex_unlock(&vdomain->mutex);
740
741         if (ret)
742                 return ret;
743
744         /*
745          * In the virtio-iommu device, when attaching the endpoint to a new
746          * domain, it is detached from the old one and, if as a result the
747          * old domain isn't attached to any endpoint, all mappings are removed
748          * from the old domain and it is freed.
749          *
750          * In the driver the old domain still exists, and its mappings will be
751          * recreated if it gets reattached to an endpoint. Otherwise it will be
752          * freed explicitly.
753          *
754          * vdev->vdomain is protected by group->mutex
755          */
756         if (vdev->vdomain)
757                 vdev->vdomain->nr_endpoints--;
758
759         req = (struct virtio_iommu_req_attach) {
760                 .head.type      = VIRTIO_IOMMU_T_ATTACH,
761                 .domain         = cpu_to_le32(vdomain->id),
762         };
763
764         if (vdomain->bypass)
765                 req.flags |= cpu_to_le32(VIRTIO_IOMMU_ATTACH_F_BYPASS);
766
767         for (i = 0; i < fwspec->num_ids; i++) {
768                 req.endpoint = cpu_to_le32(fwspec->ids[i]);
769
770                 ret = viommu_send_req_sync(vdomain->viommu, &req, sizeof(req));
771                 if (ret)
772                         return ret;
773         }
774
775         if (!vdomain->nr_endpoints) {
776                 /*
777                  * This endpoint is the first to be attached to the domain.
778                  * Replay existing mappings (e.g. SW MSI).
779                  */
780                 ret = viommu_replay_mappings(vdomain);
781                 if (ret)
782                         return ret;
783         }
784
785         vdomain->nr_endpoints++;
786         vdev->vdomain = vdomain;
787
788         return 0;
789 }
790
791 static int viommu_map(struct iommu_domain *domain, unsigned long iova,
792                       phys_addr_t paddr, size_t size, int prot, gfp_t gfp)
793 {
794         int ret;
795         u32 flags;
796         u64 end = iova + size - 1;
797         struct virtio_iommu_req_map map;
798         struct viommu_domain *vdomain = to_viommu_domain(domain);
799
800         flags = (prot & IOMMU_READ ? VIRTIO_IOMMU_MAP_F_READ : 0) |
801                 (prot & IOMMU_WRITE ? VIRTIO_IOMMU_MAP_F_WRITE : 0) |
802                 (prot & IOMMU_MMIO ? VIRTIO_IOMMU_MAP_F_MMIO : 0);
803
804         if (flags & ~vdomain->map_flags)
805                 return -EINVAL;
806
807         ret = viommu_add_mapping(vdomain, iova, end, paddr, flags);
808         if (ret)
809                 return ret;
810
811         map = (struct virtio_iommu_req_map) {
812                 .head.type      = VIRTIO_IOMMU_T_MAP,
813                 .domain         = cpu_to_le32(vdomain->id),
814                 .virt_start     = cpu_to_le64(iova),
815                 .phys_start     = cpu_to_le64(paddr),
816                 .virt_end       = cpu_to_le64(end),
817                 .flags          = cpu_to_le32(flags),
818         };
819
820         if (!vdomain->nr_endpoints)
821                 return 0;
822
823         ret = viommu_send_req_sync(vdomain->viommu, &map, sizeof(map));
824         if (ret)
825                 viommu_del_mappings(vdomain, iova, end);
826
827         return ret;
828 }
829
830 static size_t viommu_unmap(struct iommu_domain *domain, unsigned long iova,
831                            size_t size, struct iommu_iotlb_gather *gather)
832 {
833         int ret = 0;
834         size_t unmapped;
835         struct virtio_iommu_req_unmap unmap;
836         struct viommu_domain *vdomain = to_viommu_domain(domain);
837
838         unmapped = viommu_del_mappings(vdomain, iova, iova + size - 1);
839         if (unmapped < size)
840                 return 0;
841
842         /* Device already removed all mappings after detach. */
843         if (!vdomain->nr_endpoints)
844                 return unmapped;
845
846         unmap = (struct virtio_iommu_req_unmap) {
847                 .head.type      = VIRTIO_IOMMU_T_UNMAP,
848                 .domain         = cpu_to_le32(vdomain->id),
849                 .virt_start     = cpu_to_le64(iova),
850                 .virt_end       = cpu_to_le64(iova + unmapped - 1),
851         };
852
853         ret = viommu_add_req(vdomain->viommu, &unmap, sizeof(unmap));
854         return ret ? 0 : unmapped;
855 }
856
857 static phys_addr_t viommu_iova_to_phys(struct iommu_domain *domain,
858                                        dma_addr_t iova)
859 {
860         u64 paddr = 0;
861         unsigned long flags;
862         struct viommu_mapping *mapping;
863         struct interval_tree_node *node;
864         struct viommu_domain *vdomain = to_viommu_domain(domain);
865
866         spin_lock_irqsave(&vdomain->mappings_lock, flags);
867         node = interval_tree_iter_first(&vdomain->mappings, iova, iova);
868         if (node) {
869                 mapping = container_of(node, struct viommu_mapping, iova);
870                 paddr = mapping->paddr + (iova - mapping->iova.start);
871         }
872         spin_unlock_irqrestore(&vdomain->mappings_lock, flags);
873
874         return paddr;
875 }
876
877 static void viommu_iotlb_sync(struct iommu_domain *domain,
878                               struct iommu_iotlb_gather *gather)
879 {
880         struct viommu_domain *vdomain = to_viommu_domain(domain);
881
882         viommu_sync_req(vdomain->viommu);
883 }
884
885 static void viommu_get_resv_regions(struct device *dev, struct list_head *head)
886 {
887         struct iommu_resv_region *entry, *new_entry, *msi = NULL;
888         struct viommu_endpoint *vdev = dev_iommu_priv_get(dev);
889         int prot = IOMMU_WRITE | IOMMU_NOEXEC | IOMMU_MMIO;
890
891         list_for_each_entry(entry, &vdev->resv_regions, list) {
892                 if (entry->type == IOMMU_RESV_MSI)
893                         msi = entry;
894
895                 new_entry = kmemdup(entry, sizeof(*entry), GFP_KERNEL);
896                 if (!new_entry)
897                         return;
898                 list_add_tail(&new_entry->list, head);
899         }
900
901         /*
902          * If the device didn't register any bypass MSI window, add a
903          * software-mapped region.
904          */
905         if (!msi) {
906                 msi = iommu_alloc_resv_region(MSI_IOVA_BASE, MSI_IOVA_LENGTH,
907                                               prot, IOMMU_RESV_SW_MSI);
908                 if (!msi)
909                         return;
910
911                 list_add_tail(&msi->list, head);
912         }
913
914         iommu_dma_get_resv_regions(dev, head);
915 }
916
917 static struct iommu_ops viommu_ops;
918 static struct virtio_driver virtio_iommu_drv;
919
920 static int viommu_match_node(struct device *dev, const void *data)
921 {
922         return dev->parent->fwnode == data;
923 }
924
925 static struct viommu_dev *viommu_get_by_fwnode(struct fwnode_handle *fwnode)
926 {
927         struct device *dev = driver_find_device(&virtio_iommu_drv.driver, NULL,
928                                                 fwnode, viommu_match_node);
929         put_device(dev);
930
931         return dev ? dev_to_virtio(dev)->priv : NULL;
932 }
933
934 static struct iommu_device *viommu_probe_device(struct device *dev)
935 {
936         int ret;
937         struct viommu_endpoint *vdev;
938         struct viommu_dev *viommu = NULL;
939         struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
940
941         if (!fwspec || fwspec->ops != &viommu_ops)
942                 return ERR_PTR(-ENODEV);
943
944         viommu = viommu_get_by_fwnode(fwspec->iommu_fwnode);
945         if (!viommu)
946                 return ERR_PTR(-ENODEV);
947
948         vdev = kzalloc(sizeof(*vdev), GFP_KERNEL);
949         if (!vdev)
950                 return ERR_PTR(-ENOMEM);
951
952         vdev->dev = dev;
953         vdev->viommu = viommu;
954         INIT_LIST_HEAD(&vdev->resv_regions);
955         dev_iommu_priv_set(dev, vdev);
956
957         if (viommu->probe_size) {
958                 /* Get additional information for this endpoint */
959                 ret = viommu_probe_endpoint(viommu, dev);
960                 if (ret)
961                         goto err_free_dev;
962         }
963
964         return &viommu->iommu;
965
966 err_free_dev:
967         generic_iommu_put_resv_regions(dev, &vdev->resv_regions);
968         kfree(vdev);
969
970         return ERR_PTR(ret);
971 }
972
973 static void viommu_probe_finalize(struct device *dev)
974 {
975 #ifndef CONFIG_ARCH_HAS_SETUP_DMA_OPS
976         /* First clear the DMA ops in case we're switching from a DMA domain */
977         set_dma_ops(dev, NULL);
978         iommu_setup_dma_ops(dev, 0, U64_MAX);
979 #endif
980 }
981
982 static void viommu_release_device(struct device *dev)
983 {
984         struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
985         struct viommu_endpoint *vdev;
986
987         if (!fwspec || fwspec->ops != &viommu_ops)
988                 return;
989
990         vdev = dev_iommu_priv_get(dev);
991
992         generic_iommu_put_resv_regions(dev, &vdev->resv_regions);
993         kfree(vdev);
994 }
995
996 static struct iommu_group *viommu_device_group(struct device *dev)
997 {
998         if (dev_is_pci(dev))
999                 return pci_device_group(dev);
1000         else
1001                 return generic_device_group(dev);
1002 }
1003
1004 static int viommu_of_xlate(struct device *dev, struct of_phandle_args *args)
1005 {
1006         return iommu_fwspec_add_ids(dev, args->args, 1);
1007 }
1008
1009 static struct iommu_ops viommu_ops = {
1010         .domain_alloc           = viommu_domain_alloc,
1011         .probe_device           = viommu_probe_device,
1012         .probe_finalize         = viommu_probe_finalize,
1013         .release_device         = viommu_release_device,
1014         .device_group           = viommu_device_group,
1015         .get_resv_regions       = viommu_get_resv_regions,
1016         .put_resv_regions       = generic_iommu_put_resv_regions,
1017         .of_xlate               = viommu_of_xlate,
1018         .owner                  = THIS_MODULE,
1019         .default_domain_ops = &(const struct iommu_domain_ops) {
1020                 .attach_dev             = viommu_attach_dev,
1021                 .map                    = viommu_map,
1022                 .unmap                  = viommu_unmap,
1023                 .iova_to_phys           = viommu_iova_to_phys,
1024                 .iotlb_sync             = viommu_iotlb_sync,
1025                 .free                   = viommu_domain_free,
1026         }
1027 };
1028
1029 static int viommu_init_vqs(struct viommu_dev *viommu)
1030 {
1031         struct virtio_device *vdev = dev_to_virtio(viommu->dev);
1032         const char *names[] = { "request", "event" };
1033         vq_callback_t *callbacks[] = {
1034                 NULL, /* No async requests */
1035                 viommu_event_handler,
1036         };
1037
1038         return virtio_find_vqs(vdev, VIOMMU_NR_VQS, viommu->vqs, callbacks,
1039                                names, NULL);
1040 }
1041
1042 static int viommu_fill_evtq(struct viommu_dev *viommu)
1043 {
1044         int i, ret;
1045         struct scatterlist sg[1];
1046         struct viommu_event *evts;
1047         struct virtqueue *vq = viommu->vqs[VIOMMU_EVENT_VQ];
1048         size_t nr_evts = vq->num_free;
1049
1050         viommu->evts = evts = devm_kmalloc_array(viommu->dev, nr_evts,
1051                                                  sizeof(*evts), GFP_KERNEL);
1052         if (!evts)
1053                 return -ENOMEM;
1054
1055         for (i = 0; i < nr_evts; i++) {
1056                 sg_init_one(sg, &evts[i], sizeof(*evts));
1057                 ret = virtqueue_add_inbuf(vq, sg, 1, &evts[i], GFP_KERNEL);
1058                 if (ret)
1059                         return ret;
1060         }
1061
1062         return 0;
1063 }
1064
1065 static int viommu_probe(struct virtio_device *vdev)
1066 {
1067         struct device *parent_dev = vdev->dev.parent;
1068         struct viommu_dev *viommu = NULL;
1069         struct device *dev = &vdev->dev;
1070         u64 input_start = 0;
1071         u64 input_end = -1UL;
1072         int ret;
1073
1074         if (!virtio_has_feature(vdev, VIRTIO_F_VERSION_1) ||
1075             !virtio_has_feature(vdev, VIRTIO_IOMMU_F_MAP_UNMAP))
1076                 return -ENODEV;
1077
1078         viommu = devm_kzalloc(dev, sizeof(*viommu), GFP_KERNEL);
1079         if (!viommu)
1080                 return -ENOMEM;
1081
1082         spin_lock_init(&viommu->request_lock);
1083         ida_init(&viommu->domain_ids);
1084         viommu->dev = dev;
1085         viommu->vdev = vdev;
1086         INIT_LIST_HEAD(&viommu->requests);
1087
1088         ret = viommu_init_vqs(viommu);
1089         if (ret)
1090                 return ret;
1091
1092         virtio_cread_le(vdev, struct virtio_iommu_config, page_size_mask,
1093                         &viommu->pgsize_bitmap);
1094
1095         if (!viommu->pgsize_bitmap) {
1096                 ret = -EINVAL;
1097                 goto err_free_vqs;
1098         }
1099
1100         viommu->map_flags = VIRTIO_IOMMU_MAP_F_READ | VIRTIO_IOMMU_MAP_F_WRITE;
1101         viommu->last_domain = ~0U;
1102
1103         /* Optional features */
1104         virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_INPUT_RANGE,
1105                                 struct virtio_iommu_config, input_range.start,
1106                                 &input_start);
1107
1108         virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_INPUT_RANGE,
1109                                 struct virtio_iommu_config, input_range.end,
1110                                 &input_end);
1111
1112         virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_DOMAIN_RANGE,
1113                                 struct virtio_iommu_config, domain_range.start,
1114                                 &viommu->first_domain);
1115
1116         virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_DOMAIN_RANGE,
1117                                 struct virtio_iommu_config, domain_range.end,
1118                                 &viommu->last_domain);
1119
1120         virtio_cread_le_feature(vdev, VIRTIO_IOMMU_F_PROBE,
1121                                 struct virtio_iommu_config, probe_size,
1122                                 &viommu->probe_size);
1123
1124         viommu->geometry = (struct iommu_domain_geometry) {
1125                 .aperture_start = input_start,
1126                 .aperture_end   = input_end,
1127                 .force_aperture = true,
1128         };
1129
1130         if (virtio_has_feature(vdev, VIRTIO_IOMMU_F_MMIO))
1131                 viommu->map_flags |= VIRTIO_IOMMU_MAP_F_MMIO;
1132
1133         viommu_ops.pgsize_bitmap = viommu->pgsize_bitmap;
1134
1135         virtio_device_ready(vdev);
1136
1137         /* Populate the event queue with buffers */
1138         ret = viommu_fill_evtq(viommu);
1139         if (ret)
1140                 goto err_free_vqs;
1141
1142         ret = iommu_device_sysfs_add(&viommu->iommu, dev, NULL, "%s",
1143                                      virtio_bus_name(vdev));
1144         if (ret)
1145                 goto err_free_vqs;
1146
1147         iommu_device_register(&viommu->iommu, &viommu_ops, parent_dev);
1148
1149 #ifdef CONFIG_PCI
1150         if (pci_bus_type.iommu_ops != &viommu_ops) {
1151                 ret = bus_set_iommu(&pci_bus_type, &viommu_ops);
1152                 if (ret)
1153                         goto err_unregister;
1154         }
1155 #endif
1156 #ifdef CONFIG_ARM_AMBA
1157         if (amba_bustype.iommu_ops != &viommu_ops) {
1158                 ret = bus_set_iommu(&amba_bustype, &viommu_ops);
1159                 if (ret)
1160                         goto err_unregister;
1161         }
1162 #endif
1163         if (platform_bus_type.iommu_ops != &viommu_ops) {
1164                 ret = bus_set_iommu(&platform_bus_type, &viommu_ops);
1165                 if (ret)
1166                         goto err_unregister;
1167         }
1168
1169         vdev->priv = viommu;
1170
1171         dev_info(dev, "input address: %u bits\n",
1172                  order_base_2(viommu->geometry.aperture_end));
1173         dev_info(dev, "page mask: %#llx\n", viommu->pgsize_bitmap);
1174
1175         return 0;
1176
1177 err_unregister:
1178         iommu_device_sysfs_remove(&viommu->iommu);
1179         iommu_device_unregister(&viommu->iommu);
1180 err_free_vqs:
1181         vdev->config->del_vqs(vdev);
1182
1183         return ret;
1184 }
1185
1186 static void viommu_remove(struct virtio_device *vdev)
1187 {
1188         struct viommu_dev *viommu = vdev->priv;
1189
1190         iommu_device_sysfs_remove(&viommu->iommu);
1191         iommu_device_unregister(&viommu->iommu);
1192
1193         /* Stop all virtqueues */
1194         virtio_reset_device(vdev);
1195         vdev->config->del_vqs(vdev);
1196
1197         dev_info(&vdev->dev, "device removed\n");
1198 }
1199
1200 static void viommu_config_changed(struct virtio_device *vdev)
1201 {
1202         dev_warn(&vdev->dev, "config changed\n");
1203 }
1204
1205 static unsigned int features[] = {
1206         VIRTIO_IOMMU_F_MAP_UNMAP,
1207         VIRTIO_IOMMU_F_INPUT_RANGE,
1208         VIRTIO_IOMMU_F_DOMAIN_RANGE,
1209         VIRTIO_IOMMU_F_PROBE,
1210         VIRTIO_IOMMU_F_MMIO,
1211         VIRTIO_IOMMU_F_BYPASS_CONFIG,
1212 };
1213
1214 static struct virtio_device_id id_table[] = {
1215         { VIRTIO_ID_IOMMU, VIRTIO_DEV_ANY_ID },
1216         { 0 },
1217 };
1218 MODULE_DEVICE_TABLE(virtio, id_table);
1219
1220 static struct virtio_driver virtio_iommu_drv = {
1221         .driver.name            = KBUILD_MODNAME,
1222         .driver.owner           = THIS_MODULE,
1223         .id_table               = id_table,
1224         .feature_table          = features,
1225         .feature_table_size     = ARRAY_SIZE(features),
1226         .probe                  = viommu_probe,
1227         .remove                 = viommu_remove,
1228         .config_changed         = viommu_config_changed,
1229 };
1230
1231 module_virtio_driver(virtio_iommu_drv);
1232
1233 MODULE_DESCRIPTION("Virtio IOMMU driver");
1234 MODULE_AUTHOR("Jean-Philippe Brucker <jean-philippe.brucker@arm.com>");
1235 MODULE_LICENSE("GPL v2");