xref: /linux/net/psp/psp_nl.c (revision 117f02a49b7719b210d154a0d0e728001bf4af06)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/skbuff.h>
4 #include <linux/xarray.h>
5 #include <net/genetlink.h>
6 #include <net/psp.h>
7 #include <net/sock.h>
8 
9 #include "psp-nl-gen.h"
10 #include "psp.h"
11 
12 /* Netlink helpers */
13 
14 static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
15 {
16 	struct sk_buff *rsp;
17 	void *hdr;
18 
19 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
20 	if (!rsp)
21 		return NULL;
22 
23 	hdr = genlmsg_iput(rsp, info);
24 	if (!hdr) {
25 		nlmsg_free(rsp);
26 		return NULL;
27 	}
28 
29 	return rsp;
30 }
31 
32 static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
33 {
34 	/* Note that this *only* works with a single message per skb! */
35 	nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
36 
37 	return genlmsg_reply(rsp, info);
38 }
39 
40 /* Device stuff */
41 
42 static struct psp_dev *
43 psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
44 {
45 	struct psp_dev *psd;
46 	int err;
47 
48 	mutex_lock(&psp_devs_lock);
49 	psd = xa_load(&psp_devs, nla_get_u32(dev_id));
50 	if (!psd) {
51 		mutex_unlock(&psp_devs_lock);
52 		return ERR_PTR(-ENODEV);
53 	}
54 
55 	mutex_lock(&psd->lock);
56 	mutex_unlock(&psp_devs_lock);
57 
58 	err = psp_dev_check_access(psd, net);
59 	if (err) {
60 		mutex_unlock(&psd->lock);
61 		return ERR_PTR(err);
62 	}
63 
64 	return psd;
65 }
66 
67 int psp_device_get_locked(const struct genl_split_ops *ops,
68 			  struct sk_buff *skb, struct genl_info *info)
69 {
70 	if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
71 		return -EINVAL;
72 
73 	info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
74 						    info->attrs[PSP_A_DEV_ID]);
75 	return PTR_ERR_OR_ZERO(info->user_ptr[0]);
76 }
77 
78 void
79 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
80 		  struct genl_info *info)
81 {
82 	struct psp_dev *psd = info->user_ptr[0];
83 
84 	mutex_unlock(&psd->lock);
85 }
86 
87 static int
88 psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
89 		const struct genl_info *info)
90 {
91 	void *hdr;
92 
93 	hdr = genlmsg_iput(rsp, info);
94 	if (!hdr)
95 		return -EMSGSIZE;
96 
97 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
98 	    nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
99 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
100 	    nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
101 		goto err_cancel_msg;
102 
103 	genlmsg_end(rsp, hdr);
104 	return 0;
105 
106 err_cancel_msg:
107 	genlmsg_cancel(rsp, hdr);
108 	return -EMSGSIZE;
109 }
110 
111 void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
112 {
113 	struct genl_info info;
114 	struct sk_buff *ntf;
115 
116 	if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
117 				PSP_NLGRP_MGMT))
118 		return;
119 
120 	ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
121 	if (!ntf)
122 		return;
123 
124 	genl_info_init_ntf(&info, &psp_nl_family, cmd);
125 	if (psp_nl_dev_fill(psd, ntf, &info)) {
126 		nlmsg_free(ntf);
127 		return;
128 	}
129 
130 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
131 				0, PSP_NLGRP_MGMT, GFP_KERNEL);
132 }
133 
134 int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
135 {
136 	struct psp_dev *psd = info->user_ptr[0];
137 	struct sk_buff *rsp;
138 	int err;
139 
140 	rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
141 	if (!rsp)
142 		return -ENOMEM;
143 
144 	err = psp_nl_dev_fill(psd, rsp, info);
145 	if (err)
146 		goto err_free_msg;
147 
148 	return genlmsg_reply(rsp, info);
149 
150 err_free_msg:
151 	nlmsg_free(rsp);
152 	return err;
153 }
154 
155 static int
156 psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
157 			  struct psp_dev *psd)
158 {
159 	if (psp_dev_check_access(psd, sock_net(rsp->sk)))
160 		return 0;
161 
162 	return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
163 }
164 
165 int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
166 {
167 	struct psp_dev *psd;
168 	int err = 0;
169 
170 	mutex_lock(&psp_devs_lock);
171 	xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
172 		mutex_lock(&psd->lock);
173 		err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
174 		mutex_unlock(&psd->lock);
175 		if (err)
176 			break;
177 	}
178 	mutex_unlock(&psp_devs_lock);
179 
180 	return err;
181 }
182 
183 int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
184 {
185 	struct psp_dev *psd = info->user_ptr[0];
186 	struct psp_dev_config new_config;
187 	struct sk_buff *rsp;
188 	int err;
189 
190 	memcpy(&new_config, &psd->config, sizeof(new_config));
191 
192 	if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
193 		new_config.versions =
194 			nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
195 		if (new_config.versions & ~psd->caps->versions) {
196 			NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
197 			return -EINVAL;
198 		}
199 	} else {
200 		NL_SET_ERR_MSG(info->extack, "No settings present");
201 		return -EINVAL;
202 	}
203 
204 	rsp = psp_nl_reply_new(info);
205 	if (!rsp)
206 		return -ENOMEM;
207 
208 	if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
209 		err = psd->ops->set_config(psd, &new_config, info->extack);
210 		if (err)
211 			goto err_free_rsp;
212 
213 		memcpy(&psd->config, &new_config, sizeof(new_config));
214 	}
215 
216 	psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
217 
218 	return psp_nl_reply_send(rsp, info);
219 
220 err_free_rsp:
221 	nlmsg_free(rsp);
222 	return err;
223 }
224 
225 int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
226 {
227 	struct psp_dev *psd = info->user_ptr[0];
228 	struct genl_info ntf_info;
229 	struct sk_buff *ntf, *rsp;
230 	int err;
231 
232 	rsp = psp_nl_reply_new(info);
233 	if (!rsp)
234 		return -ENOMEM;
235 
236 	genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
237 	ntf = psp_nl_reply_new(&ntf_info);
238 	if (!ntf) {
239 		err = -ENOMEM;
240 		goto err_free_rsp;
241 	}
242 
243 	if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
244 	    nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
245 		err = -EMSGSIZE;
246 		goto err_free_ntf;
247 	}
248 
249 	err = psd->ops->key_rotate(psd, info->extack);
250 	if (err)
251 		goto err_free_ntf;
252 
253 	nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
254 	genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
255 				0, PSP_NLGRP_USE, GFP_KERNEL);
256 	return psp_nl_reply_send(rsp, info);
257 
258 err_free_ntf:
259 	nlmsg_free(ntf);
260 err_free_rsp:
261 	nlmsg_free(rsp);
262 	return err;
263 }
264