xref: /linux/tools/testing/selftests/bpf/progs/bpf_smc.c (revision beb3c67297d92f9428484410cf79135d38d0aff3)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include "vmlinux.h"
4 
5 #include <bpf/bpf_helpers.h>
6 #include <bpf/bpf_tracing.h>
7 #include "bpf_tracing_net.h"
8 
9 char _license[] SEC("license") = "GPL";
10 
11 enum {
12 	BPF_SMC_LISTEN	= 10,
13 };
14 
15 struct smc_sock___local {
16 	struct sock sk;
17 	struct smc_sock *listen_smc;
18 	bool use_fallback;
19 } __attribute__((preserve_access_index));
20 
21 int smc_cnt = 0;
22 int fallback_cnt = 0;
23 
24 SEC("fentry/smc_release")
25 int BPF_PROG(bpf_smc_release, struct socket *sock)
26 {
27 	/* only count from one side (client) */
28 	if (sock->sk->__sk_common.skc_state == BPF_SMC_LISTEN)
29 		return 0;
30 	smc_cnt++;
31 	return 0;
32 }
33 
34 SEC("fentry/smc_switch_to_fallback")
35 int BPF_PROG(bpf_smc_switch_to_fallback, struct smc_sock___local *smc)
36 {
37 	/* only count from one side (client) */
38 	if (smc && !smc->listen_smc)
39 		fallback_cnt++;
40 	return 0;
41 }
42 
43 /* go with default value if no strat was found */
44 bool default_ip_strat_value = true;
45 
46 struct smc_policy_ip_key {
47 	__u32	sip;
48 	__u32	dip;
49 };
50 
51 struct smc_policy_ip_value {
52 	__u8	mode;
53 };
54 
55 struct {
56 	__uint(type, BPF_MAP_TYPE_HASH);
57 	__uint(key_size, sizeof(struct smc_policy_ip_key));
58 	__uint(value_size, sizeof(struct smc_policy_ip_value));
59 	__uint(max_entries, 128);
60 	__uint(map_flags, BPF_F_NO_PREALLOC);
61 } smc_policy_ip SEC(".maps");
62 
63 static bool smc_check(__u32 src, __u32 dst)
64 {
65 	struct smc_policy_ip_value *value;
66 	struct smc_policy_ip_key key = {
67 		.sip = src,
68 		.dip = dst,
69 	};
70 
71 	value = bpf_map_lookup_elem(&smc_policy_ip, &key);
72 	return value ? value->mode : default_ip_strat_value;
73 }
74 
75 SEC("fmod_ret/update_socket_protocol")
76 int BPF_PROG(smc_run, int family, int type, int protocol)
77 {
78 	struct task_struct *task;
79 
80 	if (family != AF_INET && family != AF_INET6)
81 		return protocol;
82 
83 	if ((type & 0xf) != SOCK_STREAM)
84 		return protocol;
85 
86 	if (protocol != 0 && protocol != IPPROTO_TCP)
87 		return protocol;
88 
89 	task = bpf_get_current_task_btf();
90 	/* Prevent from affecting other tests */
91 	if (!task || !task->nsproxy->net_ns->smc.hs_ctrl)
92 		return protocol;
93 
94 	return IPPROTO_SMC;
95 }
96 
97 SEC("struct_ops")
98 int BPF_PROG(bpf_smc_set_tcp_option_cond, const struct tcp_sock *tp,
99 	     struct inet_request_sock *ireq)
100 {
101 	return smc_check(ireq->req.__req_common.skc_daddr,
102 			 ireq->req.__req_common.skc_rcv_saddr);
103 }
104 
105 SEC("struct_ops")
106 int BPF_PROG(bpf_smc_set_tcp_option, struct tcp_sock *tp)
107 {
108 	return smc_check(tp->inet_conn.icsk_inet.sk.__sk_common.skc_rcv_saddr,
109 			 tp->inet_conn.icsk_inet.sk.__sk_common.skc_daddr);
110 }
111 
112 SEC(".struct_ops")
113 struct smc_hs_ctrl  linkcheck = {
114 	.name		= "linkcheck",
115 	.syn_option	= (void *)bpf_smc_set_tcp_option,
116 	.synack_option	= (void *)bpf_smc_set_tcp_option_cond,
117 };
118