1 // SPDX-License-Identifier: GPL-2.0 2 3 #include "vmlinux.h" 4 5 #include <bpf/bpf_helpers.h> 6 #include <bpf/bpf_tracing.h> 7 #include "bpf_tracing_net.h" 8 9 char _license[] SEC("license") = "GPL"; 10 11 #ifndef SMC_HS_CTRL_NAME_MAX 12 #define SMC_HS_CTRL_NAME_MAX 16 13 #endif 14 15 enum { 16 BPF_SMC_LISTEN = 10, 17 }; 18 19 struct smc_sock___local { 20 struct sock sk; 21 struct smc_sock *listen_smc; 22 bool use_fallback; 23 } __attribute__((preserve_access_index)); 24 25 struct smc_hs_ctrl___local { 26 char name[SMC_HS_CTRL_NAME_MAX]; 27 int (*syn_option)(struct tcp_sock *); 28 int (*synack_option)(const struct tcp_sock *, struct inet_request_sock *); 29 } __attribute__((preserve_access_index)); 30 31 struct netns_smc___local { 32 struct smc_hs_ctrl___local *hs_ctrl; 33 } __attribute__((preserve_access_index)); 34 35 struct net___local { 36 struct netns_smc___local smc; 37 } __attribute__((preserve_access_index)); 38 39 int smc_cnt = 0; 40 int fallback_cnt = 0; 41 42 SEC("fentry/smc_release") 43 int BPF_PROG(bpf_smc_release, struct socket *sock) 44 { 45 /* only count from one side (client) */ 46 if (sock->sk->__sk_common.skc_state == BPF_SMC_LISTEN) 47 return 0; 48 smc_cnt++; 49 return 0; 50 } 51 52 SEC("fentry/smc_switch_to_fallback") 53 int BPF_PROG(bpf_smc_switch_to_fallback, struct smc_sock___local *smc) 54 { 55 /* only count from one side (client) */ 56 if (smc && !smc->listen_smc) 57 fallback_cnt++; 58 return 0; 59 } 60 61 /* go with default value if no strat was found */ 62 bool default_ip_strat_value = true; 63 64 struct smc_policy_ip_key { 65 __u32 sip; 66 __u32 dip; 67 }; 68 69 struct smc_policy_ip_value { 70 __u8 mode; 71 }; 72 73 struct { 74 __uint(type, BPF_MAP_TYPE_HASH); 75 __uint(key_size, sizeof(struct smc_policy_ip_key)); 76 __uint(value_size, sizeof(struct smc_policy_ip_value)); 77 __uint(max_entries, 128); 78 __uint(map_flags, BPF_F_NO_PREALLOC); 79 } smc_policy_ip SEC(".maps"); 80 81 static bool smc_check(__u32 src, __u32 dst) 82 { 83 struct smc_policy_ip_value *value; 84 struct smc_policy_ip_key key = { 85 .sip = src, 86 .dip = dst, 87 }; 88 89 value = bpf_map_lookup_elem(&smc_policy_ip, &key); 90 return value ? value->mode : default_ip_strat_value; 91 } 92 93 SEC("fmod_ret/update_socket_protocol") 94 int BPF_PROG(smc_run, int family, int type, int protocol) 95 { 96 struct task_struct *task; 97 98 if (family != AF_INET && family != AF_INET6) 99 return protocol; 100 101 if ((type & 0xf) != SOCK_STREAM) 102 return protocol; 103 104 if (protocol != 0 && protocol != IPPROTO_TCP) 105 return protocol; 106 107 task = bpf_get_current_task_btf(); 108 /* Prevent from affecting other tests */ 109 if (!task) { 110 return protocol; 111 } else { 112 struct net___local *net = (struct net___local *)task->nsproxy->net_ns; 113 114 if (!bpf_core_field_exists(struct net___local, smc) || !net->smc.hs_ctrl) 115 return protocol; 116 } 117 118 return IPPROTO_SMC; 119 } 120 121 SEC("struct_ops") 122 int BPF_PROG(bpf_smc_set_tcp_option_cond, const struct tcp_sock *tp, 123 struct inet_request_sock *ireq) 124 { 125 return smc_check(ireq->req.__req_common.skc_daddr, 126 ireq->req.__req_common.skc_rcv_saddr); 127 } 128 129 SEC("struct_ops") 130 int BPF_PROG(bpf_smc_set_tcp_option, struct tcp_sock *tp) 131 { 132 return smc_check(tp->inet_conn.icsk_inet.sk.__sk_common.skc_rcv_saddr, 133 tp->inet_conn.icsk_inet.sk.__sk_common.skc_daddr); 134 } 135 136 SEC(".struct_ops") 137 struct smc_hs_ctrl___local linkcheck = { 138 .name = "linkcheck", 139 .syn_option = (void *)bpf_smc_set_tcp_option, 140 .synack_option = (void *)bpf_smc_set_tcp_option_cond, 141 }; 142