xref: /linux/include/net/psp/functions.h (revision f8d2f8205be8cceef2dd3c0e68e7af3c5f83c75c)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 
3 #ifndef __NET_PSP_HELPERS_H
4 #define __NET_PSP_HELPERS_H
5 
6 #include <linux/skbuff.h>
7 #include <linux/rcupdate.h>
8 #include <linux/udp.h>
9 #include <net/sock.h>
10 #include <net/tcp.h>
11 #include <net/psp/types.h>
12 
13 struct inet_timewait_sock;
14 
15 /* Driver-facing API */
16 struct psp_dev *
17 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
18 	       struct psp_dev_caps *psd_caps, void *priv_ptr);
19 void psp_dev_unregister(struct psp_dev *psd);
20 bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
21 			 u8 ver, __be16 sport);
22 int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
23 
24 /* Kernel-facing API */
25 void psp_assoc_put(struct psp_assoc *pas);
26 
27 static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
28 {
29 	return pas->drv_data;
30 }
31 
32 #if IS_ENABLED(CONFIG_INET_PSP)
33 unsigned int psp_key_size(u32 version);
34 void psp_sk_assoc_free(struct sock *sk);
35 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
36 void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
37 void psp_reply_set_decrypted(struct sk_buff *skb);
38 
39 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
40 {
41 	return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
42 }
43 
44 static inline void
45 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
46 {
47 	struct psp_assoc *pas;
48 
49 	pas = psp_sk_assoc(sk);
50 	if (pas && pas->tx.spi)
51 		skb->decrypted = 1;
52 }
53 
54 static inline unsigned long
55 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
56 			unsigned long diffs)
57 {
58 	struct psp_skb_ext *a, *b;
59 
60 	a = skb_ext_find(one, SKB_EXT_PSP);
61 	b = skb_ext_find(two, SKB_EXT_PSP);
62 
63 	diffs |= (!!a) ^ (!!b);
64 	if (!diffs && unlikely(a))
65 		diffs |= memcmp(a, b, sizeof(*a));
66 	return diffs;
67 }
68 
69 static inline bool
70 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
71 {
72 	bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
73 	u32 end_seq = TCP_SKB_CB(skb)->end_seq;
74 	u32 seq = TCP_SKB_CB(skb)->seq;
75 	bool pure_fin;
76 
77 	pure_fin = fin && end_seq - seq == 1;
78 
79 	return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
80 }
81 
82 static inline bool
83 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
84 {
85 	return pse && pas->rx.spi == pse->spi &&
86 	       pas->generation == pse->generation &&
87 	       pas->version == pse->version &&
88 	       pas->dev_id == pse->dev_id;
89 }
90 
91 static inline enum skb_drop_reason
92 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
93 {
94 	struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
95 
96 	if (!pas)
97 		return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
98 
99 	if (likely(psp_pse_matches_pas(pse, pas))) {
100 		if (unlikely(!pas->peer_tx))
101 			pas->peer_tx = 1;
102 
103 		return 0;
104 	}
105 
106 	if (!pse) {
107 		if (!pas->tx.spi ||
108 		    (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
109 			return 0;
110 	}
111 
112 	return SKB_DROP_REASON_PSP_INPUT;
113 }
114 
115 static inline enum skb_drop_reason
116 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
117 {
118 	return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
119 }
120 
121 static inline enum skb_drop_reason
122 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
123 {
124 	return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
125 }
126 
127 static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk)
128 {
129 	struct inet_timewait_sock *tw;
130 	struct psp_assoc *pas;
131 	int state;
132 
133 	state = 1 << READ_ONCE(sk->sk_state);
134 	if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV)
135 		return NULL;
136 
137 	tw = inet_twsk(sk);
138 	pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) :
139 				       rcu_dereference(sk->psp_assoc);
140 	return pas;
141 }
142 
143 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
144 {
145 	if (!skb->decrypted || !skb->sk)
146 		return NULL;
147 
148 	return psp_sk_get_assoc_rcu(skb->sk);
149 }
150 
151 static inline unsigned int psp_sk_overhead(const struct sock *sk)
152 {
153 	int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
154 	bool has_psp = rcu_access_pointer(sk->psp_assoc);
155 
156 	return has_psp ? psp_encap : 0;
157 }
158 #else
159 static inline void psp_sk_assoc_free(struct sock *sk) { }
160 static inline void
161 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
162 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
163 static inline void
164 psp_reply_set_decrypted(struct sk_buff *skb) { }
165 
166 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
167 {
168 	return NULL;
169 }
170 
171 static inline void
172 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
173 
174 static inline unsigned long
175 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
176 			unsigned long diffs)
177 {
178 	return diffs;
179 }
180 
181 static inline enum skb_drop_reason
182 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
183 {
184 	return 0;
185 }
186 
187 static inline enum skb_drop_reason
188 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
189 {
190 	return 0;
191 }
192 
193 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
194 {
195 	return NULL;
196 }
197 
198 static inline unsigned int psp_sk_overhead(const struct sock *sk)
199 {
200 	return 0;
201 }
202 #endif
203 
204 static inline unsigned long
205 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
206 {
207 	return __psp_skb_coalesce_diff(one, two, 0);
208 }
209 
210 #endif /* __NET_PSP_HELPERS_H */
211