GNU Linux-libre 5.19-rc6-gnu
[releases.git] / drivers / iommu / amd / iommu_v2.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
4  * Author: Joerg Roedel <jroedel@suse.de>
5  */
6
7 #define pr_fmt(fmt)     "AMD-Vi: " fmt
8
9 #include <linux/refcount.h>
10 #include <linux/mmu_notifier.h>
11 #include <linux/amd-iommu.h>
12 #include <linux/mm_types.h>
13 #include <linux/profile.h>
14 #include <linux/module.h>
15 #include <linux/sched.h>
16 #include <linux/sched/mm.h>
17 #include <linux/wait.h>
18 #include <linux/pci.h>
19 #include <linux/gfp.h>
20 #include <linux/cc_platform.h>
21
22 #include "amd_iommu.h"
23
24 MODULE_LICENSE("GPL v2");
25 MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
26
27 #define PRI_QUEUE_SIZE          512
28
29 struct pri_queue {
30         atomic_t inflight;
31         bool finish;
32         int status;
33 };
34
35 struct pasid_state {
36         struct list_head list;                  /* For global state-list */
37         refcount_t count;                               /* Reference count */
38         unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
39                                                    calls */
40         struct mm_struct *mm;                   /* mm_struct for the faults */
41         struct mmu_notifier mn;                 /* mmu_notifier handle */
42         struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
43         struct device_state *device_state;      /* Link to our device_state */
44         u32 pasid;                              /* PASID index */
45         bool invalid;                           /* Used during setup and
46                                                    teardown of the pasid */
47         spinlock_t lock;                        /* Protect pri_queues and
48                                                    mmu_notifer_count */
49         wait_queue_head_t wq;                   /* To wait for count == 0 */
50 };
51
52 struct device_state {
53         struct list_head list;
54         u16 devid;
55         atomic_t count;
56         struct pci_dev *pdev;
57         struct pasid_state **states;
58         struct iommu_domain *domain;
59         int pasid_levels;
60         int max_pasids;
61         amd_iommu_invalid_ppr_cb inv_ppr_cb;
62         amd_iommu_invalidate_ctx inv_ctx_cb;
63         spinlock_t lock;
64         wait_queue_head_t wq;
65 };
66
67 struct fault {
68         struct work_struct work;
69         struct device_state *dev_state;
70         struct pasid_state *state;
71         struct mm_struct *mm;
72         u64 address;
73         u32 pasid;
74         u16 tag;
75         u16 finish;
76         u16 flags;
77 };
78
79 static LIST_HEAD(state_list);
80 static DEFINE_SPINLOCK(state_lock);
81
82 static struct workqueue_struct *iommu_wq;
83
84 static void free_pasid_states(struct device_state *dev_state);
85
86 static u16 device_id(struct pci_dev *pdev)
87 {
88         u16 devid;
89
90         devid = pdev->bus->number;
91         devid = (devid << 8) | pdev->devfn;
92
93         return devid;
94 }
95
96 static struct device_state *__get_device_state(u16 devid)
97 {
98         struct device_state *dev_state;
99
100         list_for_each_entry(dev_state, &state_list, list) {
101                 if (dev_state->devid == devid)
102                         return dev_state;
103         }
104
105         return NULL;
106 }
107
108 static struct device_state *get_device_state(u16 devid)
109 {
110         struct device_state *dev_state;
111         unsigned long flags;
112
113         spin_lock_irqsave(&state_lock, flags);
114         dev_state = __get_device_state(devid);
115         if (dev_state != NULL)
116                 atomic_inc(&dev_state->count);
117         spin_unlock_irqrestore(&state_lock, flags);
118
119         return dev_state;
120 }
121
122 static void free_device_state(struct device_state *dev_state)
123 {
124         struct iommu_group *group;
125
126         /* Get rid of any remaining pasid states */
127         free_pasid_states(dev_state);
128
129         /*
130          * Wait until the last reference is dropped before freeing
131          * the device state.
132          */
133         wait_event(dev_state->wq, !atomic_read(&dev_state->count));
134
135         /*
136          * First detach device from domain - No more PRI requests will arrive
137          * from that device after it is unbound from the IOMMUv2 domain.
138          */
139         group = iommu_group_get(&dev_state->pdev->dev);
140         if (WARN_ON(!group))
141                 return;
142
143         iommu_detach_group(dev_state->domain, group);
144
145         iommu_group_put(group);
146
147         /* Everything is down now, free the IOMMUv2 domain */
148         iommu_domain_free(dev_state->domain);
149
150         /* Finally get rid of the device-state */
151         kfree(dev_state);
152 }
153
154 static void put_device_state(struct device_state *dev_state)
155 {
156         if (atomic_dec_and_test(&dev_state->count))
157                 wake_up(&dev_state->wq);
158 }
159
160 /* Must be called under dev_state->lock */
161 static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
162                                                   u32 pasid, bool alloc)
163 {
164         struct pasid_state **root, **ptr;
165         int level, index;
166
167         level = dev_state->pasid_levels;
168         root  = dev_state->states;
169
170         while (true) {
171
172                 index = (pasid >> (9 * level)) & 0x1ff;
173                 ptr   = &root[index];
174
175                 if (level == 0)
176                         break;
177
178                 if (*ptr == NULL) {
179                         if (!alloc)
180                                 return NULL;
181
182                         *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
183                         if (*ptr == NULL)
184                                 return NULL;
185                 }
186
187                 root   = (struct pasid_state **)*ptr;
188                 level -= 1;
189         }
190
191         return ptr;
192 }
193
194 static int set_pasid_state(struct device_state *dev_state,
195                            struct pasid_state *pasid_state,
196                            u32 pasid)
197 {
198         struct pasid_state **ptr;
199         unsigned long flags;
200         int ret;
201
202         spin_lock_irqsave(&dev_state->lock, flags);
203         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
204
205         ret = -ENOMEM;
206         if (ptr == NULL)
207                 goto out_unlock;
208
209         ret = -ENOMEM;
210         if (*ptr != NULL)
211                 goto out_unlock;
212
213         *ptr = pasid_state;
214
215         ret = 0;
216
217 out_unlock:
218         spin_unlock_irqrestore(&dev_state->lock, flags);
219
220         return ret;
221 }
222
223 static void clear_pasid_state(struct device_state *dev_state, u32 pasid)
224 {
225         struct pasid_state **ptr;
226         unsigned long flags;
227
228         spin_lock_irqsave(&dev_state->lock, flags);
229         ptr = __get_pasid_state_ptr(dev_state, pasid, true);
230
231         if (ptr == NULL)
232                 goto out_unlock;
233
234         *ptr = NULL;
235
236 out_unlock:
237         spin_unlock_irqrestore(&dev_state->lock, flags);
238 }
239
240 static struct pasid_state *get_pasid_state(struct device_state *dev_state,
241                                            u32 pasid)
242 {
243         struct pasid_state **ptr, *ret = NULL;
244         unsigned long flags;
245
246         spin_lock_irqsave(&dev_state->lock, flags);
247         ptr = __get_pasid_state_ptr(dev_state, pasid, false);
248
249         if (ptr == NULL)
250                 goto out_unlock;
251
252         ret = *ptr;
253         if (ret)
254                 refcount_inc(&ret->count);
255
256 out_unlock:
257         spin_unlock_irqrestore(&dev_state->lock, flags);
258
259         return ret;
260 }
261
262 static void free_pasid_state(struct pasid_state *pasid_state)
263 {
264         kfree(pasid_state);
265 }
266
267 static void put_pasid_state(struct pasid_state *pasid_state)
268 {
269         if (refcount_dec_and_test(&pasid_state->count))
270                 wake_up(&pasid_state->wq);
271 }
272
273 static void put_pasid_state_wait(struct pasid_state *pasid_state)
274 {
275         refcount_dec(&pasid_state->count);
276         wait_event(pasid_state->wq, !refcount_read(&pasid_state->count));
277         free_pasid_state(pasid_state);
278 }
279
280 static void unbind_pasid(struct pasid_state *pasid_state)
281 {
282         struct iommu_domain *domain;
283
284         domain = pasid_state->device_state->domain;
285
286         /*
287          * Mark pasid_state as invalid, no more faults will we added to the
288          * work queue after this is visible everywhere.
289          */
290         pasid_state->invalid = true;
291
292         /* Make sure this is visible */
293         smp_wmb();
294
295         /* After this the device/pasid can't access the mm anymore */
296         amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
297
298         /* Make sure no more pending faults are in the queue */
299         flush_workqueue(iommu_wq);
300 }
301
302 static void free_pasid_states_level1(struct pasid_state **tbl)
303 {
304         int i;
305
306         for (i = 0; i < 512; ++i) {
307                 if (tbl[i] == NULL)
308                         continue;
309
310                 free_page((unsigned long)tbl[i]);
311         }
312 }
313
314 static void free_pasid_states_level2(struct pasid_state **tbl)
315 {
316         struct pasid_state **ptr;
317         int i;
318
319         for (i = 0; i < 512; ++i) {
320                 if (tbl[i] == NULL)
321                         continue;
322
323                 ptr = (struct pasid_state **)tbl[i];
324                 free_pasid_states_level1(ptr);
325         }
326 }
327
328 static void free_pasid_states(struct device_state *dev_state)
329 {
330         struct pasid_state *pasid_state;
331         int i;
332
333         for (i = 0; i < dev_state->max_pasids; ++i) {
334                 pasid_state = get_pasid_state(dev_state, i);
335                 if (pasid_state == NULL)
336                         continue;
337
338                 put_pasid_state(pasid_state);
339
340                 /*
341                  * This will call the mn_release function and
342                  * unbind the PASID
343                  */
344                 mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
345
346                 put_pasid_state_wait(pasid_state); /* Reference taken in
347                                                       amd_iommu_bind_pasid */
348
349                 /* Drop reference taken in amd_iommu_bind_pasid */
350                 put_device_state(dev_state);
351         }
352
353         if (dev_state->pasid_levels == 2)
354                 free_pasid_states_level2(dev_state->states);
355         else if (dev_state->pasid_levels == 1)
356                 free_pasid_states_level1(dev_state->states);
357         else
358                 BUG_ON(dev_state->pasid_levels != 0);
359
360         free_page((unsigned long)dev_state->states);
361 }
362
363 static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
364 {
365         return container_of(mn, struct pasid_state, mn);
366 }
367
368 static void mn_invalidate_range(struct mmu_notifier *mn,
369                                 struct mm_struct *mm,
370                                 unsigned long start, unsigned long end)
371 {
372         struct pasid_state *pasid_state;
373         struct device_state *dev_state;
374
375         pasid_state = mn_to_state(mn);
376         dev_state   = pasid_state->device_state;
377
378         if ((start ^ (end - 1)) < PAGE_SIZE)
379                 amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
380                                      start);
381         else
382                 amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
383 }
384
385 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
386 {
387         struct pasid_state *pasid_state;
388         struct device_state *dev_state;
389         bool run_inv_ctx_cb;
390
391         might_sleep();
392
393         pasid_state    = mn_to_state(mn);
394         dev_state      = pasid_state->device_state;
395         run_inv_ctx_cb = !pasid_state->invalid;
396
397         if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
398                 dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
399
400         unbind_pasid(pasid_state);
401 }
402
403 static const struct mmu_notifier_ops iommu_mn = {
404         .release                = mn_release,
405         .invalidate_range       = mn_invalidate_range,
406 };
407
408 static void set_pri_tag_status(struct pasid_state *pasid_state,
409                                u16 tag, int status)
410 {
411         unsigned long flags;
412
413         spin_lock_irqsave(&pasid_state->lock, flags);
414         pasid_state->pri[tag].status = status;
415         spin_unlock_irqrestore(&pasid_state->lock, flags);
416 }
417
418 static void finish_pri_tag(struct device_state *dev_state,
419                            struct pasid_state *pasid_state,
420                            u16 tag)
421 {
422         unsigned long flags;
423
424         spin_lock_irqsave(&pasid_state->lock, flags);
425         if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
426             pasid_state->pri[tag].finish) {
427                 amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
428                                        pasid_state->pri[tag].status, tag);
429                 pasid_state->pri[tag].finish = false;
430                 pasid_state->pri[tag].status = PPR_SUCCESS;
431         }
432         spin_unlock_irqrestore(&pasid_state->lock, flags);
433 }
434
435 static void handle_fault_error(struct fault *fault)
436 {
437         int status;
438
439         if (!fault->dev_state->inv_ppr_cb) {
440                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
441                 return;
442         }
443
444         status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
445                                               fault->pasid,
446                                               fault->address,
447                                               fault->flags);
448         switch (status) {
449         case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
450                 set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
451                 break;
452         case AMD_IOMMU_INV_PRI_RSP_INVALID:
453                 set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
454                 break;
455         case AMD_IOMMU_INV_PRI_RSP_FAIL:
456                 set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
457                 break;
458         default:
459                 BUG();
460         }
461 }
462
463 static bool access_error(struct vm_area_struct *vma, struct fault *fault)
464 {
465         unsigned long requested = 0;
466
467         if (fault->flags & PPR_FAULT_EXEC)
468                 requested |= VM_EXEC;
469
470         if (fault->flags & PPR_FAULT_READ)
471                 requested |= VM_READ;
472
473         if (fault->flags & PPR_FAULT_WRITE)
474                 requested |= VM_WRITE;
475
476         return (requested & ~vma->vm_flags) != 0;
477 }
478
479 static void do_fault(struct work_struct *work)
480 {
481         struct fault *fault = container_of(work, struct fault, work);
482         struct vm_area_struct *vma;
483         vm_fault_t ret = VM_FAULT_ERROR;
484         unsigned int flags = 0;
485         struct mm_struct *mm;
486         u64 address;
487
488         mm = fault->state->mm;
489         address = fault->address;
490
491         if (fault->flags & PPR_FAULT_USER)
492                 flags |= FAULT_FLAG_USER;
493         if (fault->flags & PPR_FAULT_WRITE)
494                 flags |= FAULT_FLAG_WRITE;
495         flags |= FAULT_FLAG_REMOTE;
496
497         mmap_read_lock(mm);
498         vma = find_extend_vma(mm, address);
499         if (!vma || address < vma->vm_start)
500                 /* failed to get a vma in the right range */
501                 goto out;
502
503         /* Check if we have the right permissions on the vma */
504         if (access_error(vma, fault))
505                 goto out;
506
507         ret = handle_mm_fault(vma, address, flags, NULL);
508 out:
509         mmap_read_unlock(mm);
510
511         if (ret & VM_FAULT_ERROR)
512                 /* failed to service fault */
513                 handle_fault_error(fault);
514
515         finish_pri_tag(fault->dev_state, fault->state, fault->tag);
516
517         put_pasid_state(fault->state);
518
519         kfree(fault);
520 }
521
522 static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
523 {
524         struct amd_iommu_fault *iommu_fault;
525         struct pasid_state *pasid_state;
526         struct device_state *dev_state;
527         struct pci_dev *pdev = NULL;
528         unsigned long flags;
529         struct fault *fault;
530         bool finish;
531         u16 tag, devid;
532         int ret;
533
534         iommu_fault = data;
535         tag         = iommu_fault->tag & 0x1ff;
536         finish      = (iommu_fault->tag >> 9) & 1;
537
538         devid = iommu_fault->device_id;
539         pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
540                                            devid & 0xff);
541         if (!pdev)
542                 return -ENODEV;
543
544         ret = NOTIFY_DONE;
545
546         /* In kdump kernel pci dev is not initialized yet -> send INVALID */
547         if (amd_iommu_is_attach_deferred(&pdev->dev)) {
548                 amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
549                                        PPR_INVALID, tag);
550                 goto out;
551         }
552
553         dev_state = get_device_state(iommu_fault->device_id);
554         if (dev_state == NULL)
555                 goto out;
556
557         pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
558         if (pasid_state == NULL || pasid_state->invalid) {
559                 /* We know the device but not the PASID -> send INVALID */
560                 amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
561                                        PPR_INVALID, tag);
562                 goto out_drop_state;
563         }
564
565         spin_lock_irqsave(&pasid_state->lock, flags);
566         atomic_inc(&pasid_state->pri[tag].inflight);
567         if (finish)
568                 pasid_state->pri[tag].finish = true;
569         spin_unlock_irqrestore(&pasid_state->lock, flags);
570
571         fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
572         if (fault == NULL) {
573                 /* We are OOM - send success and let the device re-fault */
574                 finish_pri_tag(dev_state, pasid_state, tag);
575                 goto out_drop_state;
576         }
577
578         fault->dev_state = dev_state;
579         fault->address   = iommu_fault->address;
580         fault->state     = pasid_state;
581         fault->tag       = tag;
582         fault->finish    = finish;
583         fault->pasid     = iommu_fault->pasid;
584         fault->flags     = iommu_fault->flags;
585         INIT_WORK(&fault->work, do_fault);
586
587         queue_work(iommu_wq, &fault->work);
588
589         ret = NOTIFY_OK;
590
591 out_drop_state:
592
593         if (ret != NOTIFY_OK && pasid_state)
594                 put_pasid_state(pasid_state);
595
596         put_device_state(dev_state);
597
598 out:
599         return ret;
600 }
601
602 static struct notifier_block ppr_nb = {
603         .notifier_call = ppr_notifier,
604 };
605
606 int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
607                          struct task_struct *task)
608 {
609         struct pasid_state *pasid_state;
610         struct device_state *dev_state;
611         struct mm_struct *mm;
612         u16 devid;
613         int ret;
614
615         might_sleep();
616
617         if (!amd_iommu_v2_supported())
618                 return -ENODEV;
619
620         devid     = device_id(pdev);
621         dev_state = get_device_state(devid);
622
623         if (dev_state == NULL)
624                 return -EINVAL;
625
626         ret = -EINVAL;
627         if (pasid >= dev_state->max_pasids)
628                 goto out;
629
630         ret = -ENOMEM;
631         pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
632         if (pasid_state == NULL)
633                 goto out;
634
635
636         refcount_set(&pasid_state->count, 1);
637         init_waitqueue_head(&pasid_state->wq);
638         spin_lock_init(&pasid_state->lock);
639
640         mm                        = get_task_mm(task);
641         pasid_state->mm           = mm;
642         pasid_state->device_state = dev_state;
643         pasid_state->pasid        = pasid;
644         pasid_state->invalid      = true; /* Mark as valid only if we are
645                                              done with setting up the pasid */
646         pasid_state->mn.ops       = &iommu_mn;
647
648         if (pasid_state->mm == NULL)
649                 goto out_free;
650
651         mmu_notifier_register(&pasid_state->mn, mm);
652
653         ret = set_pasid_state(dev_state, pasid_state, pasid);
654         if (ret)
655                 goto out_unregister;
656
657         ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
658                                         __pa(pasid_state->mm->pgd));
659         if (ret)
660                 goto out_clear_state;
661
662         /* Now we are ready to handle faults */
663         pasid_state->invalid = false;
664
665         /*
666          * Drop the reference to the mm_struct here. We rely on the
667          * mmu_notifier release call-back to inform us when the mm
668          * is going away.
669          */
670         mmput(mm);
671
672         return 0;
673
674 out_clear_state:
675         clear_pasid_state(dev_state, pasid);
676
677 out_unregister:
678         mmu_notifier_unregister(&pasid_state->mn, mm);
679         mmput(mm);
680
681 out_free:
682         free_pasid_state(pasid_state);
683
684 out:
685         put_device_state(dev_state);
686
687         return ret;
688 }
689 EXPORT_SYMBOL(amd_iommu_bind_pasid);
690
691 void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid)
692 {
693         struct pasid_state *pasid_state;
694         struct device_state *dev_state;
695         u16 devid;
696
697         might_sleep();
698
699         if (!amd_iommu_v2_supported())
700                 return;
701
702         devid = device_id(pdev);
703         dev_state = get_device_state(devid);
704         if (dev_state == NULL)
705                 return;
706
707         if (pasid >= dev_state->max_pasids)
708                 goto out;
709
710         pasid_state = get_pasid_state(dev_state, pasid);
711         if (pasid_state == NULL)
712                 goto out;
713         /*
714          * Drop reference taken here. We are safe because we still hold
715          * the reference taken in the amd_iommu_bind_pasid function.
716          */
717         put_pasid_state(pasid_state);
718
719         /* Clear the pasid state so that the pasid can be re-used */
720         clear_pasid_state(dev_state, pasid_state->pasid);
721
722         /*
723          * Call mmu_notifier_unregister to drop our reference
724          * to pasid_state->mm
725          */
726         mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
727
728         put_pasid_state_wait(pasid_state); /* Reference taken in
729                                               amd_iommu_bind_pasid */
730 out:
731         /* Drop reference taken in this function */
732         put_device_state(dev_state);
733
734         /* Drop reference taken in amd_iommu_bind_pasid */
735         put_device_state(dev_state);
736 }
737 EXPORT_SYMBOL(amd_iommu_unbind_pasid);
738
739 int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
740 {
741         struct device_state *dev_state;
742         struct iommu_group *group;
743         unsigned long flags;
744         int ret, tmp;
745         u16 devid;
746
747         might_sleep();
748
749         /*
750          * When memory encryption is active the device is likely not in a
751          * direct-mapped domain. Forbid using IOMMUv2 functionality for now.
752          */
753         if (cc_platform_has(CC_ATTR_MEM_ENCRYPT))
754                 return -ENODEV;
755
756         if (!amd_iommu_v2_supported())
757                 return -ENODEV;
758
759         if (pasids <= 0 || pasids > (PASID_MASK + 1))
760                 return -EINVAL;
761
762         devid = device_id(pdev);
763
764         dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
765         if (dev_state == NULL)
766                 return -ENOMEM;
767
768         spin_lock_init(&dev_state->lock);
769         init_waitqueue_head(&dev_state->wq);
770         dev_state->pdev  = pdev;
771         dev_state->devid = devid;
772
773         tmp = pasids;
774         for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
775                 dev_state->pasid_levels += 1;
776
777         atomic_set(&dev_state->count, 1);
778         dev_state->max_pasids = pasids;
779
780         ret = -ENOMEM;
781         dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
782         if (dev_state->states == NULL)
783                 goto out_free_dev_state;
784
785         dev_state->domain = iommu_domain_alloc(&pci_bus_type);
786         if (dev_state->domain == NULL)
787                 goto out_free_states;
788
789         amd_iommu_domain_direct_map(dev_state->domain);
790
791         ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
792         if (ret)
793                 goto out_free_domain;
794
795         group = iommu_group_get(&pdev->dev);
796         if (!group) {
797                 ret = -EINVAL;
798                 goto out_free_domain;
799         }
800
801         ret = iommu_attach_group(dev_state->domain, group);
802         if (ret != 0)
803                 goto out_drop_group;
804
805         iommu_group_put(group);
806
807         spin_lock_irqsave(&state_lock, flags);
808
809         if (__get_device_state(devid) != NULL) {
810                 spin_unlock_irqrestore(&state_lock, flags);
811                 ret = -EBUSY;
812                 goto out_free_domain;
813         }
814
815         list_add_tail(&dev_state->list, &state_list);
816
817         spin_unlock_irqrestore(&state_lock, flags);
818
819         return 0;
820
821 out_drop_group:
822         iommu_group_put(group);
823
824 out_free_domain:
825         iommu_domain_free(dev_state->domain);
826
827 out_free_states:
828         free_page((unsigned long)dev_state->states);
829
830 out_free_dev_state:
831         kfree(dev_state);
832
833         return ret;
834 }
835 EXPORT_SYMBOL(amd_iommu_init_device);
836
837 void amd_iommu_free_device(struct pci_dev *pdev)
838 {
839         struct device_state *dev_state;
840         unsigned long flags;
841         u16 devid;
842
843         if (!amd_iommu_v2_supported())
844                 return;
845
846         devid = device_id(pdev);
847
848         spin_lock_irqsave(&state_lock, flags);
849
850         dev_state = __get_device_state(devid);
851         if (dev_state == NULL) {
852                 spin_unlock_irqrestore(&state_lock, flags);
853                 return;
854         }
855
856         list_del(&dev_state->list);
857
858         spin_unlock_irqrestore(&state_lock, flags);
859
860         put_device_state(dev_state);
861         free_device_state(dev_state);
862 }
863 EXPORT_SYMBOL(amd_iommu_free_device);
864
865 int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
866                                  amd_iommu_invalid_ppr_cb cb)
867 {
868         struct device_state *dev_state;
869         unsigned long flags;
870         u16 devid;
871         int ret;
872
873         if (!amd_iommu_v2_supported())
874                 return -ENODEV;
875
876         devid = device_id(pdev);
877
878         spin_lock_irqsave(&state_lock, flags);
879
880         ret = -EINVAL;
881         dev_state = __get_device_state(devid);
882         if (dev_state == NULL)
883                 goto out_unlock;
884
885         dev_state->inv_ppr_cb = cb;
886
887         ret = 0;
888
889 out_unlock:
890         spin_unlock_irqrestore(&state_lock, flags);
891
892         return ret;
893 }
894 EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
895
896 int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
897                                     amd_iommu_invalidate_ctx cb)
898 {
899         struct device_state *dev_state;
900         unsigned long flags;
901         u16 devid;
902         int ret;
903
904         if (!amd_iommu_v2_supported())
905                 return -ENODEV;
906
907         devid = device_id(pdev);
908
909         spin_lock_irqsave(&state_lock, flags);
910
911         ret = -EINVAL;
912         dev_state = __get_device_state(devid);
913         if (dev_state == NULL)
914                 goto out_unlock;
915
916         dev_state->inv_ctx_cb = cb;
917
918         ret = 0;
919
920 out_unlock:
921         spin_unlock_irqrestore(&state_lock, flags);
922
923         return ret;
924 }
925 EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
926
927 static int __init amd_iommu_v2_init(void)
928 {
929         int ret;
930
931         if (!amd_iommu_v2_supported()) {
932                 pr_info("AMD IOMMUv2 functionality not available on this system - This is not a bug.\n");
933                 /*
934                  * Load anyway to provide the symbols to other modules
935                  * which may use AMD IOMMUv2 optionally.
936                  */
937                 return 0;
938         }
939
940         ret = -ENOMEM;
941         iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
942         if (iommu_wq == NULL)
943                 goto out;
944
945         amd_iommu_register_ppr_notifier(&ppr_nb);
946
947         pr_info("AMD IOMMUv2 loaded and initialized\n");
948
949         return 0;
950
951 out:
952         return ret;
953 }
954
955 static void __exit amd_iommu_v2_exit(void)
956 {
957         struct device_state *dev_state, *next;
958         unsigned long flags;
959         LIST_HEAD(freelist);
960
961         if (!amd_iommu_v2_supported())
962                 return;
963
964         amd_iommu_unregister_ppr_notifier(&ppr_nb);
965
966         flush_workqueue(iommu_wq);
967
968         /*
969          * The loop below might call flush_workqueue(), so call
970          * destroy_workqueue() after it
971          */
972         spin_lock_irqsave(&state_lock, flags);
973
974         list_for_each_entry_safe(dev_state, next, &state_list, list) {
975                 WARN_ON_ONCE(1);
976
977                 put_device_state(dev_state);
978                 list_del(&dev_state->list);
979                 list_add_tail(&dev_state->list, &freelist);
980         }
981
982         spin_unlock_irqrestore(&state_lock, flags);
983
984         /*
985          * Since free_device_state waits on the count to be zero,
986          * we need to free dev_state outside the spinlock.
987          */
988         list_for_each_entry_safe(dev_state, next, &freelist, list) {
989                 list_del(&dev_state->list);
990                 free_device_state(dev_state);
991         }
992
993         destroy_workqueue(iommu_wq);
994 }
995
996 module_init(amd_iommu_v2_init);
997 module_exit(amd_iommu_v2_exit);