xref: /linux/drivers/net/netdevsim/psp.c (revision dad4d4b92a9b9f0edb8c66deda049da1b62f6089)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <linux/ip.h>
4 #include <linux/skbuff.h>
5 #include <net/ip6_checksum.h>
6 #include <net/psp.h>
7 #include <net/sock.h>
8 
9 #include "netdevsim.h"
10 
11 void nsim_psp_handle_ext(struct sk_buff *skb, struct skb_ext *psp_ext)
12 {
13 	if (psp_ext)
14 		__skb_ext_set(skb, SKB_EXT_PSP, psp_ext);
15 }
16 
17 enum skb_drop_reason
18 nsim_do_psp(struct sk_buff *skb, struct netdevsim *ns,
19 	    struct netdevsim *peer_ns, struct skb_ext **psp_ext)
20 {
21 	enum skb_drop_reason rc = 0;
22 	struct psp_dev *peer_psd;
23 	struct psp_assoc *pas;
24 	struct net *net;
25 	int psp_len;
26 	void **ptr;
27 
28 	rcu_read_lock();
29 	pas = psp_skb_get_assoc_rcu(skb);
30 	if (!pas) {
31 		rc = SKB_NOT_DROPPED_YET;
32 		goto out_unlock;
33 	}
34 
35 	if (!skb_transport_header_was_set(skb)) {
36 		rc = SKB_DROP_REASON_PSP_OUTPUT;
37 		goto out_unlock;
38 	}
39 
40 	ptr = psp_assoc_drv_data(pas);
41 	if (*ptr != ns) {
42 		rc = SKB_DROP_REASON_PSP_OUTPUT;
43 		goto out_unlock;
44 	}
45 
46 	net = sock_net(skb->sk);
47 	if (!psp_dev_encapsulate(net, skb, pas->tx.spi, pas->version, 0)) {
48 		rc = SKB_DROP_REASON_PSP_OUTPUT;
49 		goto out_unlock;
50 	}
51 
52 	psp_len = skb->len - skb_inner_transport_offset(skb);
53 	atomic64_inc(&ns->psp.tx_packets);
54 	atomic64_add(psp_len, &ns->psp.tx_bytes);
55 
56 	/* Now pretend we just received this frame */
57 	peer_psd = rcu_dereference(peer_ns->psp.dev);
58 	if (peer_psd && peer_psd->config.versions & (1 << pas->version)) {
59 		bool strip_icv = false;
60 		u8 generation;
61 
62 		/* We cheat a bit and put the generation in the key.
63 		 * In real life if generation was too old, then decryption would
64 		 * fail. Here, we just make it so a bad key causes a bad
65 		 * generation too, and psp_sk_rx_policy_check() will fail.
66 		 */
67 		generation = pas->tx.key[0];
68 
69 		skb_ext_reset(skb);
70 		skb->mac_len = ETH_HLEN;
71 		if (psp_dev_rcv(skb, peer_psd->id, generation, strip_icv)) {
72 			rc = SKB_DROP_REASON_PSP_OUTPUT;
73 			goto out_unlock;
74 		}
75 
76 		*psp_ext = skb->extensions;
77 		refcount_inc(&(*psp_ext)->refcnt);
78 		skb->decrypted = 1;
79 
80 		atomic64_inc(&peer_ns->psp.rx_packets);
81 		atomic64_add(psp_len, &peer_ns->psp.rx_bytes);
82 	} else {
83 		struct ipv6hdr *ip6h __maybe_unused;
84 		struct iphdr *iph;
85 		struct udphdr *uh;
86 		__wsum csum;
87 
88 		/* Do not decapsulate. Receive the skb with the udp and psp
89 		 * headers still there as if this is a normal udp packet.
90 		 * psp_dev_encapsulate() sets udp checksum to 0, so we need to
91 		 * provide a valid checksum here, so the skb isn't dropped.
92 		 */
93 		uh = udp_hdr(skb);
94 		csum = skb_checksum(skb, skb_transport_offset(skb),
95 				    ntohs(uh->len), 0);
96 
97 		switch (skb->protocol) {
98 		case htons(ETH_P_IP):
99 			iph = ip_hdr(skb);
100 			uh->check = udp_v4_check(ntohs(uh->len), iph->saddr,
101 						 iph->daddr, csum);
102 			break;
103 #if IS_ENABLED(CONFIG_IPV6)
104 		case htons(ETH_P_IPV6):
105 			ip6h = ipv6_hdr(skb);
106 			uh->check = udp_v6_check(ntohs(uh->len), &ip6h->saddr,
107 						 &ip6h->daddr, csum);
108 			break;
109 #endif
110 		}
111 
112 		uh->check	= uh->check ?: CSUM_MANGLED_0;
113 		skb->ip_summed	= CHECKSUM_NONE;
114 	}
115 
116 out_unlock:
117 	rcu_read_unlock();
118 	return rc;
119 }
120 
121 static int
122 nsim_psp_set_config(struct psp_dev *psd, struct psp_dev_config *conf,
123 		    struct netlink_ext_ack *extack)
124 {
125 	return 0;
126 }
127 
128 static int
129 nsim_rx_spi_alloc(struct psp_dev *psd, u32 version,
130 		  struct psp_key_parsed *assoc,
131 		  struct netlink_ext_ack *extack)
132 {
133 	struct netdevsim *ns = psd->drv_priv;
134 	int i;
135 
136 	/* Check if incrementing the spi would change the phase bit */
137 	if ((ns->psp.spi & PSP_SPI_KEY_ID) == PSP_SPI_KEY_ID) {
138 		NL_SET_ERR_MSG(extack, "SPI space exhausted");
139 		return -ENOSPC;
140 	}
141 
142 	assoc->spi = cpu_to_be32(++ns->psp.spi);
143 	assoc->key[0] = psd->generation;
144 	for (i = 1; i < PSP_MAX_KEY; i++)
145 		assoc->key[i] = ns->psp.spi + i;
146 
147 	return 0;
148 }
149 
150 static int nsim_assoc_add(struct psp_dev *psd, struct psp_assoc *pas,
151 			  struct netlink_ext_ack *extack)
152 {
153 	struct netdevsim *ns = psd->drv_priv;
154 	void **ptr = psp_assoc_drv_data(pas);
155 
156 	/* Copy drv_priv from psd to assoc */
157 	*ptr = psd->drv_priv;
158 	ns->psp.assoc_cnt++;
159 
160 	return 0;
161 }
162 
163 static int nsim_key_rotate(struct psp_dev *psd, struct netlink_ext_ack *extack)
164 {
165 	struct netdevsim *ns = psd->drv_priv;
166 
167 	/* Flip key phase and reset SPI to 0 within that space
168 	 * (will be pre-incremented, as 0 is an invalid SPI).
169 	 */
170 	if (ns->psp.spi & PSP_SPI_KEY_PHASE)
171 		ns->psp.spi = 0;
172 	else
173 		ns->psp.spi = PSP_SPI_KEY_PHASE;
174 
175 	return 0;
176 }
177 
178 static void nsim_assoc_del(struct psp_dev *psd, struct psp_assoc *pas)
179 {
180 	struct netdevsim *ns = psd->drv_priv;
181 	void **ptr = psp_assoc_drv_data(pas);
182 
183 	*ptr = NULL;
184 	ns->psp.assoc_cnt--;
185 }
186 
187 static void nsim_get_stats(struct psp_dev *psd, struct psp_dev_stats *stats)
188 {
189 	struct netdevsim *ns = psd->drv_priv;
190 
191 	/* WARNING: do *not* blindly zero stats in real drivers!
192 	 * All required stats must be reported by the device!
193 	 */
194 	memset(stats, 0, sizeof(struct psp_dev_stats));
195 
196 	stats->rx_bytes = atomic64_read(&ns->psp.rx_bytes);
197 	stats->rx_packets = atomic64_read(&ns->psp.rx_packets);
198 	stats->tx_bytes = atomic64_read(&ns->psp.tx_bytes);
199 	stats->tx_packets = atomic64_read(&ns->psp.tx_packets);
200 }
201 
202 static struct psp_dev_ops nsim_psp_ops = {
203 	.set_config	= nsim_psp_set_config,
204 	.rx_spi_alloc	= nsim_rx_spi_alloc,
205 	.tx_key_add	= nsim_assoc_add,
206 	.tx_key_del	= nsim_assoc_del,
207 	.key_rotate	= nsim_key_rotate,
208 	.get_stats	= nsim_get_stats,
209 };
210 
211 static struct psp_dev_caps nsim_psp_caps = {
212 	.versions = 1 << PSP_VERSION_HDR0_AES_GCM_128 |
213 		    1 << PSP_VERSION_HDR0_AES_GMAC_128 |
214 		    1 << PSP_VERSION_HDR0_AES_GCM_256 |
215 		    1 << PSP_VERSION_HDR0_AES_GMAC_256,
216 	.assoc_drv_spc = sizeof(void *),
217 };
218 
219 static void __nsim_psp_uninit(struct netdevsim *ns, bool teardown)
220 {
221 	struct psp_dev *psd;
222 
223 	psd = rcu_dereference_protected(ns->psp.dev,
224 					teardown ||
225 					lockdep_is_held(&ns->psp.rereg_lock));
226 	if (psd) {
227 		rcu_assign_pointer(ns->psp.dev, NULL);
228 		synchronize_rcu();
229 		psp_dev_unregister(psd);
230 	}
231 	WARN_ON(ns->psp.assoc_cnt);
232 }
233 
234 void nsim_psp_uninit(struct netdevsim *ns)
235 {
236 	debugfs_remove(ns->psp.rereg);
237 	mutex_destroy(&ns->psp.rereg_lock);
238 	__nsim_psp_uninit(ns, true);
239 }
240 
241 static ssize_t
242 nsim_psp_rereg_write(struct file *file, const char __user *data, size_t count,
243 		     loff_t *ppos)
244 {
245 	struct netdevsim *ns = file->private_data;
246 	struct psp_dev *psd;
247 	ssize_t ret;
248 
249 	mutex_lock(&ns->psp.rereg_lock);
250 	__nsim_psp_uninit(ns, false);
251 
252 	psd = psp_dev_create(ns->netdev, &nsim_psp_ops, &nsim_psp_caps, ns);
253 	if (IS_ERR(psd)) {
254 		ret = PTR_ERR(psd);
255 		goto out;
256 	}
257 
258 	rcu_assign_pointer(ns->psp.dev, psd);
259 	ret = count;
260 out:
261 	mutex_unlock(&ns->psp.rereg_lock);
262 	return ret;
263 }
264 
265 static const struct file_operations nsim_psp_rereg_fops = {
266 	.open = simple_open,
267 	.write = nsim_psp_rereg_write,
268 	.llseek = generic_file_llseek,
269 	.owner = THIS_MODULE,
270 };
271 
272 int nsim_psp_init(struct netdevsim *ns)
273 {
274 	struct dentry *ddir = ns->nsim_dev_port->ddir;
275 	struct psp_dev *psd;
276 
277 	psd = psp_dev_create(ns->netdev, &nsim_psp_ops, &nsim_psp_caps, ns);
278 	if (IS_ERR(psd))
279 		return PTR_ERR(psd);
280 
281 	rcu_assign_pointer(ns->psp.dev, psd);
282 
283 	mutex_init(&ns->psp.rereg_lock);
284 	ns->psp.rereg = debugfs_create_file("psp_rereg", 0200, ddir, ns,
285 					    &nsim_psp_rereg_fops);
286 	return 0;
287 }
288