1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Handle the TLS Alert protocol 4 * 5 * Author: Chuck Lever <chuck.lever@oracle.com> 6 * 7 * Copyright (c) 2023, Oracle and/or its affiliates. 8 */ 9 10 #include <linux/types.h> 11 #include <linux/socket.h> 12 #include <linux/kernel.h> 13 #include <linux/module.h> 14 #include <linux/skbuff.h> 15 #include <linux/inet.h> 16 17 #include <net/sock.h> 18 #include <net/handshake.h> 19 #include <net/tls.h> 20 #include <net/tls_prot.h> 21 22 #include "handshake.h" 23 24 #include <trace/events/handshake.h> 25 26 /** 27 * tls_alert_send - send a TLS Alert on a kTLS socket 28 * @sock: open kTLS socket to send on 29 * @level: TLS Alert level 30 * @description: TLS Alert description 31 * 32 * Returns zero on success or a negative errno. 33 */ 34 int tls_alert_send(struct socket *sock, u8 level, u8 description) 35 { 36 u8 record_type = TLS_RECORD_TYPE_ALERT; 37 u8 buf[CMSG_SPACE(sizeof(record_type))]; 38 struct msghdr msg = { 0 }; 39 struct cmsghdr *cmsg; 40 struct kvec iov; 41 u8 alert[2]; 42 int ret; 43 44 trace_tls_alert_send(sock->sk, level, description); 45 46 alert[0] = level; 47 alert[1] = description; 48 iov.iov_base = alert; 49 iov.iov_len = sizeof(alert); 50 51 memset(buf, 0, sizeof(buf)); 52 msg.msg_control = buf; 53 msg.msg_controllen = sizeof(buf); 54 msg.msg_flags = MSG_DONTWAIT; 55 56 cmsg = CMSG_FIRSTHDR(&msg); 57 cmsg->cmsg_level = SOL_TLS; 58 cmsg->cmsg_type = TLS_SET_RECORD_TYPE; 59 cmsg->cmsg_len = CMSG_LEN(sizeof(record_type)); 60 memcpy(CMSG_DATA(cmsg), &record_type, sizeof(record_type)); 61 62 iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, iov.iov_len); 63 ret = sock_sendmsg(sock, &msg); 64 return ret < 0 ? ret : 0; 65 } 66 67 /** 68 * tls_get_record_type - Look for TLS RECORD_TYPE information 69 * @sk: socket (for IP address information) 70 * @cmsg: incoming message to be parsed 71 * 72 * Returns zero or a TLS_RECORD_TYPE value. 73 */ 74 u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *cmsg) 75 { 76 u8 record_type; 77 78 if (cmsg->cmsg_level != SOL_TLS) 79 return 0; 80 if (cmsg->cmsg_type != TLS_GET_RECORD_TYPE) 81 return 0; 82 83 record_type = *((u8 *)CMSG_DATA(cmsg)); 84 trace_tls_contenttype(sk, record_type); 85 return record_type; 86 } 87 EXPORT_SYMBOL(tls_get_record_type); 88 89 /** 90 * tls_alert_recv - Parse TLS Alert messages 91 * @sk: socket (for IP address information) 92 * @msg: incoming message to be parsed 93 * @level: OUT - TLS AlertLevel value 94 * @description: OUT - TLS AlertDescription value 95 * 96 */ 97 void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, 98 u8 *level, u8 *description) 99 { 100 const struct kvec *iov; 101 u8 *data; 102 103 iov = msg->msg_iter.kvec; 104 data = iov->iov_base; 105 *level = data[0]; 106 *description = data[1]; 107 108 trace_tls_alert_recv(sk, *level, *description); 109 } 110 EXPORT_SYMBOL(tls_alert_recv); 111