xref: /linux/drivers/net/netdevsim/psp.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <linux/ip.h>
4 #include <linux/skbuff.h>
5 #include <net/ip6_checksum.h>
6 #include <net/psp.h>
7 #include <net/sock.h>
8 
9 #include "netdevsim.h"
10 
11 void nsim_psp_handle_ext(struct sk_buff *skb, struct skb_ext *psp_ext)
12 {
13 	if (psp_ext)
14 		__skb_ext_set(skb, SKB_EXT_PSP, psp_ext);
15 }
16 
17 enum skb_drop_reason
18 nsim_do_psp(struct sk_buff *skb, struct netdevsim *ns,
19 	    struct netdevsim *peer_ns, struct skb_ext **psp_ext)
20 {
21 	enum skb_drop_reason rc = 0;
22 	struct psp_assoc *pas;
23 	struct net *net;
24 	void **ptr;
25 
26 	rcu_read_lock();
27 	pas = psp_skb_get_assoc_rcu(skb);
28 	if (!pas) {
29 		rc = SKB_NOT_DROPPED_YET;
30 		goto out_unlock;
31 	}
32 
33 	if (!skb_transport_header_was_set(skb)) {
34 		rc = SKB_DROP_REASON_PSP_OUTPUT;
35 		goto out_unlock;
36 	}
37 
38 	ptr = psp_assoc_drv_data(pas);
39 	if (*ptr != ns) {
40 		rc = SKB_DROP_REASON_PSP_OUTPUT;
41 		goto out_unlock;
42 	}
43 
44 	net = sock_net(skb->sk);
45 	if (!psp_dev_encapsulate(net, skb, pas->tx.spi, pas->version, 0)) {
46 		rc = SKB_DROP_REASON_PSP_OUTPUT;
47 		goto out_unlock;
48 	}
49 
50 	/* Now pretend we just received this frame */
51 	if (peer_ns->psp.dev->config.versions & (1 << pas->version)) {
52 		bool strip_icv = false;
53 		u8 generation;
54 
55 		/* We cheat a bit and put the generation in the key.
56 		 * In real life if generation was too old, then decryption would
57 		 * fail. Here, we just make it so a bad key causes a bad
58 		 * generation too, and psp_sk_rx_policy_check() will fail.
59 		 */
60 		generation = pas->tx.key[0];
61 
62 		skb_ext_reset(skb);
63 		skb->mac_len = ETH_HLEN;
64 		if (psp_dev_rcv(skb, peer_ns->psp.dev->id, generation,
65 				strip_icv)) {
66 			rc = SKB_DROP_REASON_PSP_OUTPUT;
67 			goto out_unlock;
68 		}
69 
70 		*psp_ext = skb->extensions;
71 		refcount_inc(&(*psp_ext)->refcnt);
72 		skb->decrypted = 1;
73 	} else {
74 		struct ipv6hdr *ip6h __maybe_unused;
75 		struct iphdr *iph;
76 		struct udphdr *uh;
77 		__wsum csum;
78 
79 		/* Do not decapsulate. Receive the skb with the udp and psp
80 		 * headers still there as if this is a normal udp packet.
81 		 * psp_dev_encapsulate() sets udp checksum to 0, so we need to
82 		 * provide a valid checksum here, so the skb isn't dropped.
83 		 */
84 		uh = udp_hdr(skb);
85 		csum = skb_checksum(skb, skb_transport_offset(skb),
86 				    ntohs(uh->len), 0);
87 
88 		switch (skb->protocol) {
89 		case htons(ETH_P_IP):
90 			iph = ip_hdr(skb);
91 			uh->check = udp_v4_check(ntohs(uh->len), iph->saddr,
92 						 iph->daddr, csum);
93 			break;
94 #if IS_ENABLED(CONFIG_IPV6)
95 		case htons(ETH_P_IPV6):
96 			ip6h = ipv6_hdr(skb);
97 			uh->check = udp_v6_check(ntohs(uh->len), &ip6h->saddr,
98 						 &ip6h->daddr, csum);
99 			break;
100 #endif
101 		}
102 
103 		uh->check	= uh->check ?: CSUM_MANGLED_0;
104 		skb->ip_summed	= CHECKSUM_NONE;
105 	}
106 
107 out_unlock:
108 	rcu_read_unlock();
109 	return rc;
110 }
111 
112 static int
113 nsim_psp_set_config(struct psp_dev *psd, struct psp_dev_config *conf,
114 		    struct netlink_ext_ack *extack)
115 {
116 	return 0;
117 }
118 
119 static int
120 nsim_rx_spi_alloc(struct psp_dev *psd, u32 version,
121 		  struct psp_key_parsed *assoc,
122 		  struct netlink_ext_ack *extack)
123 {
124 	struct netdevsim *ns = psd->drv_priv;
125 	unsigned int new;
126 	int i;
127 
128 	new = ++ns->psp.spi & PSP_SPI_KEY_ID;
129 	if (psd->generation & 1)
130 		new |= PSP_SPI_KEY_PHASE;
131 
132 	assoc->spi = cpu_to_be32(new);
133 	assoc->key[0] = psd->generation;
134 	for (i = 1; i < PSP_MAX_KEY; i++)
135 		assoc->key[i] = ns->psp.spi + i;
136 
137 	return 0;
138 }
139 
140 static int nsim_assoc_add(struct psp_dev *psd, struct psp_assoc *pas,
141 			  struct netlink_ext_ack *extack)
142 {
143 	struct netdevsim *ns = psd->drv_priv;
144 	void **ptr = psp_assoc_drv_data(pas);
145 
146 	/* Copy drv_priv from psd to assoc */
147 	*ptr = psd->drv_priv;
148 	ns->psp.assoc_cnt++;
149 
150 	return 0;
151 }
152 
153 static int nsim_key_rotate(struct psp_dev *psd, struct netlink_ext_ack *extack)
154 {
155 	return 0;
156 }
157 
158 static void nsim_assoc_del(struct psp_dev *psd, struct psp_assoc *pas)
159 {
160 	struct netdevsim *ns = psd->drv_priv;
161 	void **ptr = psp_assoc_drv_data(pas);
162 
163 	*ptr = NULL;
164 	ns->psp.assoc_cnt--;
165 }
166 
167 static struct psp_dev_ops nsim_psp_ops = {
168 	.set_config	= nsim_psp_set_config,
169 	.rx_spi_alloc	= nsim_rx_spi_alloc,
170 	.tx_key_add	= nsim_assoc_add,
171 	.tx_key_del	= nsim_assoc_del,
172 	.key_rotate	= nsim_key_rotate,
173 };
174 
175 static struct psp_dev_caps nsim_psp_caps = {
176 	.versions = 1 << PSP_VERSION_HDR0_AES_GCM_128 |
177 		    1 << PSP_VERSION_HDR0_AES_GMAC_128 |
178 		    1 << PSP_VERSION_HDR0_AES_GCM_256 |
179 		    1 << PSP_VERSION_HDR0_AES_GMAC_256,
180 	.assoc_drv_spc = sizeof(void *),
181 };
182 
183 void nsim_psp_uninit(struct netdevsim *ns)
184 {
185 	if (!IS_ERR(ns->psp.dev))
186 		psp_dev_unregister(ns->psp.dev);
187 	WARN_ON(ns->psp.assoc_cnt);
188 }
189 
190 static ssize_t
191 nsim_psp_rereg_write(struct file *file, const char __user *data, size_t count,
192 		     loff_t *ppos)
193 {
194 	struct netdevsim *ns = file->private_data;
195 	int err;
196 
197 	nsim_psp_uninit(ns);
198 
199 	ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops,
200 				     &nsim_psp_caps, ns);
201 	err = PTR_ERR_OR_ZERO(ns->psp.dev);
202 	return err ?: count;
203 }
204 
205 static const struct file_operations nsim_psp_rereg_fops = {
206 	.open = simple_open,
207 	.write = nsim_psp_rereg_write,
208 	.llseek = generic_file_llseek,
209 	.owner = THIS_MODULE,
210 };
211 
212 int nsim_psp_init(struct netdevsim *ns)
213 {
214 	struct dentry *ddir = ns->nsim_dev_port->ddir;
215 	int err;
216 
217 	ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops,
218 				     &nsim_psp_caps, ns);
219 	err = PTR_ERR_OR_ZERO(ns->psp.dev);
220 	if (err)
221 		return err;
222 
223 	debugfs_create_file("psp_rereg", 0200, ddir, ns, &nsim_psp_rereg_fops);
224 	return 0;
225 }
226