xref: /linux/net/sunrpc/auth_tls.c (revision 1b0975ee3bdd3eb19a47371c26fd7ef8f7f6b599)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2021, 2022 Oracle.  All rights reserved.
4  *
5  * The AUTH_TLS credential is used only to probe a remote peer
6  * for RPC-over-TLS support.
7  */
8 
9 #include <linux/types.h>
10 #include <linux/module.h>
11 #include <linux/sunrpc/clnt.h>
12 
13 static const char *starttls_token = "STARTTLS";
14 static const size_t starttls_len = 8;
15 
16 static struct rpc_auth tls_auth;
17 static struct rpc_cred tls_cred;
18 
19 static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
20 			     const void *obj)
21 {
22 }
23 
24 static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
25 			    void *obj)
26 {
27 	return 0;
28 }
29 
30 static const struct rpc_procinfo rpcproc_tls_probe = {
31 	.p_encode	= tls_encode_probe,
32 	.p_decode	= tls_decode_probe,
33 };
34 
35 static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
36 {
37 	task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
38 	rpc_call_start(task);
39 }
40 
41 static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
42 {
43 }
44 
45 static const struct rpc_call_ops rpc_tls_probe_ops = {
46 	.rpc_call_prepare	= rpc_tls_probe_call_prepare,
47 	.rpc_call_done		= rpc_tls_probe_call_done,
48 };
49 
50 static int tls_probe(struct rpc_clnt *clnt)
51 {
52 	struct rpc_message msg = {
53 		.rpc_proc	= &rpcproc_tls_probe,
54 	};
55 	struct rpc_task_setup task_setup_data = {
56 		.rpc_client	= clnt,
57 		.rpc_message	= &msg,
58 		.rpc_op_cred	= &tls_cred,
59 		.callback_ops	= &rpc_tls_probe_ops,
60 		.flags		= RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
61 	};
62 	struct rpc_task	*task;
63 	int status;
64 
65 	task = rpc_run_task(&task_setup_data);
66 	if (IS_ERR(task))
67 		return PTR_ERR(task);
68 	status = task->tk_status;
69 	rpc_put_task(task);
70 	return status;
71 }
72 
73 static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
74 				   struct rpc_clnt *clnt)
75 {
76 	refcount_inc(&tls_auth.au_count);
77 	return &tls_auth;
78 }
79 
80 static void tls_destroy(struct rpc_auth *auth)
81 {
82 }
83 
84 static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
85 					struct auth_cred *acred, int flags)
86 {
87 	return get_rpccred(&tls_cred);
88 }
89 
90 static void tls_destroy_cred(struct rpc_cred *cred)
91 {
92 }
93 
94 static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
95 {
96 	return 1;
97 }
98 
99 static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
100 {
101 	__be32 *p;
102 
103 	p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
104 	if (!p)
105 		return -EMSGSIZE;
106 	/* Credential */
107 	*p++ = rpc_auth_tls;
108 	*p++ = xdr_zero;
109 	/* Verifier */
110 	*p++ = rpc_auth_null;
111 	*p   = xdr_zero;
112 	return 0;
113 }
114 
115 static int tls_refresh(struct rpc_task *task)
116 {
117 	set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
118 	return 0;
119 }
120 
121 static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
122 {
123 	__be32 *p;
124 	void *str;
125 
126 	p = xdr_inline_decode(xdr, XDR_UNIT);
127 	if (!p)
128 		return -EIO;
129 	if (*p != rpc_auth_null)
130 		return -EIO;
131 	if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
132 		return -EIO;
133 	if (memcmp(str, starttls_token, starttls_len))
134 		return -EIO;
135 	return 0;
136 }
137 
138 const struct rpc_authops authtls_ops = {
139 	.owner		= THIS_MODULE,
140 	.au_flavor	= RPC_AUTH_TLS,
141 	.au_name	= "NULL",
142 	.create		= tls_create,
143 	.destroy	= tls_destroy,
144 	.lookup_cred	= tls_lookup_cred,
145 	.ping		= tls_probe,
146 };
147 
148 static struct rpc_auth tls_auth = {
149 	.au_cslack	= NUL_CALLSLACK,
150 	.au_rslack	= NUL_REPLYSLACK,
151 	.au_verfsize	= NUL_REPLYSLACK,
152 	.au_ralign	= NUL_REPLYSLACK,
153 	.au_ops		= &authtls_ops,
154 	.au_flavor	= RPC_AUTH_TLS,
155 	.au_count	= REFCOUNT_INIT(1),
156 };
157 
158 static const struct rpc_credops tls_credops = {
159 	.cr_name	= "AUTH_TLS",
160 	.crdestroy	= tls_destroy_cred,
161 	.crmatch	= tls_match,
162 	.crmarshal	= tls_marshal,
163 	.crwrap_req	= rpcauth_wrap_req_encode,
164 	.crrefresh	= tls_refresh,
165 	.crvalidate	= tls_validate,
166 	.crunwrap_resp	= rpcauth_unwrap_resp_decode,
167 };
168 
169 static struct rpc_cred tls_cred = {
170 	.cr_lru		= LIST_HEAD_INIT(tls_cred.cr_lru),
171 	.cr_auth	= &tls_auth,
172 	.cr_ops		= &tls_credops,
173 	.cr_count	= REFCOUNT_INIT(2),
174 	.cr_flags	= 1UL << RPCAUTH_CRED_UPTODATE,
175 };
176