xref: /linux/drivers/nvme/common/auth.c (revision 55a42f78ffd386e01a5404419f8c5ded7db70a21)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (c) 2020 Hannes Reinecke, SUSE Linux
4  */
5 
6 #include <linux/module.h>
7 #include <linux/crc32.h>
8 #include <linux/base64.h>
9 #include <linux/prandom.h>
10 #include <linux/scatterlist.h>
11 #include <linux/unaligned.h>
12 #include <crypto/hash.h>
13 #include <crypto/dh.h>
14 #include <crypto/hkdf.h>
15 #include <linux/nvme.h>
16 #include <linux/nvme-auth.h>
17 
18 #define HKDF_MAX_HASHLEN 64
19 
20 static u32 nvme_dhchap_seqnum;
21 static DEFINE_MUTEX(nvme_dhchap_mutex);
22 
23 u32 nvme_auth_get_seqnum(void)
24 {
25 	u32 seqnum;
26 
27 	mutex_lock(&nvme_dhchap_mutex);
28 	if (!nvme_dhchap_seqnum)
29 		nvme_dhchap_seqnum = get_random_u32();
30 	else {
31 		nvme_dhchap_seqnum++;
32 		if (!nvme_dhchap_seqnum)
33 			nvme_dhchap_seqnum++;
34 	}
35 	seqnum = nvme_dhchap_seqnum;
36 	mutex_unlock(&nvme_dhchap_mutex);
37 	return seqnum;
38 }
39 EXPORT_SYMBOL_GPL(nvme_auth_get_seqnum);
40 
41 static struct nvme_auth_dhgroup_map {
42 	const char name[16];
43 	const char kpp[16];
44 } dhgroup_map[] = {
45 	[NVME_AUTH_DHGROUP_NULL] = {
46 		.name = "null", .kpp = "null" },
47 	[NVME_AUTH_DHGROUP_2048] = {
48 		.name = "ffdhe2048", .kpp = "ffdhe2048(dh)" },
49 	[NVME_AUTH_DHGROUP_3072] = {
50 		.name = "ffdhe3072", .kpp = "ffdhe3072(dh)" },
51 	[NVME_AUTH_DHGROUP_4096] = {
52 		.name = "ffdhe4096", .kpp = "ffdhe4096(dh)" },
53 	[NVME_AUTH_DHGROUP_6144] = {
54 		.name = "ffdhe6144", .kpp = "ffdhe6144(dh)" },
55 	[NVME_AUTH_DHGROUP_8192] = {
56 		.name = "ffdhe8192", .kpp = "ffdhe8192(dh)" },
57 };
58 
59 const char *nvme_auth_dhgroup_name(u8 dhgroup_id)
60 {
61 	if (dhgroup_id >= ARRAY_SIZE(dhgroup_map))
62 		return NULL;
63 	return dhgroup_map[dhgroup_id].name;
64 }
65 EXPORT_SYMBOL_GPL(nvme_auth_dhgroup_name);
66 
67 const char *nvme_auth_dhgroup_kpp(u8 dhgroup_id)
68 {
69 	if (dhgroup_id >= ARRAY_SIZE(dhgroup_map))
70 		return NULL;
71 	return dhgroup_map[dhgroup_id].kpp;
72 }
73 EXPORT_SYMBOL_GPL(nvme_auth_dhgroup_kpp);
74 
75 u8 nvme_auth_dhgroup_id(const char *dhgroup_name)
76 {
77 	int i;
78 
79 	if (!dhgroup_name || !strlen(dhgroup_name))
80 		return NVME_AUTH_DHGROUP_INVALID;
81 	for (i = 0; i < ARRAY_SIZE(dhgroup_map); i++) {
82 		if (!strlen(dhgroup_map[i].name))
83 			continue;
84 		if (!strncmp(dhgroup_map[i].name, dhgroup_name,
85 			     strlen(dhgroup_map[i].name)))
86 			return i;
87 	}
88 	return NVME_AUTH_DHGROUP_INVALID;
89 }
90 EXPORT_SYMBOL_GPL(nvme_auth_dhgroup_id);
91 
92 static struct nvme_dhchap_hash_map {
93 	int len;
94 	const char hmac[15];
95 	const char digest[8];
96 } hash_map[] = {
97 	[NVME_AUTH_HASH_SHA256] = {
98 		.len = 32,
99 		.hmac = "hmac(sha256)",
100 		.digest = "sha256",
101 	},
102 	[NVME_AUTH_HASH_SHA384] = {
103 		.len = 48,
104 		.hmac = "hmac(sha384)",
105 		.digest = "sha384",
106 	},
107 	[NVME_AUTH_HASH_SHA512] = {
108 		.len = 64,
109 		.hmac = "hmac(sha512)",
110 		.digest = "sha512",
111 	},
112 };
113 
114 const char *nvme_auth_hmac_name(u8 hmac_id)
115 {
116 	if (hmac_id >= ARRAY_SIZE(hash_map))
117 		return NULL;
118 	return hash_map[hmac_id].hmac;
119 }
120 EXPORT_SYMBOL_GPL(nvme_auth_hmac_name);
121 
122 const char *nvme_auth_digest_name(u8 hmac_id)
123 {
124 	if (hmac_id >= ARRAY_SIZE(hash_map))
125 		return NULL;
126 	return hash_map[hmac_id].digest;
127 }
128 EXPORT_SYMBOL_GPL(nvme_auth_digest_name);
129 
130 u8 nvme_auth_hmac_id(const char *hmac_name)
131 {
132 	int i;
133 
134 	if (!hmac_name || !strlen(hmac_name))
135 		return NVME_AUTH_HASH_INVALID;
136 
137 	for (i = 0; i < ARRAY_SIZE(hash_map); i++) {
138 		if (!strlen(hash_map[i].hmac))
139 			continue;
140 		if (!strncmp(hash_map[i].hmac, hmac_name,
141 			     strlen(hash_map[i].hmac)))
142 			return i;
143 	}
144 	return NVME_AUTH_HASH_INVALID;
145 }
146 EXPORT_SYMBOL_GPL(nvme_auth_hmac_id);
147 
148 size_t nvme_auth_hmac_hash_len(u8 hmac_id)
149 {
150 	if (hmac_id >= ARRAY_SIZE(hash_map))
151 		return 0;
152 	return hash_map[hmac_id].len;
153 }
154 EXPORT_SYMBOL_GPL(nvme_auth_hmac_hash_len);
155 
156 u32 nvme_auth_key_struct_size(u32 key_len)
157 {
158 	struct nvme_dhchap_key key;
159 
160 	return struct_size(&key, key, key_len);
161 }
162 EXPORT_SYMBOL_GPL(nvme_auth_key_struct_size);
163 
164 struct nvme_dhchap_key *nvme_auth_extract_key(unsigned char *secret,
165 					      u8 key_hash)
166 {
167 	struct nvme_dhchap_key *key;
168 	unsigned char *p;
169 	u32 crc;
170 	int ret, key_len;
171 	size_t allocated_len = strlen(secret);
172 
173 	/* Secret might be affixed with a ':' */
174 	p = strrchr(secret, ':');
175 	if (p)
176 		allocated_len = p - secret;
177 	key = nvme_auth_alloc_key(allocated_len, 0);
178 	if (!key)
179 		return ERR_PTR(-ENOMEM);
180 
181 	key_len = base64_decode(secret, allocated_len, key->key);
182 	if (key_len < 0) {
183 		pr_debug("base64 key decoding error %d\n",
184 			 key_len);
185 		ret = key_len;
186 		goto out_free_secret;
187 	}
188 
189 	if (key_len != 36 && key_len != 52 &&
190 	    key_len != 68) {
191 		pr_err("Invalid key len %d\n", key_len);
192 		ret = -EINVAL;
193 		goto out_free_secret;
194 	}
195 
196 	/* The last four bytes is the CRC in little-endian format */
197 	key_len -= 4;
198 	/*
199 	 * The linux implementation doesn't do pre- and post-increments,
200 	 * so we have to do it manually.
201 	 */
202 	crc = ~crc32(~0, key->key, key_len);
203 
204 	if (get_unaligned_le32(key->key + key_len) != crc) {
205 		pr_err("key crc mismatch (key %08x, crc %08x)\n",
206 		       get_unaligned_le32(key->key + key_len), crc);
207 		ret = -EKEYREJECTED;
208 		goto out_free_secret;
209 	}
210 	key->len = key_len;
211 	key->hash = key_hash;
212 	return key;
213 out_free_secret:
214 	nvme_auth_free_key(key);
215 	return ERR_PTR(ret);
216 }
217 EXPORT_SYMBOL_GPL(nvme_auth_extract_key);
218 
219 struct nvme_dhchap_key *nvme_auth_alloc_key(u32 len, u8 hash)
220 {
221 	u32 num_bytes = nvme_auth_key_struct_size(len);
222 	struct nvme_dhchap_key *key = kzalloc(num_bytes, GFP_KERNEL);
223 
224 	if (key) {
225 		key->len = len;
226 		key->hash = hash;
227 	}
228 	return key;
229 }
230 EXPORT_SYMBOL_GPL(nvme_auth_alloc_key);
231 
232 void nvme_auth_free_key(struct nvme_dhchap_key *key)
233 {
234 	if (!key)
235 		return;
236 	kfree_sensitive(key);
237 }
238 EXPORT_SYMBOL_GPL(nvme_auth_free_key);
239 
240 struct nvme_dhchap_key *nvme_auth_transform_key(
241 		struct nvme_dhchap_key *key, char *nqn)
242 {
243 	const char *hmac_name;
244 	struct crypto_shash *key_tfm;
245 	SHASH_DESC_ON_STACK(shash, key_tfm);
246 	struct nvme_dhchap_key *transformed_key;
247 	int ret, key_len;
248 
249 	if (!key) {
250 		pr_warn("No key specified\n");
251 		return ERR_PTR(-ENOKEY);
252 	}
253 	if (key->hash == 0) {
254 		key_len = nvme_auth_key_struct_size(key->len);
255 		transformed_key = kmemdup(key, key_len, GFP_KERNEL);
256 		if (!transformed_key)
257 			return ERR_PTR(-ENOMEM);
258 		return transformed_key;
259 	}
260 	hmac_name = nvme_auth_hmac_name(key->hash);
261 	if (!hmac_name) {
262 		pr_warn("Invalid key hash id %d\n", key->hash);
263 		return ERR_PTR(-EINVAL);
264 	}
265 
266 	key_tfm = crypto_alloc_shash(hmac_name, 0, 0);
267 	if (IS_ERR(key_tfm))
268 		return ERR_CAST(key_tfm);
269 
270 	key_len = crypto_shash_digestsize(key_tfm);
271 	transformed_key = nvme_auth_alloc_key(key_len, key->hash);
272 	if (!transformed_key) {
273 		ret = -ENOMEM;
274 		goto out_free_key;
275 	}
276 
277 	shash->tfm = key_tfm;
278 	ret = crypto_shash_setkey(key_tfm, key->key, key->len);
279 	if (ret < 0)
280 		goto out_free_transformed_key;
281 	ret = crypto_shash_init(shash);
282 	if (ret < 0)
283 		goto out_free_transformed_key;
284 	ret = crypto_shash_update(shash, nqn, strlen(nqn));
285 	if (ret < 0)
286 		goto out_free_transformed_key;
287 	ret = crypto_shash_update(shash, "NVMe-over-Fabrics", 17);
288 	if (ret < 0)
289 		goto out_free_transformed_key;
290 	ret = crypto_shash_final(shash, transformed_key->key);
291 	if (ret < 0)
292 		goto out_free_transformed_key;
293 
294 	crypto_free_shash(key_tfm);
295 
296 	return transformed_key;
297 
298 out_free_transformed_key:
299 	nvme_auth_free_key(transformed_key);
300 out_free_key:
301 	crypto_free_shash(key_tfm);
302 
303 	return ERR_PTR(ret);
304 }
305 EXPORT_SYMBOL_GPL(nvme_auth_transform_key);
306 
307 static int nvme_auth_hash_skey(int hmac_id, u8 *skey, size_t skey_len, u8 *hkey)
308 {
309 	const char *digest_name;
310 	struct crypto_shash *tfm;
311 	int ret;
312 
313 	digest_name = nvme_auth_digest_name(hmac_id);
314 	if (!digest_name) {
315 		pr_debug("%s: failed to get digest for %d\n", __func__,
316 			 hmac_id);
317 		return -EINVAL;
318 	}
319 	tfm = crypto_alloc_shash(digest_name, 0, 0);
320 	if (IS_ERR(tfm))
321 		return -ENOMEM;
322 
323 	ret = crypto_shash_tfm_digest(tfm, skey, skey_len, hkey);
324 	if (ret < 0)
325 		pr_debug("%s: Failed to hash digest len %zu\n", __func__,
326 			 skey_len);
327 
328 	crypto_free_shash(tfm);
329 	return ret;
330 }
331 
332 int nvme_auth_augmented_challenge(u8 hmac_id, u8 *skey, size_t skey_len,
333 		u8 *challenge, u8 *aug, size_t hlen)
334 {
335 	struct crypto_shash *tfm;
336 	u8 *hashed_key;
337 	const char *hmac_name;
338 	int ret;
339 
340 	hashed_key = kmalloc(hlen, GFP_KERNEL);
341 	if (!hashed_key)
342 		return -ENOMEM;
343 
344 	ret = nvme_auth_hash_skey(hmac_id, skey,
345 				  skey_len, hashed_key);
346 	if (ret < 0)
347 		goto out_free_key;
348 
349 	hmac_name = nvme_auth_hmac_name(hmac_id);
350 	if (!hmac_name) {
351 		pr_warn("%s: invalid hash algorithm %d\n",
352 			__func__, hmac_id);
353 		ret = -EINVAL;
354 		goto out_free_key;
355 	}
356 
357 	tfm = crypto_alloc_shash(hmac_name, 0, 0);
358 	if (IS_ERR(tfm)) {
359 		ret = PTR_ERR(tfm);
360 		goto out_free_key;
361 	}
362 
363 	ret = crypto_shash_setkey(tfm, hashed_key, hlen);
364 	if (ret)
365 		goto out_free_hash;
366 
367 	ret = crypto_shash_tfm_digest(tfm, challenge, hlen, aug);
368 out_free_hash:
369 	crypto_free_shash(tfm);
370 out_free_key:
371 	kfree_sensitive(hashed_key);
372 	return ret;
373 }
374 EXPORT_SYMBOL_GPL(nvme_auth_augmented_challenge);
375 
376 int nvme_auth_gen_privkey(struct crypto_kpp *dh_tfm, u8 dh_gid)
377 {
378 	int ret;
379 
380 	ret = crypto_kpp_set_secret(dh_tfm, NULL, 0);
381 	if (ret)
382 		pr_debug("failed to set private key, error %d\n", ret);
383 
384 	return ret;
385 }
386 EXPORT_SYMBOL_GPL(nvme_auth_gen_privkey);
387 
388 int nvme_auth_gen_pubkey(struct crypto_kpp *dh_tfm,
389 		u8 *host_key, size_t host_key_len)
390 {
391 	struct kpp_request *req;
392 	struct crypto_wait wait;
393 	struct scatterlist dst;
394 	int ret;
395 
396 	req = kpp_request_alloc(dh_tfm, GFP_KERNEL);
397 	if (!req)
398 		return -ENOMEM;
399 
400 	crypto_init_wait(&wait);
401 	kpp_request_set_input(req, NULL, 0);
402 	sg_init_one(&dst, host_key, host_key_len);
403 	kpp_request_set_output(req, &dst, host_key_len);
404 	kpp_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
405 				 crypto_req_done, &wait);
406 
407 	ret = crypto_wait_req(crypto_kpp_generate_public_key(req), &wait);
408 	kpp_request_free(req);
409 	return ret;
410 }
411 EXPORT_SYMBOL_GPL(nvme_auth_gen_pubkey);
412 
413 int nvme_auth_gen_shared_secret(struct crypto_kpp *dh_tfm,
414 		u8 *ctrl_key, size_t ctrl_key_len,
415 		u8 *sess_key, size_t sess_key_len)
416 {
417 	struct kpp_request *req;
418 	struct crypto_wait wait;
419 	struct scatterlist src, dst;
420 	int ret;
421 
422 	req = kpp_request_alloc(dh_tfm, GFP_KERNEL);
423 	if (!req)
424 		return -ENOMEM;
425 
426 	crypto_init_wait(&wait);
427 	sg_init_one(&src, ctrl_key, ctrl_key_len);
428 	kpp_request_set_input(req, &src, ctrl_key_len);
429 	sg_init_one(&dst, sess_key, sess_key_len);
430 	kpp_request_set_output(req, &dst, sess_key_len);
431 	kpp_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
432 				 crypto_req_done, &wait);
433 
434 	ret = crypto_wait_req(crypto_kpp_compute_shared_secret(req), &wait);
435 
436 	kpp_request_free(req);
437 	return ret;
438 }
439 EXPORT_SYMBOL_GPL(nvme_auth_gen_shared_secret);
440 
441 int nvme_auth_generate_key(u8 *secret, struct nvme_dhchap_key **ret_key)
442 {
443 	struct nvme_dhchap_key *key;
444 	u8 key_hash;
445 
446 	if (!secret) {
447 		*ret_key = NULL;
448 		return 0;
449 	}
450 
451 	if (sscanf(secret, "DHHC-1:%hhd:%*s:", &key_hash) != 1)
452 		return -EINVAL;
453 
454 	/* Pass in the secret without the 'DHHC-1:XX:' prefix */
455 	key = nvme_auth_extract_key(secret + 10, key_hash);
456 	if (IS_ERR(key)) {
457 		*ret_key = NULL;
458 		return PTR_ERR(key);
459 	}
460 
461 	*ret_key = key;
462 	return 0;
463 }
464 EXPORT_SYMBOL_GPL(nvme_auth_generate_key);
465 
466 /**
467  * nvme_auth_generate_psk - Generate a PSK for TLS
468  * @hmac_id: Hash function identifier
469  * @skey: Session key
470  * @skey_len: Length of @skey
471  * @c1: Value of challenge C1
472  * @c2: Value of challenge C2
473  * @hash_len: Hash length of the hash algorithm
474  * @ret_psk: Pointer to the resulting generated PSK
475  * @ret_len: length of @ret_psk
476  *
477  * Generate a PSK for TLS as specified in NVMe base specification, section
478  * 8.13.5.9: Generated PSK for TLS
479  *
480  * The generated PSK for TLS shall be computed applying the HMAC function
481  * using the hash function H( ) selected by the HashID parameter in the
482  * DH-HMAC-CHAP_Challenge message with the session key KS as key to the
483  * concatenation of the two challenges C1 and C2 (i.e., generated
484  * PSK = HMAC(KS, C1 || C2)).
485  *
486  * Returns 0 on success with a valid generated PSK pointer in @ret_psk and
487  * the length of @ret_psk in @ret_len, or a negative error number otherwise.
488  */
489 int nvme_auth_generate_psk(u8 hmac_id, u8 *skey, size_t skey_len,
490 		u8 *c1, u8 *c2, size_t hash_len, u8 **ret_psk, size_t *ret_len)
491 {
492 	struct crypto_shash *tfm;
493 	SHASH_DESC_ON_STACK(shash, tfm);
494 	u8 *psk;
495 	const char *hmac_name;
496 	int ret, psk_len;
497 
498 	if (!c1 || !c2)
499 		return -EINVAL;
500 
501 	hmac_name = nvme_auth_hmac_name(hmac_id);
502 	if (!hmac_name) {
503 		pr_warn("%s: invalid hash algorithm %d\n",
504 			__func__, hmac_id);
505 		return -EINVAL;
506 	}
507 
508 	tfm = crypto_alloc_shash(hmac_name, 0, 0);
509 	if (IS_ERR(tfm))
510 		return PTR_ERR(tfm);
511 
512 	psk_len = crypto_shash_digestsize(tfm);
513 	psk = kzalloc(psk_len, GFP_KERNEL);
514 	if (!psk) {
515 		ret = -ENOMEM;
516 		goto out_free_tfm;
517 	}
518 
519 	shash->tfm = tfm;
520 	ret = crypto_shash_setkey(tfm, skey, skey_len);
521 	if (ret)
522 		goto out_free_psk;
523 
524 	ret = crypto_shash_init(shash);
525 	if (ret)
526 		goto out_free_psk;
527 
528 	ret = crypto_shash_update(shash, c1, hash_len);
529 	if (ret)
530 		goto out_free_psk;
531 
532 	ret = crypto_shash_update(shash, c2, hash_len);
533 	if (ret)
534 		goto out_free_psk;
535 
536 	ret = crypto_shash_final(shash, psk);
537 	if (!ret) {
538 		*ret_psk = psk;
539 		*ret_len = psk_len;
540 	}
541 
542 out_free_psk:
543 	if (ret)
544 		kfree_sensitive(psk);
545 out_free_tfm:
546 	crypto_free_shash(tfm);
547 
548 	return ret;
549 }
550 EXPORT_SYMBOL_GPL(nvme_auth_generate_psk);
551 
552 /**
553  * nvme_auth_generate_digest - Generate TLS PSK digest
554  * @hmac_id: Hash function identifier
555  * @psk: Generated input PSK
556  * @psk_len: Length of @psk
557  * @subsysnqn: NQN of the subsystem
558  * @hostnqn: NQN of the host
559  * @ret_digest: Pointer to the returned digest
560  *
561  * Generate a TLS PSK digest as specified in TP8018 Section 3.6.1.3:
562  *   TLS PSK and PSK identity Derivation
563  *
564  * The PSK digest shall be computed by encoding in Base64 (refer to RFC
565  * 4648) the result of the application of the HMAC function using the hash
566  * function specified in item 4 above (ie the hash function of the cipher
567  * suite associated with the PSK identity) with the PSK as HMAC key to the
568  * concatenation of:
569  * - the NQN of the host (i.e., NQNh) not including the null terminator;
570  * - a space character;
571  * - the NQN of the NVM subsystem (i.e., NQNc) not including the null
572  *   terminator;
573  * - a space character; and
574  * - the seventeen ASCII characters "NVMe-over-Fabrics"
575  * (i.e., <PSK digest> = Base64(HMAC(PSK, NQNh || " " || NQNc || " " ||
576  *  "NVMe-over-Fabrics"))).
577  * The length of the PSK digest depends on the hash function used to compute
578  * it as follows:
579  * - If the SHA-256 hash function is used, the resulting PSK digest is 44
580  *   characters long; or
581  * - If the SHA-384 hash function is used, the resulting PSK digest is 64
582  *   characters long.
583  *
584  * Returns 0 on success with a valid digest pointer in @ret_digest, or a
585  * negative error number on failure.
586  */
587 int nvme_auth_generate_digest(u8 hmac_id, u8 *psk, size_t psk_len,
588 		char *subsysnqn, char *hostnqn, u8 **ret_digest)
589 {
590 	struct crypto_shash *tfm;
591 	SHASH_DESC_ON_STACK(shash, tfm);
592 	u8 *digest, *enc;
593 	const char *hmac_name;
594 	size_t digest_len, hmac_len;
595 	int ret;
596 
597 	if (WARN_ON(!subsysnqn || !hostnqn))
598 		return -EINVAL;
599 
600 	hmac_name = nvme_auth_hmac_name(hmac_id);
601 	if (!hmac_name) {
602 		pr_warn("%s: invalid hash algorithm %d\n",
603 			__func__, hmac_id);
604 		return -EINVAL;
605 	}
606 
607 	switch (nvme_auth_hmac_hash_len(hmac_id)) {
608 	case 32:
609 		hmac_len = 44;
610 		break;
611 	case 48:
612 		hmac_len = 64;
613 		break;
614 	default:
615 		pr_warn("%s: invalid hash algorithm '%s'\n",
616 			__func__, hmac_name);
617 		return -EINVAL;
618 	}
619 
620 	enc = kzalloc(hmac_len + 1, GFP_KERNEL);
621 	if (!enc)
622 		return -ENOMEM;
623 
624 	tfm = crypto_alloc_shash(hmac_name, 0, 0);
625 	if (IS_ERR(tfm)) {
626 		ret = PTR_ERR(tfm);
627 		goto out_free_enc;
628 	}
629 
630 	digest_len = crypto_shash_digestsize(tfm);
631 	digest = kzalloc(digest_len, GFP_KERNEL);
632 	if (!digest) {
633 		ret = -ENOMEM;
634 		goto out_free_tfm;
635 	}
636 
637 	shash->tfm = tfm;
638 	ret = crypto_shash_setkey(tfm, psk, psk_len);
639 	if (ret)
640 		goto out_free_digest;
641 
642 	ret = crypto_shash_init(shash);
643 	if (ret)
644 		goto out_free_digest;
645 
646 	ret = crypto_shash_update(shash, hostnqn, strlen(hostnqn));
647 	if (ret)
648 		goto out_free_digest;
649 
650 	ret = crypto_shash_update(shash, " ", 1);
651 	if (ret)
652 		goto out_free_digest;
653 
654 	ret = crypto_shash_update(shash, subsysnqn, strlen(subsysnqn));
655 	if (ret)
656 		goto out_free_digest;
657 
658 	ret = crypto_shash_update(shash, " NVMe-over-Fabrics", 18);
659 	if (ret)
660 		goto out_free_digest;
661 
662 	ret = crypto_shash_final(shash, digest);
663 	if (ret)
664 		goto out_free_digest;
665 
666 	ret = base64_encode(digest, digest_len, enc);
667 	if (ret < hmac_len) {
668 		ret = -ENOKEY;
669 		goto out_free_digest;
670 	}
671 	*ret_digest = enc;
672 	ret = 0;
673 
674 out_free_digest:
675 	kfree_sensitive(digest);
676 out_free_tfm:
677 	crypto_free_shash(tfm);
678 out_free_enc:
679 	if (ret)
680 		kfree_sensitive(enc);
681 
682 	return ret;
683 }
684 EXPORT_SYMBOL_GPL(nvme_auth_generate_digest);
685 
686 /**
687  * hkdf_expand_label - HKDF-Expand-Label (RFC 8846 section 7.1)
688  * @hmac_tfm: hash context keyed with pseudorandom key
689  * @label: ASCII label without "tls13 " prefix
690  * @labellen: length of @label
691  * @context: context bytes
692  * @contextlen: length of @context
693  * @okm: output keying material
694  * @okmlen: length of @okm
695  *
696  * Build the TLS 1.3 HkdfLabel structure and invoke hkdf_expand().
697  *
698  * Returns 0 on success with output keying material stored in @okm,
699  * or a negative errno value otherwise.
700  */
701 static int hkdf_expand_label(struct crypto_shash *hmac_tfm,
702 		const u8 *label, unsigned int labellen,
703 		const u8 *context, unsigned int contextlen,
704 		u8 *okm, unsigned int okmlen)
705 {
706 	int err;
707 	u8 *info;
708 	unsigned int infolen;
709 	const char *tls13_prefix = "tls13 ";
710 	unsigned int prefixlen = strlen(tls13_prefix);
711 
712 	if (WARN_ON(labellen > (255 - prefixlen)))
713 		return -EINVAL;
714 	if (WARN_ON(contextlen > 255))
715 		return -EINVAL;
716 
717 	infolen = 2 + (1 + prefixlen + labellen) + (1 + contextlen);
718 	info = kzalloc(infolen, GFP_KERNEL);
719 	if (!info)
720 		return -ENOMEM;
721 
722 	/* HkdfLabel.Length */
723 	put_unaligned_be16(okmlen, info);
724 
725 	/* HkdfLabel.Label */
726 	info[2] = prefixlen + labellen;
727 	memcpy(info + 3, tls13_prefix, prefixlen);
728 	memcpy(info + 3 + prefixlen, label, labellen);
729 
730 	/* HkdfLabel.Context */
731 	info[3 + prefixlen + labellen] = contextlen;
732 	memcpy(info + 4 + prefixlen + labellen, context, contextlen);
733 
734 	err = hkdf_expand(hmac_tfm, info, infolen, okm, okmlen);
735 	kfree_sensitive(info);
736 	return err;
737 }
738 
739 /**
740  * nvme_auth_derive_tls_psk - Derive TLS PSK
741  * @hmac_id: Hash function identifier
742  * @psk: generated input PSK
743  * @psk_len: size of @psk
744  * @psk_digest: TLS PSK digest
745  * @ret_psk: Pointer to the resulting TLS PSK
746  *
747  * Derive a TLS PSK as specified in TP8018 Section 3.6.1.3:
748  *   TLS PSK and PSK identity Derivation
749  *
750  * The TLS PSK shall be derived as follows from an input PSK
751  * (i.e., either a retained PSK or a generated PSK) and a PSK
752  * identity using the HKDF-Extract and HKDF-Expand-Label operations
753  * (refer to RFC 5869 and RFC 8446) where the hash function is the
754  * one specified by the hash specifier of the PSK identity:
755  * 1. PRK = HKDF-Extract(0, Input PSK); and
756  * 2. TLS PSK = HKDF-Expand-Label(PRK, "nvme-tls-psk", PskIdentityContext, L),
757  * where PskIdentityContext is the hash identifier indicated in
758  * the PSK identity concatenated to a space character and to the
759  * Base64 PSK digest (i.e., "<hash> <PSK digest>") and L is the
760  * output size in bytes of the hash function (i.e., 32 for SHA-256
761  * and 48 for SHA-384).
762  *
763  * Returns 0 on success with a valid psk pointer in @ret_psk or a negative
764  * error number otherwise.
765  */
766 int nvme_auth_derive_tls_psk(int hmac_id, u8 *psk, size_t psk_len,
767 		u8 *psk_digest, u8 **ret_psk)
768 {
769 	struct crypto_shash *hmac_tfm;
770 	const char *hmac_name;
771 	const char *label = "nvme-tls-psk";
772 	static const char default_salt[HKDF_MAX_HASHLEN];
773 	size_t prk_len;
774 	const char *ctx;
775 	unsigned char *prk, *tls_key;
776 	int ret;
777 
778 	hmac_name = nvme_auth_hmac_name(hmac_id);
779 	if (!hmac_name) {
780 		pr_warn("%s: invalid hash algorithm %d\n",
781 			__func__, hmac_id);
782 		return -EINVAL;
783 	}
784 	if (hmac_id == NVME_AUTH_HASH_SHA512) {
785 		pr_warn("%s: unsupported hash algorithm %s\n",
786 			__func__, hmac_name);
787 		return -EINVAL;
788 	}
789 
790 	hmac_tfm = crypto_alloc_shash(hmac_name, 0, 0);
791 	if (IS_ERR(hmac_tfm))
792 		return PTR_ERR(hmac_tfm);
793 
794 	prk_len = crypto_shash_digestsize(hmac_tfm);
795 	prk = kzalloc(prk_len, GFP_KERNEL);
796 	if (!prk) {
797 		ret = -ENOMEM;
798 		goto out_free_shash;
799 	}
800 
801 	if (WARN_ON(prk_len > HKDF_MAX_HASHLEN)) {
802 		ret = -EINVAL;
803 		goto out_free_prk;
804 	}
805 	ret = hkdf_extract(hmac_tfm, psk, psk_len,
806 			   default_salt, prk_len, prk);
807 	if (ret)
808 		goto out_free_prk;
809 
810 	ret = crypto_shash_setkey(hmac_tfm, prk, prk_len);
811 	if (ret)
812 		goto out_free_prk;
813 
814 	ctx = kasprintf(GFP_KERNEL, "%02d %s", hmac_id, psk_digest);
815 	if (!ctx) {
816 		ret = -ENOMEM;
817 		goto out_free_prk;
818 	}
819 
820 	tls_key = kzalloc(psk_len, GFP_KERNEL);
821 	if (!tls_key) {
822 		ret = -ENOMEM;
823 		goto out_free_ctx;
824 	}
825 	ret = hkdf_expand_label(hmac_tfm,
826 				label, strlen(label),
827 				ctx, strlen(ctx),
828 				tls_key, psk_len);
829 	if (ret) {
830 		kfree(tls_key);
831 		goto out_free_ctx;
832 	}
833 	*ret_psk = tls_key;
834 
835 out_free_ctx:
836 	kfree(ctx);
837 out_free_prk:
838 	kfree(prk);
839 out_free_shash:
840 	crypto_free_shash(hmac_tfm);
841 
842 	return ret;
843 }
844 EXPORT_SYMBOL_GPL(nvme_auth_derive_tls_psk);
845 
846 MODULE_DESCRIPTION("NVMe Authentication framework");
847 MODULE_LICENSE("GPL v2");
848