1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Handle the TLS Alert protocol 4 * 5 * Author: Chuck Lever <chuck.lever@oracle.com> 6 * 7 * Copyright (c) 2023, Oracle and/or its affiliates. 8 */ 9 10 #include <linux/types.h> 11 #include <linux/socket.h> 12 #include <linux/kernel.h> 13 #include <linux/module.h> 14 #include <linux/skbuff.h> 15 #include <linux/inet.h> 16 17 #include <net/sock.h> 18 #include <net/handshake.h> 19 #include <net/tls.h> 20 #include <net/tls_prot.h> 21 22 #include "handshake.h" 23 24 /** 25 * tls_alert_send - send a TLS Alert on a kTLS socket 26 * @sock: open kTLS socket to send on 27 * @level: TLS Alert level 28 * @description: TLS Alert description 29 * 30 * Returns zero on success or a negative errno. 31 */ 32 int tls_alert_send(struct socket *sock, u8 level, u8 description) 33 { 34 u8 record_type = TLS_RECORD_TYPE_ALERT; 35 u8 buf[CMSG_SPACE(sizeof(record_type))]; 36 struct msghdr msg = { 0 }; 37 struct cmsghdr *cmsg; 38 struct kvec iov; 39 u8 alert[2]; 40 int ret; 41 42 alert[0] = level; 43 alert[1] = description; 44 iov.iov_base = alert; 45 iov.iov_len = sizeof(alert); 46 47 memset(buf, 0, sizeof(buf)); 48 msg.msg_control = buf; 49 msg.msg_controllen = sizeof(buf); 50 msg.msg_flags = MSG_DONTWAIT; 51 52 cmsg = CMSG_FIRSTHDR(&msg); 53 cmsg->cmsg_level = SOL_TLS; 54 cmsg->cmsg_type = TLS_SET_RECORD_TYPE; 55 cmsg->cmsg_len = CMSG_LEN(sizeof(record_type)); 56 memcpy(CMSG_DATA(cmsg), &record_type, sizeof(record_type)); 57 58 iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, iov.iov_len); 59 ret = sock_sendmsg(sock, &msg); 60 return ret < 0 ? ret : 0; 61 } 62 63 /** 64 * tls_get_record_type - Look for TLS RECORD_TYPE information 65 * @sk: socket (for IP address information) 66 * @cmsg: incoming message to be parsed 67 * 68 * Returns zero or a TLS_RECORD_TYPE value. 69 */ 70 u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *cmsg) 71 { 72 u8 record_type; 73 74 if (cmsg->cmsg_level != SOL_TLS) 75 return 0; 76 if (cmsg->cmsg_type != TLS_GET_RECORD_TYPE) 77 return 0; 78 79 record_type = *((u8 *)CMSG_DATA(cmsg)); 80 return record_type; 81 } 82 EXPORT_SYMBOL(tls_get_record_type); 83 84 /** 85 * tls_alert_recv - Parse TLS Alert messages 86 * @sk: socket (for IP address information) 87 * @msg: incoming message to be parsed 88 * @level: OUT - TLS AlertLevel value 89 * @description: OUT - TLS AlertDescription value 90 * 91 */ 92 void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, 93 u8 *level, u8 *description) 94 { 95 const struct kvec *iov; 96 u8 *data; 97 98 iov = msg->msg_iter.kvec; 99 data = iov->iov_base; 100 *level = data[0]; 101 *description = data[1]; 102 } 103 EXPORT_SYMBOL(tls_alert_recv); 104