xref: /linux/net/psp/psp_main.c (revision fcee7d82f27d6a8b1ddc5bbefda59b4e441e9bc0)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/bitfield.h>
4 #include <linux/list.h>
5 #include <linux/netdevice.h>
6 #include <linux/xarray.h>
7 #include <net/net_namespace.h>
8 #include <net/psp.h>
9 #include <net/udp.h>
10 
11 #include "psp.h"
12 #include "psp-nl-gen.h"
13 
14 DEFINE_XARRAY_ALLOC1(psp_devs);
15 struct mutex psp_devs_lock;
16 
17 /**
18  * DOC: PSP locking
19  *
20  * psp_devs_lock protects the psp_devs xarray.
21  * Ordering is take the psp_devs_lock and then the instance lock.
22  * Each instance is protected by RCU, and has a refcount.
23  * When driver unregisters the instance gets flushed, but struct sticks around.
24  */
25 
26 /**
27  * psp_dev_check_access() - check if user in a given net ns can access PSP dev
28  * @psd:	PSP device structure user is trying to access
29  * @net:	net namespace user is in
30  *
31  * Return: 0 if PSP device should be visible in @net, errno otherwise.
32  */
psp_dev_check_access(struct psp_dev * psd,struct net * net)33 int psp_dev_check_access(struct psp_dev *psd, struct net *net)
34 {
35 	if (dev_net(psd->main_netdev) == net)
36 		return 0;
37 	return -ENOENT;
38 }
39 
40 /**
41  * psp_dev_create() - create and register PSP device
42  * @netdev:	main netdevice
43  * @psd_ops:	driver callbacks
44  * @psd_caps:	device capabilities
45  * @priv_ptr:	back-pointer to driver private data
46  *
47  * Return: pointer to allocated PSP device, or ERR_PTR.
48  */
49 struct psp_dev *
psp_dev_create(struct net_device * netdev,struct psp_dev_ops * psd_ops,struct psp_dev_caps * psd_caps,void * priv_ptr)50 psp_dev_create(struct net_device *netdev,
51 	       struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
52 	       void *priv_ptr)
53 {
54 	struct psp_dev *psd;
55 	static u32 last_id;
56 	int err;
57 
58 	if (WARN_ON(!psd_caps->versions ||
59 		    !psd_ops->set_config ||
60 		    !psd_ops->key_rotate ||
61 		    !psd_ops->rx_spi_alloc ||
62 		    !psd_ops->tx_key_add ||
63 		    !psd_ops->tx_key_del ||
64 		    !psd_ops->get_stats))
65 		return ERR_PTR(-EINVAL);
66 
67 	psd = kzalloc_obj(*psd);
68 	if (!psd)
69 		return ERR_PTR(-ENOMEM);
70 
71 	psd->main_netdev = netdev;
72 	psd->ops = psd_ops;
73 	psd->caps = psd_caps;
74 	psd->drv_priv = priv_ptr;
75 
76 	mutex_init(&psd->lock);
77 	INIT_LIST_HEAD(&psd->active_assocs);
78 	INIT_LIST_HEAD(&psd->prev_assocs);
79 	INIT_LIST_HEAD(&psd->stale_assocs);
80 	refcount_set(&psd->refcnt, 1);
81 
82 	mutex_lock(&psp_devs_lock);
83 	err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
84 			      &last_id, GFP_KERNEL);
85 	if (err) {
86 		mutex_unlock(&psp_devs_lock);
87 		kfree(psd);
88 		return ERR_PTR(err);
89 	}
90 	mutex_lock(&psd->lock);
91 	mutex_unlock(&psp_devs_lock);
92 
93 	psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
94 
95 	rcu_assign_pointer(netdev->psp_dev, psd);
96 
97 	mutex_unlock(&psd->lock);
98 
99 	return psd;
100 }
101 EXPORT_SYMBOL(psp_dev_create);
102 
psp_dev_free(struct psp_dev * psd)103 void psp_dev_free(struct psp_dev *psd)
104 {
105 	mutex_lock(&psp_devs_lock);
106 	xa_erase(&psp_devs, psd->id);
107 	mutex_unlock(&psp_devs_lock);
108 
109 	mutex_destroy(&psd->lock);
110 	kfree_rcu(psd, rcu);
111 }
112 
113 /**
114  * psp_dev_unregister() - unregister PSP device
115  * @psd:	PSP device structure
116  */
psp_dev_unregister(struct psp_dev * psd)117 void psp_dev_unregister(struct psp_dev *psd)
118 {
119 	struct psp_assoc *pas, *next;
120 
121 	mutex_lock(&psp_devs_lock);
122 	mutex_lock(&psd->lock);
123 
124 	psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
125 
126 	/* Wait until psp_dev_free() to call xa_erase() to prevent a
127 	 * different psd from being added to the xarray with this id, while
128 	 * there are still references to this psd being held.
129 	 */
130 	xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
131 	mutex_unlock(&psp_devs_lock);
132 
133 	list_splice_init(&psd->active_assocs, &psd->prev_assocs);
134 	list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
135 	list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
136 		psp_dev_tx_key_del(psd, pas);
137 
138 	rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
139 
140 	psd->ops = NULL;
141 	psd->drv_priv = NULL;
142 
143 	mutex_unlock(&psd->lock);
144 
145 	psp_dev_put(psd);
146 }
147 EXPORT_SYMBOL(psp_dev_unregister);
148 
psp_key_size(u32 version)149 unsigned int psp_key_size(u32 version)
150 {
151 	switch (version) {
152 	case PSP_VERSION_HDR0_AES_GCM_128:
153 	case PSP_VERSION_HDR0_AES_GMAC_128:
154 		return 16;
155 	case PSP_VERSION_HDR0_AES_GCM_256:
156 	case PSP_VERSION_HDR0_AES_GMAC_256:
157 		return 32;
158 	default:
159 		return 0;
160 	}
161 }
162 EXPORT_SYMBOL(psp_key_size);
163 
psp_write_headers(struct net * net,struct sk_buff * skb,__be32 spi,u8 ver,unsigned int udp_len,__be16 sport)164 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
165 			      u8 ver, unsigned int udp_len, __be16 sport)
166 {
167 	struct udphdr *uh = udp_hdr(skb);
168 	struct psphdr *psph = (struct psphdr *)(uh + 1);
169 	const struct sock *sk = skb->sk;
170 
171 	uh->dest = htons(PSP_DEFAULT_UDP_PORT);
172 
173 	/* A bit of theory: Selection of the source port.
174 	 *
175 	 * We need some entropy, so that multiple flows use different
176 	 * source ports for better RSS spreading at the receiver.
177 	 *
178 	 * We also need that all packets belonging to one TCP flow
179 	 * use the same source port through their duration,
180 	 * so that all these packets land in the same receive queue.
181 	 *
182 	 * udp_flow_src_port() is using sk_txhash, inherited from
183 	 * skb_set_hash_from_sk() call in __tcp_transmit_skb().
184 	 * This field is subject to reshuffling, thanks to
185 	 * sk_rethink_txhash() calls in various TCP functions.
186 	 *
187 	 * Instead, use sk->sk_hash which is constant through
188 	 * the whole flow duration.
189 	 */
190 	if (likely(sk)) {
191 		u32 hash = sk->sk_hash;
192 		int min, max;
193 
194 		/* These operations are cheap, no need to cache the result
195 		 * in another socket field.
196 		 */
197 		inet_get_local_port_range(net, &min, &max);
198 		/* Since this is being sent on the wire obfuscate hash a bit
199 		 * to minimize possibility that any useful information to an
200 		 * attacker is leaked. Only upper 16 bits are relevant in the
201 		 * computation for 16 bit port value because we use a
202 		 * reciprocal divide.
203 		 */
204 		hash ^= hash << 16;
205 		uh->source = htons(reciprocal_scale(hash, max - min + 1) + min);
206 	} else {
207 		uh->source = udp_flow_src_port(net, skb, 0, 0, false);
208 	}
209 	uh->check = 0;
210 	uh->len = htons(udp_len);
211 
212 	psph->nexthdr = IPPROTO_TCP;
213 	psph->hdrlen = PSP_HDRLEN_NOOPT;
214 	psph->crypt_offset = 0;
215 	psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
216 		      FIELD_PREP(PSPHDR_VERFL_ONE, 1);
217 	psph->spi = spi;
218 	memset(&psph->iv, 0, sizeof(psph->iv));
219 }
220 
221 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
222  * them in.
223  */
psp_dev_encapsulate(struct net * net,struct sk_buff * skb,__be32 spi,u8 ver,__be16 sport)224 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
225 			 u8 ver, __be16 sport)
226 {
227 	u32 network_len = skb_network_header_len(skb);
228 	u32 ethr_len = skb_mac_header_len(skb);
229 	u32 bufflen = ethr_len + network_len;
230 
231 	if (skb_cow_head(skb, PSP_ENCAP_HLEN))
232 		return false;
233 
234 	skb_push(skb, PSP_ENCAP_HLEN);
235 	skb->mac_header		-= PSP_ENCAP_HLEN;
236 	skb->network_header	-= PSP_ENCAP_HLEN;
237 	skb->transport_header	-= PSP_ENCAP_HLEN;
238 	memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
239 
240 	if (skb->protocol == htons(ETH_P_IP)) {
241 		ip_hdr(skb)->protocol = IPPROTO_UDP;
242 		be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
243 		ip_hdr(skb)->check = 0;
244 		ip_hdr(skb)->check =
245 			ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
246 	} else if (skb->protocol == htons(ETH_P_IPV6)) {
247 		ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
248 		be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
249 	} else {
250 		return false;
251 	}
252 
253 	skb_set_inner_ipproto(skb, IPPROTO_TCP);
254 	skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
255 						    PSP_ENCAP_HLEN);
256 	skb->encapsulation = 1;
257 	psp_write_headers(net, skb, spi, ver,
258 			  skb->len - skb_transport_offset(skb), sport);
259 
260 	return true;
261 }
262 EXPORT_SYMBOL(psp_dev_encapsulate);
263 
264 /* Receive handler for PSP packets.
265  *
266  * Accepts only already-authenticated packets. The full PSP header is
267  * stripped according to psph->hdrlen; any optional fields it advertises
268  * (virtualization cookies, etc.) are ignored and discarded along with the
269  * rest of the header. The caller should ensure that skb->data is pointing
270  * to the mac header, and that skb->mac_len is set. This function does not
271  * currently adjust skb->csum (CHECKSUM_COMPLETE is not supported).
272  */
psp_dev_rcv(struct sk_buff * skb,u16 dev_id,u8 generation,bool strip_icv)273 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
274 {
275 	int l2_hlen = 0, l3_hlen, encap, psp_hlen;
276 	struct psp_skb_ext *pse;
277 	struct psphdr *psph;
278 	struct ethhdr *eth;
279 	struct udphdr *uh;
280 	__be16 proto;
281 	bool is_udp;
282 
283 	eth = (struct ethhdr *)skb->data;
284 	proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
285 	if (proto == htons(ETH_P_IP))
286 		l3_hlen = sizeof(struct iphdr);
287 	else if (proto == htons(ETH_P_IPV6))
288 		l3_hlen = sizeof(struct ipv6hdr);
289 	else
290 		return -EINVAL;
291 
292 	if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
293 		return -EINVAL;
294 
295 	if (proto == htons(ETH_P_IP)) {
296 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
297 
298 		is_udp = iph->protocol == IPPROTO_UDP;
299 		l3_hlen = iph->ihl * 4;
300 		if (l3_hlen != sizeof(struct iphdr) &&
301 		    !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
302 			return -EINVAL;
303 	} else {
304 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
305 
306 		is_udp = ipv6h->nexthdr == IPPROTO_UDP;
307 	}
308 
309 	if (unlikely(!is_udp))
310 		return -EINVAL;
311 
312 	uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
313 	if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
314 		return -EINVAL;
315 
316 	psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
317 				 sizeof(struct udphdr));
318 
319 	/* Strip the full PSP header per psph->hdrlen; VC/options are pulled
320 	 * into the linear region only so they can be discarded with the
321 	 * rest of the header.
322 	 */
323 	psp_hlen = (psph->hdrlen + 1) * 8;
324 
325 	if (unlikely(psp_hlen < sizeof(struct psphdr)))
326 		return -EINVAL;
327 
328 	if (psp_hlen > sizeof(struct psphdr) &&
329 	    !pskb_may_pull(skb, l2_hlen + l3_hlen +
330 				sizeof(struct udphdr) + psp_hlen))
331 		return -EINVAL;
332 
333 	psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
334 				 sizeof(struct udphdr));
335 
336 	pse = skb_ext_add(skb, SKB_EXT_PSP);
337 	if (!pse)
338 		return -EINVAL;
339 
340 	pse->spi = psph->spi;
341 	pse->dev_id = dev_id;
342 	pse->generation = generation;
343 	pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
344 
345 	encap = sizeof(struct udphdr) + psp_hlen;
346 	encap += strip_icv ? PSP_TRL_SIZE : 0;
347 
348 	if (proto == htons(ETH_P_IP)) {
349 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
350 
351 		iph->protocol = psph->nexthdr;
352 		iph->tot_len = htons(ntohs(iph->tot_len) - encap);
353 		iph->check = 0;
354 		iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
355 	} else {
356 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
357 
358 		ipv6h->nexthdr = psph->nexthdr;
359 		ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
360 	}
361 
362 	memmove(skb->data + sizeof(struct udphdr) + psp_hlen,
363 		skb->data, l2_hlen + l3_hlen);
364 	skb_pull(skb, sizeof(struct udphdr) + psp_hlen);
365 
366 	if (strip_icv)
367 		pskb_trim(skb, skb->len - PSP_TRL_SIZE);
368 
369 	return 0;
370 }
371 EXPORT_SYMBOL(psp_dev_rcv);
372 
psp_init(void)373 static int __init psp_init(void)
374 {
375 	mutex_init(&psp_devs_lock);
376 
377 	return genl_register_family(&psp_nl_family);
378 }
379 
380 subsys_initcall(psp_init);
381