1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (C) 2024 Advanced Micro Devices, Inc.
4 */
5
6 #define pr_fmt(fmt) "AMD-Vi: " fmt
7 #define dev_fmt(fmt) pr_fmt(fmt)
8
9 #include <linux/iommu.h>
10 #include <linux/mm_types.h>
11
12 #include "amd_iommu.h"
13
is_pasid_enabled(struct iommu_dev_data * dev_data)14 static inline bool is_pasid_enabled(struct iommu_dev_data *dev_data)
15 {
16 if (dev_data->pasid_enabled && dev_data->max_pasids &&
17 dev_data->gcr3_info.gcr3_tbl != NULL)
18 return true;
19
20 return false;
21 }
22
is_pasid_valid(struct iommu_dev_data * dev_data,ioasid_t pasid)23 static inline bool is_pasid_valid(struct iommu_dev_data *dev_data,
24 ioasid_t pasid)
25 {
26 if (pasid > 0 && pasid < dev_data->max_pasids)
27 return true;
28
29 return false;
30 }
31
remove_dev_pasid(struct pdom_dev_data * pdom_dev_data)32 static void remove_dev_pasid(struct pdom_dev_data *pdom_dev_data)
33 {
34 /* Update GCR3 table and flush IOTLB */
35 amd_iommu_clear_gcr3(pdom_dev_data->dev_data, pdom_dev_data->pasid);
36
37 list_del(&pdom_dev_data->list);
38 kfree(pdom_dev_data);
39 }
40
41 /* Clear PASID from device GCR3 table and remove pdom_dev_data from list */
remove_pdom_dev_pasid(struct protection_domain * pdom,struct device * dev,ioasid_t pasid)42 static void remove_pdom_dev_pasid(struct protection_domain *pdom,
43 struct device *dev, ioasid_t pasid)
44 {
45 struct pdom_dev_data *pdom_dev_data;
46 struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
47
48 lockdep_assert_held(&pdom->lock);
49
50 for_each_pdom_dev_data(pdom_dev_data, pdom) {
51 if (pdom_dev_data->dev_data == dev_data &&
52 pdom_dev_data->pasid == pasid) {
53 remove_dev_pasid(pdom_dev_data);
54 break;
55 }
56 }
57 }
58
sva_arch_invalidate_secondary_tlbs(struct mmu_notifier * mn,struct mm_struct * mm,unsigned long start,unsigned long end)59 static void sva_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
60 struct mm_struct *mm,
61 unsigned long start, unsigned long end)
62 {
63 struct pdom_dev_data *pdom_dev_data;
64 struct protection_domain *sva_pdom;
65 unsigned long flags;
66
67 sva_pdom = container_of(mn, struct protection_domain, mn);
68
69 spin_lock_irqsave(&sva_pdom->lock, flags);
70
71 for_each_pdom_dev_data(pdom_dev_data, sva_pdom) {
72 amd_iommu_dev_flush_pasid_pages(pdom_dev_data->dev_data,
73 pdom_dev_data->pasid,
74 start, end - start);
75 }
76
77 spin_unlock_irqrestore(&sva_pdom->lock, flags);
78 }
79
sva_mn_release(struct mmu_notifier * mn,struct mm_struct * mm)80 static void sva_mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
81 {
82 struct pdom_dev_data *pdom_dev_data, *next;
83 struct protection_domain *sva_pdom;
84 unsigned long flags;
85
86 sva_pdom = container_of(mn, struct protection_domain, mn);
87
88 spin_lock_irqsave(&sva_pdom->lock, flags);
89
90 /* Assume dev_data_list contains same PASID with different devices */
91 for_each_pdom_dev_data_safe(pdom_dev_data, next, sva_pdom)
92 remove_dev_pasid(pdom_dev_data);
93
94 spin_unlock_irqrestore(&sva_pdom->lock, flags);
95 }
96
97 static const struct mmu_notifier_ops sva_mn = {
98 .arch_invalidate_secondary_tlbs = sva_arch_invalidate_secondary_tlbs,
99 .release = sva_mn_release,
100 };
101
iommu_sva_set_dev_pasid(struct iommu_domain * domain,struct device * dev,ioasid_t pasid)102 int iommu_sva_set_dev_pasid(struct iommu_domain *domain,
103 struct device *dev, ioasid_t pasid)
104 {
105 struct pdom_dev_data *pdom_dev_data;
106 struct protection_domain *sva_pdom = to_pdomain(domain);
107 struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
108 unsigned long flags;
109 int ret = -EINVAL;
110
111 /* PASID zero is used for requests from the I/O device without PASID */
112 if (!is_pasid_valid(dev_data, pasid))
113 return ret;
114
115 /* Make sure PASID is enabled */
116 if (!is_pasid_enabled(dev_data))
117 return ret;
118
119 /* Add PASID to protection domain pasid list */
120 pdom_dev_data = kzalloc(sizeof(*pdom_dev_data), GFP_KERNEL);
121 if (pdom_dev_data == NULL)
122 return ret;
123
124 pdom_dev_data->pasid = pasid;
125 pdom_dev_data->dev_data = dev_data;
126
127 spin_lock_irqsave(&sva_pdom->lock, flags);
128
129 /* Setup GCR3 table */
130 ret = amd_iommu_set_gcr3(dev_data, pasid,
131 iommu_virt_to_phys(domain->mm->pgd));
132 if (ret) {
133 kfree(pdom_dev_data);
134 goto out_unlock;
135 }
136
137 list_add(&pdom_dev_data->list, &sva_pdom->dev_data_list);
138
139 out_unlock:
140 spin_unlock_irqrestore(&sva_pdom->lock, flags);
141 return ret;
142 }
143
amd_iommu_remove_dev_pasid(struct device * dev,ioasid_t pasid,struct iommu_domain * domain)144 void amd_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
145 struct iommu_domain *domain)
146 {
147 struct protection_domain *sva_pdom;
148 unsigned long flags;
149
150 if (!is_pasid_valid(dev_iommu_priv_get(dev), pasid))
151 return;
152
153 sva_pdom = to_pdomain(domain);
154
155 spin_lock_irqsave(&sva_pdom->lock, flags);
156
157 /* Remove PASID from dev_data_list */
158 remove_pdom_dev_pasid(sva_pdom, dev, pasid);
159
160 spin_unlock_irqrestore(&sva_pdom->lock, flags);
161 }
162
iommu_sva_domain_free(struct iommu_domain * domain)163 static void iommu_sva_domain_free(struct iommu_domain *domain)
164 {
165 struct protection_domain *sva_pdom = to_pdomain(domain);
166
167 if (sva_pdom->mn.ops)
168 mmu_notifier_unregister(&sva_pdom->mn, domain->mm);
169
170 amd_iommu_domain_free(domain);
171 }
172
173 static const struct iommu_domain_ops amd_sva_domain_ops = {
174 .set_dev_pasid = iommu_sva_set_dev_pasid,
175 .free = iommu_sva_domain_free
176 };
177
amd_iommu_domain_alloc_sva(struct device * dev,struct mm_struct * mm)178 struct iommu_domain *amd_iommu_domain_alloc_sva(struct device *dev,
179 struct mm_struct *mm)
180 {
181 struct protection_domain *pdom;
182 int ret;
183
184 pdom = protection_domain_alloc(IOMMU_DOMAIN_SVA, dev_to_node(dev));
185 if (!pdom)
186 return ERR_PTR(-ENOMEM);
187
188 pdom->domain.ops = &amd_sva_domain_ops;
189 pdom->mn.ops = &sva_mn;
190
191 ret = mmu_notifier_register(&pdom->mn, mm);
192 if (ret) {
193 protection_domain_free(pdom);
194 return ERR_PTR(ret);
195 }
196
197 return &pdom->domain;
198 }
199