xref: /linux/net/psp/psp_nl.c (revision 07fdad3a93756b872da7b53647715c48d0f4a2d0)
100c94ca2SJakub Kicinski // SPDX-License-Identifier: GPL-2.0-only
200c94ca2SJakub Kicinski 
300c94ca2SJakub Kicinski #include <linux/skbuff.h>
400c94ca2SJakub Kicinski #include <linux/xarray.h>
500c94ca2SJakub Kicinski #include <net/genetlink.h>
600c94ca2SJakub Kicinski #include <net/psp.h>
700c94ca2SJakub Kicinski #include <net/sock.h>
800c94ca2SJakub Kicinski 
900c94ca2SJakub Kicinski #include "psp-nl-gen.h"
1000c94ca2SJakub Kicinski #include "psp.h"
1100c94ca2SJakub Kicinski 
1200c94ca2SJakub Kicinski /* Netlink helpers */
1300c94ca2SJakub Kicinski 
1400c94ca2SJakub Kicinski static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
1500c94ca2SJakub Kicinski {
1600c94ca2SJakub Kicinski 	struct sk_buff *rsp;
1700c94ca2SJakub Kicinski 	void *hdr;
1800c94ca2SJakub Kicinski 
1900c94ca2SJakub Kicinski 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
2000c94ca2SJakub Kicinski 	if (!rsp)
2100c94ca2SJakub Kicinski 		return NULL;
2200c94ca2SJakub Kicinski 
2300c94ca2SJakub Kicinski 	hdr = genlmsg_iput(rsp, info);
2400c94ca2SJakub Kicinski 	if (!hdr) {
2500c94ca2SJakub Kicinski 		nlmsg_free(rsp);
2600c94ca2SJakub Kicinski 		return NULL;
2700c94ca2SJakub Kicinski 	}
2800c94ca2SJakub Kicinski 
2900c94ca2SJakub Kicinski 	return rsp;
3000c94ca2SJakub Kicinski }
3100c94ca2SJakub Kicinski 
3200c94ca2SJakub Kicinski static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
3300c94ca2SJakub Kicinski {
3400c94ca2SJakub Kicinski 	/* Note that this *only* works with a single message per skb! */
3500c94ca2SJakub Kicinski 	nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
3600c94ca2SJakub Kicinski 
3700c94ca2SJakub Kicinski 	return genlmsg_reply(rsp, info);
3800c94ca2SJakub Kicinski }
3900c94ca2SJakub Kicinski 
4000c94ca2SJakub Kicinski /* Device stuff */
4100c94ca2SJakub Kicinski 
4200c94ca2SJakub Kicinski static struct psp_dev *
4300c94ca2SJakub Kicinski psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
4400c94ca2SJakub Kicinski {
4500c94ca2SJakub Kicinski 	struct psp_dev *psd;
4600c94ca2SJakub Kicinski 	int err;
4700c94ca2SJakub Kicinski 
4800c94ca2SJakub Kicinski 	mutex_lock(&psp_devs_lock);
4900c94ca2SJakub Kicinski 	psd = xa_load(&psp_devs, nla_get_u32(dev_id));
5000c94ca2SJakub Kicinski 	if (!psd) {
5100c94ca2SJakub Kicinski 		mutex_unlock(&psp_devs_lock);
5200c94ca2SJakub Kicinski 		return ERR_PTR(-ENODEV);
5300c94ca2SJakub Kicinski 	}
5400c94ca2SJakub Kicinski 
5500c94ca2SJakub Kicinski 	mutex_lock(&psd->lock);
5600c94ca2SJakub Kicinski 	mutex_unlock(&psp_devs_lock);
5700c94ca2SJakub Kicinski 
5800c94ca2SJakub Kicinski 	err = psp_dev_check_access(psd, net);
5900c94ca2SJakub Kicinski 	if (err) {
6000c94ca2SJakub Kicinski 		mutex_unlock(&psd->lock);
6100c94ca2SJakub Kicinski 		return ERR_PTR(err);
6200c94ca2SJakub Kicinski 	}
6300c94ca2SJakub Kicinski 
6400c94ca2SJakub Kicinski 	return psd;
6500c94ca2SJakub Kicinski }
6600c94ca2SJakub Kicinski 
6700c94ca2SJakub Kicinski int psp_device_get_locked(const struct genl_split_ops *ops,
6800c94ca2SJakub Kicinski 			  struct sk_buff *skb, struct genl_info *info)
6900c94ca2SJakub Kicinski {
7000c94ca2SJakub Kicinski 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
7100c94ca2SJakub Kicinski 		return -EINVAL;
7200c94ca2SJakub Kicinski 
7300c94ca2SJakub Kicinski 	info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
7400c94ca2SJakub Kicinski 						    info->attrs[PSP_A_DEV_ID]);
7500c94ca2SJakub Kicinski 	return PTR_ERR_OR_ZERO(info->user_ptr[0]);
7600c94ca2SJakub Kicinski }
7700c94ca2SJakub Kicinski 
7800c94ca2SJakub Kicinski void
7900c94ca2SJakub Kicinski psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
8000c94ca2SJakub Kicinski 		  struct genl_info *info)
8100c94ca2SJakub Kicinski {
826b46ca26SJakub Kicinski 	struct socket *socket = info->user_ptr[1];
8300c94ca2SJakub Kicinski 	struct psp_dev *psd = info->user_ptr[0];
8400c94ca2SJakub Kicinski 
8500c94ca2SJakub Kicinski 	mutex_unlock(&psd->lock);
866b46ca26SJakub Kicinski 	if (socket)
876b46ca26SJakub Kicinski 		sockfd_put(socket);
8800c94ca2SJakub Kicinski }
8900c94ca2SJakub Kicinski 
9000c94ca2SJakub Kicinski static int
9100c94ca2SJakub Kicinski psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
9200c94ca2SJakub Kicinski 		const struct genl_info *info)
9300c94ca2SJakub Kicinski {
9400c94ca2SJakub Kicinski 	void *hdr;
9500c94ca2SJakub Kicinski 
9600c94ca2SJakub Kicinski 	hdr = genlmsg_iput(rsp, info);
9700c94ca2SJakub Kicinski 	if (!hdr)
9800c94ca2SJakub Kicinski 		return -EMSGSIZE;
9900c94ca2SJakub Kicinski 
10000c94ca2SJakub Kicinski 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
10100c94ca2SJakub Kicinski 	    nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
10200c94ca2SJakub Kicinski 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
10300c94ca2SJakub Kicinski 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
10400c94ca2SJakub Kicinski 		goto err_cancel_msg;
10500c94ca2SJakub Kicinski 
10600c94ca2SJakub Kicinski 	genlmsg_end(rsp, hdr);
10700c94ca2SJakub Kicinski 	return 0;
10800c94ca2SJakub Kicinski 
10900c94ca2SJakub Kicinski err_cancel_msg:
11000c94ca2SJakub Kicinski 	genlmsg_cancel(rsp, hdr);
11100c94ca2SJakub Kicinski 	return -EMSGSIZE;
11200c94ca2SJakub Kicinski }
11300c94ca2SJakub Kicinski 
11400c94ca2SJakub Kicinski void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
11500c94ca2SJakub Kicinski {
11600c94ca2SJakub Kicinski 	struct genl_info info;
11700c94ca2SJakub Kicinski 	struct sk_buff *ntf;
11800c94ca2SJakub Kicinski 
11900c94ca2SJakub Kicinski 	if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
12000c94ca2SJakub Kicinski 				PSP_NLGRP_MGMT))
12100c94ca2SJakub Kicinski 		return;
12200c94ca2SJakub Kicinski 
12300c94ca2SJakub Kicinski 	ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
12400c94ca2SJakub Kicinski 	if (!ntf)
12500c94ca2SJakub Kicinski 		return;
12600c94ca2SJakub Kicinski 
12700c94ca2SJakub Kicinski 	genl_info_init_ntf(&info, &psp_nl_family, cmd);
12800c94ca2SJakub Kicinski 	if (psp_nl_dev_fill(psd, ntf, &info)) {
12900c94ca2SJakub Kicinski 		nlmsg_free(ntf);
13000c94ca2SJakub Kicinski 		return;
13100c94ca2SJakub Kicinski 	}
13200c94ca2SJakub Kicinski 
13300c94ca2SJakub Kicinski 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
13400c94ca2SJakub Kicinski 				0, PSP_NLGRP_MGMT, GFP_KERNEL);
13500c94ca2SJakub Kicinski }
13600c94ca2SJakub Kicinski 
13700c94ca2SJakub Kicinski int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
13800c94ca2SJakub Kicinski {
13900c94ca2SJakub Kicinski 	struct psp_dev *psd = info->user_ptr[0];
14000c94ca2SJakub Kicinski 	struct sk_buff *rsp;
14100c94ca2SJakub Kicinski 	int err;
14200c94ca2SJakub Kicinski 
14300c94ca2SJakub Kicinski 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
14400c94ca2SJakub Kicinski 	if (!rsp)
14500c94ca2SJakub Kicinski 		return -ENOMEM;
14600c94ca2SJakub Kicinski 
14700c94ca2SJakub Kicinski 	err = psp_nl_dev_fill(psd, rsp, info);
14800c94ca2SJakub Kicinski 	if (err)
14900c94ca2SJakub Kicinski 		goto err_free_msg;
15000c94ca2SJakub Kicinski 
15100c94ca2SJakub Kicinski 	return genlmsg_reply(rsp, info);
15200c94ca2SJakub Kicinski 
15300c94ca2SJakub Kicinski err_free_msg:
15400c94ca2SJakub Kicinski 	nlmsg_free(rsp);
15500c94ca2SJakub Kicinski 	return err;
15600c94ca2SJakub Kicinski }
15700c94ca2SJakub Kicinski 
15800c94ca2SJakub Kicinski static int
15900c94ca2SJakub Kicinski psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
16000c94ca2SJakub Kicinski 			  struct psp_dev *psd)
16100c94ca2SJakub Kicinski {
16200c94ca2SJakub Kicinski 	if (psp_dev_check_access(psd, sock_net(rsp->sk)))
16300c94ca2SJakub Kicinski 		return 0;
16400c94ca2SJakub Kicinski 
16500c94ca2SJakub Kicinski 	return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
16600c94ca2SJakub Kicinski }
16700c94ca2SJakub Kicinski 
16800c94ca2SJakub Kicinski int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
16900c94ca2SJakub Kicinski {
17000c94ca2SJakub Kicinski 	struct psp_dev *psd;
17100c94ca2SJakub Kicinski 	int err = 0;
17200c94ca2SJakub Kicinski 
17300c94ca2SJakub Kicinski 	mutex_lock(&psp_devs_lock);
17400c94ca2SJakub Kicinski 	xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
17500c94ca2SJakub Kicinski 		mutex_lock(&psd->lock);
17600c94ca2SJakub Kicinski 		err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
17700c94ca2SJakub Kicinski 		mutex_unlock(&psd->lock);
17800c94ca2SJakub Kicinski 		if (err)
17900c94ca2SJakub Kicinski 			break;
18000c94ca2SJakub Kicinski 	}
18100c94ca2SJakub Kicinski 	mutex_unlock(&psp_devs_lock);
18200c94ca2SJakub Kicinski 
18300c94ca2SJakub Kicinski 	return err;
18400c94ca2SJakub Kicinski }
18500c94ca2SJakub Kicinski 
18600c94ca2SJakub Kicinski int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
18700c94ca2SJakub Kicinski {
18800c94ca2SJakub Kicinski 	struct psp_dev *psd = info->user_ptr[0];
18900c94ca2SJakub Kicinski 	struct psp_dev_config new_config;
19000c94ca2SJakub Kicinski 	struct sk_buff *rsp;
19100c94ca2SJakub Kicinski 	int err;
19200c94ca2SJakub Kicinski 
19300c94ca2SJakub Kicinski 	memcpy(&new_config, &psd->config, sizeof(new_config));
19400c94ca2SJakub Kicinski 
19500c94ca2SJakub Kicinski 	if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
19600c94ca2SJakub Kicinski 		new_config.versions =
19700c94ca2SJakub Kicinski 			nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
19800c94ca2SJakub Kicinski 		if (new_config.versions & ~psd->caps->versions) {
19900c94ca2SJakub Kicinski 			NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
20000c94ca2SJakub Kicinski 			return -EINVAL;
20100c94ca2SJakub Kicinski 		}
20200c94ca2SJakub Kicinski 	} else {
20300c94ca2SJakub Kicinski 		NL_SET_ERR_MSG(info->extack, "No settings present");
20400c94ca2SJakub Kicinski 		return -EINVAL;
20500c94ca2SJakub Kicinski 	}
20600c94ca2SJakub Kicinski 
20700c94ca2SJakub Kicinski 	rsp = psp_nl_reply_new(info);
20800c94ca2SJakub Kicinski 	if (!rsp)
20900c94ca2SJakub Kicinski 		return -ENOMEM;
21000c94ca2SJakub Kicinski 
21100c94ca2SJakub Kicinski 	if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
21200c94ca2SJakub Kicinski 		err = psd->ops->set_config(psd, &new_config, info->extack);
21300c94ca2SJakub Kicinski 		if (err)
21400c94ca2SJakub Kicinski 			goto err_free_rsp;
21500c94ca2SJakub Kicinski 
21600c94ca2SJakub Kicinski 		memcpy(&psd->config, &new_config, sizeof(new_config));
21700c94ca2SJakub Kicinski 	}
21800c94ca2SJakub Kicinski 
21900c94ca2SJakub Kicinski 	psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
22000c94ca2SJakub Kicinski 
22100c94ca2SJakub Kicinski 	return psp_nl_reply_send(rsp, info);
22200c94ca2SJakub Kicinski 
22300c94ca2SJakub Kicinski err_free_rsp:
22400c94ca2SJakub Kicinski 	nlmsg_free(rsp);
22500c94ca2SJakub Kicinski 	return err;
22600c94ca2SJakub Kicinski }
227117f02a4SJakub Kicinski 
228117f02a4SJakub Kicinski int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
229117f02a4SJakub Kicinski {
230117f02a4SJakub Kicinski 	struct psp_dev *psd = info->user_ptr[0];
231117f02a4SJakub Kicinski 	struct genl_info ntf_info;
232117f02a4SJakub Kicinski 	struct sk_buff *ntf, *rsp;
233*e7885105SJakub Kicinski 	u8 prev_gen;
234117f02a4SJakub Kicinski 	int err;
235117f02a4SJakub Kicinski 
236117f02a4SJakub Kicinski 	rsp = psp_nl_reply_new(info);
237117f02a4SJakub Kicinski 	if (!rsp)
238117f02a4SJakub Kicinski 		return -ENOMEM;
239117f02a4SJakub Kicinski 
240117f02a4SJakub Kicinski 	genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
241117f02a4SJakub Kicinski 	ntf = psp_nl_reply_new(&ntf_info);
242117f02a4SJakub Kicinski 	if (!ntf) {
243117f02a4SJakub Kicinski 		err = -ENOMEM;
244117f02a4SJakub Kicinski 		goto err_free_rsp;
245117f02a4SJakub Kicinski 	}
246117f02a4SJakub Kicinski 
247117f02a4SJakub Kicinski 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
248117f02a4SJakub Kicinski 	    nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
249117f02a4SJakub Kicinski 		err = -EMSGSIZE;
250117f02a4SJakub Kicinski 		goto err_free_ntf;
251117f02a4SJakub Kicinski 	}
252117f02a4SJakub Kicinski 
253*e7885105SJakub Kicinski 	/* suggest the next gen number, driver can override */
254*e7885105SJakub Kicinski 	prev_gen = psd->generation;
255*e7885105SJakub Kicinski 	psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
256*e7885105SJakub Kicinski 
257117f02a4SJakub Kicinski 	err = psd->ops->key_rotate(psd, info->extack);
258117f02a4SJakub Kicinski 	if (err)
259117f02a4SJakub Kicinski 		goto err_free_ntf;
260117f02a4SJakub Kicinski 
261*e7885105SJakub Kicinski 	WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
262*e7885105SJakub Kicinski 		     psd->generation & ~PSP_GEN_VALID_MASK);
263*e7885105SJakub Kicinski 
264*e7885105SJakub Kicinski 	psp_assocs_key_rotated(psd);
265*e7885105SJakub Kicinski 
266117f02a4SJakub Kicinski 	nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
267117f02a4SJakub Kicinski 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
268117f02a4SJakub Kicinski 				0, PSP_NLGRP_USE, GFP_KERNEL);
269117f02a4SJakub Kicinski 	return psp_nl_reply_send(rsp, info);
270117f02a4SJakub Kicinski 
271117f02a4SJakub Kicinski err_free_ntf:
272117f02a4SJakub Kicinski 	nlmsg_free(ntf);
273117f02a4SJakub Kicinski err_free_rsp:
274117f02a4SJakub Kicinski 	nlmsg_free(rsp);
275117f02a4SJakub Kicinski 	return err;
276117f02a4SJakub Kicinski }
2776b46ca26SJakub Kicinski 
2786b46ca26SJakub Kicinski /* Key etc. */
2796b46ca26SJakub Kicinski 
2806b46ca26SJakub Kicinski int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
2816b46ca26SJakub Kicinski 				struct sk_buff *skb, struct genl_info *info)
2826b46ca26SJakub Kicinski {
2836b46ca26SJakub Kicinski 	struct socket *socket;
2846b46ca26SJakub Kicinski 	struct psp_dev *psd;
2856b46ca26SJakub Kicinski 	struct nlattr *id;
2866b46ca26SJakub Kicinski 	int fd, err;
2876b46ca26SJakub Kicinski 
2886b46ca26SJakub Kicinski 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
2896b46ca26SJakub Kicinski 		return -EINVAL;
2906b46ca26SJakub Kicinski 
2916b46ca26SJakub Kicinski 	fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
2926b46ca26SJakub Kicinski 	socket = sockfd_lookup(fd, &err);
2936b46ca26SJakub Kicinski 	if (!socket)
2946b46ca26SJakub Kicinski 		return err;
2956b46ca26SJakub Kicinski 
2966b46ca26SJakub Kicinski 	if (!sk_is_tcp(socket->sk)) {
2976b46ca26SJakub Kicinski 		NL_SET_ERR_MSG_ATTR(info->extack,
2986b46ca26SJakub Kicinski 				    info->attrs[PSP_A_ASSOC_SOCK_FD],
2996b46ca26SJakub Kicinski 				    "Unsupported socket family and type");
3006b46ca26SJakub Kicinski 		err = -EOPNOTSUPP;
3016b46ca26SJakub Kicinski 		goto err_sock_put;
3026b46ca26SJakub Kicinski 	}
3036b46ca26SJakub Kicinski 
3046b46ca26SJakub Kicinski 	psd = psp_dev_get_for_sock(socket->sk);
3056b46ca26SJakub Kicinski 	if (psd) {
3066b46ca26SJakub Kicinski 		err = psp_dev_check_access(psd, genl_info_net(info));
3076b46ca26SJakub Kicinski 		if (err) {
3086b46ca26SJakub Kicinski 			psp_dev_put(psd);
3096b46ca26SJakub Kicinski 			psd = NULL;
3106b46ca26SJakub Kicinski 		}
3116b46ca26SJakub Kicinski 	}
3126b46ca26SJakub Kicinski 
3136b46ca26SJakub Kicinski 	if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
3146b46ca26SJakub Kicinski 		err = -EINVAL;
3156b46ca26SJakub Kicinski 		goto err_sock_put;
3166b46ca26SJakub Kicinski 	}
3176b46ca26SJakub Kicinski 
3186b46ca26SJakub Kicinski 	id = info->attrs[PSP_A_ASSOC_DEV_ID];
3196b46ca26SJakub Kicinski 	if (psd) {
3206b46ca26SJakub Kicinski 		mutex_lock(&psd->lock);
3216b46ca26SJakub Kicinski 		if (id && psd->id != nla_get_u32(id)) {
3226b46ca26SJakub Kicinski 			mutex_unlock(&psd->lock);
3236b46ca26SJakub Kicinski 			NL_SET_ERR_MSG_ATTR(info->extack, id,
3246b46ca26SJakub Kicinski 					    "Device id vs socket mismatch");
3256b46ca26SJakub Kicinski 			err = -EINVAL;
3266b46ca26SJakub Kicinski 			goto err_psd_put;
3276b46ca26SJakub Kicinski 		}
3286b46ca26SJakub Kicinski 
3296b46ca26SJakub Kicinski 		psp_dev_put(psd);
3306b46ca26SJakub Kicinski 	} else {
3316b46ca26SJakub Kicinski 		psd = psp_device_get_and_lock(genl_info_net(info), id);
3326b46ca26SJakub Kicinski 		if (IS_ERR(psd)) {
3336b46ca26SJakub Kicinski 			err = PTR_ERR(psd);
3346b46ca26SJakub Kicinski 			goto err_sock_put;
3356b46ca26SJakub Kicinski 		}
3366b46ca26SJakub Kicinski 	}
3376b46ca26SJakub Kicinski 
3386b46ca26SJakub Kicinski 	info->user_ptr[0] = psd;
3396b46ca26SJakub Kicinski 	info->user_ptr[1] = socket;
3406b46ca26SJakub Kicinski 
3416b46ca26SJakub Kicinski 	return 0;
3426b46ca26SJakub Kicinski 
3436b46ca26SJakub Kicinski err_psd_put:
3446b46ca26SJakub Kicinski 	psp_dev_put(psd);
3456b46ca26SJakub Kicinski err_sock_put:
3466b46ca26SJakub Kicinski 	sockfd_put(socket);
3476b46ca26SJakub Kicinski 	return err;
3486b46ca26SJakub Kicinski }
3496b46ca26SJakub Kicinski 
3506b46ca26SJakub Kicinski static int
3516b46ca26SJakub Kicinski psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
3526b46ca26SJakub Kicinski 		 unsigned int key_sz)
3536b46ca26SJakub Kicinski {
3546b46ca26SJakub Kicinski 	struct nlattr *nest = info->attrs[attr];
3556b46ca26SJakub Kicinski 	struct nlattr *tb[PSP_A_KEYS_SPI + 1];
3566b46ca26SJakub Kicinski 	u32 spi;
3576b46ca26SJakub Kicinski 	int err;
3586b46ca26SJakub Kicinski 
3596b46ca26SJakub Kicinski 	err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
3606b46ca26SJakub Kicinski 			       psp_keys_nl_policy, info->extack);
3616b46ca26SJakub Kicinski 	if (err)
3626b46ca26SJakub Kicinski 		return err;
3636b46ca26SJakub Kicinski 
3646b46ca26SJakub Kicinski 	if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
3656b46ca26SJakub Kicinski 	    NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
3666b46ca26SJakub Kicinski 		return -EINVAL;
3676b46ca26SJakub Kicinski 
3686b46ca26SJakub Kicinski 	if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
3696b46ca26SJakub Kicinski 		NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
3706b46ca26SJakub Kicinski 				    "incorrect key length");
3716b46ca26SJakub Kicinski 		return -EINVAL;
3726b46ca26SJakub Kicinski 	}
3736b46ca26SJakub Kicinski 
3746b46ca26SJakub Kicinski 	spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
3756b46ca26SJakub Kicinski 	if (!(spi & PSP_SPI_KEY_ID)) {
3766b46ca26SJakub Kicinski 		NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
3776b46ca26SJakub Kicinski 				    "invalid SPI: lower 31b must be non-zero");
3786b46ca26SJakub Kicinski 		return -EINVAL;
3796b46ca26SJakub Kicinski 	}
3806b46ca26SJakub Kicinski 
3816b46ca26SJakub Kicinski 	key->spi = cpu_to_be32(spi);
3826b46ca26SJakub Kicinski 	memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
3836b46ca26SJakub Kicinski 
3846b46ca26SJakub Kicinski 	return 0;
3856b46ca26SJakub Kicinski }
3866b46ca26SJakub Kicinski 
3876b46ca26SJakub Kicinski static int
3886b46ca26SJakub Kicinski psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
3896b46ca26SJakub Kicinski 	       struct psp_key_parsed *key)
3906b46ca26SJakub Kicinski {
3916b46ca26SJakub Kicinski 	int key_sz = psp_key_size(version);
3926b46ca26SJakub Kicinski 	void *nest;
3936b46ca26SJakub Kicinski 
3946b46ca26SJakub Kicinski 	nest = nla_nest_start(skb, attr);
3956b46ca26SJakub Kicinski 
3966b46ca26SJakub Kicinski 	if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
3976b46ca26SJakub Kicinski 	    nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
3986b46ca26SJakub Kicinski 		nla_nest_cancel(skb, nest);
3996b46ca26SJakub Kicinski 		return -EMSGSIZE;
4006b46ca26SJakub Kicinski 	}
4016b46ca26SJakub Kicinski 
4026b46ca26SJakub Kicinski 	nla_nest_end(skb, nest);
4036b46ca26SJakub Kicinski 
4046b46ca26SJakub Kicinski 	return 0;
4056b46ca26SJakub Kicinski }
4066b46ca26SJakub Kicinski 
4076b46ca26SJakub Kicinski int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
4086b46ca26SJakub Kicinski {
4096b46ca26SJakub Kicinski 	struct socket *socket = info->user_ptr[1];
4106b46ca26SJakub Kicinski 	struct psp_dev *psd = info->user_ptr[0];
4116b46ca26SJakub Kicinski 	struct psp_key_parsed key;
4126b46ca26SJakub Kicinski 	struct psp_assoc *pas;
4136b46ca26SJakub Kicinski 	struct sk_buff *rsp;
4146b46ca26SJakub Kicinski 	u32 version;
4156b46ca26SJakub Kicinski 	int err;
4166b46ca26SJakub Kicinski 
4176b46ca26SJakub Kicinski 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
4186b46ca26SJakub Kicinski 		return -EINVAL;
4196b46ca26SJakub Kicinski 
4206b46ca26SJakub Kicinski 	version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
4216b46ca26SJakub Kicinski 	if (!(psd->caps->versions & (1 << version))) {
4226b46ca26SJakub Kicinski 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
4236b46ca26SJakub Kicinski 		return -EOPNOTSUPP;
4246b46ca26SJakub Kicinski 	}
4256b46ca26SJakub Kicinski 
4266b46ca26SJakub Kicinski 	rsp = psp_nl_reply_new(info);
4276b46ca26SJakub Kicinski 	if (!rsp)
4286b46ca26SJakub Kicinski 		return -ENOMEM;
4296b46ca26SJakub Kicinski 
4306b46ca26SJakub Kicinski 	pas = psp_assoc_create(psd);
4316b46ca26SJakub Kicinski 	if (!pas) {
4326b46ca26SJakub Kicinski 		err = -ENOMEM;
4336b46ca26SJakub Kicinski 		goto err_free_rsp;
4346b46ca26SJakub Kicinski 	}
4356b46ca26SJakub Kicinski 	pas->version = version;
4366b46ca26SJakub Kicinski 
4376b46ca26SJakub Kicinski 	err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
4386b46ca26SJakub Kicinski 	if (err)
4396b46ca26SJakub Kicinski 		goto err_free_pas;
4406b46ca26SJakub Kicinski 
4416b46ca26SJakub Kicinski 	if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
4426b46ca26SJakub Kicinski 	    psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
4436b46ca26SJakub Kicinski 		err = -EMSGSIZE;
4446b46ca26SJakub Kicinski 		goto err_free_pas;
4456b46ca26SJakub Kicinski 	}
4466b46ca26SJakub Kicinski 
4476b46ca26SJakub Kicinski 	err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
4486b46ca26SJakub Kicinski 	if (err) {
4496b46ca26SJakub Kicinski 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
4506b46ca26SJakub Kicinski 		goto err_free_pas;
4516b46ca26SJakub Kicinski 	}
4526b46ca26SJakub Kicinski 	psp_assoc_put(pas);
4536b46ca26SJakub Kicinski 
4546b46ca26SJakub Kicinski 	return psp_nl_reply_send(rsp, info);
4556b46ca26SJakub Kicinski 
4566b46ca26SJakub Kicinski err_free_pas:
4576b46ca26SJakub Kicinski 	psp_assoc_put(pas);
4586b46ca26SJakub Kicinski err_free_rsp:
4596b46ca26SJakub Kicinski 	nlmsg_free(rsp);
4606b46ca26SJakub Kicinski 	return err;
4616b46ca26SJakub Kicinski }
4626b46ca26SJakub Kicinski 
4636b46ca26SJakub Kicinski int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
4646b46ca26SJakub Kicinski {
4656b46ca26SJakub Kicinski 	struct socket *socket = info->user_ptr[1];
4666b46ca26SJakub Kicinski 	struct psp_dev *psd = info->user_ptr[0];
4676b46ca26SJakub Kicinski 	struct psp_key_parsed key;
4686b46ca26SJakub Kicinski 	struct sk_buff *rsp;
4696b46ca26SJakub Kicinski 	unsigned int key_sz;
4706b46ca26SJakub Kicinski 	u32 version;
4716b46ca26SJakub Kicinski 	int err;
4726b46ca26SJakub Kicinski 
4736b46ca26SJakub Kicinski 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
4746b46ca26SJakub Kicinski 	    GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
4756b46ca26SJakub Kicinski 		return -EINVAL;
4766b46ca26SJakub Kicinski 
4776b46ca26SJakub Kicinski 	version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
4786b46ca26SJakub Kicinski 	if (!(psd->caps->versions & (1 << version))) {
4796b46ca26SJakub Kicinski 		NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
4806b46ca26SJakub Kicinski 		return -EOPNOTSUPP;
4816b46ca26SJakub Kicinski 	}
4826b46ca26SJakub Kicinski 
4836b46ca26SJakub Kicinski 	key_sz = psp_key_size(version);
4846b46ca26SJakub Kicinski 	if (!key_sz)
4856b46ca26SJakub Kicinski 		return -EINVAL;
4866b46ca26SJakub Kicinski 
4876b46ca26SJakub Kicinski 	err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
4886b46ca26SJakub Kicinski 	if (err < 0)
4896b46ca26SJakub Kicinski 		return err;
4906b46ca26SJakub Kicinski 
4916b46ca26SJakub Kicinski 	rsp = psp_nl_reply_new(info);
4926b46ca26SJakub Kicinski 	if (!rsp)
4936b46ca26SJakub Kicinski 		return -ENOMEM;
4946b46ca26SJakub Kicinski 
4956b46ca26SJakub Kicinski 	err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
4966b46ca26SJakub Kicinski 				    info->extack);
4976b46ca26SJakub Kicinski 	if (err)
4986b46ca26SJakub Kicinski 		goto err_free_msg;
4996b46ca26SJakub Kicinski 
5006b46ca26SJakub Kicinski 	return psp_nl_reply_send(rsp, info);
5016b46ca26SJakub Kicinski 
5026b46ca26SJakub Kicinski err_free_msg:
5036b46ca26SJakub Kicinski 	nlmsg_free(rsp);
5046b46ca26SJakub Kicinski 	return err;
5056b46ca26SJakub Kicinski }
506