100c94ca2SJakub Kicinski /* SPDX-License-Identifier: GPL-2.0-only */ 200c94ca2SJakub Kicinski 300c94ca2SJakub Kicinski #ifndef __NET_PSP_HELPERS_H 400c94ca2SJakub Kicinski #define __NET_PSP_HELPERS_H 500c94ca2SJakub Kicinski 6659a2899SJakub Kicinski #include <linux/skbuff.h> 7*6b46ca26SJakub Kicinski #include <linux/rcupdate.h> 8659a2899SJakub Kicinski #include <net/sock.h> 9*6b46ca26SJakub Kicinski #include <net/tcp.h> 1000c94ca2SJakub Kicinski #include <net/psp/types.h> 1100c94ca2SJakub Kicinski 12ed8a507bSJakub Kicinski struct inet_timewait_sock; 13ed8a507bSJakub Kicinski 1400c94ca2SJakub Kicinski /* Driver-facing API */ 1500c94ca2SJakub Kicinski struct psp_dev * 1600c94ca2SJakub Kicinski psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops, 1700c94ca2SJakub Kicinski struct psp_dev_caps *psd_caps, void *priv_ptr); 1800c94ca2SJakub Kicinski void psp_dev_unregister(struct psp_dev *psd); 1900c94ca2SJakub Kicinski 20ed8a507bSJakub Kicinski /* Kernel-facing API */ 21*6b46ca26SJakub Kicinski void psp_assoc_put(struct psp_assoc *pas); 22*6b46ca26SJakub Kicinski 23*6b46ca26SJakub Kicinski static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 24*6b46ca26SJakub Kicinski { 25*6b46ca26SJakub Kicinski return pas->drv_data; 26*6b46ca26SJakub Kicinski } 27*6b46ca26SJakub Kicinski 28659a2899SJakub Kicinski #if IS_ENABLED(CONFIG_INET_PSP) 29*6b46ca26SJakub Kicinski unsigned int psp_key_size(u32 version); 30*6b46ca26SJakub Kicinski void psp_sk_assoc_free(struct sock *sk); 31*6b46ca26SJakub Kicinski void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 32*6b46ca26SJakub Kicinski void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 33*6b46ca26SJakub Kicinski void psp_reply_set_decrypted(struct sk_buff *skb); 34*6b46ca26SJakub Kicinski 35*6b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 36*6b46ca26SJakub Kicinski { 37*6b46ca26SJakub Kicinski return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 38*6b46ca26SJakub Kicinski } 39659a2899SJakub Kicinski 40659a2899SJakub Kicinski static inline void 41659a2899SJakub Kicinski psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 42659a2899SJakub Kicinski { 43*6b46ca26SJakub Kicinski struct psp_assoc *pas; 44*6b46ca26SJakub Kicinski 45*6b46ca26SJakub Kicinski pas = psp_sk_assoc(sk); 46*6b46ca26SJakub Kicinski if (pas && pas->tx.spi) 47*6b46ca26SJakub Kicinski skb->decrypted = 1; 48659a2899SJakub Kicinski } 49659a2899SJakub Kicinski 50659a2899SJakub Kicinski static inline unsigned long 51659a2899SJakub Kicinski __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 52659a2899SJakub Kicinski unsigned long diffs) 53659a2899SJakub Kicinski { 54*6b46ca26SJakub Kicinski struct psp_skb_ext *a, *b; 55*6b46ca26SJakub Kicinski 56*6b46ca26SJakub Kicinski a = skb_ext_find(one, SKB_EXT_PSP); 57*6b46ca26SJakub Kicinski b = skb_ext_find(two, SKB_EXT_PSP); 58*6b46ca26SJakub Kicinski 59*6b46ca26SJakub Kicinski diffs |= (!!a) ^ (!!b); 60*6b46ca26SJakub Kicinski if (!diffs && unlikely(a)) 61*6b46ca26SJakub Kicinski diffs |= memcmp(a, b, sizeof(*a)); 62659a2899SJakub Kicinski return diffs; 63659a2899SJakub Kicinski } 64659a2899SJakub Kicinski 65*6b46ca26SJakub Kicinski static inline bool 66*6b46ca26SJakub Kicinski psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 67*6b46ca26SJakub Kicinski { 68*6b46ca26SJakub Kicinski bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 69*6b46ca26SJakub Kicinski u32 end_seq = TCP_SKB_CB(skb)->end_seq; 70*6b46ca26SJakub Kicinski u32 seq = TCP_SKB_CB(skb)->seq; 71*6b46ca26SJakub Kicinski bool pure_fin; 72*6b46ca26SJakub Kicinski 73*6b46ca26SJakub Kicinski pure_fin = fin && end_seq - seq == 1; 74*6b46ca26SJakub Kicinski 75*6b46ca26SJakub Kicinski return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 76*6b46ca26SJakub Kicinski } 77*6b46ca26SJakub Kicinski 78*6b46ca26SJakub Kicinski static inline bool 79*6b46ca26SJakub Kicinski psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 80*6b46ca26SJakub Kicinski { 81*6b46ca26SJakub Kicinski return pse && pas->rx.spi == pse->spi && 82*6b46ca26SJakub Kicinski pas->generation == pse->generation && 83*6b46ca26SJakub Kicinski pas->version == pse->version && 84*6b46ca26SJakub Kicinski pas->dev_id == pse->dev_id; 85*6b46ca26SJakub Kicinski } 86*6b46ca26SJakub Kicinski 87*6b46ca26SJakub Kicinski static inline enum skb_drop_reason 88*6b46ca26SJakub Kicinski __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 89*6b46ca26SJakub Kicinski { 90*6b46ca26SJakub Kicinski struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 91*6b46ca26SJakub Kicinski 92*6b46ca26SJakub Kicinski if (!pas) 93*6b46ca26SJakub Kicinski return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 94*6b46ca26SJakub Kicinski 95*6b46ca26SJakub Kicinski if (likely(psp_pse_matches_pas(pse, pas))) { 96*6b46ca26SJakub Kicinski if (unlikely(!pas->peer_tx)) 97*6b46ca26SJakub Kicinski pas->peer_tx = 1; 98*6b46ca26SJakub Kicinski 99*6b46ca26SJakub Kicinski return 0; 100*6b46ca26SJakub Kicinski } 101*6b46ca26SJakub Kicinski 102*6b46ca26SJakub Kicinski if (!pse) { 103*6b46ca26SJakub Kicinski if (!pas->tx.spi || 104*6b46ca26SJakub Kicinski (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 105*6b46ca26SJakub Kicinski return 0; 106*6b46ca26SJakub Kicinski } 107*6b46ca26SJakub Kicinski 108*6b46ca26SJakub Kicinski return SKB_DROP_REASON_PSP_INPUT; 109*6b46ca26SJakub Kicinski } 110*6b46ca26SJakub Kicinski 111659a2899SJakub Kicinski static inline enum skb_drop_reason 112659a2899SJakub Kicinski psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 113659a2899SJakub Kicinski { 114*6b46ca26SJakub Kicinski return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 115659a2899SJakub Kicinski } 116659a2899SJakub Kicinski 117659a2899SJakub Kicinski static inline enum skb_drop_reason 118659a2899SJakub Kicinski psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 119659a2899SJakub Kicinski { 120*6b46ca26SJakub Kicinski return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 121*6b46ca26SJakub Kicinski } 122*6b46ca26SJakub Kicinski 123*6b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk) 124*6b46ca26SJakub Kicinski { 125*6b46ca26SJakub Kicinski struct inet_timewait_sock *tw; 126*6b46ca26SJakub Kicinski struct psp_assoc *pas; 127*6b46ca26SJakub Kicinski int state; 128*6b46ca26SJakub Kicinski 129*6b46ca26SJakub Kicinski state = 1 << READ_ONCE(sk->sk_state); 130*6b46ca26SJakub Kicinski if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV) 131*6b46ca26SJakub Kicinski return NULL; 132*6b46ca26SJakub Kicinski 133*6b46ca26SJakub Kicinski tw = inet_twsk(sk); 134*6b46ca26SJakub Kicinski pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) : 135*6b46ca26SJakub Kicinski rcu_dereference(sk->psp_assoc); 136*6b46ca26SJakub Kicinski return pas; 137659a2899SJakub Kicinski } 138659a2899SJakub Kicinski 139659a2899SJakub Kicinski static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 140659a2899SJakub Kicinski { 141*6b46ca26SJakub Kicinski if (!skb->decrypted || !skb->sk) 142659a2899SJakub Kicinski return NULL; 143*6b46ca26SJakub Kicinski 144*6b46ca26SJakub Kicinski return psp_sk_get_assoc_rcu(skb->sk); 145659a2899SJakub Kicinski } 146659a2899SJakub Kicinski #else 147659a2899SJakub Kicinski static inline void psp_sk_assoc_free(struct sock *sk) { } 148659a2899SJakub Kicinski static inline void 149659a2899SJakub Kicinski psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 150659a2899SJakub Kicinski static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 151659a2899SJakub Kicinski static inline void 152659a2899SJakub Kicinski psp_reply_set_decrypted(struct sk_buff *skb) { } 153659a2899SJakub Kicinski 154*6b46ca26SJakub Kicinski static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 155*6b46ca26SJakub Kicinski { 156*6b46ca26SJakub Kicinski return NULL; 157*6b46ca26SJakub Kicinski } 158*6b46ca26SJakub Kicinski 159659a2899SJakub Kicinski static inline void 160659a2899SJakub Kicinski psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { } 161659a2899SJakub Kicinski 162659a2899SJakub Kicinski static inline unsigned long 163659a2899SJakub Kicinski __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 164659a2899SJakub Kicinski unsigned long diffs) 165659a2899SJakub Kicinski { 166659a2899SJakub Kicinski return diffs; 167659a2899SJakub Kicinski } 168659a2899SJakub Kicinski 169659a2899SJakub Kicinski static inline enum skb_drop_reason 170659a2899SJakub Kicinski psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 171659a2899SJakub Kicinski { 172659a2899SJakub Kicinski return 0; 173659a2899SJakub Kicinski } 174659a2899SJakub Kicinski 175659a2899SJakub Kicinski static inline enum skb_drop_reason 176659a2899SJakub Kicinski psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 177659a2899SJakub Kicinski { 178659a2899SJakub Kicinski return 0; 179659a2899SJakub Kicinski } 180659a2899SJakub Kicinski 181659a2899SJakub Kicinski static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 182659a2899SJakub Kicinski { 183659a2899SJakub Kicinski return NULL; 184659a2899SJakub Kicinski } 185659a2899SJakub Kicinski #endif 186659a2899SJakub Kicinski 187659a2899SJakub Kicinski static inline unsigned long 188659a2899SJakub Kicinski psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two) 189659a2899SJakub Kicinski { 190659a2899SJakub Kicinski return __psp_skb_coalesce_diff(one, two, 0); 191659a2899SJakub Kicinski } 192ed8a507bSJakub Kicinski 19300c94ca2SJakub Kicinski #endif /* __NET_PSP_HELPERS_H */ 194