1 // SPDX-License-Identifier: GPL-2.0-only
2
3 #include <linux/bitfield.h>
4 #include <linux/list.h>
5 #include <linux/netdevice.h>
6 #include <linux/xarray.h>
7 #include <net/net_namespace.h>
8 #include <net/psp.h>
9 #include <net/udp.h>
10
11 #include "psp.h"
12 #include "psp-nl-gen.h"
13
14 DEFINE_XARRAY_ALLOC1(psp_devs);
15 struct mutex psp_devs_lock;
16
17 /**
18 * DOC: PSP locking
19 *
20 * psp_devs_lock protects the psp_devs xarray.
21 * Ordering is take the psp_devs_lock and then the instance lock.
22 * Each instance is protected by RCU, and has a refcount.
23 * When driver unregisters the instance gets flushed, but struct sticks around.
24 */
25
26 /**
27 * psp_dev_check_access() - check if user in a given net ns can access PSP dev
28 * @psd: PSP device structure user is trying to access
29 * @net: net namespace user is in
30 *
31 * Return: 0 if PSP device should be visible in @net, errno otherwise.
32 */
psp_dev_check_access(struct psp_dev * psd,struct net * net)33 int psp_dev_check_access(struct psp_dev *psd, struct net *net)
34 {
35 if (dev_net(psd->main_netdev) == net)
36 return 0;
37 return -ENOENT;
38 }
39
40 /**
41 * psp_dev_create() - create and register PSP device
42 * @netdev: main netdevice
43 * @psd_ops: driver callbacks
44 * @psd_caps: device capabilities
45 * @priv_ptr: back-pointer to driver private data
46 *
47 * Return: pointer to allocated PSP device, or ERR_PTR.
48 */
49 struct psp_dev *
psp_dev_create(struct net_device * netdev,struct psp_dev_ops * psd_ops,struct psp_dev_caps * psd_caps,void * priv_ptr)50 psp_dev_create(struct net_device *netdev,
51 struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
52 void *priv_ptr)
53 {
54 struct psp_dev *psd;
55 static u32 last_id;
56 int err;
57
58 if (WARN_ON(!psd_caps->versions ||
59 !psd_ops->set_config ||
60 !psd_ops->key_rotate ||
61 !psd_ops->rx_spi_alloc ||
62 !psd_ops->tx_key_add ||
63 !psd_ops->tx_key_del ||
64 !psd_ops->get_stats))
65 return ERR_PTR(-EINVAL);
66
67 psd = kzalloc_obj(*psd);
68 if (!psd)
69 return ERR_PTR(-ENOMEM);
70
71 psd->main_netdev = netdev;
72 psd->ops = psd_ops;
73 psd->caps = psd_caps;
74 psd->drv_priv = priv_ptr;
75
76 mutex_init(&psd->lock);
77 INIT_LIST_HEAD(&psd->active_assocs);
78 INIT_LIST_HEAD(&psd->prev_assocs);
79 INIT_LIST_HEAD(&psd->stale_assocs);
80 refcount_set(&psd->refcnt, 1);
81
82 mutex_lock(&psp_devs_lock);
83 err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
84 &last_id, GFP_KERNEL);
85 if (err) {
86 mutex_unlock(&psp_devs_lock);
87 kfree(psd);
88 return ERR_PTR(err);
89 }
90 mutex_lock(&psd->lock);
91 mutex_unlock(&psp_devs_lock);
92
93 psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
94
95 rcu_assign_pointer(netdev->psp_dev, psd);
96
97 mutex_unlock(&psd->lock);
98
99 return psd;
100 }
101 EXPORT_SYMBOL(psp_dev_create);
102
psp_dev_free(struct psp_dev * psd)103 void psp_dev_free(struct psp_dev *psd)
104 {
105 mutex_lock(&psp_devs_lock);
106 xa_erase(&psp_devs, psd->id);
107 mutex_unlock(&psp_devs_lock);
108
109 mutex_destroy(&psd->lock);
110 kfree_rcu(psd, rcu);
111 }
112
113 /**
114 * psp_dev_unregister() - unregister PSP device
115 * @psd: PSP device structure
116 */
psp_dev_unregister(struct psp_dev * psd)117 void psp_dev_unregister(struct psp_dev *psd)
118 {
119 struct psp_assoc *pas, *next;
120
121 mutex_lock(&psp_devs_lock);
122 mutex_lock(&psd->lock);
123
124 psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
125
126 /* Wait until psp_dev_free() to call xa_erase() to prevent a
127 * different psd from being added to the xarray with this id, while
128 * there are still references to this psd being held.
129 */
130 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
131 mutex_unlock(&psp_devs_lock);
132
133 list_splice_init(&psd->active_assocs, &psd->prev_assocs);
134 list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
135 list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
136 psp_dev_tx_key_del(psd, pas);
137
138 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
139
140 psd->ops = NULL;
141 psd->drv_priv = NULL;
142
143 mutex_unlock(&psd->lock);
144
145 psp_dev_put(psd);
146 }
147 EXPORT_SYMBOL(psp_dev_unregister);
148
psp_key_size(u32 version)149 unsigned int psp_key_size(u32 version)
150 {
151 switch (version) {
152 case PSP_VERSION_HDR0_AES_GCM_128:
153 case PSP_VERSION_HDR0_AES_GMAC_128:
154 return 16;
155 case PSP_VERSION_HDR0_AES_GCM_256:
156 case PSP_VERSION_HDR0_AES_GMAC_256:
157 return 32;
158 default:
159 return 0;
160 }
161 }
162 EXPORT_SYMBOL(psp_key_size);
163
psp_write_headers(struct net * net,struct sk_buff * skb,__be32 spi,u8 ver,unsigned int udp_len,__be16 sport)164 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
165 u8 ver, unsigned int udp_len, __be16 sport)
166 {
167 struct udphdr *uh = udp_hdr(skb);
168 struct psphdr *psph = (struct psphdr *)(uh + 1);
169
170 uh->dest = htons(PSP_DEFAULT_UDP_PORT);
171 uh->source = udp_flow_src_port(net, skb, 0, 0, false);
172 uh->check = 0;
173 uh->len = htons(udp_len);
174
175 psph->nexthdr = IPPROTO_TCP;
176 psph->hdrlen = PSP_HDRLEN_NOOPT;
177 psph->crypt_offset = 0;
178 psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
179 FIELD_PREP(PSPHDR_VERFL_ONE, 1);
180 psph->spi = spi;
181 memset(&psph->iv, 0, sizeof(psph->iv));
182 }
183
184 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
185 * them in.
186 */
psp_dev_encapsulate(struct net * net,struct sk_buff * skb,__be32 spi,u8 ver,__be16 sport)187 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
188 u8 ver, __be16 sport)
189 {
190 u32 network_len = skb_network_header_len(skb);
191 u32 ethr_len = skb_mac_header_len(skb);
192 u32 bufflen = ethr_len + network_len;
193
194 if (skb_cow_head(skb, PSP_ENCAP_HLEN))
195 return false;
196
197 skb_push(skb, PSP_ENCAP_HLEN);
198 skb->mac_header -= PSP_ENCAP_HLEN;
199 skb->network_header -= PSP_ENCAP_HLEN;
200 skb->transport_header -= PSP_ENCAP_HLEN;
201 memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
202
203 if (skb->protocol == htons(ETH_P_IP)) {
204 ip_hdr(skb)->protocol = IPPROTO_UDP;
205 be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
206 ip_hdr(skb)->check = 0;
207 ip_hdr(skb)->check =
208 ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
209 } else if (skb->protocol == htons(ETH_P_IPV6)) {
210 ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
211 be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
212 } else {
213 return false;
214 }
215
216 skb_set_inner_ipproto(skb, IPPROTO_TCP);
217 skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
218 PSP_ENCAP_HLEN);
219 skb->encapsulation = 1;
220 psp_write_headers(net, skb, spi, ver,
221 skb->len - skb_transport_offset(skb), sport);
222
223 return true;
224 }
225 EXPORT_SYMBOL(psp_dev_encapsulate);
226
227 /* Receive handler for PSP packets.
228 *
229 * Presently it accepts only already-authenticated packets and does not
230 * support optional fields, such as virtualization cookies. The caller should
231 * ensure that skb->data is pointing to the mac header, and that skb->mac_len
232 * is set. This function does not currently adjust skb->csum (CHECKSUM_COMPLETE
233 * is not supported).
234 */
psp_dev_rcv(struct sk_buff * skb,u16 dev_id,u8 generation,bool strip_icv)235 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
236 {
237 int l2_hlen = 0, l3_hlen, encap;
238 struct psp_skb_ext *pse;
239 struct psphdr *psph;
240 struct ethhdr *eth;
241 struct udphdr *uh;
242 __be16 proto;
243 bool is_udp;
244
245 eth = (struct ethhdr *)skb->data;
246 proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
247 if (proto == htons(ETH_P_IP))
248 l3_hlen = sizeof(struct iphdr);
249 else if (proto == htons(ETH_P_IPV6))
250 l3_hlen = sizeof(struct ipv6hdr);
251 else
252 return -EINVAL;
253
254 if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
255 return -EINVAL;
256
257 if (proto == htons(ETH_P_IP)) {
258 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
259
260 is_udp = iph->protocol == IPPROTO_UDP;
261 l3_hlen = iph->ihl * 4;
262 if (l3_hlen != sizeof(struct iphdr) &&
263 !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
264 return -EINVAL;
265 } else {
266 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
267
268 is_udp = ipv6h->nexthdr == IPPROTO_UDP;
269 }
270
271 if (unlikely(!is_udp))
272 return -EINVAL;
273
274 uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
275 if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
276 return -EINVAL;
277
278 pse = skb_ext_add(skb, SKB_EXT_PSP);
279 if (!pse)
280 return -EINVAL;
281
282 psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
283 sizeof(struct udphdr));
284 pse->spi = psph->spi;
285 pse->dev_id = dev_id;
286 pse->generation = generation;
287 pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
288
289 encap = PSP_ENCAP_HLEN;
290 encap += strip_icv ? PSP_TRL_SIZE : 0;
291
292 if (proto == htons(ETH_P_IP)) {
293 struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
294
295 iph->protocol = psph->nexthdr;
296 iph->tot_len = htons(ntohs(iph->tot_len) - encap);
297 iph->check = 0;
298 iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
299 } else {
300 struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
301
302 ipv6h->nexthdr = psph->nexthdr;
303 ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
304 }
305
306 memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
307 skb_pull(skb, PSP_ENCAP_HLEN);
308
309 if (strip_icv)
310 pskb_trim(skb, skb->len - PSP_TRL_SIZE);
311
312 return 0;
313 }
314 EXPORT_SYMBOL(psp_dev_rcv);
315
psp_init(void)316 static int __init psp_init(void)
317 {
318 mutex_init(&psp_devs_lock);
319
320 return genl_register_family(&psp_nl_family);
321 }
322
323 subsys_initcall(psp_init);
324