xref: /linux/fs/smb/server/auth.c (revision 954d196bebb2b50151cb96454c72dc113b2af1ac)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org>
4  *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
5  */
6 
7 #include <linux/kernel.h>
8 #include <linux/fs.h>
9 #include <linux/uaccess.h>
10 #include <linux/backing-dev.h>
11 #include <linux/writeback.h>
12 #include <linux/uio.h>
13 #include <linux/xattr.h>
14 #include <crypto/aead.h>
15 #include <crypto/aes-cbc-macs.h>
16 #include <crypto/md5.h>
17 #include <crypto/sha2.h>
18 #include <crypto/utils.h>
19 #include <linux/random.h>
20 #include <linux/scatterlist.h>
21 
22 #include "auth.h"
23 #include "glob.h"
24 
25 #include <linux/fips.h>
26 #include <crypto/arc4.h>
27 #include <crypto/des.h>
28 
29 #include "server.h"
30 #include "smb_common.h"
31 #include "connection.h"
32 #include "mgmt/user_session.h"
33 #include "mgmt/user_config.h"
34 #include "crypto_ctx.h"
35 #include "transport_ipc.h"
36 
37 /*
38  * Fixed format data defining GSS header and fixed string
39  * "not_defined_in_RFC4178@please_ignore".
40  * So sec blob data in neg phase could be generated statically.
41  */
42 static char NEGOTIATE_GSS_HEADER[AUTH_GSS_LENGTH] = {
43 #ifdef CONFIG_SMB_SERVER_KERBEROS5
44 	0x60, 0x5e, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05,
45 	0x05, 0x02, 0xa0, 0x54, 0x30, 0x52, 0xa0, 0x24,
46 	0x30, 0x22, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86,
47 	0xf7, 0x12, 0x01, 0x02, 0x02, 0x06, 0x09, 0x2a,
48 	0x86, 0x48, 0x82, 0xf7, 0x12, 0x01, 0x02, 0x02,
49 	0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82,
50 	0x37, 0x02, 0x02, 0x0a, 0xa3, 0x2a, 0x30, 0x28,
51 	0xa0, 0x26, 0x1b, 0x24, 0x6e, 0x6f, 0x74, 0x5f,
52 	0x64, 0x65, 0x66, 0x69, 0x6e, 0x65, 0x64, 0x5f,
53 	0x69, 0x6e, 0x5f, 0x52, 0x46, 0x43, 0x34, 0x31,
54 	0x37, 0x38, 0x40, 0x70, 0x6c, 0x65, 0x61, 0x73,
55 	0x65, 0x5f, 0x69, 0x67, 0x6e, 0x6f, 0x72, 0x65
56 #else
57 	0x60, 0x48, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05,
58 	0x05, 0x02, 0xa0, 0x3e, 0x30, 0x3c, 0xa0, 0x0e,
59 	0x30, 0x0c, 0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04,
60 	0x01, 0x82, 0x37, 0x02, 0x02, 0x0a, 0xa3, 0x2a,
61 	0x30, 0x28, 0xa0, 0x26, 0x1b, 0x24, 0x6e, 0x6f,
62 	0x74, 0x5f, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x65,
63 	0x64, 0x5f, 0x69, 0x6e, 0x5f, 0x52, 0x46, 0x43,
64 	0x34, 0x31, 0x37, 0x38, 0x40, 0x70, 0x6c, 0x65,
65 	0x61, 0x73, 0x65, 0x5f, 0x69, 0x67, 0x6e, 0x6f,
66 	0x72, 0x65
67 #endif
68 };
69 
70 void ksmbd_copy_gss_neg_header(void *buf)
71 {
72 	memcpy(buf, NEGOTIATE_GSS_HEADER, AUTH_GSS_LENGTH);
73 }
74 
75 static int calc_ntlmv2_hash(struct ksmbd_conn *conn, struct ksmbd_session *sess,
76 			    char *ntlmv2_hash, char *dname)
77 {
78 	int ret, len, conv_len;
79 	wchar_t *domain = NULL;
80 	__le16 *uniname = NULL;
81 	struct hmac_md5_ctx ctx;
82 
83 	hmac_md5_init_usingrawkey(&ctx, user_passkey(sess->user),
84 				  CIFS_ENCPWD_SIZE);
85 
86 	/* convert user_name to unicode */
87 	len = strlen(user_name(sess->user));
88 	uniname = kzalloc(2 + UNICODE_LEN(len), KSMBD_DEFAULT_GFP);
89 	if (!uniname) {
90 		ret = -ENOMEM;
91 		goto out;
92 	}
93 
94 	conv_len = smb_strtoUTF16(uniname, user_name(sess->user), len,
95 				  conn->local_nls);
96 	if (conv_len < 0 || conv_len > len) {
97 		ret = -EINVAL;
98 		goto out;
99 	}
100 	UniStrupr(uniname);
101 
102 	hmac_md5_update(&ctx, (const u8 *)uniname, UNICODE_LEN(conv_len));
103 
104 	/* Convert domain name or conn name to unicode and uppercase */
105 	len = strlen(dname);
106 	domain = kzalloc(2 + UNICODE_LEN(len), KSMBD_DEFAULT_GFP);
107 	if (!domain) {
108 		ret = -ENOMEM;
109 		goto out;
110 	}
111 
112 	conv_len = smb_strtoUTF16((__le16 *)domain, dname, len,
113 				  conn->local_nls);
114 	if (conv_len < 0 || conv_len > len) {
115 		ret = -EINVAL;
116 		goto out;
117 	}
118 
119 	hmac_md5_update(&ctx, (const u8 *)domain, UNICODE_LEN(conv_len));
120 	hmac_md5_final(&ctx, ntlmv2_hash);
121 	ret = 0;
122 out:
123 	kfree(uniname);
124 	kfree(domain);
125 	return ret;
126 }
127 
128 /**
129  * ksmbd_auth_ntlmv2() - NTLMv2 authentication handler
130  * @conn:		connection
131  * @sess:		session of connection
132  * @ntlmv2:		NTLMv2 challenge response
133  * @blen:		NTLMv2 blob length
134  * @domain_name:	domain name
135  * @cryptkey:		session crypto key
136  *
137  * Return:	0 on success, error number on error
138  */
139 int ksmbd_auth_ntlmv2(struct ksmbd_conn *conn, struct ksmbd_session *sess,
140 		      struct ntlmv2_resp *ntlmv2, int blen, char *domain_name,
141 		      char *cryptkey)
142 {
143 	char ntlmv2_hash[CIFS_ENCPWD_SIZE];
144 	char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE];
145 	char sess_key[SMB2_NTLMV2_SESSKEY_SIZE];
146 	struct hmac_md5_ctx ctx;
147 	int rc;
148 
149 	if (fips_enabled) {
150 		ksmbd_debug(AUTH, "NTLMv2 support is disabled due to FIPS\n");
151 		return -EOPNOTSUPP;
152 	}
153 
154 	rc = calc_ntlmv2_hash(conn, sess, ntlmv2_hash, domain_name);
155 	if (rc) {
156 		ksmbd_debug(AUTH, "could not get v2 hash rc %d\n", rc);
157 		return rc;
158 	}
159 
160 	hmac_md5_init_usingrawkey(&ctx, ntlmv2_hash, CIFS_HMAC_MD5_HASH_SIZE);
161 	hmac_md5_update(&ctx, cryptkey, CIFS_CRYPTO_KEY_SIZE);
162 	hmac_md5_update(&ctx, (const u8 *)&ntlmv2->blob_signature, blen);
163 	hmac_md5_final(&ctx, ntlmv2_rsp);
164 
165 	/* Generate the session key */
166 	hmac_md5_usingrawkey(ntlmv2_hash, CIFS_HMAC_MD5_HASH_SIZE,
167 			     ntlmv2_rsp, CIFS_HMAC_MD5_HASH_SIZE,
168 			     sess_key);
169 
170 	if (crypto_memneq(ntlmv2->ntlmv2_hash, ntlmv2_rsp,
171 			  CIFS_HMAC_MD5_HASH_SIZE)) {
172 		rc = -EINVAL;
173 		goto out;
174 	}
175 
176 	memcpy(sess->sess_key, sess_key, sizeof(sess_key));
177 	rc = 0;
178 out:
179 	memzero_explicit(ntlmv2_hash, sizeof(ntlmv2_hash));
180 	memzero_explicit(ntlmv2_rsp, sizeof(ntlmv2_rsp));
181 	memzero_explicit(sess_key, sizeof(sess_key));
182 	return rc;
183 }
184 
185 /**
186  * ksmbd_decode_ntlmssp_auth_blob() - helper function to construct
187  * authenticate blob
188  * @authblob:	authenticate blob source pointer
189  * @blob_len:	length of the @authblob message
190  * @conn:	connection
191  * @sess:	session of connection
192  *
193  * Return:	0 on success, error number on error
194  */
195 int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
196 				   int blob_len, struct ksmbd_conn *conn,
197 				   struct ksmbd_session *sess)
198 {
199 	char *domain_name;
200 	unsigned int nt_off, dn_off;
201 	unsigned short nt_len, dn_len;
202 	int ret;
203 
204 	if (blob_len < sizeof(struct authenticate_message)) {
205 		ksmbd_debug(AUTH, "negotiate blob len %d too small\n",
206 			    blob_len);
207 		return -EINVAL;
208 	}
209 
210 	if (memcmp(authblob->Signature, "NTLMSSP", 8)) {
211 		ksmbd_debug(AUTH, "blob signature incorrect %s\n",
212 			    authblob->Signature);
213 		return -EINVAL;
214 	}
215 
216 	nt_off = le32_to_cpu(authblob->NtChallengeResponse.BufferOffset);
217 	nt_len = le16_to_cpu(authblob->NtChallengeResponse.Length);
218 	dn_off = le32_to_cpu(authblob->DomainName.BufferOffset);
219 	dn_len = le16_to_cpu(authblob->DomainName.Length);
220 
221 	if (blob_len < (u64)dn_off + dn_len || blob_len < (u64)nt_off + nt_len ||
222 	    nt_len < CIFS_ENCPWD_SIZE)
223 		return -EINVAL;
224 
225 	/* TODO : use domain name that imported from configuration file */
226 	domain_name = smb_strndup_from_utf16((const char *)authblob + dn_off,
227 					     dn_len, true, conn->local_nls);
228 	if (IS_ERR(domain_name))
229 		return PTR_ERR(domain_name);
230 
231 	/* process NTLMv2 authentication */
232 	ksmbd_debug(AUTH, "decode_ntlmssp_authenticate_blob dname%s\n",
233 		    domain_name);
234 	ret = ksmbd_auth_ntlmv2(conn, sess,
235 				(struct ntlmv2_resp *)((char *)authblob + nt_off),
236 				nt_len - CIFS_ENCPWD_SIZE,
237 				domain_name, conn->ntlmssp.cryptkey);
238 	kfree(domain_name);
239 	if (ret)
240 		return ret;
241 
242 	/* The recovered secondary session key */
243 	if (conn->ntlmssp.client_flags & NTLMSSP_NEGOTIATE_KEY_XCH) {
244 		struct arc4_ctx *ctx_arc4;
245 		unsigned int sess_key_off, sess_key_len;
246 
247 		sess_key_off = le32_to_cpu(authblob->SessionKey.BufferOffset);
248 		sess_key_len = le16_to_cpu(authblob->SessionKey.Length);
249 
250 		if (blob_len < (u64)sess_key_off + sess_key_len)
251 			return -EINVAL;
252 
253 		if (sess_key_len > CIFS_KEY_SIZE)
254 			return -EINVAL;
255 
256 		ctx_arc4 = kmalloc_obj(*ctx_arc4, KSMBD_DEFAULT_GFP);
257 		if (!ctx_arc4)
258 			return -ENOMEM;
259 
260 		arc4_setkey(ctx_arc4, sess->sess_key, SMB2_NTLMV2_SESSKEY_SIZE);
261 		arc4_crypt(ctx_arc4, sess->sess_key,
262 			   (char *)authblob + sess_key_off, sess_key_len);
263 		kfree_sensitive(ctx_arc4);
264 	}
265 
266 	return ret;
267 }
268 
269 /**
270  * ksmbd_decode_ntlmssp_neg_blob() - helper function to construct
271  * negotiate blob
272  * @negblob: negotiate blob source pointer
273  * @blob_len:	length of the @authblob message
274  * @conn:	connection
275  *
276  */
277 int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob,
278 				  int blob_len, struct ksmbd_conn *conn)
279 {
280 	if (blob_len < sizeof(struct negotiate_message)) {
281 		ksmbd_debug(AUTH, "negotiate blob len %d too small\n",
282 			    blob_len);
283 		return -EINVAL;
284 	}
285 
286 	if (memcmp(negblob->Signature, "NTLMSSP", 8)) {
287 		ksmbd_debug(AUTH, "blob signature incorrect %s\n",
288 			    negblob->Signature);
289 		return -EINVAL;
290 	}
291 
292 	conn->ntlmssp.client_flags = le32_to_cpu(negblob->NegotiateFlags);
293 	return 0;
294 }
295 
296 /**
297  * ksmbd_build_ntlmssp_challenge_blob() - helper function to construct
298  * challenge blob
299  * @chgblob: challenge blob source pointer to initialize
300  * @conn:	connection
301  *
302  */
303 unsigned int
304 ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
305 				   struct ksmbd_conn *conn)
306 {
307 	struct target_info *tinfo;
308 	wchar_t *name;
309 	__u8 *target_name;
310 	unsigned int flags, blob_off, blob_len, type, target_info_len = 0;
311 	int len, uni_len, conv_len;
312 	int cflags = conn->ntlmssp.client_flags;
313 
314 	memcpy(chgblob->Signature, NTLMSSP_SIGNATURE, 8);
315 	chgblob->MessageType = NtLmChallenge;
316 
317 	flags = NTLMSSP_NEGOTIATE_UNICODE |
318 		NTLMSSP_NEGOTIATE_NTLM | NTLMSSP_TARGET_TYPE_SERVER |
319 		NTLMSSP_NEGOTIATE_TARGET_INFO;
320 
321 	if (cflags & NTLMSSP_NEGOTIATE_SIGN) {
322 		flags |= NTLMSSP_NEGOTIATE_SIGN;
323 		flags |= cflags & (NTLMSSP_NEGOTIATE_128 |
324 				   NTLMSSP_NEGOTIATE_56);
325 	}
326 
327 	if (cflags & NTLMSSP_NEGOTIATE_SEAL && smb3_encryption_negotiated(conn))
328 		flags |= NTLMSSP_NEGOTIATE_SEAL;
329 
330 	if (cflags & NTLMSSP_NEGOTIATE_ALWAYS_SIGN)
331 		flags |= NTLMSSP_NEGOTIATE_ALWAYS_SIGN;
332 
333 	if (cflags & NTLMSSP_REQUEST_TARGET)
334 		flags |= NTLMSSP_REQUEST_TARGET;
335 
336 	if (conn->use_spnego &&
337 	    (cflags & NTLMSSP_NEGOTIATE_EXTENDED_SEC))
338 		flags |= NTLMSSP_NEGOTIATE_EXTENDED_SEC;
339 
340 	if (cflags & NTLMSSP_NEGOTIATE_KEY_XCH)
341 		flags |= NTLMSSP_NEGOTIATE_KEY_XCH;
342 
343 	chgblob->NegotiateFlags = cpu_to_le32(flags);
344 	len = strlen(ksmbd_netbios_name());
345 	name = kmalloc(2 + UNICODE_LEN(len), KSMBD_DEFAULT_GFP);
346 	if (!name)
347 		return -ENOMEM;
348 
349 	conv_len = smb_strtoUTF16((__le16 *)name, ksmbd_netbios_name(), len,
350 				  conn->local_nls);
351 	if (conv_len < 0 || conv_len > len) {
352 		kfree(name);
353 		return -EINVAL;
354 	}
355 
356 	uni_len = UNICODE_LEN(conv_len);
357 
358 	blob_off = sizeof(struct challenge_message);
359 	blob_len = blob_off + uni_len;
360 
361 	chgblob->TargetName.Length = cpu_to_le16(uni_len);
362 	chgblob->TargetName.MaximumLength = cpu_to_le16(uni_len);
363 	chgblob->TargetName.BufferOffset = cpu_to_le32(blob_off);
364 
365 	/* Initialize random conn challenge */
366 	get_random_bytes(conn->ntlmssp.cryptkey, sizeof(__u64));
367 	memcpy(chgblob->Challenge, conn->ntlmssp.cryptkey,
368 	       CIFS_CRYPTO_KEY_SIZE);
369 
370 	/* Add Target Information to security buffer */
371 	chgblob->TargetInfoArray.BufferOffset = cpu_to_le32(blob_len);
372 
373 	target_name = (__u8 *)chgblob + blob_off;
374 	memcpy(target_name, name, uni_len);
375 	tinfo = (struct target_info *)(target_name + uni_len);
376 
377 	chgblob->TargetInfoArray.Length = 0;
378 	/* Add target info list for NetBIOS/DNS settings */
379 	for (type = NTLMSSP_AV_NB_COMPUTER_NAME;
380 	     type <= NTLMSSP_AV_DNS_DOMAIN_NAME; type++) {
381 		tinfo->Type = cpu_to_le16(type);
382 		tinfo->Length = cpu_to_le16(uni_len);
383 		memcpy(tinfo->Content, name, uni_len);
384 		tinfo = (struct target_info *)((char *)tinfo + 4 + uni_len);
385 		target_info_len += 4 + uni_len;
386 	}
387 
388 	/* Add terminator subblock */
389 	tinfo->Type = 0;
390 	tinfo->Length = 0;
391 	target_info_len += 4;
392 
393 	chgblob->TargetInfoArray.Length = cpu_to_le16(target_info_len);
394 	chgblob->TargetInfoArray.MaximumLength = cpu_to_le16(target_info_len);
395 	blob_len += target_info_len;
396 	kfree(name);
397 	ksmbd_debug(AUTH, "NTLMSSP SecurityBufferLength %d\n", blob_len);
398 	return blob_len;
399 }
400 
401 #ifdef CONFIG_SMB_SERVER_KERBEROS5
402 int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob,
403 			    int in_len, char *out_blob, int *out_len)
404 {
405 	struct ksmbd_spnego_authen_response *resp;
406 	struct ksmbd_login_response_ext *resp_ext = NULL;
407 	struct ksmbd_user *user = NULL;
408 	int retval;
409 
410 	resp = ksmbd_ipc_spnego_authen_request(in_blob, in_len);
411 	if (!resp) {
412 		ksmbd_debug(AUTH, "SPNEGO_AUTHEN_REQUEST failure\n");
413 		return -EINVAL;
414 	}
415 
416 	if (!(resp->login_response.status & KSMBD_USER_FLAG_OK)) {
417 		ksmbd_debug(AUTH, "krb5 authentication failure\n");
418 		retval = -EPERM;
419 		goto out;
420 	}
421 
422 	if (*out_len <= resp->spnego_blob_len) {
423 		ksmbd_debug(AUTH, "buf len %d, but blob len %d\n",
424 			    *out_len, resp->spnego_blob_len);
425 		retval = -EINVAL;
426 		goto out;
427 	}
428 
429 	if (resp->session_key_len > sizeof(sess->sess_key)) {
430 		ksmbd_debug(AUTH, "session key is too long\n");
431 		retval = -EINVAL;
432 		goto out;
433 	}
434 
435 	if (resp->login_response.status & KSMBD_USER_FLAG_EXTENSION)
436 		resp_ext = ksmbd_ipc_login_request_ext(resp->login_response.account);
437 
438 	user = ksmbd_alloc_user(&resp->login_response, resp_ext);
439 	if (!user) {
440 		ksmbd_debug(AUTH, "login failure\n");
441 		retval = -ENOMEM;
442 		goto out;
443 	}
444 
445 	if (!sess->user) {
446 		/* First successful authentication */
447 		sess->user = user;
448 	} else {
449 		if (!ksmbd_compare_user(sess->user, user)) {
450 			ksmbd_debug(AUTH, "different user tried to reuse session\n");
451 			retval = -EPERM;
452 			ksmbd_free_user(user);
453 			goto out;
454 		}
455 		ksmbd_free_user(user);
456 	}
457 
458 	memcpy(sess->sess_key, resp->payload, resp->session_key_len);
459 	memcpy(out_blob, resp->payload + resp->session_key_len,
460 	       resp->spnego_blob_len);
461 	*out_len = resp->spnego_blob_len;
462 	retval = 0;
463 out:
464 	kvfree(resp);
465 	return retval;
466 }
467 #else
468 int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob,
469 			    int in_len, char *out_blob, int *out_len)
470 {
471 	return -EOPNOTSUPP;
472 }
473 #endif
474 
475 /**
476  * ksmbd_sign_smb2_pdu() - function to generate packet signing
477  * @conn:	connection
478  * @key:	signing key
479  * @iov:        buffer iov array
480  * @n_vec:	number of iovecs
481  * @sig:	signature value generated for client request packet
482  *
483  */
484 void ksmbd_sign_smb2_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov,
485 			 int n_vec, char *sig)
486 {
487 	struct hmac_sha256_ctx ctx;
488 	int i;
489 
490 	hmac_sha256_init_usingrawkey(&ctx, key, SMB2_NTLMV2_SESSKEY_SIZE);
491 	for (i = 0; i < n_vec; i++)
492 		hmac_sha256_update(&ctx, iov[i].iov_base, iov[i].iov_len);
493 	hmac_sha256_final(&ctx, sig);
494 }
495 
496 /**
497  * ksmbd_sign_smb3_pdu() - function to generate packet signing
498  * @conn:	connection
499  * @key:	signing key
500  * @iov:        buffer iov array
501  * @n_vec:	number of iovecs
502  * @sig:	signature value generated for client request packet
503  *
504  */
505 void ksmbd_sign_smb3_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov,
506 			 int n_vec, char *sig)
507 {
508 	struct aes_cmac_key cmac_key;
509 	struct aes_cmac_ctx cmac_ctx;
510 	int i;
511 
512 	/* This cannot fail, since we always pass a valid key length. */
513 	static_assert(SMB2_CMACAES_SIZE == AES_KEYSIZE_128);
514 	aes_cmac_preparekey(&cmac_key, key, SMB2_CMACAES_SIZE);
515 
516 	aes_cmac_init(&cmac_ctx, &cmac_key);
517 	for (i = 0; i < n_vec; i++)
518 		aes_cmac_update(&cmac_ctx, iov[i].iov_base, iov[i].iov_len);
519 	aes_cmac_final(&cmac_ctx, sig);
520 }
521 
522 struct derivation {
523 	struct kvec label;
524 	struct kvec context;
525 	bool binding;
526 };
527 
528 static void generate_key(struct ksmbd_conn *conn, struct ksmbd_session *sess,
529 			 struct kvec label, struct kvec context, __u8 *key,
530 			 unsigned int key_size)
531 {
532 	unsigned char zero = 0x0;
533 	__u8 i[4] = {0, 0, 0, 1};
534 	__u8 L128[4] = {0, 0, 0, 128};
535 	__u8 L256[4] = {0, 0, 1, 0};
536 	unsigned char prfhash[SMB2_HMACSHA256_SIZE];
537 	struct hmac_sha256_ctx ctx;
538 
539 	hmac_sha256_init_usingrawkey(&ctx, sess->sess_key,
540 				     SMB2_NTLMV2_SESSKEY_SIZE);
541 	hmac_sha256_update(&ctx, i, 4);
542 	hmac_sha256_update(&ctx, label.iov_base, label.iov_len);
543 	hmac_sha256_update(&ctx, &zero, 1);
544 	hmac_sha256_update(&ctx, context.iov_base, context.iov_len);
545 
546 	if (key_size == SMB3_ENC_DEC_KEY_SIZE &&
547 	    (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
548 	     conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM))
549 		hmac_sha256_update(&ctx, L256, 4);
550 	else
551 		hmac_sha256_update(&ctx, L128, 4);
552 
553 	hmac_sha256_final(&ctx, prfhash);
554 	memcpy(key, prfhash, key_size);
555 }
556 
557 static int generate_smb3signingkey(struct ksmbd_session *sess,
558 				   struct ksmbd_conn *conn,
559 				   const struct derivation *signing)
560 {
561 	struct channel *chann;
562 	char *key;
563 
564 	chann = lookup_chann_list(sess, conn);
565 	if (!chann)
566 		return 0;
567 
568 	if (conn->dialect >= SMB30_PROT_ID && signing->binding)
569 		key = chann->smb3signingkey;
570 	else
571 		key = sess->smb3signingkey;
572 
573 	generate_key(conn, sess, signing->label, signing->context, key,
574 		     SMB3_SIGN_KEY_SIZE);
575 
576 	if (!(conn->dialect >= SMB30_PROT_ID && signing->binding))
577 		memcpy(chann->smb3signingkey, key, SMB3_SIGN_KEY_SIZE);
578 
579 	ksmbd_debug(AUTH, "generated SMB3 signing key\n");
580 	ksmbd_debug(AUTH, "Session Id    %llu\n", sess->id);
581 	return 0;
582 }
583 
584 int ksmbd_gen_smb30_signingkey(struct ksmbd_session *sess,
585 			       struct ksmbd_conn *conn)
586 {
587 	struct derivation d;
588 
589 	d.label.iov_base = "SMB2AESCMAC";
590 	d.label.iov_len = 12;
591 	d.context.iov_base = "SmbSign";
592 	d.context.iov_len = 8;
593 	d.binding = conn->binding;
594 
595 	return generate_smb3signingkey(sess, conn, &d);
596 }
597 
598 int ksmbd_gen_smb311_signingkey(struct ksmbd_session *sess,
599 				struct ksmbd_conn *conn)
600 {
601 	struct derivation d;
602 
603 	d.label.iov_base = "SMBSigningKey";
604 	d.label.iov_len = 14;
605 	if (conn->binding) {
606 		struct preauth_session *preauth_sess;
607 
608 		preauth_sess = ksmbd_preauth_session_lookup(conn, sess->id);
609 		if (!preauth_sess)
610 			return -ENOENT;
611 		d.context.iov_base = preauth_sess->Preauth_HashValue;
612 	} else {
613 		d.context.iov_base = sess->Preauth_HashValue;
614 	}
615 	d.context.iov_len = 64;
616 	d.binding = conn->binding;
617 
618 	return generate_smb3signingkey(sess, conn, &d);
619 }
620 
621 struct derivation_twin {
622 	struct derivation encryption;
623 	struct derivation decryption;
624 };
625 
626 static void generate_smb3encryptionkey(struct ksmbd_conn *conn,
627 				       struct ksmbd_session *sess,
628 				       const struct derivation_twin *ptwin)
629 {
630 	generate_key(conn, sess, ptwin->encryption.label,
631 		     ptwin->encryption.context, sess->smb3encryptionkey,
632 		     SMB3_ENC_DEC_KEY_SIZE);
633 
634 	generate_key(conn, sess, ptwin->decryption.label,
635 		     ptwin->decryption.context,
636 		     sess->smb3decryptionkey, SMB3_ENC_DEC_KEY_SIZE);
637 
638 	ksmbd_debug(AUTH, "generated SMB3 encryption/decryption keys\n");
639 	ksmbd_debug(AUTH, "Cipher type   %d\n", conn->cipher_type);
640 	ksmbd_debug(AUTH, "Session Id    %llu\n", sess->id);
641 }
642 
643 void ksmbd_gen_smb30_encryptionkey(struct ksmbd_conn *conn,
644 				   struct ksmbd_session *sess)
645 {
646 	struct derivation_twin twin;
647 	struct derivation *d;
648 
649 	d = &twin.encryption;
650 	d->label.iov_base = "SMB2AESCCM";
651 	d->label.iov_len = 11;
652 	d->context.iov_base = "ServerOut";
653 	d->context.iov_len = 10;
654 
655 	d = &twin.decryption;
656 	d->label.iov_base = "SMB2AESCCM";
657 	d->label.iov_len = 11;
658 	d->context.iov_base = "ServerIn ";
659 	d->context.iov_len = 10;
660 
661 	generate_smb3encryptionkey(conn, sess, &twin);
662 }
663 
664 void ksmbd_gen_smb311_encryptionkey(struct ksmbd_conn *conn,
665 				    struct ksmbd_session *sess)
666 {
667 	struct derivation_twin twin;
668 	struct derivation *d;
669 
670 	d = &twin.encryption;
671 	d->label.iov_base = "SMBS2CCipherKey";
672 	d->label.iov_len = 16;
673 	d->context.iov_base = sess->Preauth_HashValue;
674 	d->context.iov_len = 64;
675 
676 	d = &twin.decryption;
677 	d->label.iov_base = "SMBC2SCipherKey";
678 	d->label.iov_len = 16;
679 	d->context.iov_base = sess->Preauth_HashValue;
680 	d->context.iov_len = 64;
681 
682 	generate_smb3encryptionkey(conn, sess, &twin);
683 }
684 
685 int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf,
686 				     __u8 *pi_hash)
687 {
688 	struct smb2_hdr *rcv_hdr = smb_get_msg(buf);
689 	char *all_bytes_msg = (char *)&rcv_hdr->ProtocolId;
690 	int msg_size = get_rfc1002_len(buf);
691 	struct sha512_ctx sha_ctx;
692 
693 	if (conn->preauth_info->Preauth_HashId !=
694 	    SMB2_PREAUTH_INTEGRITY_SHA512)
695 		return -EINVAL;
696 
697 	sha512_init(&sha_ctx);
698 	sha512_update(&sha_ctx, pi_hash, 64);
699 	sha512_update(&sha_ctx, all_bytes_msg, msg_size);
700 	sha512_final(&sha_ctx, pi_hash);
701 	return 0;
702 }
703 
704 static int ksmbd_get_encryption_key(struct ksmbd_work *work, __u64 ses_id,
705 				    int enc, u8 *key)
706 {
707 	struct ksmbd_session *sess;
708 	u8 *ses_enc_key;
709 
710 	if (enc)
711 		sess = work->sess;
712 	else
713 		sess = ksmbd_session_lookup_all(work->conn, ses_id);
714 	if (!sess)
715 		return -EINVAL;
716 
717 	ses_enc_key = enc ? sess->smb3encryptionkey :
718 		sess->smb3decryptionkey;
719 	memcpy(key, ses_enc_key, SMB3_ENC_DEC_KEY_SIZE);
720 	if (!enc)
721 		ksmbd_user_session_put(sess);
722 
723 	return 0;
724 }
725 
726 static inline void smb2_sg_set_buf(struct scatterlist *sg, const void *buf,
727 				   unsigned int buflen)
728 {
729 	void *addr;
730 
731 	if (is_vmalloc_addr(buf))
732 		addr = vmalloc_to_page(buf);
733 	else
734 		addr = virt_to_page(buf);
735 	sg_set_page(sg, addr, buflen, offset_in_page(buf));
736 }
737 
738 static struct scatterlist *ksmbd_init_sg(struct kvec *iov, unsigned int nvec,
739 					 u8 *sign)
740 {
741 	struct scatterlist *sg;
742 	unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
743 	int i, *nr_entries, total_entries = 0, sg_idx = 0;
744 
745 	if (!nvec)
746 		return NULL;
747 
748 	nr_entries = kzalloc_objs(int, nvec, KSMBD_DEFAULT_GFP);
749 	if (!nr_entries)
750 		return NULL;
751 
752 	for (i = 0; i < nvec - 1; i++) {
753 		unsigned long kaddr = (unsigned long)iov[i + 1].iov_base;
754 
755 		if (is_vmalloc_addr(iov[i + 1].iov_base)) {
756 			nr_entries[i] = ((kaddr + iov[i + 1].iov_len +
757 					PAGE_SIZE - 1) >> PAGE_SHIFT) -
758 				(kaddr >> PAGE_SHIFT);
759 		} else {
760 			nr_entries[i]++;
761 		}
762 		total_entries += nr_entries[i];
763 	}
764 
765 	/* Add two entries for transform header and signature */
766 	total_entries += 2;
767 
768 	sg = kmalloc_objs(struct scatterlist, total_entries, KSMBD_DEFAULT_GFP);
769 	if (!sg) {
770 		kfree(nr_entries);
771 		return NULL;
772 	}
773 
774 	sg_init_table(sg, total_entries);
775 	smb2_sg_set_buf(&sg[sg_idx++], iov[0].iov_base + 24, assoc_data_len);
776 	for (i = 0; i < nvec - 1; i++) {
777 		void *data = iov[i + 1].iov_base;
778 		int len = iov[i + 1].iov_len;
779 
780 		if (is_vmalloc_addr(data)) {
781 			int j, offset = offset_in_page(data);
782 
783 			for (j = 0; j < nr_entries[i]; j++) {
784 				unsigned int bytes = PAGE_SIZE - offset;
785 
786 				if (!len)
787 					break;
788 
789 				if (bytes > len)
790 					bytes = len;
791 
792 				sg_set_page(&sg[sg_idx++],
793 					    vmalloc_to_page(data), bytes,
794 					    offset_in_page(data));
795 
796 				data += bytes;
797 				len -= bytes;
798 				offset = 0;
799 			}
800 		} else {
801 			sg_set_page(&sg[sg_idx++], virt_to_page(data), len,
802 				    offset_in_page(data));
803 		}
804 	}
805 	smb2_sg_set_buf(&sg[sg_idx], sign, SMB2_SIGNATURE_SIZE);
806 	kfree(nr_entries);
807 	return sg;
808 }
809 
810 int ksmbd_crypt_message(struct ksmbd_work *work, struct kvec *iov,
811 			unsigned int nvec, int enc)
812 {
813 	struct ksmbd_conn *conn = work->conn;
814 	struct smb2_transform_hdr *tr_hdr = smb_get_msg(iov[0].iov_base);
815 	unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
816 	int rc;
817 	DECLARE_CRYPTO_WAIT(wait);
818 	struct scatterlist *sg;
819 	u8 sign[SMB2_SIGNATURE_SIZE] = {};
820 	u8 key[SMB3_ENC_DEC_KEY_SIZE];
821 	struct aead_request *req;
822 	char *iv;
823 	unsigned int iv_len;
824 	struct crypto_aead *tfm;
825 	unsigned int crypt_len = le32_to_cpu(tr_hdr->OriginalMessageSize);
826 	struct ksmbd_crypto_ctx *ctx;
827 
828 	rc = ksmbd_get_encryption_key(work,
829 				      le64_to_cpu(tr_hdr->SessionId),
830 				      enc,
831 				      key);
832 	if (rc) {
833 		pr_err("Could not get %scryption key\n", enc ? "en" : "de");
834 		return rc;
835 	}
836 
837 	if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
838 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
839 		ctx = ksmbd_crypto_ctx_find_gcm();
840 	else
841 		ctx = ksmbd_crypto_ctx_find_ccm();
842 	if (!ctx) {
843 		pr_err("crypto alloc failed\n");
844 		return -ENOMEM;
845 	}
846 
847 	if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
848 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
849 		tfm = CRYPTO_GCM(ctx);
850 	else
851 		tfm = CRYPTO_CCM(ctx);
852 
853 	if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
854 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
855 		rc = crypto_aead_setkey(tfm, key, SMB3_GCM256_CRYPTKEY_SIZE);
856 	else
857 		rc = crypto_aead_setkey(tfm, key, SMB3_GCM128_CRYPTKEY_SIZE);
858 	if (rc) {
859 		pr_err("Failed to set aead key %d\n", rc);
860 		goto free_ctx;
861 	}
862 
863 	rc = crypto_aead_setauthsize(tfm, SMB2_SIGNATURE_SIZE);
864 	if (rc) {
865 		pr_err("Failed to set authsize %d\n", rc);
866 		goto free_ctx;
867 	}
868 
869 	req = aead_request_alloc(tfm, KSMBD_DEFAULT_GFP);
870 	if (!req) {
871 		rc = -ENOMEM;
872 		goto free_ctx;
873 	}
874 
875 	if (!enc) {
876 		memcpy(sign, &tr_hdr->Signature, SMB2_SIGNATURE_SIZE);
877 		crypt_len += SMB2_SIGNATURE_SIZE;
878 	}
879 
880 	sg = ksmbd_init_sg(iov, nvec, sign);
881 	if (!sg) {
882 		pr_err("Failed to init sg\n");
883 		rc = -ENOMEM;
884 		goto free_req;
885 	}
886 
887 	iv_len = crypto_aead_ivsize(tfm);
888 	iv = kzalloc(iv_len, KSMBD_DEFAULT_GFP);
889 	if (!iv) {
890 		rc = -ENOMEM;
891 		goto free_sg;
892 	}
893 
894 	if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
895 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
896 		memcpy(iv, (char *)tr_hdr->Nonce, SMB3_AES_GCM_NONCE);
897 	} else {
898 		iv[0] = 3;
899 		memcpy(iv + 1, (char *)tr_hdr->Nonce, SMB3_AES_CCM_NONCE);
900 	}
901 
902 	aead_request_set_crypt(req, sg, sg, crypt_len, iv);
903 	aead_request_set_ad(req, assoc_data_len);
904 	aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG |
905 				  CRYPTO_TFM_REQ_MAY_SLEEP,
906 				  crypto_req_done, &wait);
907 
908 	rc = crypto_wait_req(enc ? crypto_aead_encrypt(req) :
909 			     crypto_aead_decrypt(req), &wait);
910 	if (rc)
911 		goto free_iv;
912 
913 	if (enc)
914 		memcpy(&tr_hdr->Signature, sign, SMB2_SIGNATURE_SIZE);
915 
916 free_iv:
917 	kfree(iv);
918 free_sg:
919 	kfree(sg);
920 free_req:
921 	aead_request_free(req);
922 free_ctx:
923 	ksmbd_release_crypto_ctx(ctx);
924 	return rc;
925 }
926