1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright 2025 Google LLC */ 3 4 #include "bpf_tracing_net.h" 5 #include <bpf/bpf_helpers.h> 6 #include <bpf/bpf_tracing.h> 7 #include <errno.h> 8 9 extern int tcp_memory_per_cpu_fw_alloc __ksym; 10 extern int udp_memory_per_cpu_fw_alloc __ksym; 11 12 int nr_cpus; 13 bool tcp_activated, udp_activated; 14 long tcp_memory_allocated, udp_memory_allocated; 15 16 struct sk_prot { 17 long *memory_allocated; 18 int *memory_per_cpu_fw_alloc; 19 }; 20 21 static int drain_memory_per_cpu_fw_alloc(__u32 i, struct sk_prot *sk_prot_ctx) 22 { 23 int *memory_per_cpu_fw_alloc; 24 25 memory_per_cpu_fw_alloc = bpf_per_cpu_ptr(sk_prot_ctx->memory_per_cpu_fw_alloc, i); 26 if (memory_per_cpu_fw_alloc) 27 *sk_prot_ctx->memory_allocated += *memory_per_cpu_fw_alloc; 28 29 return 0; 30 } 31 32 static long get_memory_allocated(struct sock *_sk, int *memory_per_cpu_fw_alloc) 33 { 34 struct sock *sk = bpf_core_cast(_sk, struct sock); 35 struct sk_prot sk_prot_ctx; 36 long memory_allocated; 37 38 /* net_aligned_data.{tcp,udp}_memory_allocated was not available. */ 39 memory_allocated = sk->__sk_common.skc_prot->memory_allocated->counter; 40 41 sk_prot_ctx.memory_allocated = &memory_allocated; 42 sk_prot_ctx.memory_per_cpu_fw_alloc = memory_per_cpu_fw_alloc; 43 44 bpf_loop(nr_cpus, drain_memory_per_cpu_fw_alloc, &sk_prot_ctx, 0); 45 46 return memory_allocated; 47 } 48 49 static void fentry_init_sock(struct sock *sk, bool *activated, 50 long *memory_allocated, int *memory_per_cpu_fw_alloc) 51 { 52 if (!*activated) 53 return; 54 55 *memory_allocated = get_memory_allocated(sk, memory_per_cpu_fw_alloc); 56 *activated = false; 57 } 58 59 SEC("fentry/tcp_init_sock") 60 int BPF_PROG(fentry_tcp_init_sock, struct sock *sk) 61 { 62 fentry_init_sock(sk, &tcp_activated, 63 &tcp_memory_allocated, &tcp_memory_per_cpu_fw_alloc); 64 return 0; 65 } 66 67 SEC("fentry/udp_init_sock") 68 int BPF_PROG(fentry_udp_init_sock, struct sock *sk) 69 { 70 fentry_init_sock(sk, &udp_activated, 71 &udp_memory_allocated, &udp_memory_per_cpu_fw_alloc); 72 return 0; 73 } 74 75 SEC("cgroup/sock_create") 76 int sock_create(struct bpf_sock *ctx) 77 { 78 int err, val = 1; 79 80 err = bpf_setsockopt(ctx, SOL_SOCKET, SK_BPF_BYPASS_PROT_MEM, 81 &val, sizeof(val)); 82 if (err) 83 goto err; 84 85 val = 0; 86 87 err = bpf_getsockopt(ctx, SOL_SOCKET, SK_BPF_BYPASS_PROT_MEM, 88 &val, sizeof(val)); 89 if (err) 90 goto err; 91 92 if (val != 1) { 93 err = -EINVAL; 94 goto err; 95 } 96 97 return 1; 98 99 err: 100 bpf_set_retval(err); 101 return 0; 102 } 103 104 char LICENSE[] SEC("license") = "GPL"; 105