xref: /linux/drivers/net/netdevsim/psp.c (revision 6dfafbd0299a60bfb5d5e277fdf100037c7ded07)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <linux/ip.h>
4 #include <linux/skbuff.h>
5 #include <net/ip6_checksum.h>
6 #include <net/psp.h>
7 #include <net/sock.h>
8 
9 #include "netdevsim.h"
10 
11 void nsim_psp_handle_ext(struct sk_buff *skb, struct skb_ext *psp_ext)
12 {
13 	if (psp_ext)
14 		__skb_ext_set(skb, SKB_EXT_PSP, psp_ext);
15 }
16 
17 enum skb_drop_reason
18 nsim_do_psp(struct sk_buff *skb, struct netdevsim *ns,
19 	    struct netdevsim *peer_ns, struct skb_ext **psp_ext)
20 {
21 	enum skb_drop_reason rc = 0;
22 	struct psp_assoc *pas;
23 	struct net *net;
24 	void **ptr;
25 
26 	rcu_read_lock();
27 	pas = psp_skb_get_assoc_rcu(skb);
28 	if (!pas) {
29 		rc = SKB_NOT_DROPPED_YET;
30 		goto out_unlock;
31 	}
32 
33 	if (!skb_transport_header_was_set(skb)) {
34 		rc = SKB_DROP_REASON_PSP_OUTPUT;
35 		goto out_unlock;
36 	}
37 
38 	ptr = psp_assoc_drv_data(pas);
39 	if (*ptr != ns) {
40 		rc = SKB_DROP_REASON_PSP_OUTPUT;
41 		goto out_unlock;
42 	}
43 
44 	net = sock_net(skb->sk);
45 	if (!psp_dev_encapsulate(net, skb, pas->tx.spi, pas->version, 0)) {
46 		rc = SKB_DROP_REASON_PSP_OUTPUT;
47 		goto out_unlock;
48 	}
49 
50 	/* Now pretend we just received this frame */
51 	if (peer_ns->psp.dev->config.versions & (1 << pas->version)) {
52 		bool strip_icv = false;
53 		u8 generation;
54 
55 		/* We cheat a bit and put the generation in the key.
56 		 * In real life if generation was too old, then decryption would
57 		 * fail. Here, we just make it so a bad key causes a bad
58 		 * generation too, and psp_sk_rx_policy_check() will fail.
59 		 */
60 		generation = pas->tx.key[0];
61 
62 		skb_ext_reset(skb);
63 		skb->mac_len = ETH_HLEN;
64 		if (psp_dev_rcv(skb, peer_ns->psp.dev->id, generation,
65 				strip_icv)) {
66 			rc = SKB_DROP_REASON_PSP_OUTPUT;
67 			goto out_unlock;
68 		}
69 
70 		*psp_ext = skb->extensions;
71 		refcount_inc(&(*psp_ext)->refcnt);
72 		skb->decrypted = 1;
73 
74 		u64_stats_update_begin(&ns->psp.syncp);
75 		ns->psp.tx_packets++;
76 		ns->psp.rx_packets++;
77 		ns->psp.tx_bytes += skb->len - skb_inner_transport_offset(skb);
78 		ns->psp.rx_bytes += skb->len - skb_inner_transport_offset(skb);
79 		u64_stats_update_end(&ns->psp.syncp);
80 	} else {
81 		struct ipv6hdr *ip6h __maybe_unused;
82 		struct iphdr *iph;
83 		struct udphdr *uh;
84 		__wsum csum;
85 
86 		/* Do not decapsulate. Receive the skb with the udp and psp
87 		 * headers still there as if this is a normal udp packet.
88 		 * psp_dev_encapsulate() sets udp checksum to 0, so we need to
89 		 * provide a valid checksum here, so the skb isn't dropped.
90 		 */
91 		uh = udp_hdr(skb);
92 		csum = skb_checksum(skb, skb_transport_offset(skb),
93 				    ntohs(uh->len), 0);
94 
95 		switch (skb->protocol) {
96 		case htons(ETH_P_IP):
97 			iph = ip_hdr(skb);
98 			uh->check = udp_v4_check(ntohs(uh->len), iph->saddr,
99 						 iph->daddr, csum);
100 			break;
101 #if IS_ENABLED(CONFIG_IPV6)
102 		case htons(ETH_P_IPV6):
103 			ip6h = ipv6_hdr(skb);
104 			uh->check = udp_v6_check(ntohs(uh->len), &ip6h->saddr,
105 						 &ip6h->daddr, csum);
106 			break;
107 #endif
108 		}
109 
110 		uh->check	= uh->check ?: CSUM_MANGLED_0;
111 		skb->ip_summed	= CHECKSUM_NONE;
112 	}
113 
114 out_unlock:
115 	rcu_read_unlock();
116 	return rc;
117 }
118 
119 static int
120 nsim_psp_set_config(struct psp_dev *psd, struct psp_dev_config *conf,
121 		    struct netlink_ext_ack *extack)
122 {
123 	return 0;
124 }
125 
126 static int
127 nsim_rx_spi_alloc(struct psp_dev *psd, u32 version,
128 		  struct psp_key_parsed *assoc,
129 		  struct netlink_ext_ack *extack)
130 {
131 	struct netdevsim *ns = psd->drv_priv;
132 	unsigned int new;
133 	int i;
134 
135 	new = ++ns->psp.spi & PSP_SPI_KEY_ID;
136 	if (psd->generation & 1)
137 		new |= PSP_SPI_KEY_PHASE;
138 
139 	assoc->spi = cpu_to_be32(new);
140 	assoc->key[0] = psd->generation;
141 	for (i = 1; i < PSP_MAX_KEY; i++)
142 		assoc->key[i] = ns->psp.spi + i;
143 
144 	return 0;
145 }
146 
147 static int nsim_assoc_add(struct psp_dev *psd, struct psp_assoc *pas,
148 			  struct netlink_ext_ack *extack)
149 {
150 	struct netdevsim *ns = psd->drv_priv;
151 	void **ptr = psp_assoc_drv_data(pas);
152 
153 	/* Copy drv_priv from psd to assoc */
154 	*ptr = psd->drv_priv;
155 	ns->psp.assoc_cnt++;
156 
157 	return 0;
158 }
159 
160 static int nsim_key_rotate(struct psp_dev *psd, struct netlink_ext_ack *extack)
161 {
162 	return 0;
163 }
164 
165 static void nsim_assoc_del(struct psp_dev *psd, struct psp_assoc *pas)
166 {
167 	struct netdevsim *ns = psd->drv_priv;
168 	void **ptr = psp_assoc_drv_data(pas);
169 
170 	*ptr = NULL;
171 	ns->psp.assoc_cnt--;
172 }
173 
174 static void nsim_get_stats(struct psp_dev *psd, struct psp_dev_stats *stats)
175 {
176 	struct netdevsim *ns = psd->drv_priv;
177 	unsigned int start;
178 
179 	/* WARNING: do *not* blindly zero stats in real drivers!
180 	 * All required stats must be reported by the device!
181 	 */
182 	memset(stats, 0, sizeof(struct psp_dev_stats));
183 
184 	do {
185 		start = u64_stats_fetch_begin(&ns->psp.syncp);
186 		stats->rx_bytes = ns->psp.rx_bytes;
187 		stats->rx_packets = ns->psp.rx_packets;
188 		stats->tx_bytes = ns->psp.tx_bytes;
189 		stats->tx_packets = ns->psp.tx_packets;
190 	} while (u64_stats_fetch_retry(&ns->psp.syncp, start));
191 }
192 
193 static struct psp_dev_ops nsim_psp_ops = {
194 	.set_config	= nsim_psp_set_config,
195 	.rx_spi_alloc	= nsim_rx_spi_alloc,
196 	.tx_key_add	= nsim_assoc_add,
197 	.tx_key_del	= nsim_assoc_del,
198 	.key_rotate	= nsim_key_rotate,
199 	.get_stats	= nsim_get_stats,
200 };
201 
202 static struct psp_dev_caps nsim_psp_caps = {
203 	.versions = 1 << PSP_VERSION_HDR0_AES_GCM_128 |
204 		    1 << PSP_VERSION_HDR0_AES_GMAC_128 |
205 		    1 << PSP_VERSION_HDR0_AES_GCM_256 |
206 		    1 << PSP_VERSION_HDR0_AES_GMAC_256,
207 	.assoc_drv_spc = sizeof(void *),
208 };
209 
210 void nsim_psp_uninit(struct netdevsim *ns)
211 {
212 	if (!IS_ERR(ns->psp.dev))
213 		psp_dev_unregister(ns->psp.dev);
214 	WARN_ON(ns->psp.assoc_cnt);
215 }
216 
217 static ssize_t
218 nsim_psp_rereg_write(struct file *file, const char __user *data, size_t count,
219 		     loff_t *ppos)
220 {
221 	struct netdevsim *ns = file->private_data;
222 	int err;
223 
224 	nsim_psp_uninit(ns);
225 
226 	ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops,
227 				     &nsim_psp_caps, ns);
228 	err = PTR_ERR_OR_ZERO(ns->psp.dev);
229 	return err ?: count;
230 }
231 
232 static const struct file_operations nsim_psp_rereg_fops = {
233 	.open = simple_open,
234 	.write = nsim_psp_rereg_write,
235 	.llseek = generic_file_llseek,
236 	.owner = THIS_MODULE,
237 };
238 
239 int nsim_psp_init(struct netdevsim *ns)
240 {
241 	struct dentry *ddir = ns->nsim_dev_port->ddir;
242 	int err;
243 
244 	ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops,
245 				     &nsim_psp_caps, ns);
246 	err = PTR_ERR_OR_ZERO(ns->psp.dev);
247 	if (err)
248 		return err;
249 
250 	debugfs_create_file("psp_rereg", 0200, ddir, ns, &nsim_psp_rereg_fops);
251 	return 0;
252 }
253