xref: /linux/net/handshake/request.c (revision 364eeb79a213fcf9164208b53764223ad522d6b3)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Handshake request lifetime events
4  *
5  * Author: Chuck Lever <chuck.lever@oracle.com>
6  *
7  * Copyright (c) 2023, Oracle and/or its affiliates.
8  */
9 
10 #include <linux/types.h>
11 #include <linux/socket.h>
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/skbuff.h>
15 #include <linux/inet.h>
16 #include <linux/rhashtable.h>
17 
18 #include <net/sock.h>
19 #include <net/genetlink.h>
20 #include <net/netns/generic.h>
21 
22 #include <kunit/visibility.h>
23 
24 #include <uapi/linux/handshake.h>
25 #include "handshake.h"
26 
27 #include <trace/events/handshake.h>
28 
29 /*
30  * We need both a handshake_req -> sock mapping, and a sock ->
31  * handshake_req mapping. Both are one-to-one.
32  *
33  * To avoid adding another pointer field to struct sock, net/handshake
34  * maintains a hash table, indexed by the memory address of @sock, to
35  * find the struct handshake_req outstanding for that socket. The
36  * reverse direction uses a simple pointer field in the handshake_req
37  * struct.
38  */
39 
40 static struct rhashtable handshake_rhashtbl ____cacheline_aligned_in_smp;
41 
42 static const struct rhashtable_params handshake_rhash_params = {
43 	.key_len		= sizeof_field(struct handshake_req, hr_sk),
44 	.key_offset		= offsetof(struct handshake_req, hr_sk),
45 	.head_offset		= offsetof(struct handshake_req, hr_rhash),
46 	.automatic_shrinking	= true,
47 };
48 
49 int handshake_req_hash_init(void)
50 {
51 	return rhashtable_init(&handshake_rhashtbl, &handshake_rhash_params);
52 }
53 
54 void handshake_req_hash_destroy(void)
55 {
56 	rhashtable_destroy(&handshake_rhashtbl);
57 }
58 
59 struct handshake_req *handshake_req_hash_lookup(struct sock *sk)
60 {
61 	return rhashtable_lookup_fast(&handshake_rhashtbl, &sk,
62 				      handshake_rhash_params);
63 }
64 EXPORT_SYMBOL_IF_KUNIT(handshake_req_hash_lookup);
65 
66 static bool handshake_req_hash_add(struct handshake_req *req)
67 {
68 	int ret;
69 
70 	ret = rhashtable_lookup_insert_fast(&handshake_rhashtbl,
71 					    &req->hr_rhash,
72 					    handshake_rhash_params);
73 	return ret == 0;
74 }
75 
76 static void handshake_req_destroy(struct handshake_req *req)
77 {
78 	if (req->hr_proto->hp_destroy)
79 		req->hr_proto->hp_destroy(req);
80 	rhashtable_remove_fast(&handshake_rhashtbl, &req->hr_rhash,
81 			       handshake_rhash_params);
82 	kfree(req);
83 }
84 
85 static void handshake_sk_destruct(struct sock *sk)
86 {
87 	void (*sk_destruct)(struct sock *sk);
88 	struct handshake_req *req;
89 
90 	req = handshake_req_hash_lookup(sk);
91 	if (!req)
92 		return;
93 
94 	trace_handshake_destruct(sock_net(sk), req, sk);
95 	sk_destruct = req->hr_odestruct;
96 	handshake_req_destroy(req);
97 	if (sk_destruct)
98 		sk_destruct(sk);
99 }
100 
101 /**
102  * handshake_req_alloc - Allocate a handshake request
103  * @proto: security protocol
104  * @flags: memory allocation flags
105  *
106  * Returns an initialized handshake_req or NULL.
107  */
108 struct handshake_req *handshake_req_alloc(const struct handshake_proto *proto,
109 					  gfp_t flags)
110 {
111 	struct handshake_req *req;
112 
113 	if (!proto)
114 		return NULL;
115 	if (proto->hp_handler_class <= HANDSHAKE_HANDLER_CLASS_NONE)
116 		return NULL;
117 	if (proto->hp_handler_class >= HANDSHAKE_HANDLER_CLASS_MAX)
118 		return NULL;
119 	if (!proto->hp_accept || !proto->hp_done)
120 		return NULL;
121 
122 	req = kzalloc(struct_size(req, hr_priv, proto->hp_privsize), flags);
123 	if (!req)
124 		return NULL;
125 
126 	INIT_LIST_HEAD(&req->hr_list);
127 	req->hr_proto = proto;
128 	return req;
129 }
130 EXPORT_SYMBOL(handshake_req_alloc);
131 
132 /**
133  * handshake_req_private - Get per-handshake private data
134  * @req: handshake arguments
135  *
136  */
137 void *handshake_req_private(struct handshake_req *req)
138 {
139 	return (void *)&req->hr_priv;
140 }
141 EXPORT_SYMBOL(handshake_req_private);
142 
143 static bool __add_pending_locked(struct handshake_net *hn,
144 				 struct handshake_req *req)
145 {
146 	if (WARN_ON_ONCE(!list_empty(&req->hr_list)))
147 		return false;
148 	hn->hn_pending++;
149 	list_add_tail(&req->hr_list, &hn->hn_requests);
150 	return true;
151 }
152 
153 static void __remove_pending_locked(struct handshake_net *hn,
154 				    struct handshake_req *req)
155 {
156 	hn->hn_pending--;
157 	list_del_init(&req->hr_list);
158 }
159 
160 /*
161  * Returns %true if the request was found on @net's pending list,
162  * otherwise %false.
163  *
164  * If @req was on a pending list, it has not yet been accepted.
165  */
166 static bool remove_pending(struct handshake_net *hn, struct handshake_req *req)
167 {
168 	bool ret = false;
169 
170 	spin_lock(&hn->hn_lock);
171 	if (!list_empty(&req->hr_list)) {
172 		__remove_pending_locked(hn, req);
173 		ret = true;
174 	}
175 	spin_unlock(&hn->hn_lock);
176 
177 	return ret;
178 }
179 
180 struct handshake_req *handshake_req_next(struct handshake_net *hn, int class)
181 {
182 	struct handshake_req *req, *pos;
183 
184 	req = NULL;
185 	spin_lock(&hn->hn_lock);
186 	list_for_each_entry(pos, &hn->hn_requests, hr_list) {
187 		if (pos->hr_proto->hp_handler_class != class)
188 			continue;
189 		__remove_pending_locked(hn, pos);
190 		req = pos;
191 		break;
192 	}
193 	spin_unlock(&hn->hn_lock);
194 
195 	return req;
196 }
197 EXPORT_SYMBOL_IF_KUNIT(handshake_req_next);
198 
199 /**
200  * handshake_req_submit - Submit a handshake request
201  * @sock: open socket on which to perform the handshake
202  * @req: handshake arguments
203  * @flags: memory allocation flags
204  *
205  * Return values:
206  *   %0: Request queued
207  *   %-EINVAL: Invalid argument
208  *   %-EBUSY: A handshake is already under way for this socket
209  *   %-ESRCH: No handshake agent is available
210  *   %-EAGAIN: Too many pending handshake requests
211  *   %-ENOMEM: Failed to allocate memory
212  *   %-EMSGSIZE: Failed to construct notification message
213  *   %-EOPNOTSUPP: Handshake module not initialized
214  *
215  * A zero return value from handshake_req_submit() means that
216  * exactly one subsequent completion callback is guaranteed.
217  *
218  * A negative return value from handshake_req_submit() means that
219  * no completion callback will be done and that @req has been
220  * destroyed.
221  */
222 int handshake_req_submit(struct socket *sock, struct handshake_req *req,
223 			 gfp_t flags)
224 {
225 	struct handshake_net *hn;
226 	struct net *net;
227 	int ret;
228 
229 	if (!sock || !req || !sock->file) {
230 		kfree(req);
231 		return -EINVAL;
232 	}
233 
234 	req->hr_sk = sock->sk;
235 	if (!req->hr_sk) {
236 		kfree(req);
237 		return -EINVAL;
238 	}
239 	req->hr_odestruct = req->hr_sk->sk_destruct;
240 	req->hr_sk->sk_destruct = handshake_sk_destruct;
241 
242 	ret = -EOPNOTSUPP;
243 	net = sock_net(req->hr_sk);
244 	hn = handshake_pernet(net);
245 	if (!hn)
246 		goto out_err;
247 
248 	ret = -EAGAIN;
249 	if (READ_ONCE(hn->hn_pending) >= hn->hn_pending_max)
250 		goto out_err;
251 
252 	spin_lock(&hn->hn_lock);
253 	ret = -EOPNOTSUPP;
254 	if (test_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags))
255 		goto out_unlock;
256 	ret = -EBUSY;
257 	if (!handshake_req_hash_add(req))
258 		goto out_unlock;
259 	if (!__add_pending_locked(hn, req))
260 		goto out_unlock;
261 	spin_unlock(&hn->hn_lock);
262 
263 	ret = handshake_genl_notify(net, req->hr_proto, flags);
264 	if (ret) {
265 		trace_handshake_notify_err(net, req, req->hr_sk, ret);
266 		if (remove_pending(hn, req))
267 			goto out_err;
268 	}
269 
270 	/* Prevent socket release while a handshake request is pending */
271 	sock_hold(req->hr_sk);
272 
273 	trace_handshake_submit(net, req, req->hr_sk);
274 	return 0;
275 
276 out_unlock:
277 	spin_unlock(&hn->hn_lock);
278 out_err:
279 	trace_handshake_submit_err(net, req, req->hr_sk, ret);
280 	handshake_req_destroy(req);
281 	return ret;
282 }
283 EXPORT_SYMBOL(handshake_req_submit);
284 
285 void handshake_complete(struct handshake_req *req, unsigned int status,
286 			struct genl_info *info)
287 {
288 	struct sock *sk = req->hr_sk;
289 	struct net *net = sock_net(sk);
290 
291 	if (!test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
292 		trace_handshake_complete(net, req, sk, status);
293 		req->hr_proto->hp_done(req, status, info);
294 
295 		/* Handshake request is no longer pending */
296 		sock_put(sk);
297 	}
298 }
299 EXPORT_SYMBOL_IF_KUNIT(handshake_complete);
300 
301 /**
302  * handshake_req_cancel - Cancel an in-progress handshake
303  * @sk: socket on which there is an ongoing handshake
304  *
305  * Request cancellation races with request completion. To determine
306  * who won, callers examine the return value from this function.
307  *
308  * Return values:
309  *   %true - Uncompleted handshake request was canceled
310  *   %false - Handshake request already completed or not found
311  */
312 bool handshake_req_cancel(struct sock *sk)
313 {
314 	struct handshake_req *req;
315 	struct handshake_net *hn;
316 	struct net *net;
317 
318 	net = sock_net(sk);
319 	req = handshake_req_hash_lookup(sk);
320 	if (!req) {
321 		trace_handshake_cancel_none(net, req, sk);
322 		return false;
323 	}
324 
325 	hn = handshake_pernet(net);
326 	if (hn && remove_pending(hn, req)) {
327 		/* Request hadn't been accepted */
328 		goto out_true;
329 	}
330 	if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) {
331 		/* Request already completed */
332 		trace_handshake_cancel_busy(net, req, sk);
333 		return false;
334 	}
335 
336 out_true:
337 	trace_handshake_cancel(net, req, sk);
338 
339 	/* Handshake request is no longer pending */
340 	sock_put(sk);
341 	return true;
342 }
343 EXPORT_SYMBOL(handshake_req_cancel);
344