xref: /linux/tools/testing/selftests/bpf/progs/bpf_smc.c (revision c4dde411bc366f568dbe33366253bbfea049e8ea)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include "vmlinux.h"
4 
5 #include <bpf/bpf_helpers.h>
6 #include <bpf/bpf_tracing.h>
7 #include "bpf_tracing_net.h"
8 
9 char _license[] SEC("license") = "GPL";
10 
11 #ifndef SMC_HS_CTRL_NAME_MAX
12 #define SMC_HS_CTRL_NAME_MAX 16
13 #endif
14 
15 enum {
16 	BPF_SMC_LISTEN	= 10,
17 };
18 
19 struct smc_sock___local {
20 	struct sock sk;
21 	struct smc_sock *listen_smc;
22 	bool use_fallback;
23 } __attribute__((preserve_access_index));
24 
25 struct smc_hs_ctrl___local {
26 	char name[SMC_HS_CTRL_NAME_MAX];
27 	int (*syn_option)(struct tcp_sock *);
28 	int (*synack_option)(const struct tcp_sock  *, struct inet_request_sock *);
29 } __attribute__((preserve_access_index));
30 
31 struct netns_smc___local {
32 	struct smc_hs_ctrl___local *hs_ctrl;
33 } __attribute__((preserve_access_index));
34 
35 struct net___local {
36 	struct netns_smc___local smc;
37 } __attribute__((preserve_access_index));
38 
39 int smc_cnt = 0;
40 int fallback_cnt = 0;
41 
42 SEC("fentry/smc_release")
43 int BPF_PROG(bpf_smc_release, struct socket *sock)
44 {
45 	/* only count from one side (client) */
46 	if (sock->sk->__sk_common.skc_state == BPF_SMC_LISTEN)
47 		return 0;
48 	smc_cnt++;
49 	return 0;
50 }
51 
52 SEC("fentry/smc_switch_to_fallback")
53 int BPF_PROG(bpf_smc_switch_to_fallback, struct smc_sock___local *smc)
54 {
55 	/* only count from one side (client) */
56 	if (smc && !smc->listen_smc)
57 		fallback_cnt++;
58 	return 0;
59 }
60 
61 /* go with default value if no strat was found */
62 bool default_ip_strat_value = true;
63 
64 struct smc_policy_ip_key {
65 	__u32	sip;
66 	__u32	dip;
67 };
68 
69 struct smc_policy_ip_value {
70 	__u8	mode;
71 };
72 
73 struct {
74 	__uint(type, BPF_MAP_TYPE_HASH);
75 	__uint(key_size, sizeof(struct smc_policy_ip_key));
76 	__uint(value_size, sizeof(struct smc_policy_ip_value));
77 	__uint(max_entries, 128);
78 	__uint(map_flags, BPF_F_NO_PREALLOC);
79 } smc_policy_ip SEC(".maps");
80 
81 static bool smc_check(__u32 src, __u32 dst)
82 {
83 	struct smc_policy_ip_value *value;
84 	struct smc_policy_ip_key key = {
85 		.sip = src,
86 		.dip = dst,
87 	};
88 
89 	value = bpf_map_lookup_elem(&smc_policy_ip, &key);
90 	return value ? value->mode : default_ip_strat_value;
91 }
92 
93 SEC("fmod_ret/update_socket_protocol")
94 int BPF_PROG(smc_run, int family, int type, int protocol)
95 {
96 	struct task_struct *task;
97 
98 	if (family != AF_INET && family != AF_INET6)
99 		return protocol;
100 
101 	if ((type & 0xf) != SOCK_STREAM)
102 		return protocol;
103 
104 	if (protocol != 0 && protocol != IPPROTO_TCP)
105 		return protocol;
106 
107 	task = bpf_get_current_task_btf();
108 	/* Prevent from affecting other tests */
109 	if (!task) {
110 		return protocol;
111 	} else {
112 		struct net___local *net = (struct net___local *)task->nsproxy->net_ns;
113 
114 		if (!bpf_core_field_exists(struct net___local, smc) || !net->smc.hs_ctrl)
115 			return protocol;
116 	}
117 
118 	return IPPROTO_SMC;
119 }
120 
121 SEC("struct_ops")
122 int BPF_PROG(bpf_smc_set_tcp_option_cond, const struct tcp_sock *tp,
123 	     struct inet_request_sock *ireq)
124 {
125 	return smc_check(ireq->req.__req_common.skc_daddr,
126 			 ireq->req.__req_common.skc_rcv_saddr);
127 }
128 
129 SEC("struct_ops")
130 int BPF_PROG(bpf_smc_set_tcp_option, struct tcp_sock *tp)
131 {
132 	return smc_check(tp->inet_conn.icsk_inet.sk.__sk_common.skc_rcv_saddr,
133 			 tp->inet_conn.icsk_inet.sk.__sk_common.skc_daddr);
134 }
135 
136 SEC(".struct_ops")
137 struct smc_hs_ctrl___local  linkcheck = {
138 	.name		= "linkcheck",
139 	.syn_option	= (void *)bpf_smc_set_tcp_option,
140 	.synack_option	= (void *)bpf_smc_set_tcp_option_cond,
141 };
142