xref: /linux/include/net/psp/functions.h (revision 6b46ca260e2290e3453d1355ab5b6d283d73d780)
1 /* SPDX-License-Identifier: GPL-2.0-only */
2 
3 #ifndef __NET_PSP_HELPERS_H
4 #define __NET_PSP_HELPERS_H
5 
6 #include <linux/skbuff.h>
7 #include <linux/rcupdate.h>
8 #include <net/sock.h>
9 #include <net/tcp.h>
10 #include <net/psp/types.h>
11 
12 struct inet_timewait_sock;
13 
14 /* Driver-facing API */
15 struct psp_dev *
16 psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
17 	       struct psp_dev_caps *psd_caps, void *priv_ptr);
18 void psp_dev_unregister(struct psp_dev *psd);
19 
20 /* Kernel-facing API */
21 void psp_assoc_put(struct psp_assoc *pas);
22 
23 static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
24 {
25 	return pas->drv_data;
26 }
27 
28 #if IS_ENABLED(CONFIG_INET_PSP)
29 unsigned int psp_key_size(u32 version);
30 void psp_sk_assoc_free(struct sock *sk);
31 void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
32 void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
33 void psp_reply_set_decrypted(struct sk_buff *skb);
34 
35 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
36 {
37 	return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
38 }
39 
40 static inline void
41 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
42 {
43 	struct psp_assoc *pas;
44 
45 	pas = psp_sk_assoc(sk);
46 	if (pas && pas->tx.spi)
47 		skb->decrypted = 1;
48 }
49 
50 static inline unsigned long
51 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
52 			unsigned long diffs)
53 {
54 	struct psp_skb_ext *a, *b;
55 
56 	a = skb_ext_find(one, SKB_EXT_PSP);
57 	b = skb_ext_find(two, SKB_EXT_PSP);
58 
59 	diffs |= (!!a) ^ (!!b);
60 	if (!diffs && unlikely(a))
61 		diffs |= memcmp(a, b, sizeof(*a));
62 	return diffs;
63 }
64 
65 static inline bool
66 psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
67 {
68 	bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
69 	u32 end_seq = TCP_SKB_CB(skb)->end_seq;
70 	u32 seq = TCP_SKB_CB(skb)->seq;
71 	bool pure_fin;
72 
73 	pure_fin = fin && end_seq - seq == 1;
74 
75 	return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
76 }
77 
78 static inline bool
79 psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
80 {
81 	return pse && pas->rx.spi == pse->spi &&
82 	       pas->generation == pse->generation &&
83 	       pas->version == pse->version &&
84 	       pas->dev_id == pse->dev_id;
85 }
86 
87 static inline enum skb_drop_reason
88 __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
89 {
90 	struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
91 
92 	if (!pas)
93 		return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
94 
95 	if (likely(psp_pse_matches_pas(pse, pas))) {
96 		if (unlikely(!pas->peer_tx))
97 			pas->peer_tx = 1;
98 
99 		return 0;
100 	}
101 
102 	if (!pse) {
103 		if (!pas->tx.spi ||
104 		    (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
105 			return 0;
106 	}
107 
108 	return SKB_DROP_REASON_PSP_INPUT;
109 }
110 
111 static inline enum skb_drop_reason
112 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
113 {
114 	return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
115 }
116 
117 static inline enum skb_drop_reason
118 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
119 {
120 	return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
121 }
122 
123 static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk)
124 {
125 	struct inet_timewait_sock *tw;
126 	struct psp_assoc *pas;
127 	int state;
128 
129 	state = 1 << READ_ONCE(sk->sk_state);
130 	if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV)
131 		return NULL;
132 
133 	tw = inet_twsk(sk);
134 	pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) :
135 				       rcu_dereference(sk->psp_assoc);
136 	return pas;
137 }
138 
139 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
140 {
141 	if (!skb->decrypted || !skb->sk)
142 		return NULL;
143 
144 	return psp_sk_get_assoc_rcu(skb->sk);
145 }
146 #else
147 static inline void psp_sk_assoc_free(struct sock *sk) { }
148 static inline void
149 psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
150 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
151 static inline void
152 psp_reply_set_decrypted(struct sk_buff *skb) { }
153 
154 static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
155 {
156 	return NULL;
157 }
158 
159 static inline void
160 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
161 
162 static inline unsigned long
163 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
164 			unsigned long diffs)
165 {
166 	return diffs;
167 }
168 
169 static inline enum skb_drop_reason
170 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
171 {
172 	return 0;
173 }
174 
175 static inline enum skb_drop_reason
176 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
177 {
178 	return 0;
179 }
180 
181 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
182 {
183 	return NULL;
184 }
185 #endif
186 
187 static inline unsigned long
188 psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
189 {
190 	return __psp_skb_coalesce_diff(one, two, 0);
191 }
192 
193 #endif /* __NET_PSP_HELPERS_H */
194