xref: /linux/fs/smb/server/auth.c (revision 869737543b39a145809c41a7253c6ee777e22729)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org>
4  *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
5  */
6 
7 #include <linux/kernel.h>
8 #include <linux/fs.h>
9 #include <linux/uaccess.h>
10 #include <linux/backing-dev.h>
11 #include <linux/writeback.h>
12 #include <linux/uio.h>
13 #include <linux/xattr.h>
14 #include <crypto/hash.h>
15 #include <crypto/aead.h>
16 #include <crypto/md5.h>
17 #include <crypto/sha2.h>
18 #include <linux/random.h>
19 #include <linux/scatterlist.h>
20 
21 #include "auth.h"
22 #include "glob.h"
23 
24 #include <linux/fips.h>
25 #include <crypto/arc4.h>
26 #include <crypto/des.h>
27 
28 #include "server.h"
29 #include "smb_common.h"
30 #include "connection.h"
31 #include "mgmt/user_session.h"
32 #include "mgmt/user_config.h"
33 #include "crypto_ctx.h"
34 #include "transport_ipc.h"
35 
36 /*
37  * Fixed format data defining GSS header and fixed string
38  * "not_defined_in_RFC4178@please_ignore".
39  * So sec blob data in neg phase could be generated statically.
40  */
41 static char NEGOTIATE_GSS_HEADER[AUTH_GSS_LENGTH] = {
42 #ifdef CONFIG_SMB_SERVER_KERBEROS5
43 	0x60, 0x5e, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05,
44 	0x05, 0x02, 0xa0, 0x54, 0x30, 0x52, 0xa0, 0x24,
45 	0x30, 0x22, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86,
46 	0xf7, 0x12, 0x01, 0x02, 0x02, 0x06, 0x09, 0x2a,
47 	0x86, 0x48, 0x82, 0xf7, 0x12, 0x01, 0x02, 0x02,
48 	0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82,
49 	0x37, 0x02, 0x02, 0x0a, 0xa3, 0x2a, 0x30, 0x28,
50 	0xa0, 0x26, 0x1b, 0x24, 0x6e, 0x6f, 0x74, 0x5f,
51 	0x64, 0x65, 0x66, 0x69, 0x6e, 0x65, 0x64, 0x5f,
52 	0x69, 0x6e, 0x5f, 0x52, 0x46, 0x43, 0x34, 0x31,
53 	0x37, 0x38, 0x40, 0x70, 0x6c, 0x65, 0x61, 0x73,
54 	0x65, 0x5f, 0x69, 0x67, 0x6e, 0x6f, 0x72, 0x65
55 #else
56 	0x60, 0x48, 0x06, 0x06, 0x2b, 0x06, 0x01, 0x05,
57 	0x05, 0x02, 0xa0, 0x3e, 0x30, 0x3c, 0xa0, 0x0e,
58 	0x30, 0x0c, 0x06, 0x0a, 0x2b, 0x06, 0x01, 0x04,
59 	0x01, 0x82, 0x37, 0x02, 0x02, 0x0a, 0xa3, 0x2a,
60 	0x30, 0x28, 0xa0, 0x26, 0x1b, 0x24, 0x6e, 0x6f,
61 	0x74, 0x5f, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x65,
62 	0x64, 0x5f, 0x69, 0x6e, 0x5f, 0x52, 0x46, 0x43,
63 	0x34, 0x31, 0x37, 0x38, 0x40, 0x70, 0x6c, 0x65,
64 	0x61, 0x73, 0x65, 0x5f, 0x69, 0x67, 0x6e, 0x6f,
65 	0x72, 0x65
66 #endif
67 };
68 
69 void ksmbd_copy_gss_neg_header(void *buf)
70 {
71 	memcpy(buf, NEGOTIATE_GSS_HEADER, AUTH_GSS_LENGTH);
72 }
73 
74 static int calc_ntlmv2_hash(struct ksmbd_conn *conn, struct ksmbd_session *sess,
75 			    char *ntlmv2_hash, char *dname)
76 {
77 	int ret, len, conv_len;
78 	wchar_t *domain = NULL;
79 	__le16 *uniname = NULL;
80 	struct hmac_md5_ctx ctx;
81 
82 	hmac_md5_init_usingrawkey(&ctx, user_passkey(sess->user),
83 				  CIFS_ENCPWD_SIZE);
84 
85 	/* convert user_name to unicode */
86 	len = strlen(user_name(sess->user));
87 	uniname = kzalloc(2 + UNICODE_LEN(len), KSMBD_DEFAULT_GFP);
88 	if (!uniname) {
89 		ret = -ENOMEM;
90 		goto out;
91 	}
92 
93 	conv_len = smb_strtoUTF16(uniname, user_name(sess->user), len,
94 				  conn->local_nls);
95 	if (conv_len < 0 || conv_len > len) {
96 		ret = -EINVAL;
97 		goto out;
98 	}
99 	UniStrupr(uniname);
100 
101 	hmac_md5_update(&ctx, (const u8 *)uniname, UNICODE_LEN(conv_len));
102 
103 	/* Convert domain name or conn name to unicode and uppercase */
104 	len = strlen(dname);
105 	domain = kzalloc(2 + UNICODE_LEN(len), KSMBD_DEFAULT_GFP);
106 	if (!domain) {
107 		ret = -ENOMEM;
108 		goto out;
109 	}
110 
111 	conv_len = smb_strtoUTF16((__le16 *)domain, dname, len,
112 				  conn->local_nls);
113 	if (conv_len < 0 || conv_len > len) {
114 		ret = -EINVAL;
115 		goto out;
116 	}
117 
118 	hmac_md5_update(&ctx, (const u8 *)domain, UNICODE_LEN(conv_len));
119 	hmac_md5_final(&ctx, ntlmv2_hash);
120 	ret = 0;
121 out:
122 	kfree(uniname);
123 	kfree(domain);
124 	return ret;
125 }
126 
127 /**
128  * ksmbd_auth_ntlmv2() - NTLMv2 authentication handler
129  * @conn:		connection
130  * @sess:		session of connection
131  * @ntlmv2:		NTLMv2 challenge response
132  * @blen:		NTLMv2 blob length
133  * @domain_name:	domain name
134  * @cryptkey:		session crypto key
135  *
136  * Return:	0 on success, error number on error
137  */
138 int ksmbd_auth_ntlmv2(struct ksmbd_conn *conn, struct ksmbd_session *sess,
139 		      struct ntlmv2_resp *ntlmv2, int blen, char *domain_name,
140 		      char *cryptkey)
141 {
142 	char ntlmv2_hash[CIFS_ENCPWD_SIZE];
143 	char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE];
144 	struct hmac_md5_ctx ctx;
145 	int rc;
146 
147 	if (fips_enabled) {
148 		ksmbd_debug(AUTH, "NTLMv2 support is disabled due to FIPS\n");
149 		return -EOPNOTSUPP;
150 	}
151 
152 	rc = calc_ntlmv2_hash(conn, sess, ntlmv2_hash, domain_name);
153 	if (rc) {
154 		ksmbd_debug(AUTH, "could not get v2 hash rc %d\n", rc);
155 		return rc;
156 	}
157 
158 	hmac_md5_init_usingrawkey(&ctx, ntlmv2_hash, CIFS_HMAC_MD5_HASH_SIZE);
159 	hmac_md5_update(&ctx, cryptkey, CIFS_CRYPTO_KEY_SIZE);
160 	hmac_md5_update(&ctx, (const u8 *)&ntlmv2->blob_signature, blen);
161 	hmac_md5_final(&ctx, ntlmv2_rsp);
162 
163 	/* Generate the session key */
164 	hmac_md5_usingrawkey(ntlmv2_hash, CIFS_HMAC_MD5_HASH_SIZE,
165 			     ntlmv2_rsp, CIFS_HMAC_MD5_HASH_SIZE,
166 			     sess->sess_key);
167 
168 	if (memcmp(ntlmv2->ntlmv2_hash, ntlmv2_rsp, CIFS_HMAC_MD5_HASH_SIZE) != 0)
169 		return -EINVAL;
170 	return 0;
171 }
172 
173 /**
174  * ksmbd_decode_ntlmssp_auth_blob() - helper function to construct
175  * authenticate blob
176  * @authblob:	authenticate blob source pointer
177  * @blob_len:	length of the @authblob message
178  * @conn:	connection
179  * @sess:	session of connection
180  *
181  * Return:	0 on success, error number on error
182  */
183 int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
184 				   int blob_len, struct ksmbd_conn *conn,
185 				   struct ksmbd_session *sess)
186 {
187 	char *domain_name;
188 	unsigned int nt_off, dn_off;
189 	unsigned short nt_len, dn_len;
190 	int ret;
191 
192 	if (blob_len < sizeof(struct authenticate_message)) {
193 		ksmbd_debug(AUTH, "negotiate blob len %d too small\n",
194 			    blob_len);
195 		return -EINVAL;
196 	}
197 
198 	if (memcmp(authblob->Signature, "NTLMSSP", 8)) {
199 		ksmbd_debug(AUTH, "blob signature incorrect %s\n",
200 			    authblob->Signature);
201 		return -EINVAL;
202 	}
203 
204 	nt_off = le32_to_cpu(authblob->NtChallengeResponse.BufferOffset);
205 	nt_len = le16_to_cpu(authblob->NtChallengeResponse.Length);
206 	dn_off = le32_to_cpu(authblob->DomainName.BufferOffset);
207 	dn_len = le16_to_cpu(authblob->DomainName.Length);
208 
209 	if (blob_len < (u64)dn_off + dn_len || blob_len < (u64)nt_off + nt_len ||
210 	    nt_len < CIFS_ENCPWD_SIZE)
211 		return -EINVAL;
212 
213 	/* TODO : use domain name that imported from configuration file */
214 	domain_name = smb_strndup_from_utf16((const char *)authblob + dn_off,
215 					     dn_len, true, conn->local_nls);
216 	if (IS_ERR(domain_name))
217 		return PTR_ERR(domain_name);
218 
219 	/* process NTLMv2 authentication */
220 	ksmbd_debug(AUTH, "decode_ntlmssp_authenticate_blob dname%s\n",
221 		    domain_name);
222 	ret = ksmbd_auth_ntlmv2(conn, sess,
223 				(struct ntlmv2_resp *)((char *)authblob + nt_off),
224 				nt_len - CIFS_ENCPWD_SIZE,
225 				domain_name, conn->ntlmssp.cryptkey);
226 	kfree(domain_name);
227 
228 	/* The recovered secondary session key */
229 	if (conn->ntlmssp.client_flags & NTLMSSP_NEGOTIATE_KEY_XCH) {
230 		struct arc4_ctx *ctx_arc4;
231 		unsigned int sess_key_off, sess_key_len;
232 
233 		sess_key_off = le32_to_cpu(authblob->SessionKey.BufferOffset);
234 		sess_key_len = le16_to_cpu(authblob->SessionKey.Length);
235 
236 		if (blob_len < (u64)sess_key_off + sess_key_len)
237 			return -EINVAL;
238 
239 		if (sess_key_len > CIFS_KEY_SIZE)
240 			return -EINVAL;
241 
242 		ctx_arc4 = kmalloc(sizeof(*ctx_arc4), KSMBD_DEFAULT_GFP);
243 		if (!ctx_arc4)
244 			return -ENOMEM;
245 
246 		arc4_setkey(ctx_arc4, sess->sess_key, SMB2_NTLMV2_SESSKEY_SIZE);
247 		arc4_crypt(ctx_arc4, sess->sess_key,
248 			   (char *)authblob + sess_key_off, sess_key_len);
249 		kfree_sensitive(ctx_arc4);
250 	}
251 
252 	return ret;
253 }
254 
255 /**
256  * ksmbd_decode_ntlmssp_neg_blob() - helper function to construct
257  * negotiate blob
258  * @negblob: negotiate blob source pointer
259  * @blob_len:	length of the @authblob message
260  * @conn:	connection
261  *
262  */
263 int ksmbd_decode_ntlmssp_neg_blob(struct negotiate_message *negblob,
264 				  int blob_len, struct ksmbd_conn *conn)
265 {
266 	if (blob_len < sizeof(struct negotiate_message)) {
267 		ksmbd_debug(AUTH, "negotiate blob len %d too small\n",
268 			    blob_len);
269 		return -EINVAL;
270 	}
271 
272 	if (memcmp(negblob->Signature, "NTLMSSP", 8)) {
273 		ksmbd_debug(AUTH, "blob signature incorrect %s\n",
274 			    negblob->Signature);
275 		return -EINVAL;
276 	}
277 
278 	conn->ntlmssp.client_flags = le32_to_cpu(negblob->NegotiateFlags);
279 	return 0;
280 }
281 
282 /**
283  * ksmbd_build_ntlmssp_challenge_blob() - helper function to construct
284  * challenge blob
285  * @chgblob: challenge blob source pointer to initialize
286  * @conn:	connection
287  *
288  */
289 unsigned int
290 ksmbd_build_ntlmssp_challenge_blob(struct challenge_message *chgblob,
291 				   struct ksmbd_conn *conn)
292 {
293 	struct target_info *tinfo;
294 	wchar_t *name;
295 	__u8 *target_name;
296 	unsigned int flags, blob_off, blob_len, type, target_info_len = 0;
297 	int len, uni_len, conv_len;
298 	int cflags = conn->ntlmssp.client_flags;
299 
300 	memcpy(chgblob->Signature, NTLMSSP_SIGNATURE, 8);
301 	chgblob->MessageType = NtLmChallenge;
302 
303 	flags = NTLMSSP_NEGOTIATE_UNICODE |
304 		NTLMSSP_NEGOTIATE_NTLM | NTLMSSP_TARGET_TYPE_SERVER |
305 		NTLMSSP_NEGOTIATE_TARGET_INFO;
306 
307 	if (cflags & NTLMSSP_NEGOTIATE_SIGN) {
308 		flags |= NTLMSSP_NEGOTIATE_SIGN;
309 		flags |= cflags & (NTLMSSP_NEGOTIATE_128 |
310 				   NTLMSSP_NEGOTIATE_56);
311 	}
312 
313 	if (cflags & NTLMSSP_NEGOTIATE_SEAL && smb3_encryption_negotiated(conn))
314 		flags |= NTLMSSP_NEGOTIATE_SEAL;
315 
316 	if (cflags & NTLMSSP_NEGOTIATE_ALWAYS_SIGN)
317 		flags |= NTLMSSP_NEGOTIATE_ALWAYS_SIGN;
318 
319 	if (cflags & NTLMSSP_REQUEST_TARGET)
320 		flags |= NTLMSSP_REQUEST_TARGET;
321 
322 	if (conn->use_spnego &&
323 	    (cflags & NTLMSSP_NEGOTIATE_EXTENDED_SEC))
324 		flags |= NTLMSSP_NEGOTIATE_EXTENDED_SEC;
325 
326 	if (cflags & NTLMSSP_NEGOTIATE_KEY_XCH)
327 		flags |= NTLMSSP_NEGOTIATE_KEY_XCH;
328 
329 	chgblob->NegotiateFlags = cpu_to_le32(flags);
330 	len = strlen(ksmbd_netbios_name());
331 	name = kmalloc(2 + UNICODE_LEN(len), KSMBD_DEFAULT_GFP);
332 	if (!name)
333 		return -ENOMEM;
334 
335 	conv_len = smb_strtoUTF16((__le16 *)name, ksmbd_netbios_name(), len,
336 				  conn->local_nls);
337 	if (conv_len < 0 || conv_len > len) {
338 		kfree(name);
339 		return -EINVAL;
340 	}
341 
342 	uni_len = UNICODE_LEN(conv_len);
343 
344 	blob_off = sizeof(struct challenge_message);
345 	blob_len = blob_off + uni_len;
346 
347 	chgblob->TargetName.Length = cpu_to_le16(uni_len);
348 	chgblob->TargetName.MaximumLength = cpu_to_le16(uni_len);
349 	chgblob->TargetName.BufferOffset = cpu_to_le32(blob_off);
350 
351 	/* Initialize random conn challenge */
352 	get_random_bytes(conn->ntlmssp.cryptkey, sizeof(__u64));
353 	memcpy(chgblob->Challenge, conn->ntlmssp.cryptkey,
354 	       CIFS_CRYPTO_KEY_SIZE);
355 
356 	/* Add Target Information to security buffer */
357 	chgblob->TargetInfoArray.BufferOffset = cpu_to_le32(blob_len);
358 
359 	target_name = (__u8 *)chgblob + blob_off;
360 	memcpy(target_name, name, uni_len);
361 	tinfo = (struct target_info *)(target_name + uni_len);
362 
363 	chgblob->TargetInfoArray.Length = 0;
364 	/* Add target info list for NetBIOS/DNS settings */
365 	for (type = NTLMSSP_AV_NB_COMPUTER_NAME;
366 	     type <= NTLMSSP_AV_DNS_DOMAIN_NAME; type++) {
367 		tinfo->Type = cpu_to_le16(type);
368 		tinfo->Length = cpu_to_le16(uni_len);
369 		memcpy(tinfo->Content, name, uni_len);
370 		tinfo = (struct target_info *)((char *)tinfo + 4 + uni_len);
371 		target_info_len += 4 + uni_len;
372 	}
373 
374 	/* Add terminator subblock */
375 	tinfo->Type = 0;
376 	tinfo->Length = 0;
377 	target_info_len += 4;
378 
379 	chgblob->TargetInfoArray.Length = cpu_to_le16(target_info_len);
380 	chgblob->TargetInfoArray.MaximumLength = cpu_to_le16(target_info_len);
381 	blob_len += target_info_len;
382 	kfree(name);
383 	ksmbd_debug(AUTH, "NTLMSSP SecurityBufferLength %d\n", blob_len);
384 	return blob_len;
385 }
386 
387 #ifdef CONFIG_SMB_SERVER_KERBEROS5
388 int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob,
389 			    int in_len, char *out_blob, int *out_len)
390 {
391 	struct ksmbd_spnego_authen_response *resp;
392 	struct ksmbd_login_response_ext *resp_ext = NULL;
393 	struct ksmbd_user *user = NULL;
394 	int retval;
395 
396 	resp = ksmbd_ipc_spnego_authen_request(in_blob, in_len);
397 	if (!resp) {
398 		ksmbd_debug(AUTH, "SPNEGO_AUTHEN_REQUEST failure\n");
399 		return -EINVAL;
400 	}
401 
402 	if (!(resp->login_response.status & KSMBD_USER_FLAG_OK)) {
403 		ksmbd_debug(AUTH, "krb5 authentication failure\n");
404 		retval = -EPERM;
405 		goto out;
406 	}
407 
408 	if (*out_len <= resp->spnego_blob_len) {
409 		ksmbd_debug(AUTH, "buf len %d, but blob len %d\n",
410 			    *out_len, resp->spnego_blob_len);
411 		retval = -EINVAL;
412 		goto out;
413 	}
414 
415 	if (resp->session_key_len > sizeof(sess->sess_key)) {
416 		ksmbd_debug(AUTH, "session key is too long\n");
417 		retval = -EINVAL;
418 		goto out;
419 	}
420 
421 	if (resp->login_response.status & KSMBD_USER_FLAG_EXTENSION)
422 		resp_ext = ksmbd_ipc_login_request_ext(resp->login_response.account);
423 
424 	user = ksmbd_alloc_user(&resp->login_response, resp_ext);
425 	if (!user) {
426 		ksmbd_debug(AUTH, "login failure\n");
427 		retval = -ENOMEM;
428 		goto out;
429 	}
430 
431 	if (!sess->user) {
432 		/* First successful authentication */
433 		sess->user = user;
434 	} else {
435 		if (!ksmbd_compare_user(sess->user, user)) {
436 			ksmbd_debug(AUTH, "different user tried to reuse session\n");
437 			retval = -EPERM;
438 			ksmbd_free_user(user);
439 			goto out;
440 		}
441 		ksmbd_free_user(user);
442 	}
443 
444 	memcpy(sess->sess_key, resp->payload, resp->session_key_len);
445 	memcpy(out_blob, resp->payload + resp->session_key_len,
446 	       resp->spnego_blob_len);
447 	*out_len = resp->spnego_blob_len;
448 	retval = 0;
449 out:
450 	kvfree(resp);
451 	return retval;
452 }
453 #else
454 int ksmbd_krb5_authenticate(struct ksmbd_session *sess, char *in_blob,
455 			    int in_len, char *out_blob, int *out_len)
456 {
457 	return -EOPNOTSUPP;
458 }
459 #endif
460 
461 /**
462  * ksmbd_sign_smb2_pdu() - function to generate packet signing
463  * @conn:	connection
464  * @key:	signing key
465  * @iov:        buffer iov array
466  * @n_vec:	number of iovecs
467  * @sig:	signature value generated for client request packet
468  *
469  */
470 void ksmbd_sign_smb2_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov,
471 			 int n_vec, char *sig)
472 {
473 	struct hmac_sha256_ctx ctx;
474 	int i;
475 
476 	hmac_sha256_init_usingrawkey(&ctx, key, SMB2_NTLMV2_SESSKEY_SIZE);
477 	for (i = 0; i < n_vec; i++)
478 		hmac_sha256_update(&ctx, iov[i].iov_base, iov[i].iov_len);
479 	hmac_sha256_final(&ctx, sig);
480 }
481 
482 /**
483  * ksmbd_sign_smb3_pdu() - function to generate packet signing
484  * @conn:	connection
485  * @key:	signing key
486  * @iov:        buffer iov array
487  * @n_vec:	number of iovecs
488  * @sig:	signature value generated for client request packet
489  *
490  */
491 int ksmbd_sign_smb3_pdu(struct ksmbd_conn *conn, char *key, struct kvec *iov,
492 			int n_vec, char *sig)
493 {
494 	struct ksmbd_crypto_ctx *ctx;
495 	int rc, i;
496 
497 	ctx = ksmbd_crypto_ctx_find_cmacaes();
498 	if (!ctx) {
499 		ksmbd_debug(AUTH, "could not crypto alloc cmac\n");
500 		return -ENOMEM;
501 	}
502 
503 	rc = crypto_shash_setkey(CRYPTO_CMACAES_TFM(ctx),
504 				 key,
505 				 SMB2_CMACAES_SIZE);
506 	if (rc)
507 		goto out;
508 
509 	rc = crypto_shash_init(CRYPTO_CMACAES(ctx));
510 	if (rc) {
511 		ksmbd_debug(AUTH, "cmaces init error %d\n", rc);
512 		goto out;
513 	}
514 
515 	for (i = 0; i < n_vec; i++) {
516 		rc = crypto_shash_update(CRYPTO_CMACAES(ctx),
517 					 iov[i].iov_base,
518 					 iov[i].iov_len);
519 		if (rc) {
520 			ksmbd_debug(AUTH, "cmaces update error %d\n", rc);
521 			goto out;
522 		}
523 	}
524 
525 	rc = crypto_shash_final(CRYPTO_CMACAES(ctx), sig);
526 	if (rc)
527 		ksmbd_debug(AUTH, "cmaces generation error %d\n", rc);
528 out:
529 	ksmbd_release_crypto_ctx(ctx);
530 	return rc;
531 }
532 
533 struct derivation {
534 	struct kvec label;
535 	struct kvec context;
536 	bool binding;
537 };
538 
539 static void generate_key(struct ksmbd_conn *conn, struct ksmbd_session *sess,
540 			 struct kvec label, struct kvec context, __u8 *key,
541 			 unsigned int key_size)
542 {
543 	unsigned char zero = 0x0;
544 	__u8 i[4] = {0, 0, 0, 1};
545 	__u8 L128[4] = {0, 0, 0, 128};
546 	__u8 L256[4] = {0, 0, 1, 0};
547 	unsigned char prfhash[SMB2_HMACSHA256_SIZE];
548 	struct hmac_sha256_ctx ctx;
549 
550 	hmac_sha256_init_usingrawkey(&ctx, sess->sess_key,
551 				     SMB2_NTLMV2_SESSKEY_SIZE);
552 	hmac_sha256_update(&ctx, i, 4);
553 	hmac_sha256_update(&ctx, label.iov_base, label.iov_len);
554 	hmac_sha256_update(&ctx, &zero, 1);
555 	hmac_sha256_update(&ctx, context.iov_base, context.iov_len);
556 
557 	if (key_size == SMB3_ENC_DEC_KEY_SIZE &&
558 	    (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
559 	     conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM))
560 		hmac_sha256_update(&ctx, L256, 4);
561 	else
562 		hmac_sha256_update(&ctx, L128, 4);
563 
564 	hmac_sha256_final(&ctx, prfhash);
565 	memcpy(key, prfhash, key_size);
566 }
567 
568 static int generate_smb3signingkey(struct ksmbd_session *sess,
569 				   struct ksmbd_conn *conn,
570 				   const struct derivation *signing)
571 {
572 	struct channel *chann;
573 	char *key;
574 
575 	chann = lookup_chann_list(sess, conn);
576 	if (!chann)
577 		return 0;
578 
579 	if (conn->dialect >= SMB30_PROT_ID && signing->binding)
580 		key = chann->smb3signingkey;
581 	else
582 		key = sess->smb3signingkey;
583 
584 	generate_key(conn, sess, signing->label, signing->context, key,
585 		     SMB3_SIGN_KEY_SIZE);
586 
587 	if (!(conn->dialect >= SMB30_PROT_ID && signing->binding))
588 		memcpy(chann->smb3signingkey, key, SMB3_SIGN_KEY_SIZE);
589 
590 	ksmbd_debug(AUTH, "dumping generated AES signing keys\n");
591 	ksmbd_debug(AUTH, "Session Id    %llu\n", sess->id);
592 	ksmbd_debug(AUTH, "Session Key   %*ph\n",
593 		    SMB2_NTLMV2_SESSKEY_SIZE, sess->sess_key);
594 	ksmbd_debug(AUTH, "Signing Key   %*ph\n",
595 		    SMB3_SIGN_KEY_SIZE, key);
596 	return 0;
597 }
598 
599 int ksmbd_gen_smb30_signingkey(struct ksmbd_session *sess,
600 			       struct ksmbd_conn *conn)
601 {
602 	struct derivation d;
603 
604 	d.label.iov_base = "SMB2AESCMAC";
605 	d.label.iov_len = 12;
606 	d.context.iov_base = "SmbSign";
607 	d.context.iov_len = 8;
608 	d.binding = conn->binding;
609 
610 	return generate_smb3signingkey(sess, conn, &d);
611 }
612 
613 int ksmbd_gen_smb311_signingkey(struct ksmbd_session *sess,
614 				struct ksmbd_conn *conn)
615 {
616 	struct derivation d;
617 
618 	d.label.iov_base = "SMBSigningKey";
619 	d.label.iov_len = 14;
620 	if (conn->binding) {
621 		struct preauth_session *preauth_sess;
622 
623 		preauth_sess = ksmbd_preauth_session_lookup(conn, sess->id);
624 		if (!preauth_sess)
625 			return -ENOENT;
626 		d.context.iov_base = preauth_sess->Preauth_HashValue;
627 	} else {
628 		d.context.iov_base = sess->Preauth_HashValue;
629 	}
630 	d.context.iov_len = 64;
631 	d.binding = conn->binding;
632 
633 	return generate_smb3signingkey(sess, conn, &d);
634 }
635 
636 struct derivation_twin {
637 	struct derivation encryption;
638 	struct derivation decryption;
639 };
640 
641 static void generate_smb3encryptionkey(struct ksmbd_conn *conn,
642 				       struct ksmbd_session *sess,
643 				       const struct derivation_twin *ptwin)
644 {
645 	generate_key(conn, sess, ptwin->encryption.label,
646 		     ptwin->encryption.context, sess->smb3encryptionkey,
647 		     SMB3_ENC_DEC_KEY_SIZE);
648 
649 	generate_key(conn, sess, ptwin->decryption.label,
650 		     ptwin->decryption.context,
651 		     sess->smb3decryptionkey, SMB3_ENC_DEC_KEY_SIZE);
652 
653 	ksmbd_debug(AUTH, "dumping generated AES encryption keys\n");
654 	ksmbd_debug(AUTH, "Cipher type   %d\n", conn->cipher_type);
655 	ksmbd_debug(AUTH, "Session Id    %llu\n", sess->id);
656 	ksmbd_debug(AUTH, "Session Key   %*ph\n",
657 		    SMB2_NTLMV2_SESSKEY_SIZE, sess->sess_key);
658 	if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
659 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
660 		ksmbd_debug(AUTH, "ServerIn Key  %*ph\n",
661 			    SMB3_GCM256_CRYPTKEY_SIZE, sess->smb3encryptionkey);
662 		ksmbd_debug(AUTH, "ServerOut Key %*ph\n",
663 			    SMB3_GCM256_CRYPTKEY_SIZE, sess->smb3decryptionkey);
664 	} else {
665 		ksmbd_debug(AUTH, "ServerIn Key  %*ph\n",
666 			    SMB3_GCM128_CRYPTKEY_SIZE, sess->smb3encryptionkey);
667 		ksmbd_debug(AUTH, "ServerOut Key %*ph\n",
668 			    SMB3_GCM128_CRYPTKEY_SIZE, sess->smb3decryptionkey);
669 	}
670 }
671 
672 void ksmbd_gen_smb30_encryptionkey(struct ksmbd_conn *conn,
673 				   struct ksmbd_session *sess)
674 {
675 	struct derivation_twin twin;
676 	struct derivation *d;
677 
678 	d = &twin.encryption;
679 	d->label.iov_base = "SMB2AESCCM";
680 	d->label.iov_len = 11;
681 	d->context.iov_base = "ServerOut";
682 	d->context.iov_len = 10;
683 
684 	d = &twin.decryption;
685 	d->label.iov_base = "SMB2AESCCM";
686 	d->label.iov_len = 11;
687 	d->context.iov_base = "ServerIn ";
688 	d->context.iov_len = 10;
689 
690 	generate_smb3encryptionkey(conn, sess, &twin);
691 }
692 
693 void ksmbd_gen_smb311_encryptionkey(struct ksmbd_conn *conn,
694 				    struct ksmbd_session *sess)
695 {
696 	struct derivation_twin twin;
697 	struct derivation *d;
698 
699 	d = &twin.encryption;
700 	d->label.iov_base = "SMBS2CCipherKey";
701 	d->label.iov_len = 16;
702 	d->context.iov_base = sess->Preauth_HashValue;
703 	d->context.iov_len = 64;
704 
705 	d = &twin.decryption;
706 	d->label.iov_base = "SMBC2SCipherKey";
707 	d->label.iov_len = 16;
708 	d->context.iov_base = sess->Preauth_HashValue;
709 	d->context.iov_len = 64;
710 
711 	generate_smb3encryptionkey(conn, sess, &twin);
712 }
713 
714 int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf,
715 				     __u8 *pi_hash)
716 {
717 	struct smb2_hdr *rcv_hdr = smb2_get_msg(buf);
718 	char *all_bytes_msg = (char *)&rcv_hdr->ProtocolId;
719 	int msg_size = get_rfc1002_len(buf);
720 	struct sha512_ctx sha_ctx;
721 
722 	if (conn->preauth_info->Preauth_HashId !=
723 	    SMB2_PREAUTH_INTEGRITY_SHA512)
724 		return -EINVAL;
725 
726 	sha512_init(&sha_ctx);
727 	sha512_update(&sha_ctx, pi_hash, 64);
728 	sha512_update(&sha_ctx, all_bytes_msg, msg_size);
729 	sha512_final(&sha_ctx, pi_hash);
730 	return 0;
731 }
732 
733 static int ksmbd_get_encryption_key(struct ksmbd_work *work, __u64 ses_id,
734 				    int enc, u8 *key)
735 {
736 	struct ksmbd_session *sess;
737 	u8 *ses_enc_key;
738 
739 	if (enc)
740 		sess = work->sess;
741 	else
742 		sess = ksmbd_session_lookup_all(work->conn, ses_id);
743 	if (!sess)
744 		return -EINVAL;
745 
746 	ses_enc_key = enc ? sess->smb3encryptionkey :
747 		sess->smb3decryptionkey;
748 	memcpy(key, ses_enc_key, SMB3_ENC_DEC_KEY_SIZE);
749 	if (!enc)
750 		ksmbd_user_session_put(sess);
751 
752 	return 0;
753 }
754 
755 static inline void smb2_sg_set_buf(struct scatterlist *sg, const void *buf,
756 				   unsigned int buflen)
757 {
758 	void *addr;
759 
760 	if (is_vmalloc_addr(buf))
761 		addr = vmalloc_to_page(buf);
762 	else
763 		addr = virt_to_page(buf);
764 	sg_set_page(sg, addr, buflen, offset_in_page(buf));
765 }
766 
767 static struct scatterlist *ksmbd_init_sg(struct kvec *iov, unsigned int nvec,
768 					 u8 *sign)
769 {
770 	struct scatterlist *sg;
771 	unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
772 	int i, *nr_entries, total_entries = 0, sg_idx = 0;
773 
774 	if (!nvec)
775 		return NULL;
776 
777 	nr_entries = kcalloc(nvec, sizeof(int), KSMBD_DEFAULT_GFP);
778 	if (!nr_entries)
779 		return NULL;
780 
781 	for (i = 0; i < nvec - 1; i++) {
782 		unsigned long kaddr = (unsigned long)iov[i + 1].iov_base;
783 
784 		if (is_vmalloc_addr(iov[i + 1].iov_base)) {
785 			nr_entries[i] = ((kaddr + iov[i + 1].iov_len +
786 					PAGE_SIZE - 1) >> PAGE_SHIFT) -
787 				(kaddr >> PAGE_SHIFT);
788 		} else {
789 			nr_entries[i]++;
790 		}
791 		total_entries += nr_entries[i];
792 	}
793 
794 	/* Add two entries for transform header and signature */
795 	total_entries += 2;
796 
797 	sg = kmalloc_array(total_entries, sizeof(struct scatterlist),
798 			   KSMBD_DEFAULT_GFP);
799 	if (!sg) {
800 		kfree(nr_entries);
801 		return NULL;
802 	}
803 
804 	sg_init_table(sg, total_entries);
805 	smb2_sg_set_buf(&sg[sg_idx++], iov[0].iov_base + 24, assoc_data_len);
806 	for (i = 0; i < nvec - 1; i++) {
807 		void *data = iov[i + 1].iov_base;
808 		int len = iov[i + 1].iov_len;
809 
810 		if (is_vmalloc_addr(data)) {
811 			int j, offset = offset_in_page(data);
812 
813 			for (j = 0; j < nr_entries[i]; j++) {
814 				unsigned int bytes = PAGE_SIZE - offset;
815 
816 				if (!len)
817 					break;
818 
819 				if (bytes > len)
820 					bytes = len;
821 
822 				sg_set_page(&sg[sg_idx++],
823 					    vmalloc_to_page(data), bytes,
824 					    offset_in_page(data));
825 
826 				data += bytes;
827 				len -= bytes;
828 				offset = 0;
829 			}
830 		} else {
831 			sg_set_page(&sg[sg_idx++], virt_to_page(data), len,
832 				    offset_in_page(data));
833 		}
834 	}
835 	smb2_sg_set_buf(&sg[sg_idx], sign, SMB2_SIGNATURE_SIZE);
836 	kfree(nr_entries);
837 	return sg;
838 }
839 
840 int ksmbd_crypt_message(struct ksmbd_work *work, struct kvec *iov,
841 			unsigned int nvec, int enc)
842 {
843 	struct ksmbd_conn *conn = work->conn;
844 	struct smb2_transform_hdr *tr_hdr = smb2_get_msg(iov[0].iov_base);
845 	unsigned int assoc_data_len = sizeof(struct smb2_transform_hdr) - 20;
846 	int rc;
847 	struct scatterlist *sg;
848 	u8 sign[SMB2_SIGNATURE_SIZE] = {};
849 	u8 key[SMB3_ENC_DEC_KEY_SIZE];
850 	struct aead_request *req;
851 	char *iv;
852 	unsigned int iv_len;
853 	struct crypto_aead *tfm;
854 	unsigned int crypt_len = le32_to_cpu(tr_hdr->OriginalMessageSize);
855 	struct ksmbd_crypto_ctx *ctx;
856 
857 	rc = ksmbd_get_encryption_key(work,
858 				      le64_to_cpu(tr_hdr->SessionId),
859 				      enc,
860 				      key);
861 	if (rc) {
862 		pr_err("Could not get %scryption key\n", enc ? "en" : "de");
863 		return rc;
864 	}
865 
866 	if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
867 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
868 		ctx = ksmbd_crypto_ctx_find_gcm();
869 	else
870 		ctx = ksmbd_crypto_ctx_find_ccm();
871 	if (!ctx) {
872 		pr_err("crypto alloc failed\n");
873 		return -ENOMEM;
874 	}
875 
876 	if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
877 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
878 		tfm = CRYPTO_GCM(ctx);
879 	else
880 		tfm = CRYPTO_CCM(ctx);
881 
882 	if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
883 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
884 		rc = crypto_aead_setkey(tfm, key, SMB3_GCM256_CRYPTKEY_SIZE);
885 	else
886 		rc = crypto_aead_setkey(tfm, key, SMB3_GCM128_CRYPTKEY_SIZE);
887 	if (rc) {
888 		pr_err("Failed to set aead key %d\n", rc);
889 		goto free_ctx;
890 	}
891 
892 	rc = crypto_aead_setauthsize(tfm, SMB2_SIGNATURE_SIZE);
893 	if (rc) {
894 		pr_err("Failed to set authsize %d\n", rc);
895 		goto free_ctx;
896 	}
897 
898 	req = aead_request_alloc(tfm, KSMBD_DEFAULT_GFP);
899 	if (!req) {
900 		rc = -ENOMEM;
901 		goto free_ctx;
902 	}
903 
904 	if (!enc) {
905 		memcpy(sign, &tr_hdr->Signature, SMB2_SIGNATURE_SIZE);
906 		crypt_len += SMB2_SIGNATURE_SIZE;
907 	}
908 
909 	sg = ksmbd_init_sg(iov, nvec, sign);
910 	if (!sg) {
911 		pr_err("Failed to init sg\n");
912 		rc = -ENOMEM;
913 		goto free_req;
914 	}
915 
916 	iv_len = crypto_aead_ivsize(tfm);
917 	iv = kzalloc(iv_len, KSMBD_DEFAULT_GFP);
918 	if (!iv) {
919 		rc = -ENOMEM;
920 		goto free_sg;
921 	}
922 
923 	if (conn->cipher_type == SMB2_ENCRYPTION_AES128_GCM ||
924 	    conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
925 		memcpy(iv, (char *)tr_hdr->Nonce, SMB3_AES_GCM_NONCE);
926 	} else {
927 		iv[0] = 3;
928 		memcpy(iv + 1, (char *)tr_hdr->Nonce, SMB3_AES_CCM_NONCE);
929 	}
930 
931 	aead_request_set_crypt(req, sg, sg, crypt_len, iv);
932 	aead_request_set_ad(req, assoc_data_len);
933 	aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL);
934 
935 	if (enc)
936 		rc = crypto_aead_encrypt(req);
937 	else
938 		rc = crypto_aead_decrypt(req);
939 	if (rc)
940 		goto free_iv;
941 
942 	if (enc)
943 		memcpy(&tr_hdr->Signature, sign, SMB2_SIGNATURE_SIZE);
944 
945 free_iv:
946 	kfree(iv);
947 free_sg:
948 	kfree(sg);
949 free_req:
950 	aead_request_free(req);
951 free_ctx:
952 	ksmbd_release_crypto_ctx(ctx);
953 	return rc;
954 }
955