xref: /linux/fs/bcachefs/checksum.c (revision 566ab427f827b0256d3e8ce0235d088e6a9c28bd)
1 // SPDX-License-Identifier: GPL-2.0
2 #include "bcachefs.h"
3 #include "checksum.h"
4 #include "errcode.h"
5 #include "super.h"
6 #include "super-io.h"
7 
8 #include <linux/crc32c.h>
9 #include <linux/crypto.h>
10 #include <linux/xxhash.h>
11 #include <linux/key.h>
12 #include <linux/random.h>
13 #include <linux/ratelimit.h>
14 #include <linux/scatterlist.h>
15 #include <crypto/algapi.h>
16 #include <crypto/chacha.h>
17 #include <crypto/hash.h>
18 #include <crypto/poly1305.h>
19 #include <crypto/skcipher.h>
20 #include <keys/user-type.h>
21 
22 /*
23  * bch2_checksum state is an abstraction of the checksum state calculated over different pages.
24  * it features page merging without having the checksum algorithm lose its state.
25  * for native checksum aglorithms (like crc), a default seed value will do.
26  * for hash-like algorithms, a state needs to be stored
27  */
28 
29 struct bch2_checksum_state {
30 	union {
31 		u64 seed;
32 		struct xxh64_state h64state;
33 	};
34 	unsigned int type;
35 };
36 
37 static void bch2_checksum_init(struct bch2_checksum_state *state)
38 {
39 	switch (state->type) {
40 	case BCH_CSUM_none:
41 	case BCH_CSUM_crc32c:
42 	case BCH_CSUM_crc64:
43 		state->seed = 0;
44 		break;
45 	case BCH_CSUM_crc32c_nonzero:
46 		state->seed = U32_MAX;
47 		break;
48 	case BCH_CSUM_crc64_nonzero:
49 		state->seed = U64_MAX;
50 		break;
51 	case BCH_CSUM_xxhash:
52 		xxh64_reset(&state->h64state, 0);
53 		break;
54 	default:
55 		BUG();
56 	}
57 }
58 
59 static u64 bch2_checksum_final(const struct bch2_checksum_state *state)
60 {
61 	switch (state->type) {
62 	case BCH_CSUM_none:
63 	case BCH_CSUM_crc32c:
64 	case BCH_CSUM_crc64:
65 		return state->seed;
66 	case BCH_CSUM_crc32c_nonzero:
67 		return state->seed ^ U32_MAX;
68 	case BCH_CSUM_crc64_nonzero:
69 		return state->seed ^ U64_MAX;
70 	case BCH_CSUM_xxhash:
71 		return xxh64_digest(&state->h64state);
72 	default:
73 		BUG();
74 	}
75 }
76 
77 static void bch2_checksum_update(struct bch2_checksum_state *state, const void *data, size_t len)
78 {
79 	switch (state->type) {
80 	case BCH_CSUM_none:
81 		return;
82 	case BCH_CSUM_crc32c_nonzero:
83 	case BCH_CSUM_crc32c:
84 		state->seed = crc32c(state->seed, data, len);
85 		break;
86 	case BCH_CSUM_crc64_nonzero:
87 	case BCH_CSUM_crc64:
88 		state->seed = crc64_be(state->seed, data, len);
89 		break;
90 	case BCH_CSUM_xxhash:
91 		xxh64_update(&state->h64state, data, len);
92 		break;
93 	default:
94 		BUG();
95 	}
96 }
97 
98 static inline int do_encrypt_sg(struct crypto_sync_skcipher *tfm,
99 				struct nonce nonce,
100 				struct scatterlist *sg, size_t len)
101 {
102 	SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm);
103 
104 	skcipher_request_set_sync_tfm(req, tfm);
105 	skcipher_request_set_callback(req, 0, NULL, NULL);
106 	skcipher_request_set_crypt(req, sg, sg, len, nonce.d);
107 
108 	int ret = crypto_skcipher_encrypt(req);
109 	if (ret)
110 		pr_err("got error %i from crypto_skcipher_encrypt()", ret);
111 
112 	return ret;
113 }
114 
115 static inline int do_encrypt(struct crypto_sync_skcipher *tfm,
116 			      struct nonce nonce,
117 			      void *buf, size_t len)
118 {
119 	if (!is_vmalloc_addr(buf)) {
120 		struct scatterlist sg = {};
121 
122 		sg_mark_end(&sg);
123 		sg_set_page(&sg, virt_to_page(buf), len, offset_in_page(buf));
124 		return do_encrypt_sg(tfm, nonce, &sg, len);
125 	} else {
126 		DARRAY_PREALLOCATED(struct scatterlist, 4) sgl;
127 		size_t sgl_len = 0;
128 		int ret;
129 
130 		darray_init(&sgl);
131 
132 		while (len) {
133 			unsigned offset = offset_in_page(buf);
134 			struct scatterlist sg = {
135 				.page_link	= (unsigned long) vmalloc_to_page(buf),
136 				.offset		= offset,
137 				.length		= min(len, PAGE_SIZE - offset),
138 			};
139 
140 			if (darray_push(&sgl, sg)) {
141 				sg_mark_end(&darray_last(sgl));
142 				ret = do_encrypt_sg(tfm, nonce, sgl.data, sgl_len);
143 				if (ret)
144 					goto err;
145 
146 				nonce = nonce_add(nonce, sgl_len);
147 				sgl_len = 0;
148 				sgl.nr = 0;
149 				BUG_ON(darray_push(&sgl, sg));
150 			}
151 
152 			buf += sg.length;
153 			len -= sg.length;
154 			sgl_len += sg.length;
155 		}
156 
157 		sg_mark_end(&darray_last(sgl));
158 		ret = do_encrypt_sg(tfm, nonce, sgl.data, sgl_len);
159 err:
160 		darray_exit(&sgl);
161 		return ret;
162 	}
163 }
164 
165 int bch2_chacha_encrypt_key(struct bch_key *key, struct nonce nonce,
166 			    void *buf, size_t len)
167 {
168 	struct crypto_sync_skcipher *chacha20 =
169 		crypto_alloc_sync_skcipher("chacha20", 0, 0);
170 	int ret;
171 
172 	ret = PTR_ERR_OR_ZERO(chacha20);
173 	if (ret) {
174 		pr_err("error requesting chacha20 cipher: %s", bch2_err_str(ret));
175 		return ret;
176 	}
177 
178 	ret = crypto_skcipher_setkey(&chacha20->base,
179 				     (void *) key, sizeof(*key));
180 	if (ret) {
181 		pr_err("error from crypto_skcipher_setkey(): %s", bch2_err_str(ret));
182 		goto err;
183 	}
184 
185 	ret = do_encrypt(chacha20, nonce, buf, len);
186 err:
187 	crypto_free_sync_skcipher(chacha20);
188 	return ret;
189 }
190 
191 static int gen_poly_key(struct bch_fs *c, struct shash_desc *desc,
192 			struct nonce nonce)
193 {
194 	u8 key[POLY1305_KEY_SIZE];
195 	int ret;
196 
197 	nonce.d[3] ^= BCH_NONCE_POLY;
198 
199 	memset(key, 0, sizeof(key));
200 	ret = do_encrypt(c->chacha20, nonce, key, sizeof(key));
201 	if (ret)
202 		return ret;
203 
204 	desc->tfm = c->poly1305;
205 	crypto_shash_init(desc);
206 	crypto_shash_update(desc, key, sizeof(key));
207 	return 0;
208 }
209 
210 struct bch_csum bch2_checksum(struct bch_fs *c, unsigned type,
211 			      struct nonce nonce, const void *data, size_t len)
212 {
213 	switch (type) {
214 	case BCH_CSUM_none:
215 	case BCH_CSUM_crc32c_nonzero:
216 	case BCH_CSUM_crc64_nonzero:
217 	case BCH_CSUM_crc32c:
218 	case BCH_CSUM_xxhash:
219 	case BCH_CSUM_crc64: {
220 		struct bch2_checksum_state state;
221 
222 		state.type = type;
223 
224 		bch2_checksum_init(&state);
225 		bch2_checksum_update(&state, data, len);
226 
227 		return (struct bch_csum) { .lo = cpu_to_le64(bch2_checksum_final(&state)) };
228 	}
229 
230 	case BCH_CSUM_chacha20_poly1305_80:
231 	case BCH_CSUM_chacha20_poly1305_128: {
232 		SHASH_DESC_ON_STACK(desc, c->poly1305);
233 		u8 digest[POLY1305_DIGEST_SIZE];
234 		struct bch_csum ret = { 0 };
235 
236 		gen_poly_key(c, desc, nonce);
237 
238 		crypto_shash_update(desc, data, len);
239 		crypto_shash_final(desc, digest);
240 
241 		memcpy(&ret, digest, bch_crc_bytes[type]);
242 		return ret;
243 	}
244 	default:
245 		return (struct bch_csum) {};
246 	}
247 }
248 
249 int bch2_encrypt(struct bch_fs *c, unsigned type,
250 		  struct nonce nonce, void *data, size_t len)
251 {
252 	if (!bch2_csum_type_is_encryption(type))
253 		return 0;
254 
255 	return do_encrypt(c->chacha20, nonce, data, len);
256 }
257 
258 static struct bch_csum __bch2_checksum_bio(struct bch_fs *c, unsigned type,
259 					   struct nonce nonce, struct bio *bio,
260 					   struct bvec_iter *iter)
261 {
262 	struct bio_vec bv;
263 
264 	switch (type) {
265 	case BCH_CSUM_none:
266 		return (struct bch_csum) { 0 };
267 	case BCH_CSUM_crc32c_nonzero:
268 	case BCH_CSUM_crc64_nonzero:
269 	case BCH_CSUM_crc32c:
270 	case BCH_CSUM_xxhash:
271 	case BCH_CSUM_crc64: {
272 		struct bch2_checksum_state state;
273 
274 		state.type = type;
275 		bch2_checksum_init(&state);
276 
277 #ifdef CONFIG_HIGHMEM
278 		__bio_for_each_segment(bv, bio, *iter, *iter) {
279 			void *p = kmap_local_page(bv.bv_page) + bv.bv_offset;
280 
281 			bch2_checksum_update(&state, p, bv.bv_len);
282 			kunmap_local(p);
283 		}
284 #else
285 		__bio_for_each_bvec(bv, bio, *iter, *iter)
286 			bch2_checksum_update(&state, page_address(bv.bv_page) + bv.bv_offset,
287 				bv.bv_len);
288 #endif
289 		return (struct bch_csum) { .lo = cpu_to_le64(bch2_checksum_final(&state)) };
290 	}
291 
292 	case BCH_CSUM_chacha20_poly1305_80:
293 	case BCH_CSUM_chacha20_poly1305_128: {
294 		SHASH_DESC_ON_STACK(desc, c->poly1305);
295 		u8 digest[POLY1305_DIGEST_SIZE];
296 		struct bch_csum ret = { 0 };
297 
298 		gen_poly_key(c, desc, nonce);
299 
300 #ifdef CONFIG_HIGHMEM
301 		__bio_for_each_segment(bv, bio, *iter, *iter) {
302 			void *p = kmap_local_page(bv.bv_page) + bv.bv_offset;
303 
304 			crypto_shash_update(desc, p, bv.bv_len);
305 			kunmap_local(p);
306 		}
307 #else
308 		__bio_for_each_bvec(bv, bio, *iter, *iter)
309 			crypto_shash_update(desc,
310 				page_address(bv.bv_page) + bv.bv_offset,
311 				bv.bv_len);
312 #endif
313 		crypto_shash_final(desc, digest);
314 
315 		memcpy(&ret, digest, bch_crc_bytes[type]);
316 		return ret;
317 	}
318 	default:
319 		return (struct bch_csum) {};
320 	}
321 }
322 
323 struct bch_csum bch2_checksum_bio(struct bch_fs *c, unsigned type,
324 				  struct nonce nonce, struct bio *bio)
325 {
326 	struct bvec_iter iter = bio->bi_iter;
327 
328 	return __bch2_checksum_bio(c, type, nonce, bio, &iter);
329 }
330 
331 int __bch2_encrypt_bio(struct bch_fs *c, unsigned type,
332 		     struct nonce nonce, struct bio *bio)
333 {
334 	struct bio_vec bv;
335 	struct bvec_iter iter;
336 	DARRAY_PREALLOCATED(struct scatterlist, 4) sgl;
337 	size_t sgl_len = 0;
338 	int ret = 0;
339 
340 	if (!bch2_csum_type_is_encryption(type))
341 		return 0;
342 
343 	darray_init(&sgl);
344 
345 	bio_for_each_segment(bv, bio, iter) {
346 		struct scatterlist sg = {
347 			.page_link	= (unsigned long) bv.bv_page,
348 			.offset		= bv.bv_offset,
349 			.length		= bv.bv_len,
350 		};
351 
352 		if (darray_push(&sgl, sg)) {
353 			sg_mark_end(&darray_last(sgl));
354 			ret = do_encrypt_sg(c->chacha20, nonce, sgl.data, sgl_len);
355 			if (ret)
356 				goto err;
357 
358 			nonce = nonce_add(nonce, sgl_len);
359 			sgl_len = 0;
360 			sgl.nr = 0;
361 
362 			BUG_ON(darray_push(&sgl, sg));
363 		}
364 
365 		sgl_len += sg.length;
366 	}
367 
368 	sg_mark_end(&darray_last(sgl));
369 	ret = do_encrypt_sg(c->chacha20, nonce, sgl.data, sgl_len);
370 err:
371 	darray_exit(&sgl);
372 	return ret;
373 }
374 
375 struct bch_csum bch2_checksum_merge(unsigned type, struct bch_csum a,
376 				    struct bch_csum b, size_t b_len)
377 {
378 	struct bch2_checksum_state state;
379 
380 	state.type = type;
381 	bch2_checksum_init(&state);
382 	state.seed = le64_to_cpu(a.lo);
383 
384 	BUG_ON(!bch2_checksum_mergeable(type));
385 
386 	while (b_len) {
387 		unsigned page_len = min_t(unsigned, b_len, PAGE_SIZE);
388 
389 		bch2_checksum_update(&state,
390 				page_address(ZERO_PAGE(0)), page_len);
391 		b_len -= page_len;
392 	}
393 	a.lo = cpu_to_le64(bch2_checksum_final(&state));
394 	a.lo ^= b.lo;
395 	a.hi ^= b.hi;
396 	return a;
397 }
398 
399 int bch2_rechecksum_bio(struct bch_fs *c, struct bio *bio,
400 			struct bversion version,
401 			struct bch_extent_crc_unpacked crc_old,
402 			struct bch_extent_crc_unpacked *crc_a,
403 			struct bch_extent_crc_unpacked *crc_b,
404 			unsigned len_a, unsigned len_b,
405 			unsigned new_csum_type)
406 {
407 	struct bvec_iter iter = bio->bi_iter;
408 	struct nonce nonce = extent_nonce(version, crc_old);
409 	struct bch_csum merged = { 0 };
410 	struct crc_split {
411 		struct bch_extent_crc_unpacked	*crc;
412 		unsigned			len;
413 		unsigned			csum_type;
414 		struct bch_csum			csum;
415 	} splits[3] = {
416 		{ crc_a, len_a, new_csum_type, { 0 }},
417 		{ crc_b, len_b, new_csum_type, { 0 } },
418 		{ NULL,	 bio_sectors(bio) - len_a - len_b, new_csum_type, { 0 } },
419 	}, *i;
420 	bool mergeable = crc_old.csum_type == new_csum_type &&
421 		bch2_checksum_mergeable(new_csum_type);
422 	unsigned crc_nonce = crc_old.nonce;
423 
424 	BUG_ON(len_a + len_b > bio_sectors(bio));
425 	BUG_ON(crc_old.uncompressed_size != bio_sectors(bio));
426 	BUG_ON(crc_is_compressed(crc_old));
427 	BUG_ON(bch2_csum_type_is_encryption(crc_old.csum_type) !=
428 	       bch2_csum_type_is_encryption(new_csum_type));
429 
430 	for (i = splits; i < splits + ARRAY_SIZE(splits); i++) {
431 		iter.bi_size = i->len << 9;
432 		if (mergeable || i->crc)
433 			i->csum = __bch2_checksum_bio(c, i->csum_type,
434 						      nonce, bio, &iter);
435 		else
436 			bio_advance_iter(bio, &iter, i->len << 9);
437 		nonce = nonce_add(nonce, i->len << 9);
438 	}
439 
440 	if (mergeable)
441 		for (i = splits; i < splits + ARRAY_SIZE(splits); i++)
442 			merged = bch2_checksum_merge(new_csum_type, merged,
443 						     i->csum, i->len << 9);
444 	else
445 		merged = bch2_checksum_bio(c, crc_old.csum_type,
446 				extent_nonce(version, crc_old), bio);
447 
448 	if (bch2_crc_cmp(merged, crc_old.csum) && !c->opts.no_data_io) {
449 		struct printbuf buf = PRINTBUF;
450 		prt_printf(&buf, "checksum error in %s() (memory corruption or bug?)\n"
451 			   "  expected %0llx:%0llx got %0llx:%0llx (old type ",
452 			   __func__,
453 			   crc_old.csum.hi,
454 			   crc_old.csum.lo,
455 			   merged.hi,
456 			   merged.lo);
457 		bch2_prt_csum_type(&buf, crc_old.csum_type);
458 		prt_str(&buf, " new type ");
459 		bch2_prt_csum_type(&buf, new_csum_type);
460 		prt_str(&buf, ")");
461 		WARN_RATELIMIT(1, "%s", buf.buf);
462 		printbuf_exit(&buf);
463 		return -EIO;
464 	}
465 
466 	for (i = splits; i < splits + ARRAY_SIZE(splits); i++) {
467 		if (i->crc)
468 			*i->crc = (struct bch_extent_crc_unpacked) {
469 				.csum_type		= i->csum_type,
470 				.compression_type	= crc_old.compression_type,
471 				.compressed_size	= i->len,
472 				.uncompressed_size	= i->len,
473 				.offset			= 0,
474 				.live_size		= i->len,
475 				.nonce			= crc_nonce,
476 				.csum			= i->csum,
477 			};
478 
479 		if (bch2_csum_type_is_encryption(new_csum_type))
480 			crc_nonce += i->len;
481 	}
482 
483 	return 0;
484 }
485 
486 /* BCH_SB_FIELD_crypt: */
487 
488 static int bch2_sb_crypt_validate(struct bch_sb *sb, struct bch_sb_field *f,
489 				  enum bch_validate_flags flags, struct printbuf *err)
490 {
491 	struct bch_sb_field_crypt *crypt = field_to_type(f, crypt);
492 
493 	if (vstruct_bytes(&crypt->field) < sizeof(*crypt)) {
494 		prt_printf(err, "wrong size (got %zu should be %zu)",
495 		       vstruct_bytes(&crypt->field), sizeof(*crypt));
496 		return -BCH_ERR_invalid_sb_crypt;
497 	}
498 
499 	if (BCH_CRYPT_KDF_TYPE(crypt)) {
500 		prt_printf(err, "bad kdf type %llu", BCH_CRYPT_KDF_TYPE(crypt));
501 		return -BCH_ERR_invalid_sb_crypt;
502 	}
503 
504 	return 0;
505 }
506 
507 static void bch2_sb_crypt_to_text(struct printbuf *out, struct bch_sb *sb,
508 				  struct bch_sb_field *f)
509 {
510 	struct bch_sb_field_crypt *crypt = field_to_type(f, crypt);
511 
512 	prt_printf(out, "KFD:               %llu\n", BCH_CRYPT_KDF_TYPE(crypt));
513 	prt_printf(out, "scrypt n:          %llu\n", BCH_KDF_SCRYPT_N(crypt));
514 	prt_printf(out, "scrypt r:          %llu\n", BCH_KDF_SCRYPT_R(crypt));
515 	prt_printf(out, "scrypt p:          %llu\n", BCH_KDF_SCRYPT_P(crypt));
516 }
517 
518 const struct bch_sb_field_ops bch_sb_field_ops_crypt = {
519 	.validate	= bch2_sb_crypt_validate,
520 	.to_text	= bch2_sb_crypt_to_text,
521 };
522 
523 #ifdef __KERNEL__
524 static int __bch2_request_key(char *key_description, struct bch_key *key)
525 {
526 	struct key *keyring_key;
527 	const struct user_key_payload *ukp;
528 	int ret;
529 
530 	keyring_key = request_key(&key_type_user, key_description, NULL);
531 	if (IS_ERR(keyring_key))
532 		return PTR_ERR(keyring_key);
533 
534 	down_read(&keyring_key->sem);
535 	ukp = dereference_key_locked(keyring_key);
536 	if (ukp->datalen == sizeof(*key)) {
537 		memcpy(key, ukp->data, ukp->datalen);
538 		ret = 0;
539 	} else {
540 		ret = -EINVAL;
541 	}
542 	up_read(&keyring_key->sem);
543 	key_put(keyring_key);
544 
545 	return ret;
546 }
547 #else
548 #include <keyutils.h>
549 
550 static int __bch2_request_key(char *key_description, struct bch_key *key)
551 {
552 	key_serial_t key_id;
553 
554 	key_id = request_key("user", key_description, NULL,
555 			     KEY_SPEC_SESSION_KEYRING);
556 	if (key_id >= 0)
557 		goto got_key;
558 
559 	key_id = request_key("user", key_description, NULL,
560 			     KEY_SPEC_USER_KEYRING);
561 	if (key_id >= 0)
562 		goto got_key;
563 
564 	key_id = request_key("user", key_description, NULL,
565 			     KEY_SPEC_USER_SESSION_KEYRING);
566 	if (key_id >= 0)
567 		goto got_key;
568 
569 	return -errno;
570 got_key:
571 
572 	if (keyctl_read(key_id, (void *) key, sizeof(*key)) != sizeof(*key))
573 		return -1;
574 
575 	return 0;
576 }
577 
578 #include "crypto.h"
579 #endif
580 
581 int bch2_request_key(struct bch_sb *sb, struct bch_key *key)
582 {
583 	struct printbuf key_description = PRINTBUF;
584 	int ret;
585 
586 	prt_printf(&key_description, "bcachefs:");
587 	pr_uuid(&key_description, sb->user_uuid.b);
588 
589 	ret = __bch2_request_key(key_description.buf, key);
590 	printbuf_exit(&key_description);
591 
592 #ifndef __KERNEL__
593 	if (ret) {
594 		char *passphrase = read_passphrase("Enter passphrase: ");
595 		struct bch_encrypted_key sb_key;
596 
597 		bch2_passphrase_check(sb, passphrase,
598 				      key, &sb_key);
599 		ret = 0;
600 	}
601 #endif
602 
603 	/* stash with memfd, pass memfd fd to mount */
604 
605 	return ret;
606 }
607 
608 #ifndef __KERNEL__
609 int bch2_revoke_key(struct bch_sb *sb)
610 {
611 	key_serial_t key_id;
612 	struct printbuf key_description = PRINTBUF;
613 
614 	prt_printf(&key_description, "bcachefs:");
615 	pr_uuid(&key_description, sb->user_uuid.b);
616 
617 	key_id = request_key("user", key_description.buf, NULL, KEY_SPEC_USER_KEYRING);
618 	printbuf_exit(&key_description);
619 	if (key_id < 0)
620 		return errno;
621 
622 	keyctl_revoke(key_id);
623 
624 	return 0;
625 }
626 #endif
627 
628 int bch2_decrypt_sb_key(struct bch_fs *c,
629 			struct bch_sb_field_crypt *crypt,
630 			struct bch_key *key)
631 {
632 	struct bch_encrypted_key sb_key = crypt->key;
633 	struct bch_key user_key;
634 	int ret = 0;
635 
636 	/* is key encrypted? */
637 	if (!bch2_key_is_encrypted(&sb_key))
638 		goto out;
639 
640 	ret = bch2_request_key(c->disk_sb.sb, &user_key);
641 	if (ret) {
642 		bch_err(c, "error requesting encryption key: %s", bch2_err_str(ret));
643 		goto err;
644 	}
645 
646 	/* decrypt real key: */
647 	ret = bch2_chacha_encrypt_key(&user_key, bch2_sb_key_nonce(c),
648 				      &sb_key, sizeof(sb_key));
649 	if (ret)
650 		goto err;
651 
652 	if (bch2_key_is_encrypted(&sb_key)) {
653 		bch_err(c, "incorrect encryption key");
654 		ret = -EINVAL;
655 		goto err;
656 	}
657 out:
658 	*key = sb_key.key;
659 err:
660 	memzero_explicit(&sb_key, sizeof(sb_key));
661 	memzero_explicit(&user_key, sizeof(user_key));
662 	return ret;
663 }
664 
665 static int bch2_alloc_ciphers(struct bch_fs *c)
666 {
667 	if (c->chacha20)
668 		return 0;
669 
670 	struct crypto_sync_skcipher *chacha20 = crypto_alloc_sync_skcipher("chacha20", 0, 0);
671 	int ret = PTR_ERR_OR_ZERO(chacha20);
672 	if (ret) {
673 		bch_err(c, "error requesting chacha20 module: %s", bch2_err_str(ret));
674 		return ret;
675 	}
676 
677 	struct crypto_shash *poly1305 = crypto_alloc_shash("poly1305", 0, 0);
678 	ret = PTR_ERR_OR_ZERO(poly1305);
679 	if (ret) {
680 		bch_err(c, "error requesting poly1305 module: %s", bch2_err_str(ret));
681 		crypto_free_sync_skcipher(chacha20);
682 		return ret;
683 	}
684 
685 	c->chacha20	= chacha20;
686 	c->poly1305	= poly1305;
687 	return 0;
688 }
689 
690 int bch2_disable_encryption(struct bch_fs *c)
691 {
692 	struct bch_sb_field_crypt *crypt;
693 	struct bch_key key;
694 	int ret = -EINVAL;
695 
696 	mutex_lock(&c->sb_lock);
697 
698 	crypt = bch2_sb_field_get(c->disk_sb.sb, crypt);
699 	if (!crypt)
700 		goto out;
701 
702 	/* is key encrypted? */
703 	ret = 0;
704 	if (bch2_key_is_encrypted(&crypt->key))
705 		goto out;
706 
707 	ret = bch2_decrypt_sb_key(c, crypt, &key);
708 	if (ret)
709 		goto out;
710 
711 	crypt->key.magic	= cpu_to_le64(BCH_KEY_MAGIC);
712 	crypt->key.key		= key;
713 
714 	SET_BCH_SB_ENCRYPTION_TYPE(c->disk_sb.sb, 0);
715 	bch2_write_super(c);
716 out:
717 	mutex_unlock(&c->sb_lock);
718 
719 	return ret;
720 }
721 
722 int bch2_enable_encryption(struct bch_fs *c, bool keyed)
723 {
724 	struct bch_encrypted_key key;
725 	struct bch_key user_key;
726 	struct bch_sb_field_crypt *crypt;
727 	int ret = -EINVAL;
728 
729 	mutex_lock(&c->sb_lock);
730 
731 	/* Do we already have an encryption key? */
732 	if (bch2_sb_field_get(c->disk_sb.sb, crypt))
733 		goto err;
734 
735 	ret = bch2_alloc_ciphers(c);
736 	if (ret)
737 		goto err;
738 
739 	key.magic = cpu_to_le64(BCH_KEY_MAGIC);
740 	get_random_bytes(&key.key, sizeof(key.key));
741 
742 	if (keyed) {
743 		ret = bch2_request_key(c->disk_sb.sb, &user_key);
744 		if (ret) {
745 			bch_err(c, "error requesting encryption key: %s", bch2_err_str(ret));
746 			goto err;
747 		}
748 
749 		ret = bch2_chacha_encrypt_key(&user_key, bch2_sb_key_nonce(c),
750 					      &key, sizeof(key));
751 		if (ret)
752 			goto err;
753 	}
754 
755 	ret = crypto_skcipher_setkey(&c->chacha20->base,
756 			(void *) &key.key, sizeof(key.key));
757 	if (ret)
758 		goto err;
759 
760 	crypt = bch2_sb_field_resize(&c->disk_sb, crypt,
761 				     sizeof(*crypt) / sizeof(u64));
762 	if (!crypt) {
763 		ret = -BCH_ERR_ENOSPC_sb_crypt;
764 		goto err;
765 	}
766 
767 	crypt->key = key;
768 
769 	/* write superblock */
770 	SET_BCH_SB_ENCRYPTION_TYPE(c->disk_sb.sb, 1);
771 	bch2_write_super(c);
772 err:
773 	mutex_unlock(&c->sb_lock);
774 	memzero_explicit(&user_key, sizeof(user_key));
775 	memzero_explicit(&key, sizeof(key));
776 	return ret;
777 }
778 
779 void bch2_fs_encryption_exit(struct bch_fs *c)
780 {
781 	if (c->poly1305)
782 		crypto_free_shash(c->poly1305);
783 	if (c->chacha20)
784 		crypto_free_sync_skcipher(c->chacha20);
785 	if (c->sha256)
786 		crypto_free_shash(c->sha256);
787 }
788 
789 int bch2_fs_encryption_init(struct bch_fs *c)
790 {
791 	struct bch_sb_field_crypt *crypt;
792 	struct bch_key key;
793 	int ret = 0;
794 
795 	c->sha256 = crypto_alloc_shash("sha256", 0, 0);
796 	ret = PTR_ERR_OR_ZERO(c->sha256);
797 	if (ret) {
798 		c->sha256 = NULL;
799 		bch_err(c, "error requesting sha256 module: %s", bch2_err_str(ret));
800 		goto out;
801 	}
802 
803 	crypt = bch2_sb_field_get(c->disk_sb.sb, crypt);
804 	if (!crypt)
805 		goto out;
806 
807 	ret = bch2_alloc_ciphers(c);
808 	if (ret)
809 		goto out;
810 
811 	ret = bch2_decrypt_sb_key(c, crypt, &key);
812 	if (ret)
813 		goto out;
814 
815 	ret = crypto_skcipher_setkey(&c->chacha20->base,
816 			(void *) &key.key, sizeof(key.key));
817 	if (ret)
818 		goto out;
819 out:
820 	memzero_explicit(&key, sizeof(key));
821 	return ret;
822 }
823