xref: /linux/include/net/psp/functions.h (revision fc724515741a1b86ca0457825fdb784ab038e92c)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 
3 #ifndef __NET_PSP_HELPERS_H
4 #define __NET_PSP_HELPERS_H
5 
6 #include <linux/skbuff.h>
7 #include <linux/rcupdate.h>
8 #include <linux/udp.h>
9 #include <net/sock.h>
10 #include <net/tcp.h>
11 #include <net/psp/types.h>
12 
13 struct inet_timewait_sock;
14 
15 /* Driver-facing API */
16 struct psp_dev *
17 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
18 	       struct psp_dev_caps *psd_caps, void *priv_ptr);
19 void psp_dev_unregister(struct psp_dev *psd);
20 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
21 			 u8 ver, __be16 sport);
22 
23 /* Kernel-facing API */
24 void psp_assoc_put(struct psp_assoc *pas);
25 
26 static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
27 {
28 	return pas->drv_data;
29 }
30 
31 #if IS_ENABLED(CONFIG_INET_PSP)
32 unsigned int psp_key_size(u32 version);
33 void psp_sk_assoc_free(struct sock *sk);
34 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
35 void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
36 void psp_reply_set_decrypted(struct sk_buff *skb);
37 
38 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
39 {
40 	return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
41 }
42 
43 static inline void
44 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
45 {
46 	struct psp_assoc *pas;
47 
48 	pas = psp_sk_assoc(sk);
49 	if (pas && pas->tx.spi)
50 		skb->decrypted = 1;
51 }
52 
53 static inline unsigned long
54 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
55 			unsigned long diffs)
56 {
57 	struct psp_skb_ext *a, *b;
58 
59 	a = skb_ext_find(one, SKB_EXT_PSP);
60 	b = skb_ext_find(two, SKB_EXT_PSP);
61 
62 	diffs |= (!!a) ^ (!!b);
63 	if (!diffs && unlikely(a))
64 		diffs |= memcmp(a, b, sizeof(*a));
65 	return diffs;
66 }
67 
68 static inline bool
69 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
70 {
71 	bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
72 	u32 end_seq = TCP_SKB_CB(skb)->end_seq;
73 	u32 seq = TCP_SKB_CB(skb)->seq;
74 	bool pure_fin;
75 
76 	pure_fin = fin && end_seq - seq == 1;
77 
78 	return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
79 }
80 
81 static inline bool
82 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
83 {
84 	return pse && pas->rx.spi == pse->spi &&
85 	       pas->generation == pse->generation &&
86 	       pas->version == pse->version &&
87 	       pas->dev_id == pse->dev_id;
88 }
89 
90 static inline enum skb_drop_reason
91 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
92 {
93 	struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
94 
95 	if (!pas)
96 		return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
97 
98 	if (likely(psp_pse_matches_pas(pse, pas))) {
99 		if (unlikely(!pas->peer_tx))
100 			pas->peer_tx = 1;
101 
102 		return 0;
103 	}
104 
105 	if (!pse) {
106 		if (!pas->tx.spi ||
107 		    (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
108 			return 0;
109 	}
110 
111 	return SKB_DROP_REASON_PSP_INPUT;
112 }
113 
114 static inline enum skb_drop_reason
115 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
116 {
117 	return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
118 }
119 
120 static inline enum skb_drop_reason
121 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
122 {
123 	return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
124 }
125 
126 static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk)
127 {
128 	struct inet_timewait_sock *tw;
129 	struct psp_assoc *pas;
130 	int state;
131 
132 	state = 1 << READ_ONCE(sk->sk_state);
133 	if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV)
134 		return NULL;
135 
136 	tw = inet_twsk(sk);
137 	pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) :
138 				       rcu_dereference(sk->psp_assoc);
139 	return pas;
140 }
141 
142 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
143 {
144 	if (!skb->decrypted || !skb->sk)
145 		return NULL;
146 
147 	return psp_sk_get_assoc_rcu(skb->sk);
148 }
149 
150 static inline unsigned int psp_sk_overhead(const struct sock *sk)
151 {
152 	int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
153 	bool has_psp = rcu_access_pointer(sk->psp_assoc);
154 
155 	return has_psp ? psp_encap : 0;
156 }
157 #else
158 static inline void psp_sk_assoc_free(struct sock *sk) { }
159 static inline void
160 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
161 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
162 static inline void
163 psp_reply_set_decrypted(struct sk_buff *skb) { }
164 
165 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
166 {
167 	return NULL;
168 }
169 
170 static inline void
171 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
172 
173 static inline unsigned long
174 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
175 			unsigned long diffs)
176 {
177 	return diffs;
178 }
179 
180 static inline enum skb_drop_reason
181 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
182 {
183 	return 0;
184 }
185 
186 static inline enum skb_drop_reason
187 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
188 {
189 	return 0;
190 }
191 
192 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
193 {
194 	return NULL;
195 }
196 
197 static inline unsigned int psp_sk_overhead(const struct sock *sk)
198 {
199 	return 0;
200 }
201 #endif
202 
203 static inline unsigned long
204 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
205 {
206 	return __psp_skb_coalesce_diff(one, two, 0);
207 }
208 
209 #endif /* __NET_PSP_HELPERS_H */
210