xref: /linux/net/psp/psp_main.c (revision 672beab06656f2f1bda4708cda2b9af61c58a7ea)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/bitfield.h>
4 #include <linux/list.h>
5 #include <linux/netdevice.h>
6 #include <linux/xarray.h>
7 #include <net/net_namespace.h>
8 #include <net/psp.h>
9 #include <net/udp.h>
10 
11 #include "psp.h"
12 #include "psp-nl-gen.h"
13 
14 DEFINE_XARRAY_ALLOC1(psp_devs);
15 struct mutex psp_devs_lock;
16 
17 /**
18  * DOC: PSP locking
19  *
20  * psp_devs_lock protects the psp_devs xarray.
21  * Ordering is take the psp_devs_lock and then the instance lock.
22  * Each instance is protected by RCU, and has a refcount.
23  * When driver unregisters the instance gets flushed, but struct sticks around.
24  */
25 
26 /**
27  * psp_dev_check_access() - check if user in a given net ns can access PSP dev
28  * @psd:	PSP device structure user is trying to access
29  * @net:	net namespace user is in
30  *
31  * Return: 0 if PSP device should be visible in @net, errno otherwise.
32  */
33 int psp_dev_check_access(struct psp_dev *psd, struct net *net)
34 {
35 	if (dev_net(psd->main_netdev) == net)
36 		return 0;
37 	return -ENOENT;
38 }
39 
40 /**
41  * psp_dev_create() - create and register PSP device
42  * @netdev:	main netdevice
43  * @psd_ops:	driver callbacks
44  * @psd_caps:	device capabilities
45  * @priv_ptr:	back-pointer to driver private data
46  *
47  * Return: pointer to allocated PSP device, or ERR_PTR.
48  */
49 struct psp_dev *
50 psp_dev_create(struct net_device *netdev,
51 	       struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
52 	       void *priv_ptr)
53 {
54 	struct psp_dev *psd;
55 	static u32 last_id;
56 	int err;
57 
58 	if (WARN_ON(!psd_caps->versions ||
59 		    !psd_ops->set_config ||
60 		    !psd_ops->key_rotate ||
61 		    !psd_ops->rx_spi_alloc ||
62 		    !psd_ops->tx_key_add ||
63 		    !psd_ops->tx_key_del))
64 		return ERR_PTR(-EINVAL);
65 
66 	psd = kzalloc(sizeof(*psd), GFP_KERNEL);
67 	if (!psd)
68 		return ERR_PTR(-ENOMEM);
69 
70 	psd->main_netdev = netdev;
71 	psd->ops = psd_ops;
72 	psd->caps = psd_caps;
73 	psd->drv_priv = priv_ptr;
74 
75 	mutex_init(&psd->lock);
76 	INIT_LIST_HEAD(&psd->active_assocs);
77 	INIT_LIST_HEAD(&psd->prev_assocs);
78 	INIT_LIST_HEAD(&psd->stale_assocs);
79 	refcount_set(&psd->refcnt, 1);
80 
81 	mutex_lock(&psp_devs_lock);
82 	err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
83 			      &last_id, GFP_KERNEL);
84 	if (err) {
85 		mutex_unlock(&psp_devs_lock);
86 		kfree(psd);
87 		return ERR_PTR(err);
88 	}
89 	mutex_lock(&psd->lock);
90 	mutex_unlock(&psp_devs_lock);
91 
92 	psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
93 
94 	rcu_assign_pointer(netdev->psp_dev, psd);
95 
96 	mutex_unlock(&psd->lock);
97 
98 	return psd;
99 }
100 EXPORT_SYMBOL(psp_dev_create);
101 
102 void psp_dev_free(struct psp_dev *psd)
103 {
104 	mutex_lock(&psp_devs_lock);
105 	xa_erase(&psp_devs, psd->id);
106 	mutex_unlock(&psp_devs_lock);
107 
108 	mutex_destroy(&psd->lock);
109 	kfree_rcu(psd, rcu);
110 }
111 
112 /**
113  * psp_dev_unregister() - unregister PSP device
114  * @psd:	PSP device structure
115  */
116 void psp_dev_unregister(struct psp_dev *psd)
117 {
118 	struct psp_assoc *pas, *next;
119 
120 	mutex_lock(&psp_devs_lock);
121 	mutex_lock(&psd->lock);
122 
123 	psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
124 
125 	/* Wait until psp_dev_free() to call xa_erase() to prevent a
126 	 * different psd from being added to the xarray with this id, while
127 	 * there are still references to this psd being held.
128 	 */
129 	xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
130 	mutex_unlock(&psp_devs_lock);
131 
132 	list_splice_init(&psd->active_assocs, &psd->prev_assocs);
133 	list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
134 	list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
135 		psp_dev_tx_key_del(psd, pas);
136 
137 	rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
138 
139 	psd->ops = NULL;
140 	psd->drv_priv = NULL;
141 
142 	mutex_unlock(&psd->lock);
143 
144 	psp_dev_put(psd);
145 }
146 EXPORT_SYMBOL(psp_dev_unregister);
147 
148 unsigned int psp_key_size(u32 version)
149 {
150 	switch (version) {
151 	case PSP_VERSION_HDR0_AES_GCM_128:
152 	case PSP_VERSION_HDR0_AES_GMAC_128:
153 		return 16;
154 	case PSP_VERSION_HDR0_AES_GCM_256:
155 	case PSP_VERSION_HDR0_AES_GMAC_256:
156 		return 32;
157 	default:
158 		return 0;
159 	}
160 }
161 EXPORT_SYMBOL(psp_key_size);
162 
163 static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
164 			      u8 ver, unsigned int udp_len, __be16 sport)
165 {
166 	struct udphdr *uh = udp_hdr(skb);
167 	struct psphdr *psph = (struct psphdr *)(uh + 1);
168 
169 	uh->dest = htons(PSP_DEFAULT_UDP_PORT);
170 	uh->source = udp_flow_src_port(net, skb, 0, 0, false);
171 	uh->check = 0;
172 	uh->len = htons(udp_len);
173 
174 	psph->nexthdr = IPPROTO_TCP;
175 	psph->hdrlen = PSP_HDRLEN_NOOPT;
176 	psph->crypt_offset = 0;
177 	psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
178 		      FIELD_PREP(PSPHDR_VERFL_ONE, 1);
179 	psph->spi = spi;
180 	memset(&psph->iv, 0, sizeof(psph->iv));
181 }
182 
183 /* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
184  * them in.
185  */
186 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
187 			 u8 ver, __be16 sport)
188 {
189 	u32 network_len = skb_network_header_len(skb);
190 	u32 ethr_len = skb_mac_header_len(skb);
191 	u32 bufflen = ethr_len + network_len;
192 
193 	if (skb_cow_head(skb, PSP_ENCAP_HLEN))
194 		return false;
195 
196 	skb_push(skb, PSP_ENCAP_HLEN);
197 	skb->mac_header		-= PSP_ENCAP_HLEN;
198 	skb->network_header	-= PSP_ENCAP_HLEN;
199 	skb->transport_header	-= PSP_ENCAP_HLEN;
200 	memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
201 
202 	if (skb->protocol == htons(ETH_P_IP)) {
203 		ip_hdr(skb)->protocol = IPPROTO_UDP;
204 		be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
205 		ip_hdr(skb)->check = 0;
206 		ip_hdr(skb)->check =
207 			ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
208 	} else if (skb->protocol == htons(ETH_P_IPV6)) {
209 		ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
210 		be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
211 	} else {
212 		return false;
213 	}
214 
215 	skb_set_inner_ipproto(skb, IPPROTO_TCP);
216 	skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
217 						    PSP_ENCAP_HLEN);
218 	skb->encapsulation = 1;
219 	psp_write_headers(net, skb, spi, ver,
220 			  skb->len - skb_transport_offset(skb), sport);
221 
222 	return true;
223 }
224 EXPORT_SYMBOL(psp_dev_encapsulate);
225 
226 /* Receive handler for PSP packets.
227  *
228  * Presently it accepts only already-authenticated packets and does not
229  * support optional fields, such as virtualization cookies. The caller should
230  * ensure that skb->data is pointing to the mac header, and that skb->mac_len
231  * is set.
232  */
233 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
234 {
235 	int l2_hlen = 0, l3_hlen, encap;
236 	struct psp_skb_ext *pse;
237 	struct psphdr *psph;
238 	struct ethhdr *eth;
239 	struct udphdr *uh;
240 	__be16 proto;
241 	bool is_udp;
242 
243 	eth = (struct ethhdr *)skb->data;
244 	proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
245 	if (proto == htons(ETH_P_IP))
246 		l3_hlen = sizeof(struct iphdr);
247 	else if (proto == htons(ETH_P_IPV6))
248 		l3_hlen = sizeof(struct ipv6hdr);
249 	else
250 		return -EINVAL;
251 
252 	if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
253 		return -EINVAL;
254 
255 	if (proto == htons(ETH_P_IP)) {
256 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
257 
258 		is_udp = iph->protocol == IPPROTO_UDP;
259 		l3_hlen = iph->ihl * 4;
260 		if (l3_hlen != sizeof(struct iphdr) &&
261 		    !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
262 			return -EINVAL;
263 	} else {
264 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
265 
266 		is_udp = ipv6h->nexthdr == IPPROTO_UDP;
267 	}
268 
269 	if (unlikely(!is_udp))
270 		return -EINVAL;
271 
272 	uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
273 	if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
274 		return -EINVAL;
275 
276 	pse = skb_ext_add(skb, SKB_EXT_PSP);
277 	if (!pse)
278 		return -EINVAL;
279 
280 	psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
281 				 sizeof(struct udphdr));
282 	pse->spi = psph->spi;
283 	pse->dev_id = dev_id;
284 	pse->generation = generation;
285 	pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
286 
287 	encap = PSP_ENCAP_HLEN;
288 	encap += strip_icv ? PSP_TRL_SIZE : 0;
289 
290 	if (proto == htons(ETH_P_IP)) {
291 		struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
292 
293 		iph->protocol = psph->nexthdr;
294 		iph->tot_len = htons(ntohs(iph->tot_len) - encap);
295 		iph->check = 0;
296 		iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
297 	} else {
298 		struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
299 
300 		ipv6h->nexthdr = psph->nexthdr;
301 		ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
302 	}
303 
304 	memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
305 	skb_pull(skb, PSP_ENCAP_HLEN);
306 
307 	if (strip_icv)
308 		pskb_trim(skb, skb->len - PSP_TRL_SIZE);
309 
310 	return 0;
311 }
312 EXPORT_SYMBOL(psp_dev_rcv);
313 
314 static int __init psp_init(void)
315 {
316 	mutex_init(&psp_devs_lock);
317 
318 	return genl_register_family(&psp_nl_family);
319 }
320 
321 subsys_initcall(psp_init);
322