1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Amazon Nitro Secure Module driver. 4 * 5 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 6 * 7 * The Nitro Secure Module implements commands via CBOR over virtio. 8 * This driver exposes a raw message ioctls on /dev/nsm that user 9 * space can use to issue these commands. 10 */ 11 12 #include <linux/file.h> 13 #include <linux/fs.h> 14 #include <linux/interrupt.h> 15 #include <linux/hw_random.h> 16 #include <linux/miscdevice.h> 17 #include <linux/module.h> 18 #include <linux/mutex.h> 19 #include <linux/slab.h> 20 #include <linux/string.h> 21 #include <linux/uaccess.h> 22 #include <linux/uio.h> 23 #include <linux/virtio_config.h> 24 #include <linux/virtio_ids.h> 25 #include <linux/virtio.h> 26 #include <linux/wait.h> 27 #include <uapi/linux/nsm.h> 28 29 /* Timeout for NSM virtqueue respose in milliseconds. */ 30 #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */ 31 32 /* Maximum length input data */ 33 struct nsm_data_req { 34 u32 len; 35 u8 data[NSM_REQUEST_MAX_SIZE]; 36 }; 37 38 /* Maximum length output data */ 39 struct nsm_data_resp { 40 u32 len; 41 u8 data[NSM_RESPONSE_MAX_SIZE]; 42 }; 43 44 /* Full NSM request/response message */ 45 struct nsm_msg { 46 struct nsm_data_req req; 47 struct nsm_data_resp resp; 48 }; 49 50 struct nsm { 51 struct virtio_device *vdev; 52 struct virtqueue *vq; 53 struct mutex lock; 54 struct completion cmd_done; 55 struct miscdevice misc; 56 struct hwrng hwrng; 57 struct work_struct misc_init; 58 struct nsm_msg msg; 59 }; 60 61 /* NSM device ID */ 62 static const struct virtio_device_id id_table[] = { 63 { VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID }, 64 { 0 }, 65 }; 66 67 static struct nsm *file_to_nsm(struct file *file) 68 { 69 return container_of(file->private_data, struct nsm, misc); 70 } 71 72 static struct nsm *hwrng_to_nsm(struct hwrng *rng) 73 { 74 return container_of(rng, struct nsm, hwrng); 75 } 76 77 #define CBOR_TYPE_MASK 0xE0 78 #define CBOR_TYPE_MAP 0xA0 79 #define CBOR_TYPE_TEXT 0x60 80 #define CBOR_TYPE_ARRAY 0x40 81 #define CBOR_HEADER_SIZE_SHORT 1 82 83 #define CBOR_SHORT_SIZE_MAX_VALUE 23 84 #define CBOR_LONG_SIZE_U8 24 85 #define CBOR_LONG_SIZE_U16 25 86 #define CBOR_LONG_SIZE_U32 26 87 #define CBOR_LONG_SIZE_U64 27 88 89 static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size) 90 { 91 if (cbor_object_size == 0 || cbor_object == NULL) 92 return false; 93 94 return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY; 95 } 96 97 static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array) 98 { 99 u8 cbor_short_size; 100 void *array_len_p; 101 u64 array_len; 102 u64 array_offset; 103 104 if (!cbor_object_is_array(cbor_object, cbor_object_size)) 105 return -EFAULT; 106 107 cbor_short_size = (cbor_object[0] & 0x1F); 108 109 /* Decoding byte array length */ 110 array_offset = CBOR_HEADER_SIZE_SHORT; 111 if (cbor_short_size >= CBOR_LONG_SIZE_U8) 112 array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8); 113 114 if (cbor_object_size < array_offset) 115 return -EFAULT; 116 117 array_len_p = &cbor_object[1]; 118 119 switch (cbor_short_size) { 120 case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */ 121 array_len = cbor_short_size; 122 break; 123 case CBOR_LONG_SIZE_U8: 124 array_len = *(u8 *)array_len_p; 125 break; 126 case CBOR_LONG_SIZE_U16: 127 array_len = be16_to_cpup((__be16 *)array_len_p); 128 break; 129 case CBOR_LONG_SIZE_U32: 130 array_len = be32_to_cpup((__be32 *)array_len_p); 131 break; 132 case CBOR_LONG_SIZE_U64: 133 array_len = be64_to_cpup((__be64 *)array_len_p); 134 break; 135 } 136 137 if (cbor_object_size < array_offset) 138 return -EFAULT; 139 140 if (cbor_object_size - array_offset < array_len) 141 return -EFAULT; 142 143 if (array_len > INT_MAX) 144 return -EFAULT; 145 146 *cbor_array = cbor_object + array_offset; 147 return array_len; 148 } 149 150 /* Copy the request of a raw message to kernel space */ 151 static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req, 152 struct nsm_raw *raw) 153 { 154 /* Verify the user input size. */ 155 if (raw->request.len > sizeof(req->data)) 156 return -EMSGSIZE; 157 158 /* Copy the request payload */ 159 if (copy_from_user(req->data, u64_to_user_ptr(raw->request.addr), 160 raw->request.len)) 161 return -EFAULT; 162 163 req->len = raw->request.len; 164 165 return 0; 166 } 167 168 /* Copy the response of a raw message back to user-space */ 169 static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp, 170 struct nsm_raw *raw) 171 { 172 /* Truncate any message that does not fit. */ 173 raw->response.len = min_t(u64, raw->response.len, resp->len); 174 175 /* Copy the response content to user space */ 176 if (copy_to_user(u64_to_user_ptr(raw->response.addr), 177 resp->data, raw->response.len)) 178 return -EFAULT; 179 180 return 0; 181 } 182 183 /* Virtqueue interrupt handler */ 184 static void nsm_vq_callback(struct virtqueue *vq) 185 { 186 struct nsm *nsm = vq->vdev->priv; 187 188 complete(&nsm->cmd_done); 189 } 190 191 /* Forward a message to the NSM device and wait for the response from it */ 192 static int nsm_sendrecv_msg_locked(struct nsm *nsm) 193 { 194 struct device *dev = &nsm->vdev->dev; 195 struct scatterlist sg_in, sg_out; 196 struct nsm_msg *msg = &nsm->msg; 197 struct virtqueue *vq = nsm->vq; 198 unsigned int len; 199 void *queue_buf; 200 bool kicked; 201 int rc; 202 203 /* Initialize scatter-gather lists with request and response buffers. */ 204 sg_init_one(&sg_out, msg->req.data, msg->req.len); 205 sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data)); 206 207 init_completion(&nsm->cmd_done); 208 /* Add the request buffer (read by the device). */ 209 rc = virtqueue_add_outbuf(vq, &sg_out, 1, msg->req.data, GFP_KERNEL); 210 if (rc) 211 return rc; 212 213 /* Add the response buffer (written by the device). */ 214 rc = virtqueue_add_inbuf(vq, &sg_in, 1, msg->resp.data, GFP_KERNEL); 215 if (rc) 216 goto cleanup; 217 218 kicked = virtqueue_kick(vq); 219 if (!kicked) { 220 /* Cannot kick the virtqueue. */ 221 rc = -EIO; 222 goto cleanup; 223 } 224 225 /* If the kick succeeded, wait for the device's response. */ 226 if (!wait_for_completion_io_timeout(&nsm->cmd_done, 227 msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) { 228 rc = -ETIMEDOUT; 229 goto cleanup; 230 } 231 232 queue_buf = virtqueue_get_buf(vq, &len); 233 if (!queue_buf || (queue_buf != msg->req.data)) { 234 dev_err(dev, "wrong request buffer."); 235 rc = -ENODATA; 236 goto cleanup; 237 } 238 239 queue_buf = virtqueue_get_buf(vq, &len); 240 if (!queue_buf || (queue_buf != msg->resp.data)) { 241 dev_err(dev, "wrong response buffer."); 242 rc = -ENODATA; 243 goto cleanup; 244 } 245 246 msg->resp.len = len; 247 248 rc = 0; 249 250 cleanup: 251 if (rc) { 252 /* Clean the virtqueue. */ 253 while (virtqueue_get_buf(vq, &len) != NULL) 254 ; 255 } 256 257 return rc; 258 } 259 260 static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req) 261 { 262 /* 263 * 69 # text(9) 264 * 47657452616E646F6D # "GetRandom" 265 */ 266 const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"), 267 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' }; 268 269 memcpy(req->data, request, sizeof(request)); 270 req->len = sizeof(request); 271 272 return 0; 273 } 274 275 static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp, 276 void *out, size_t max) 277 { 278 /* 279 * A1 # map(1) 280 * 69 # text(9) - Name of field 281 * 47657452616E646F6D # "GetRandom" 282 * A1 # map(1) - The field itself 283 * 66 # text(6) 284 * 72616E646F6D # "random" 285 * # The rest of the response is random data 286 */ 287 const u8 response[] = { CBOR_TYPE_MAP + 1, 288 CBOR_TYPE_TEXT + strlen("GetRandom"), 289 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm', 290 CBOR_TYPE_MAP + 1, 291 CBOR_TYPE_TEXT + strlen("random"), 292 'r', 'a', 'n', 'd', 'o', 'm' }; 293 struct device *dev = &nsm->vdev->dev; 294 u8 *rand_data = NULL; 295 u8 *resp_ptr = resp->data; 296 u64 resp_len = resp->len; 297 int rc; 298 299 if ((resp->len < sizeof(response) + 1) || 300 (memcmp(resp_ptr, response, sizeof(response)) != 0)) { 301 dev_err(dev, "Invalid response for GetRandom"); 302 return -EFAULT; 303 } 304 305 resp_ptr += sizeof(response); 306 resp_len -= sizeof(response); 307 308 rc = cbor_object_get_array(resp_ptr, resp_len, &rand_data); 309 if (rc < 0) { 310 dev_err(dev, "GetRandom: Invalid CBOR encoding\n"); 311 return rc; 312 } 313 314 rc = min_t(size_t, rc, max); 315 memcpy(out, rand_data, rc); 316 317 return rc; 318 } 319 320 /* 321 * HwRNG implementation 322 */ 323 static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait) 324 { 325 struct nsm *nsm = hwrng_to_nsm(rng); 326 struct device *dev = &nsm->vdev->dev; 327 int rc = 0; 328 329 /* NSM always needs to wait for a response */ 330 if (!wait) 331 return 0; 332 333 mutex_lock(&nsm->lock); 334 335 rc = fill_req_get_random(nsm, &nsm->msg.req); 336 if (rc != 0) 337 goto out; 338 339 rc = nsm_sendrecv_msg_locked(nsm); 340 if (rc != 0) 341 goto out; 342 343 rc = parse_resp_get_random(nsm, &nsm->msg.resp, data, max); 344 if (rc < 0) 345 goto out; 346 347 dev_dbg(dev, "RNG: returning rand bytes = %d", rc); 348 out: 349 mutex_unlock(&nsm->lock); 350 return rc; 351 } 352 353 static long nsm_dev_ioctl(struct file *file, unsigned int cmd, 354 unsigned long arg) 355 { 356 void __user *argp = u64_to_user_ptr((u64)arg); 357 struct nsm *nsm = file_to_nsm(file); 358 struct nsm_raw raw; 359 int r = 0; 360 361 if (cmd != NSM_IOCTL_RAW) 362 return -EINVAL; 363 364 if (_IOC_SIZE(cmd) != sizeof(raw)) 365 return -EINVAL; 366 367 /* Copy user argument struct to kernel argument struct */ 368 r = -EFAULT; 369 if (copy_from_user(&raw, argp, _IOC_SIZE(cmd))) 370 goto out; 371 372 mutex_lock(&nsm->lock); 373 374 /* Convert kernel argument struct to device request */ 375 r = fill_req_raw(nsm, &nsm->msg.req, &raw); 376 if (r) 377 goto out; 378 379 /* Send message to NSM and read reply */ 380 r = nsm_sendrecv_msg_locked(nsm); 381 if (r) 382 goto out; 383 384 /* Parse device response into kernel argument struct */ 385 r = parse_resp_raw(nsm, &nsm->msg.resp, &raw); 386 if (r) 387 goto out; 388 389 /* Copy kernel argument struct back to user argument struct */ 390 r = -EFAULT; 391 if (copy_to_user(argp, &raw, sizeof(raw))) 392 goto out; 393 394 r = 0; 395 396 out: 397 mutex_unlock(&nsm->lock); 398 return r; 399 } 400 401 static int nsm_device_init_vq(struct virtio_device *vdev) 402 { 403 struct virtqueue *vq = virtio_find_single_vq(vdev, 404 nsm_vq_callback, "nsm.vq.0"); 405 struct nsm *nsm = vdev->priv; 406 407 if (IS_ERR(vq)) 408 return PTR_ERR(vq); 409 410 nsm->vq = vq; 411 412 return 0; 413 } 414 415 static const struct file_operations nsm_dev_fops = { 416 .unlocked_ioctl = nsm_dev_ioctl, 417 .compat_ioctl = compat_ptr_ioctl, 418 }; 419 420 /* Handler for probing the NSM device */ 421 static int nsm_device_probe(struct virtio_device *vdev) 422 { 423 struct device *dev = &vdev->dev; 424 struct nsm *nsm; 425 int rc; 426 427 nsm = devm_kzalloc(&vdev->dev, sizeof(*nsm), GFP_KERNEL); 428 if (!nsm) 429 return -ENOMEM; 430 431 vdev->priv = nsm; 432 nsm->vdev = vdev; 433 434 rc = nsm_device_init_vq(vdev); 435 if (rc) { 436 dev_err(dev, "queue failed to initialize: %d.\n", rc); 437 goto err_init_vq; 438 } 439 440 mutex_init(&nsm->lock); 441 442 /* Register as hwrng provider */ 443 nsm->hwrng = (struct hwrng) { 444 .read = nsm_rng_read, 445 .name = "nsm-hwrng", 446 .quality = 1000, 447 }; 448 449 rc = hwrng_register(&nsm->hwrng); 450 if (rc) { 451 dev_err(dev, "RNG initialization error: %d.\n", rc); 452 goto err_hwrng; 453 } 454 455 /* Register /dev/nsm device node */ 456 nsm->misc = (struct miscdevice) { 457 .minor = MISC_DYNAMIC_MINOR, 458 .name = "nsm", 459 .fops = &nsm_dev_fops, 460 .mode = 0666, 461 }; 462 463 rc = misc_register(&nsm->misc); 464 if (rc) { 465 dev_err(dev, "misc device registration error: %d.\n", rc); 466 goto err_misc; 467 } 468 469 return 0; 470 471 err_misc: 472 hwrng_unregister(&nsm->hwrng); 473 err_hwrng: 474 vdev->config->del_vqs(vdev); 475 err_init_vq: 476 return rc; 477 } 478 479 /* Handler for removing the NSM device */ 480 static void nsm_device_remove(struct virtio_device *vdev) 481 { 482 struct nsm *nsm = vdev->priv; 483 484 hwrng_unregister(&nsm->hwrng); 485 486 vdev->config->del_vqs(vdev); 487 misc_deregister(&nsm->misc); 488 } 489 490 /* NSM device configuration structure */ 491 static struct virtio_driver virtio_nsm_driver = { 492 .feature_table = 0, 493 .feature_table_size = 0, 494 .feature_table_legacy = 0, 495 .feature_table_size_legacy = 0, 496 .driver.name = KBUILD_MODNAME, 497 .driver.owner = THIS_MODULE, 498 .id_table = id_table, 499 .probe = nsm_device_probe, 500 .remove = nsm_device_remove, 501 }; 502 503 module_virtio_driver(virtio_nsm_driver); 504 MODULE_DEVICE_TABLE(virtio, id_table); 505 MODULE_DESCRIPTION("Virtio NSM driver"); 506 MODULE_LICENSE("GPL"); 507