1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* Asymmetric algorithms supported by virtio crypto device 3 * 4 * Authors: zhenwei pi <pizhenwei@bytedance.com> 5 * lei he <helei.sig11@bytedance.com> 6 * 7 * Copyright 2022 Bytedance CO., LTD. 8 */ 9 10 #include <crypto/engine.h> 11 #include <crypto/internal/akcipher.h> 12 #include <crypto/internal/rsa.h> 13 #include <crypto/scatterwalk.h> 14 #include <linux/err.h> 15 #include <linux/kernel.h> 16 #include <linux/mpi.h> 17 #include <linux/scatterlist.h> 18 #include <linux/slab.h> 19 #include <linux/string.h> 20 #include <uapi/linux/virtio_crypto.h> 21 #include "virtio_crypto_common.h" 22 23 struct virtio_crypto_rsa_ctx { 24 unsigned int key_size; 25 }; 26 27 struct virtio_crypto_akcipher_ctx { 28 struct virtio_crypto *vcrypto; 29 bool session_valid; 30 __u64 session_id; 31 union { 32 struct virtio_crypto_rsa_ctx rsa_ctx; 33 }; 34 }; 35 36 struct virtio_crypto_akcipher_request { 37 struct virtio_crypto_request base; 38 void *src_buf; 39 void *dst_buf; 40 uint32_t opcode; 41 }; 42 43 struct virtio_crypto_akcipher_algo { 44 uint32_t algonum; 45 uint32_t service; 46 unsigned int active_devs; 47 struct akcipher_engine_alg algo; 48 }; 49 50 static DEFINE_MUTEX(algs_lock); 51 52 static void virtio_crypto_akcipher_finalize_req( 53 struct virtio_crypto_akcipher_request *vc_akcipher_req, 54 struct akcipher_request *req, int err) 55 { 56 kfree(vc_akcipher_req->src_buf); 57 kfree(vc_akcipher_req->dst_buf); 58 vc_akcipher_req->src_buf = NULL; 59 vc_akcipher_req->dst_buf = NULL; 60 virtcrypto_clear_request(&vc_akcipher_req->base); 61 62 crypto_finalize_akcipher_request(vc_akcipher_req->base.dataq->engine, req, err); 63 } 64 65 static void virtio_crypto_dataq_akcipher_callback(struct virtio_crypto_request *vc_req, int len) 66 { 67 struct virtio_crypto_akcipher_request *vc_akcipher_req = 68 container_of(vc_req, struct virtio_crypto_akcipher_request, base); 69 struct akcipher_request *akcipher_req = 70 container_of((void *)vc_akcipher_req, struct akcipher_request, 71 __ctx); 72 int error; 73 74 switch (vc_req->status) { 75 case VIRTIO_CRYPTO_OK: 76 error = 0; 77 break; 78 case VIRTIO_CRYPTO_INVSESS: 79 case VIRTIO_CRYPTO_ERR: 80 error = -EINVAL; 81 break; 82 case VIRTIO_CRYPTO_BADMSG: 83 error = -EBADMSG; 84 break; 85 default: 86 error = -EIO; 87 break; 88 } 89 90 /* actual length may be less than dst buffer */ 91 akcipher_req->dst_len = len - sizeof(vc_req->status); 92 sg_copy_from_buffer(akcipher_req->dst, sg_nents(akcipher_req->dst), 93 vc_akcipher_req->dst_buf, akcipher_req->dst_len); 94 virtio_crypto_akcipher_finalize_req(vc_akcipher_req, akcipher_req, error); 95 } 96 97 static int virtio_crypto_alg_akcipher_init_session(struct virtio_crypto_akcipher_ctx *ctx, 98 struct virtio_crypto_ctrl_header *header, 99 struct virtio_crypto_akcipher_session_para *para, 100 const uint8_t *key, unsigned int keylen) 101 { 102 struct scatterlist outhdr_sg, key_sg, inhdr_sg, *sgs[3]; 103 struct virtio_crypto *vcrypto = ctx->vcrypto; 104 uint8_t *pkey; 105 int err; 106 unsigned int num_out = 0, num_in = 0; 107 struct virtio_crypto_op_ctrl_req *ctrl; 108 struct virtio_crypto_session_input *input; 109 struct virtio_crypto_ctrl_request *vc_ctrl_req; 110 111 pkey = kmemdup(key, keylen, GFP_KERNEL); 112 if (!pkey) 113 return -ENOMEM; 114 115 vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL); 116 if (!vc_ctrl_req) { 117 err = -ENOMEM; 118 goto out; 119 } 120 121 ctrl = &vc_ctrl_req->ctrl; 122 memcpy(&ctrl->header, header, sizeof(ctrl->header)); 123 memcpy(&ctrl->u.akcipher_create_session.para, para, sizeof(*para)); 124 input = &vc_ctrl_req->input; 125 input->status = cpu_to_le32(VIRTIO_CRYPTO_ERR); 126 127 sg_init_one(&outhdr_sg, ctrl, sizeof(*ctrl)); 128 sgs[num_out++] = &outhdr_sg; 129 130 sg_init_one(&key_sg, pkey, keylen); 131 sgs[num_out++] = &key_sg; 132 133 sg_init_one(&inhdr_sg, input, sizeof(*input)); 134 sgs[num_out + num_in++] = &inhdr_sg; 135 136 err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req); 137 if (err < 0) 138 goto out; 139 140 if (le32_to_cpu(input->status) != VIRTIO_CRYPTO_OK) { 141 pr_err("virtio_crypto: Create session failed status: %u\n", 142 le32_to_cpu(input->status)); 143 err = -EINVAL; 144 goto out; 145 } 146 147 ctx->session_id = le64_to_cpu(input->session_id); 148 ctx->session_valid = true; 149 err = 0; 150 151 out: 152 kfree(vc_ctrl_req); 153 kfree_sensitive(pkey); 154 155 return err; 156 } 157 158 static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akcipher_ctx *ctx) 159 { 160 struct scatterlist outhdr_sg, inhdr_sg, *sgs[2]; 161 struct virtio_crypto_destroy_session_req *destroy_session; 162 struct virtio_crypto *vcrypto = ctx->vcrypto; 163 unsigned int num_out = 0, num_in = 0; 164 int err; 165 struct virtio_crypto_op_ctrl_req *ctrl; 166 struct virtio_crypto_inhdr *ctrl_status; 167 struct virtio_crypto_ctrl_request *vc_ctrl_req; 168 169 if (!ctx->session_valid) 170 return 0; 171 172 vc_ctrl_req = kzalloc(sizeof(*vc_ctrl_req), GFP_KERNEL); 173 if (!vc_ctrl_req) 174 return -ENOMEM; 175 176 ctrl_status = &vc_ctrl_req->ctrl_status; 177 ctrl_status->status = VIRTIO_CRYPTO_ERR; 178 ctrl = &vc_ctrl_req->ctrl; 179 ctrl->header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION); 180 ctrl->header.queue_id = 0; 181 182 destroy_session = &ctrl->u.destroy_session; 183 destroy_session->session_id = cpu_to_le64(ctx->session_id); 184 185 sg_init_one(&outhdr_sg, ctrl, sizeof(*ctrl)); 186 sgs[num_out++] = &outhdr_sg; 187 188 sg_init_one(&inhdr_sg, &ctrl_status->status, sizeof(ctrl_status->status)); 189 sgs[num_out + num_in++] = &inhdr_sg; 190 191 err = virtio_crypto_ctrl_vq_request(vcrypto, sgs, num_out, num_in, vc_ctrl_req); 192 if (err < 0) 193 goto out; 194 195 if (ctrl_status->status != VIRTIO_CRYPTO_OK) { 196 pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n", 197 ctrl_status->status, destroy_session->session_id); 198 err = -EINVAL; 199 goto out; 200 } 201 202 err = 0; 203 ctx->session_valid = false; 204 205 out: 206 kfree(vc_ctrl_req); 207 208 return err; 209 } 210 211 static int __virtio_crypto_akcipher_do_req(struct virtio_crypto_akcipher_request *vc_akcipher_req, 212 struct akcipher_request *req, struct data_queue *data_vq) 213 { 214 struct crypto_akcipher *atfm = crypto_akcipher_reqtfm(req); 215 struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(atfm); 216 struct virtio_crypto_request *vc_req = &vc_akcipher_req->base; 217 struct virtio_crypto *vcrypto = ctx->vcrypto; 218 struct virtio_crypto_op_data_req *req_data = vc_req->req_data; 219 struct scatterlist *sgs[4], outhdr_sg, inhdr_sg, srcdata_sg, dstdata_sg; 220 void *src_buf, *dst_buf = NULL; 221 unsigned int num_out = 0, num_in = 0; 222 int node = dev_to_node(&vcrypto->vdev->dev); 223 unsigned long flags; 224 int ret; 225 226 /* out header */ 227 sg_init_one(&outhdr_sg, req_data, sizeof(*req_data)); 228 sgs[num_out++] = &outhdr_sg; 229 230 /* src data */ 231 src_buf = kcalloc_node(req->src_len, 1, GFP_KERNEL, node); 232 if (!src_buf) 233 return -ENOMEM; 234 235 sg_copy_to_buffer(req->src, sg_nents(req->src), src_buf, req->src_len); 236 sg_init_one(&srcdata_sg, src_buf, req->src_len); 237 sgs[num_out++] = &srcdata_sg; 238 239 /* dst data */ 240 dst_buf = kcalloc_node(req->dst_len, 1, GFP_KERNEL, node); 241 if (!dst_buf) 242 goto free_src; 243 244 sg_init_one(&dstdata_sg, dst_buf, req->dst_len); 245 sgs[num_out + num_in++] = &dstdata_sg; 246 247 vc_akcipher_req->src_buf = src_buf; 248 vc_akcipher_req->dst_buf = dst_buf; 249 250 /* in header */ 251 sg_init_one(&inhdr_sg, &vc_req->status, sizeof(vc_req->status)); 252 sgs[num_out + num_in++] = &inhdr_sg; 253 254 spin_lock_irqsave(&data_vq->lock, flags); 255 ret = virtqueue_add_sgs(data_vq->vq, sgs, num_out, num_in, vc_req, GFP_ATOMIC); 256 virtqueue_kick(data_vq->vq); 257 spin_unlock_irqrestore(&data_vq->lock, flags); 258 if (ret) 259 goto err; 260 261 return 0; 262 263 err: 264 kfree(dst_buf); 265 free_src: 266 kfree(src_buf); 267 return -ENOMEM; 268 } 269 270 static int virtio_crypto_rsa_do_req(struct crypto_engine *engine, void *vreq) 271 { 272 struct akcipher_request *req = container_of(vreq, struct akcipher_request, base); 273 struct virtio_crypto_akcipher_request *vc_akcipher_req = akcipher_request_ctx(req); 274 struct virtio_crypto_request *vc_req = &vc_akcipher_req->base; 275 struct crypto_akcipher *atfm = crypto_akcipher_reqtfm(req); 276 struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(atfm); 277 struct virtio_crypto *vcrypto = ctx->vcrypto; 278 struct data_queue *data_vq = vc_req->dataq; 279 struct virtio_crypto_op_header *header; 280 struct virtio_crypto_akcipher_data_req *akcipher_req; 281 int ret; 282 283 vc_req->sgs = NULL; 284 vc_req->req_data = kzalloc_node(sizeof(*vc_req->req_data), 285 GFP_KERNEL, dev_to_node(&vcrypto->vdev->dev)); 286 if (!vc_req->req_data) 287 return -ENOMEM; 288 289 /* build request header */ 290 header = &vc_req->req_data->header; 291 header->opcode = cpu_to_le32(vc_akcipher_req->opcode); 292 header->algo = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA); 293 header->session_id = cpu_to_le64(ctx->session_id); 294 295 /* build request akcipher data */ 296 akcipher_req = &vc_req->req_data->u.akcipher_req; 297 akcipher_req->para.src_data_len = cpu_to_le32(req->src_len); 298 akcipher_req->para.dst_data_len = cpu_to_le32(req->dst_len); 299 300 ret = __virtio_crypto_akcipher_do_req(vc_akcipher_req, req, data_vq); 301 if (ret < 0) { 302 kfree_sensitive(vc_req->req_data); 303 vc_req->req_data = NULL; 304 return ret; 305 } 306 307 return 0; 308 } 309 310 static int virtio_crypto_rsa_req(struct akcipher_request *req, uint32_t opcode) 311 { 312 struct crypto_akcipher *atfm = crypto_akcipher_reqtfm(req); 313 struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(atfm); 314 struct virtio_crypto_akcipher_request *vc_akcipher_req = akcipher_request_ctx(req); 315 struct virtio_crypto_request *vc_req = &vc_akcipher_req->base; 316 struct virtio_crypto *vcrypto = ctx->vcrypto; 317 /* Use the first data virtqueue as default */ 318 struct data_queue *data_vq = &vcrypto->data_vq[0]; 319 320 vc_req->dataq = data_vq; 321 vc_req->alg_cb = virtio_crypto_dataq_akcipher_callback; 322 vc_akcipher_req->opcode = opcode; 323 324 return crypto_transfer_akcipher_request_to_engine(data_vq->engine, req); 325 } 326 327 static int virtio_crypto_rsa_encrypt(struct akcipher_request *req) 328 { 329 return virtio_crypto_rsa_req(req, VIRTIO_CRYPTO_AKCIPHER_ENCRYPT); 330 } 331 332 static int virtio_crypto_rsa_decrypt(struct akcipher_request *req) 333 { 334 return virtio_crypto_rsa_req(req, VIRTIO_CRYPTO_AKCIPHER_DECRYPT); 335 } 336 337 static int virtio_crypto_rsa_set_key(struct crypto_akcipher *tfm, 338 const void *key, 339 unsigned int keylen, 340 bool private, 341 int padding_algo, 342 int hash_algo) 343 { 344 struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(tfm); 345 struct virtio_crypto_rsa_ctx *rsa_ctx = &ctx->rsa_ctx; 346 struct virtio_crypto *vcrypto; 347 struct virtio_crypto_ctrl_header header; 348 struct virtio_crypto_akcipher_session_para para; 349 struct rsa_key rsa_key = {0}; 350 int node = virtio_crypto_get_current_node(); 351 uint32_t keytype; 352 int ret; 353 MPI n; 354 355 if (private) { 356 keytype = VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PRIVATE; 357 ret = rsa_parse_priv_key(&rsa_key, key, keylen); 358 } else { 359 keytype = VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PUBLIC; 360 ret = rsa_parse_pub_key(&rsa_key, key, keylen); 361 } 362 363 if (ret) 364 return ret; 365 366 n = mpi_read_raw_data(rsa_key.n, rsa_key.n_sz); 367 if (!n) 368 return -ENOMEM; 369 370 rsa_ctx->key_size = mpi_get_size(n); 371 mpi_free(n); 372 373 if (!ctx->vcrypto) { 374 vcrypto = virtcrypto_get_dev_node(node, VIRTIO_CRYPTO_SERVICE_AKCIPHER, 375 VIRTIO_CRYPTO_AKCIPHER_RSA); 376 if (!vcrypto) { 377 pr_err("virtio_crypto: Could not find a virtio device in the system or unsupported algo\n"); 378 return -ENODEV; 379 } 380 381 ctx->vcrypto = vcrypto; 382 } else { 383 virtio_crypto_alg_akcipher_close_session(ctx); 384 } 385 386 /* set ctrl header */ 387 header.opcode = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_CREATE_SESSION); 388 header.algo = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA); 389 header.queue_id = 0; 390 391 /* set RSA para */ 392 para.algo = cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA); 393 para.keytype = cpu_to_le32(keytype); 394 para.keylen = cpu_to_le32(keylen); 395 para.u.rsa.padding_algo = cpu_to_le32(padding_algo); 396 para.u.rsa.hash_algo = cpu_to_le32(hash_algo); 397 398 return virtio_crypto_alg_akcipher_init_session(ctx, &header, ¶, key, keylen); 399 } 400 401 static int virtio_crypto_rsa_raw_set_priv_key(struct crypto_akcipher *tfm, 402 const void *key, 403 unsigned int keylen) 404 { 405 return virtio_crypto_rsa_set_key(tfm, key, keylen, 1, 406 VIRTIO_CRYPTO_RSA_RAW_PADDING, 407 VIRTIO_CRYPTO_RSA_NO_HASH); 408 } 409 410 411 static int virtio_crypto_p1pad_rsa_sha1_set_priv_key(struct crypto_akcipher *tfm, 412 const void *key, 413 unsigned int keylen) 414 { 415 return virtio_crypto_rsa_set_key(tfm, key, keylen, 1, 416 VIRTIO_CRYPTO_RSA_PKCS1_PADDING, 417 VIRTIO_CRYPTO_RSA_SHA1); 418 } 419 420 static int virtio_crypto_rsa_raw_set_pub_key(struct crypto_akcipher *tfm, 421 const void *key, 422 unsigned int keylen) 423 { 424 return virtio_crypto_rsa_set_key(tfm, key, keylen, 0, 425 VIRTIO_CRYPTO_RSA_RAW_PADDING, 426 VIRTIO_CRYPTO_RSA_NO_HASH); 427 } 428 429 static int virtio_crypto_p1pad_rsa_sha1_set_pub_key(struct crypto_akcipher *tfm, 430 const void *key, 431 unsigned int keylen) 432 { 433 return virtio_crypto_rsa_set_key(tfm, key, keylen, 0, 434 VIRTIO_CRYPTO_RSA_PKCS1_PADDING, 435 VIRTIO_CRYPTO_RSA_SHA1); 436 } 437 438 static unsigned int virtio_crypto_rsa_max_size(struct crypto_akcipher *tfm) 439 { 440 struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(tfm); 441 struct virtio_crypto_rsa_ctx *rsa_ctx = &ctx->rsa_ctx; 442 443 return rsa_ctx->key_size; 444 } 445 446 static int virtio_crypto_rsa_init_tfm(struct crypto_akcipher *tfm) 447 { 448 akcipher_set_reqsize(tfm, 449 sizeof(struct virtio_crypto_akcipher_request)); 450 451 return 0; 452 } 453 454 static void virtio_crypto_rsa_exit_tfm(struct crypto_akcipher *tfm) 455 { 456 struct virtio_crypto_akcipher_ctx *ctx = akcipher_tfm_ctx(tfm); 457 458 virtio_crypto_alg_akcipher_close_session(ctx); 459 virtcrypto_dev_put(ctx->vcrypto); 460 } 461 462 static struct virtio_crypto_akcipher_algo virtio_crypto_akcipher_algs[] = { 463 { 464 .algonum = VIRTIO_CRYPTO_AKCIPHER_RSA, 465 .service = VIRTIO_CRYPTO_SERVICE_AKCIPHER, 466 .algo.base = { 467 .encrypt = virtio_crypto_rsa_encrypt, 468 .decrypt = virtio_crypto_rsa_decrypt, 469 .set_pub_key = virtio_crypto_rsa_raw_set_pub_key, 470 .set_priv_key = virtio_crypto_rsa_raw_set_priv_key, 471 .max_size = virtio_crypto_rsa_max_size, 472 .init = virtio_crypto_rsa_init_tfm, 473 .exit = virtio_crypto_rsa_exit_tfm, 474 .base = { 475 .cra_name = "rsa", 476 .cra_driver_name = "virtio-crypto-rsa", 477 .cra_priority = 150, 478 .cra_module = THIS_MODULE, 479 .cra_ctxsize = sizeof(struct virtio_crypto_akcipher_ctx), 480 }, 481 }, 482 .algo.op = { 483 .do_one_request = virtio_crypto_rsa_do_req, 484 }, 485 }, 486 { 487 .algonum = VIRTIO_CRYPTO_AKCIPHER_RSA, 488 .service = VIRTIO_CRYPTO_SERVICE_AKCIPHER, 489 .algo.base = { 490 .encrypt = virtio_crypto_rsa_encrypt, 491 .decrypt = virtio_crypto_rsa_decrypt, 492 /* 493 * Must specify an arbitrary hash algorithm upon 494 * set_{pub,priv}_key (even though it's not used 495 * by encrypt/decrypt) because qemu checks for it. 496 */ 497 .set_pub_key = virtio_crypto_p1pad_rsa_sha1_set_pub_key, 498 .set_priv_key = virtio_crypto_p1pad_rsa_sha1_set_priv_key, 499 .max_size = virtio_crypto_rsa_max_size, 500 .init = virtio_crypto_rsa_init_tfm, 501 .exit = virtio_crypto_rsa_exit_tfm, 502 .base = { 503 .cra_name = "pkcs1pad(rsa)", 504 .cra_driver_name = "virtio-pkcs1-rsa", 505 .cra_priority = 150, 506 .cra_module = THIS_MODULE, 507 .cra_ctxsize = sizeof(struct virtio_crypto_akcipher_ctx), 508 }, 509 }, 510 .algo.op = { 511 .do_one_request = virtio_crypto_rsa_do_req, 512 }, 513 }, 514 }; 515 516 int virtio_crypto_akcipher_algs_register(struct virtio_crypto *vcrypto) 517 { 518 int ret = 0; 519 int i = 0; 520 521 mutex_lock(&algs_lock); 522 523 for (i = 0; i < ARRAY_SIZE(virtio_crypto_akcipher_algs); i++) { 524 uint32_t service = virtio_crypto_akcipher_algs[i].service; 525 uint32_t algonum = virtio_crypto_akcipher_algs[i].algonum; 526 527 if (!virtcrypto_algo_is_supported(vcrypto, service, algonum)) 528 continue; 529 530 if (virtio_crypto_akcipher_algs[i].active_devs == 0) { 531 ret = crypto_engine_register_akcipher(&virtio_crypto_akcipher_algs[i].algo); 532 if (ret) 533 goto unlock; 534 } 535 536 virtio_crypto_akcipher_algs[i].active_devs++; 537 dev_info(&vcrypto->vdev->dev, "Registered akcipher algo %s\n", 538 virtio_crypto_akcipher_algs[i].algo.base.base.cra_name); 539 } 540 541 unlock: 542 mutex_unlock(&algs_lock); 543 return ret; 544 } 545 546 void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto *vcrypto) 547 { 548 int i = 0; 549 550 mutex_lock(&algs_lock); 551 552 for (i = 0; i < ARRAY_SIZE(virtio_crypto_akcipher_algs); i++) { 553 uint32_t service = virtio_crypto_akcipher_algs[i].service; 554 uint32_t algonum = virtio_crypto_akcipher_algs[i].algonum; 555 556 if (virtio_crypto_akcipher_algs[i].active_devs == 0 || 557 !virtcrypto_algo_is_supported(vcrypto, service, algonum)) 558 continue; 559 560 if (virtio_crypto_akcipher_algs[i].active_devs == 1) 561 crypto_engine_unregister_akcipher(&virtio_crypto_akcipher_algs[i].algo); 562 563 virtio_crypto_akcipher_algs[i].active_devs--; 564 } 565 566 mutex_unlock(&algs_lock); 567 } 568