1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Generic netlink handshake service
4 *
5 * Author: Chuck Lever <chuck.lever@oracle.com>
6 *
7 * Copyright (c) 2023, Oracle and/or its affiliates.
8 */
9
10 #include <linux/types.h>
11 #include <linux/socket.h>
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/skbuff.h>
15 #include <linux/mm.h>
16
17 #include <net/sock.h>
18 #include <net/genetlink.h>
19 #include <net/netns/generic.h>
20
21 #include <kunit/visibility.h>
22
23 #include <uapi/linux/handshake.h>
24 #include "handshake.h"
25 #include "genl.h"
26
27 #include <trace/events/handshake.h>
28
29 /**
30 * handshake_genl_notify - Notify handlers that a request is waiting
31 * @net: target network namespace
32 * @proto: handshake protocol
33 * @flags: memory allocation control flags
34 *
35 * Returns zero on success or a negative errno if notification failed.
36 */
handshake_genl_notify(struct net * net,const struct handshake_proto * proto,gfp_t flags)37 int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
38 gfp_t flags)
39 {
40 struct sk_buff *msg;
41 void *hdr;
42
43 /* Disable notifications during unit testing */
44 if (!test_bit(HANDSHAKE_F_PROTO_NOTIFY, &proto->hp_flags))
45 return 0;
46
47 if (!genl_has_listeners(&handshake_nl_family, net,
48 proto->hp_handler_class))
49 return -ESRCH;
50
51 msg = genlmsg_new(GENLMSG_DEFAULT_SIZE, flags);
52 if (!msg)
53 return -ENOMEM;
54
55 hdr = genlmsg_put(msg, 0, 0, &handshake_nl_family, 0,
56 HANDSHAKE_CMD_READY);
57 if (!hdr)
58 goto out_free;
59
60 if (nla_put_u32(msg, HANDSHAKE_A_ACCEPT_HANDLER_CLASS,
61 proto->hp_handler_class) < 0) {
62 genlmsg_cancel(msg, hdr);
63 goto out_free;
64 }
65
66 genlmsg_end(msg, hdr);
67 return genlmsg_multicast_netns(&handshake_nl_family, net, msg,
68 0, proto->hp_handler_class, flags);
69
70 out_free:
71 nlmsg_free(msg);
72 return -EMSGSIZE;
73 }
74
75 /**
76 * handshake_genl_put - Create a generic netlink message header
77 * @msg: buffer in which to create the header
78 * @info: generic netlink message context
79 *
80 * Returns a ready-to-use header, or NULL.
81 */
handshake_genl_put(struct sk_buff * msg,struct genl_info * info)82 struct nlmsghdr *handshake_genl_put(struct sk_buff *msg,
83 struct genl_info *info)
84 {
85 return genlmsg_put(msg, info->snd_portid, info->snd_seq,
86 &handshake_nl_family, 0, info->genlhdr->cmd);
87 }
88 EXPORT_SYMBOL(handshake_genl_put);
89
handshake_nl_accept_doit(struct sk_buff * skb,struct genl_info * info)90 int handshake_nl_accept_doit(struct sk_buff *skb, struct genl_info *info)
91 {
92 struct net *net = sock_net(skb->sk);
93 struct handshake_net *hn = handshake_pernet(net);
94 struct handshake_req *req = NULL;
95 int class, err;
96
97 err = -EOPNOTSUPP;
98 if (!hn)
99 goto out_status;
100
101 err = -EINVAL;
102 if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_ACCEPT_HANDLER_CLASS))
103 goto out_status;
104 class = nla_get_u32(info->attrs[HANDSHAKE_A_ACCEPT_HANDLER_CLASS]);
105
106 err = -EAGAIN;
107 req = handshake_req_next(hn, class);
108 if (req) {
109 FD_PREPARE(fdf, O_CLOEXEC, req->hr_file);
110 if (fdf.err) {
111 fput(req->hr_file); /* drop ref from handshake_req_next() */
112 err = fdf.err;
113 goto out_complete;
114 }
115
116 err = req->hr_proto->hp_accept(req, info, fd_prepare_fd(fdf));
117 if (err)
118 goto out_complete; /* Automatic cleanup handles fput */
119
120 trace_handshake_cmd_accept(net, req, req->hr_sk, fd_prepare_fd(fdf));
121 fd_publish(fdf);
122 return 0;
123 }
124
125 out_complete:
126 if (req)
127 handshake_complete(req, -EIO, NULL);
128 out_status:
129 trace_handshake_cmd_accept_err(net, req, NULL, err);
130 return err;
131 }
132
handshake_nl_done_doit(struct sk_buff * skb,struct genl_info * info)133 int handshake_nl_done_doit(struct sk_buff *skb, struct genl_info *info)
134 {
135 struct net *net = sock_net(skb->sk);
136 struct handshake_req *req;
137 struct socket *sock;
138 int fd, status, err;
139
140 if (GENL_REQ_ATTR_CHECK(info, HANDSHAKE_A_DONE_SOCKFD))
141 return -EINVAL;
142 fd = nla_get_s32(info->attrs[HANDSHAKE_A_DONE_SOCKFD]);
143
144 sock = sockfd_lookup(fd, &err);
145 if (!sock)
146 return err;
147
148 req = handshake_req_hash_lookup(sock->sk);
149 if (!req) {
150 err = -EBUSY;
151 trace_handshake_cmd_done_err(net, req, sock->sk, err);
152 sockfd_put(sock);
153 return err;
154 }
155
156 trace_handshake_cmd_done(net, req, sock->sk, fd);
157
158 status = -EIO;
159 if (info->attrs[HANDSHAKE_A_DONE_STATUS])
160 status = -(int)nla_get_u32(info->attrs[HANDSHAKE_A_DONE_STATUS]);
161
162 handshake_complete(req, status, info);
163 sockfd_put(sock);
164 return 0;
165 }
166
167 static unsigned int handshake_net_id;
168
handshake_net_init(struct net * net)169 static int __net_init handshake_net_init(struct net *net)
170 {
171 struct handshake_net *hn = net_generic(net, handshake_net_id);
172 unsigned long tmp;
173 struct sysinfo si;
174
175 /*
176 * Arbitrary limit to prevent handshakes that do not make
177 * progress from clogging up the system. The cap scales up
178 * with the amount of physical memory on the system.
179 */
180 si_meminfo(&si);
181 tmp = si.totalram / (25 * si.mem_unit);
182 hn->hn_pending_max = clamp(tmp, 3UL, 50UL);
183
184 spin_lock_init(&hn->hn_lock);
185 hn->hn_pending = 0;
186 hn->hn_flags = 0;
187 INIT_LIST_HEAD(&hn->hn_requests);
188 return 0;
189 }
190
handshake_net_exit(struct net * net)191 static void __net_exit handshake_net_exit(struct net *net)
192 {
193 struct handshake_net *hn = net_generic(net, handshake_net_id);
194 struct handshake_req *req;
195 LIST_HEAD(requests);
196
197 /*
198 * Drain the net's pending list. Requests that have been
199 * accepted and are in progress will be destroyed when
200 * the socket is closed.
201 */
202 spin_lock_bh(&hn->hn_lock);
203 set_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags);
204 list_splice_init(&hn->hn_requests, &requests);
205 list_for_each_entry(req, &requests, hr_list)
206 get_file(req->hr_file);
207 spin_unlock_bh(&hn->hn_lock);
208
209 while (!list_empty(&requests)) {
210 struct file *file;
211
212 req = list_first_entry(&requests, struct handshake_req, hr_list);
213 file = req->hr_file;
214 list_del_init(&req->hr_list);
215 handshake_complete(req, -ETIMEDOUT, NULL);
216 fput(file);
217 }
218 }
219
220 static struct pernet_operations handshake_genl_net_ops = {
221 .init = handshake_net_init,
222 .exit = handshake_net_exit,
223 .id = &handshake_net_id,
224 .size = sizeof(struct handshake_net),
225 };
226
227 /**
228 * handshake_pernet - Get the handshake private per-net structure
229 * @net: network namespace
230 *
231 * Returns a pointer to the net's private per-net structure for the
232 * handshake module, or NULL if handshake_init() failed.
233 */
handshake_pernet(struct net * net)234 struct handshake_net *handshake_pernet(struct net *net)
235 {
236 return handshake_net_id ?
237 net_generic(net, handshake_net_id) : NULL;
238 }
239 EXPORT_SYMBOL_IF_KUNIT(handshake_pernet);
240
handshake_init(void)241 static int __init handshake_init(void)
242 {
243 int ret;
244
245 ret = handshake_req_hash_init();
246 if (ret) {
247 pr_warn("handshake: hash initialization failed (%d)\n", ret);
248 return ret;
249 }
250
251 ret = genl_register_family(&handshake_nl_family);
252 if (ret) {
253 pr_warn("handshake: netlink registration failed (%d)\n", ret);
254 handshake_req_hash_destroy();
255 return ret;
256 }
257
258 /*
259 * ORDER: register_pernet_subsys must be done last.
260 *
261 * If initialization does not make it past pernet_subsys
262 * registration, then handshake_net_id will remain 0. That
263 * shunts the handshake consumer API to return ENOTSUPP
264 * to prevent it from dereferencing something that hasn't
265 * been allocated.
266 */
267 ret = register_pernet_subsys(&handshake_genl_net_ops);
268 if (ret) {
269 pr_warn("handshake: pernet registration failed (%d)\n", ret);
270 genl_unregister_family(&handshake_nl_family);
271 handshake_req_hash_destroy();
272 }
273
274 return ret;
275 }
276
handshake_exit(void)277 static void __exit handshake_exit(void)
278 {
279 unregister_pernet_subsys(&handshake_genl_net_ops);
280 handshake_net_id = 0;
281
282 handshake_req_hash_destroy();
283 genl_unregister_family(&handshake_nl_family);
284 }
285
286 module_init(handshake_init);
287 module_exit(handshake_exit);
288