1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (c) 2021, 2022 Oracle. All rights reserved. 4 * 5 * The AUTH_TLS credential is used only to probe a remote peer 6 * for RPC-over-TLS support. 7 */ 8 9 #include <linux/types.h> 10 #include <linux/module.h> 11 #include <linux/sunrpc/clnt.h> 12 13 static const char *starttls_token = "STARTTLS"; 14 static const size_t starttls_len = 8; 15 16 static struct rpc_auth tls_auth; 17 static struct rpc_cred tls_cred; 18 19 static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, 20 const void *obj) 21 { 22 } 23 24 static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, 25 void *obj) 26 { 27 return 0; 28 } 29 30 static const struct rpc_procinfo rpcproc_tls_probe = { 31 .p_encode = tls_encode_probe, 32 .p_decode = tls_decode_probe, 33 }; 34 35 static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data) 36 { 37 task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; 38 rpc_call_start(task); 39 } 40 41 static void rpc_tls_probe_call_done(struct rpc_task *task, void *data) 42 { 43 } 44 45 static const struct rpc_call_ops rpc_tls_probe_ops = { 46 .rpc_call_prepare = rpc_tls_probe_call_prepare, 47 .rpc_call_done = rpc_tls_probe_call_done, 48 }; 49 50 static int tls_probe(struct rpc_clnt *clnt) 51 { 52 struct rpc_message msg = { 53 .rpc_proc = &rpcproc_tls_probe, 54 }; 55 struct rpc_task_setup task_setup_data = { 56 .rpc_client = clnt, 57 .rpc_message = &msg, 58 .rpc_op_cred = &tls_cred, 59 .callback_ops = &rpc_tls_probe_ops, 60 .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN, 61 }; 62 struct rpc_task *task; 63 int status; 64 65 task = rpc_run_task(&task_setup_data); 66 if (IS_ERR(task)) 67 return PTR_ERR(task); 68 status = task->tk_status; 69 rpc_put_task(task); 70 return status; 71 } 72 73 static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args, 74 struct rpc_clnt *clnt) 75 { 76 refcount_inc(&tls_auth.au_count); 77 return &tls_auth; 78 } 79 80 static void tls_destroy(struct rpc_auth *auth) 81 { 82 } 83 84 static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth, 85 struct auth_cred *acred, int flags) 86 { 87 return get_rpccred(&tls_cred); 88 } 89 90 static void tls_destroy_cred(struct rpc_cred *cred) 91 { 92 } 93 94 static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) 95 { 96 return 1; 97 } 98 99 static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr) 100 { 101 __be32 *p; 102 103 p = xdr_reserve_space(xdr, 4 * XDR_UNIT); 104 if (!p) 105 return -EMSGSIZE; 106 /* Credential */ 107 *p++ = rpc_auth_tls; 108 *p++ = xdr_zero; 109 /* Verifier */ 110 *p++ = rpc_auth_null; 111 *p = xdr_zero; 112 return 0; 113 } 114 115 static int tls_refresh(struct rpc_task *task) 116 { 117 set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); 118 return 0; 119 } 120 121 static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr) 122 { 123 __be32 *p; 124 void *str; 125 126 p = xdr_inline_decode(xdr, XDR_UNIT); 127 if (!p) 128 return -EIO; 129 if (*p != rpc_auth_null) 130 return -EIO; 131 if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len) 132 return -EIO; 133 if (memcmp(str, starttls_token, starttls_len)) 134 return -EIO; 135 return 0; 136 } 137 138 const struct rpc_authops authtls_ops = { 139 .owner = THIS_MODULE, 140 .au_flavor = RPC_AUTH_TLS, 141 .au_name = "NULL", 142 .create = tls_create, 143 .destroy = tls_destroy, 144 .lookup_cred = tls_lookup_cred, 145 .ping = tls_probe, 146 }; 147 148 static struct rpc_auth tls_auth = { 149 .au_cslack = NUL_CALLSLACK, 150 .au_rslack = NUL_REPLYSLACK, 151 .au_verfsize = NUL_REPLYSLACK, 152 .au_ralign = NUL_REPLYSLACK, 153 .au_ops = &authtls_ops, 154 .au_flavor = RPC_AUTH_TLS, 155 .au_count = REFCOUNT_INIT(1), 156 }; 157 158 static const struct rpc_credops tls_credops = { 159 .cr_name = "AUTH_TLS", 160 .crdestroy = tls_destroy_cred, 161 .crmatch = tls_match, 162 .crmarshal = tls_marshal, 163 .crwrap_req = rpcauth_wrap_req_encode, 164 .crrefresh = tls_refresh, 165 .crvalidate = tls_validate, 166 .crunwrap_resp = rpcauth_unwrap_resp_decode, 167 }; 168 169 static struct rpc_cred tls_cred = { 170 .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru), 171 .cr_auth = &tls_auth, 172 .cr_ops = &tls_credops, 173 .cr_count = REFCOUNT_INIT(2), 174 .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, 175 }; 176