xref: /linux/net/ipv4/tcp_ao.c (revision 03cb001ef87b3f8d859cf7f96329acf3d6235d29)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * INET		An implementation of the TCP Authentication Option (TCP-AO).
4  *		See RFC5925.
5  *
6  * Authors:	Dmitry Safonov <dima@arista.com>
7  *		Francesco Ruggeri <fruggeri@arista.com>
8  *		Salam Noureddine <noureddine@arista.com>
9  */
10 #define pr_fmt(fmt) "TCP: " fmt
11 
12 #include <crypto/hash.h>
13 #include <crypto/utils.h>
14 #include <linux/inetdevice.h>
15 #include <linux/tcp.h>
16 
17 #include <net/tcp.h>
18 #include <net/ipv6.h>
19 #include <net/icmp.h>
20 #include <trace/events/tcp.h>
21 
22 DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_ao_needed, HZ);
23 
24 int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
25 			    unsigned int len, struct tcp_sigpool *hp)
26 {
27 	struct scatterlist sg;
28 	int ret;
29 
30 	if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp->req),
31 				mkt->key, mkt->keylen))
32 		goto clear_hash;
33 
34 	ret = crypto_ahash_init(hp->req);
35 	if (ret)
36 		goto clear_hash;
37 
38 	sg_init_one(&sg, ctx, len);
39 	ahash_request_set_crypt(hp->req, &sg, key, len);
40 	crypto_ahash_update(hp->req);
41 
42 	ret = crypto_ahash_final(hp->req);
43 	if (ret)
44 		goto clear_hash;
45 
46 	return 0;
47 clear_hash:
48 	memset(key, 0, tcp_ao_digest_size(mkt));
49 	return 1;
50 }
51 
52 bool tcp_ao_ignore_icmp(const struct sock *sk, int family, int type, int code)
53 {
54 	bool ignore_icmp = false;
55 	struct tcp_ao_info *ao;
56 
57 	if (!static_branch_unlikely(&tcp_ao_needed.key))
58 		return false;
59 
60 	/* RFC5925, 7.8:
61 	 * >> A TCP-AO implementation MUST default to ignore incoming ICMPv4
62 	 * messages of Type 3 (destination unreachable), Codes 2-4 (protocol
63 	 * unreachable, port unreachable, and fragmentation needed -- ’hard
64 	 * errors’), and ICMPv6 Type 1 (destination unreachable), Code 1
65 	 * (administratively prohibited) and Code 4 (port unreachable) intended
66 	 * for connections in synchronized states (ESTABLISHED, FIN-WAIT-1, FIN-
67 	 * WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT) that match MKTs.
68 	 */
69 	if (family == AF_INET) {
70 		if (type != ICMP_DEST_UNREACH)
71 			return false;
72 		if (code < ICMP_PROT_UNREACH || code > ICMP_FRAG_NEEDED)
73 			return false;
74 	} else {
75 		if (type != ICMPV6_DEST_UNREACH)
76 			return false;
77 		if (code != ICMPV6_ADM_PROHIBITED && code != ICMPV6_PORT_UNREACH)
78 			return false;
79 	}
80 
81 	rcu_read_lock();
82 	switch (sk->sk_state) {
83 	case TCP_TIME_WAIT:
84 		ao = rcu_dereference(tcp_twsk(sk)->ao_info);
85 		break;
86 	case TCP_SYN_SENT:
87 	case TCP_SYN_RECV:
88 	case TCP_LISTEN:
89 	case TCP_NEW_SYN_RECV:
90 		/* RFC5925 specifies to ignore ICMPs *only* on connections
91 		 * in synchronized states.
92 		 */
93 		rcu_read_unlock();
94 		return false;
95 	default:
96 		ao = rcu_dereference(tcp_sk(sk)->ao_info);
97 	}
98 
99 	if (ao && !ao->accept_icmps) {
100 		ignore_icmp = true;
101 		__NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAODROPPEDICMPS);
102 		atomic64_inc(&ao->counters.dropped_icmp);
103 	}
104 	rcu_read_unlock();
105 
106 	return ignore_icmp;
107 }
108 
109 /* Optimized version of tcp_ao_do_lookup(): only for sockets for which
110  * it's known that the keys in ao_info are matching peer's
111  * family/address/VRF/etc.
112  */
113 struct tcp_ao_key *tcp_ao_established_key(const struct sock *sk,
114 					  struct tcp_ao_info *ao,
115 					  int sndid, int rcvid)
116 {
117 	struct tcp_ao_key *key;
118 
119 	hlist_for_each_entry_rcu(key, &ao->head, node,
120 				 sk_fullsock(sk) && lockdep_sock_is_held(sk)) {
121 		if ((sndid >= 0 && key->sndid != sndid) ||
122 		    (rcvid >= 0 && key->rcvid != rcvid))
123 			continue;
124 		return key;
125 	}
126 
127 	return NULL;
128 }
129 
130 static int ipv4_prefix_cmp(const struct in_addr *addr1,
131 			   const struct in_addr *addr2,
132 			   unsigned int prefixlen)
133 {
134 	__be32 mask = inet_make_mask(prefixlen);
135 	__be32 a1 = addr1->s_addr & mask;
136 	__be32 a2 = addr2->s_addr & mask;
137 
138 	if (a1 == a2)
139 		return 0;
140 	return memcmp(&a1, &a2, sizeof(a1));
141 }
142 
143 static int __tcp_ao_key_cmp(const struct tcp_ao_key *key, int l3index,
144 			    const union tcp_ao_addr *addr, u8 prefixlen,
145 			    int family, int sndid, int rcvid)
146 {
147 	if (sndid >= 0 && key->sndid != sndid)
148 		return (key->sndid > sndid) ? 1 : -1;
149 	if (rcvid >= 0 && key->rcvid != rcvid)
150 		return (key->rcvid > rcvid) ? 1 : -1;
151 	if (l3index >= 0 && (key->keyflags & TCP_AO_KEYF_IFINDEX)) {
152 		if (key->l3index != l3index)
153 			return (key->l3index > l3index) ? 1 : -1;
154 	}
155 
156 	if (family == AF_UNSPEC)
157 		return 0;
158 	if (key->family != family)
159 		return (key->family > family) ? 1 : -1;
160 
161 	if (family == AF_INET) {
162 		if (ntohl(key->addr.a4.s_addr) == INADDR_ANY)
163 			return 0;
164 		if (ntohl(addr->a4.s_addr) == INADDR_ANY)
165 			return 0;
166 		return ipv4_prefix_cmp(&key->addr.a4, &addr->a4, prefixlen);
167 #if IS_ENABLED(CONFIG_IPV6)
168 	} else {
169 		if (ipv6_addr_any(&key->addr.a6) || ipv6_addr_any(&addr->a6))
170 			return 0;
171 		if (ipv6_prefix_equal(&key->addr.a6, &addr->a6, prefixlen))
172 			return 0;
173 		return memcmp(&key->addr.a6, &addr->a6, sizeof(addr->a6));
174 #endif
175 	}
176 	return -1;
177 }
178 
179 static int tcp_ao_key_cmp(const struct tcp_ao_key *key, int l3index,
180 			  const union tcp_ao_addr *addr, u8 prefixlen,
181 			  int family, int sndid, int rcvid)
182 {
183 #if IS_ENABLED(CONFIG_IPV6)
184 	if (family == AF_INET6 && ipv6_addr_v4mapped(&addr->a6)) {
185 		__be32 addr4 = addr->a6.s6_addr32[3];
186 
187 		return __tcp_ao_key_cmp(key, l3index,
188 					(union tcp_ao_addr *)&addr4,
189 					prefixlen, AF_INET, sndid, rcvid);
190 	}
191 #endif
192 	return __tcp_ao_key_cmp(key, l3index, addr,
193 				prefixlen, family, sndid, rcvid);
194 }
195 
196 static struct tcp_ao_key *__tcp_ao_do_lookup(const struct sock *sk, int l3index,
197 		const union tcp_ao_addr *addr, int family, u8 prefix,
198 		int sndid, int rcvid)
199 {
200 	struct tcp_ao_key *key;
201 	struct tcp_ao_info *ao;
202 
203 	if (!static_branch_unlikely(&tcp_ao_needed.key))
204 		return NULL;
205 
206 	ao = rcu_dereference_check(tcp_sk(sk)->ao_info,
207 				   lockdep_sock_is_held(sk));
208 	if (!ao)
209 		return NULL;
210 
211 	hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk)) {
212 		u8 prefixlen = min(prefix, key->prefixlen);
213 
214 		if (!tcp_ao_key_cmp(key, l3index, addr, prefixlen,
215 				    family, sndid, rcvid))
216 			return key;
217 	}
218 	return NULL;
219 }
220 
221 struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk, int l3index,
222 				    const union tcp_ao_addr *addr,
223 				    int family, int sndid, int rcvid)
224 {
225 	return __tcp_ao_do_lookup(sk, l3index, addr, family, U8_MAX, sndid, rcvid);
226 }
227 
228 static struct tcp_ao_info *tcp_ao_alloc_info(gfp_t flags)
229 {
230 	struct tcp_ao_info *ao;
231 
232 	ao = kzalloc_obj(*ao, flags);
233 	if (!ao)
234 		return NULL;
235 	INIT_HLIST_HEAD(&ao->head);
236 	refcount_set(&ao->refcnt, 1);
237 
238 	return ao;
239 }
240 
241 static void tcp_ao_link_mkt(struct tcp_ao_info *ao, struct tcp_ao_key *mkt)
242 {
243 	hlist_add_head_rcu(&mkt->node, &ao->head);
244 }
245 
246 static struct tcp_ao_key *tcp_ao_copy_key(struct sock *sk,
247 					  struct tcp_ao_key *key)
248 {
249 	struct tcp_ao_key *new_key;
250 
251 	new_key = sock_kmalloc(sk, tcp_ao_sizeof_key(key),
252 			       GFP_ATOMIC);
253 	if (!new_key)
254 		return NULL;
255 
256 	*new_key = *key;
257 	INIT_HLIST_NODE(&new_key->node);
258 	tcp_sigpool_get(new_key->tcp_sigpool_id);
259 	atomic64_set(&new_key->pkt_good, 0);
260 	atomic64_set(&new_key->pkt_bad, 0);
261 
262 	return new_key;
263 }
264 
265 static void tcp_ao_key_free_rcu(struct rcu_head *head)
266 {
267 	struct tcp_ao_key *key = container_of(head, struct tcp_ao_key, rcu);
268 
269 	tcp_sigpool_release(key->tcp_sigpool_id);
270 	kfree_sensitive(key);
271 }
272 
273 static void tcp_ao_info_free(struct tcp_ao_info *ao)
274 {
275 	struct tcp_ao_key *key;
276 	struct hlist_node *n;
277 
278 	hlist_for_each_entry_safe(key, n, &ao->head, node) {
279 		hlist_del(&key->node);
280 		tcp_sigpool_release(key->tcp_sigpool_id);
281 		kfree_sensitive(key);
282 	}
283 	kfree(ao);
284 	static_branch_slow_dec_deferred(&tcp_ao_needed);
285 }
286 
287 static void tcp_ao_sk_omem_free(struct sock *sk, struct tcp_ao_info *ao)
288 {
289 	size_t total_ao_sk_mem = 0;
290 	struct tcp_ao_key *key;
291 
292 	hlist_for_each_entry(key,  &ao->head, node)
293 		total_ao_sk_mem += tcp_ao_sizeof_key(key);
294 	atomic_sub(total_ao_sk_mem, &sk->sk_omem_alloc);
295 }
296 
297 void tcp_ao_destroy_sock(struct sock *sk, bool twsk)
298 {
299 	struct tcp_ao_info *ao;
300 
301 	if (twsk) {
302 		ao = rcu_dereference_protected(tcp_twsk(sk)->ao_info, 1);
303 		rcu_assign_pointer(tcp_twsk(sk)->ao_info, NULL);
304 	} else {
305 		ao = rcu_dereference_protected(tcp_sk(sk)->ao_info, 1);
306 		rcu_assign_pointer(tcp_sk(sk)->ao_info, NULL);
307 	}
308 
309 	if (!ao || !refcount_dec_and_test(&ao->refcnt))
310 		return;
311 
312 	if (!twsk)
313 		tcp_ao_sk_omem_free(sk, ao);
314 	tcp_ao_info_free(ao);
315 }
316 
317 void tcp_ao_time_wait(struct tcp_timewait_sock *tcptw, struct tcp_sock *tp)
318 {
319 	struct tcp_ao_info *ao_info = rcu_dereference_protected(tp->ao_info, 1);
320 
321 	if (ao_info) {
322 		struct tcp_ao_key *key;
323 		struct hlist_node *n;
324 		int omem = 0;
325 
326 		hlist_for_each_entry_safe(key, n, &ao_info->head, node) {
327 			omem += tcp_ao_sizeof_key(key);
328 		}
329 
330 		refcount_inc(&ao_info->refcnt);
331 		atomic_sub(omem, &(((struct sock *)tp)->sk_omem_alloc));
332 		rcu_assign_pointer(tcptw->ao_info, ao_info);
333 	} else {
334 		tcptw->ao_info = NULL;
335 	}
336 }
337 
338 /* 4 tuple and ISNs are expected in NBO */
339 static int tcp_v4_ao_calc_key(struct tcp_ao_key *mkt, u8 *key,
340 			      __be32 saddr, __be32 daddr,
341 			      __be16 sport, __be16 dport,
342 			      __be32 sisn,  __be32 disn)
343 {
344 	/* See RFC5926 3.1.1 */
345 	struct kdf_input_block {
346 		u8                      counter;
347 		u8                      label[6];
348 		struct tcp4_ao_context	ctx;
349 		__be16                  outlen;
350 	} __packed * tmp;
351 	struct tcp_sigpool hp;
352 	int err;
353 
354 	err = tcp_sigpool_start(mkt->tcp_sigpool_id, &hp);
355 	if (err)
356 		return err;
357 
358 	tmp = hp.scratch;
359 	tmp->counter	= 1;
360 	memcpy(tmp->label, "TCP-AO", 6);
361 	tmp->ctx.saddr	= saddr;
362 	tmp->ctx.daddr	= daddr;
363 	tmp->ctx.sport	= sport;
364 	tmp->ctx.dport	= dport;
365 	tmp->ctx.sisn	= sisn;
366 	tmp->ctx.disn	= disn;
367 	tmp->outlen	= htons(tcp_ao_digest_size(mkt) * 8); /* in bits */
368 
369 	err = tcp_ao_calc_traffic_key(mkt, key, tmp, sizeof(*tmp), &hp);
370 	tcp_sigpool_end(&hp);
371 
372 	return err;
373 }
374 
375 int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
376 			  const struct sock *sk,
377 			  __be32 sisn, __be32 disn, bool send)
378 {
379 	if (send)
380 		return tcp_v4_ao_calc_key(mkt, key, sk->sk_rcv_saddr,
381 					  sk->sk_daddr, htons(sk->sk_num),
382 					  sk->sk_dport, sisn, disn);
383 	else
384 		return tcp_v4_ao_calc_key(mkt, key, sk->sk_daddr,
385 					  sk->sk_rcv_saddr, sk->sk_dport,
386 					  htons(sk->sk_num), disn, sisn);
387 }
388 
389 static int tcp_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
390 			      const struct sock *sk,
391 			      __be32 sisn, __be32 disn, bool send)
392 {
393 	if (mkt->family == AF_INET)
394 		return tcp_v4_ao_calc_key_sk(mkt, key, sk, sisn, disn, send);
395 #if IS_ENABLED(CONFIG_IPV6)
396 	else if (mkt->family == AF_INET6)
397 		return tcp_v6_ao_calc_key_sk(mkt, key, sk, sisn, disn, send);
398 #endif
399 	else
400 		return -EOPNOTSUPP;
401 }
402 
403 int tcp_v4_ao_calc_key_rsk(struct tcp_ao_key *mkt, u8 *key,
404 			   struct request_sock *req)
405 {
406 	struct inet_request_sock *ireq = inet_rsk(req);
407 
408 	return tcp_v4_ao_calc_key(mkt, key,
409 				  ireq->ir_loc_addr, ireq->ir_rmt_addr,
410 				  htons(ireq->ir_num), ireq->ir_rmt_port,
411 				  htonl(tcp_rsk(req)->snt_isn),
412 				  htonl(tcp_rsk(req)->rcv_isn));
413 }
414 
415 static int tcp_v4_ao_calc_key_skb(struct tcp_ao_key *mkt, u8 *key,
416 				  const struct sk_buff *skb,
417 				  __be32 sisn, __be32 disn)
418 {
419 	const struct iphdr *iph = ip_hdr(skb);
420 	const struct tcphdr *th = tcp_hdr(skb);
421 
422 	return tcp_v4_ao_calc_key(mkt, key, iph->saddr, iph->daddr,
423 				  th->source, th->dest, sisn, disn);
424 }
425 
426 static int tcp_ao_calc_key_skb(struct tcp_ao_key *mkt, u8 *key,
427 			       const struct sk_buff *skb,
428 			       __be32 sisn, __be32 disn, int family)
429 {
430 	if (family == AF_INET)
431 		return tcp_v4_ao_calc_key_skb(mkt, key, skb, sisn, disn);
432 #if IS_ENABLED(CONFIG_IPV6)
433 	else if (family == AF_INET6)
434 		return tcp_v6_ao_calc_key_skb(mkt, key, skb, sisn, disn);
435 #endif
436 	return -EAFNOSUPPORT;
437 }
438 
439 static int tcp_v4_ao_hash_pseudoheader(struct tcp_sigpool *hp,
440 				       __be32 daddr, __be32 saddr,
441 				       int nbytes)
442 {
443 	struct tcp4_pseudohdr *bp;
444 	struct scatterlist sg;
445 
446 	bp = hp->scratch;
447 	bp->saddr = saddr;
448 	bp->daddr = daddr;
449 	bp->pad = 0;
450 	bp->protocol = IPPROTO_TCP;
451 	bp->len = cpu_to_be16(nbytes);
452 
453 	sg_init_one(&sg, bp, sizeof(*bp));
454 	ahash_request_set_crypt(hp->req, &sg, NULL, sizeof(*bp));
455 	return crypto_ahash_update(hp->req);
456 }
457 
458 static int tcp_ao_hash_pseudoheader(unsigned short int family,
459 				    const struct sock *sk,
460 				    const struct sk_buff *skb,
461 				    struct tcp_sigpool *hp, int nbytes)
462 {
463 	const struct tcphdr *th = tcp_hdr(skb);
464 
465 	/* TODO: Can we rely on checksum being zero to mean outbound pkt? */
466 	if (!th->check) {
467 		if (family == AF_INET)
468 			return tcp_v4_ao_hash_pseudoheader(hp, sk->sk_daddr,
469 					sk->sk_rcv_saddr, skb->len);
470 #if IS_ENABLED(CONFIG_IPV6)
471 		else if (family == AF_INET6)
472 			return tcp_v6_ao_hash_pseudoheader(hp, &sk->sk_v6_daddr,
473 					&sk->sk_v6_rcv_saddr, skb->len);
474 #endif
475 		else
476 			return -EAFNOSUPPORT;
477 	}
478 
479 	if (family == AF_INET) {
480 		const struct iphdr *iph = ip_hdr(skb);
481 
482 		return tcp_v4_ao_hash_pseudoheader(hp, iph->daddr,
483 				iph->saddr, skb->len);
484 #if IS_ENABLED(CONFIG_IPV6)
485 	} else if (family == AF_INET6) {
486 		const struct ipv6hdr *iph = ipv6_hdr(skb);
487 
488 		return tcp_v6_ao_hash_pseudoheader(hp, &iph->daddr,
489 				&iph->saddr, skb->len);
490 #endif
491 	}
492 	return -EAFNOSUPPORT;
493 }
494 
495 u32 tcp_ao_compute_sne(u32 next_sne, u32 next_seq, u32 seq)
496 {
497 	u32 sne = next_sne;
498 
499 	if (before(seq, next_seq)) {
500 		if (seq > next_seq)
501 			sne--;
502 	} else {
503 		if (seq < next_seq)
504 			sne++;
505 	}
506 
507 	return sne;
508 }
509 
510 /* tcp_ao_hash_sne(struct tcp_sigpool *hp)
511  * @hp	- used for hashing
512  * @sne - sne value
513  */
514 static int tcp_ao_hash_sne(struct tcp_sigpool *hp, u32 sne)
515 {
516 	struct scatterlist sg;
517 	__be32 *bp;
518 
519 	bp = (__be32 *)hp->scratch;
520 	*bp = htonl(sne);
521 
522 	sg_init_one(&sg, bp, sizeof(*bp));
523 	ahash_request_set_crypt(hp->req, &sg, NULL, sizeof(*bp));
524 	return crypto_ahash_update(hp->req);
525 }
526 
527 static int tcp_ao_hash_header(struct tcp_sigpool *hp,
528 			      const struct tcphdr *th,
529 			      bool exclude_options, u8 *hash,
530 			      int hash_offset, int hash_len)
531 {
532 	struct scatterlist sg;
533 	u8 *hdr = hp->scratch;
534 	int err, len;
535 
536 	/* We are not allowed to change tcphdr, make a local copy */
537 	if (exclude_options) {
538 		len = sizeof(*th) + sizeof(struct tcp_ao_hdr) + hash_len;
539 		memcpy(hdr, th, sizeof(*th));
540 		memcpy(hdr + sizeof(*th),
541 		       (u8 *)th + hash_offset - sizeof(struct tcp_ao_hdr),
542 		       sizeof(struct tcp_ao_hdr));
543 		memset(hdr + sizeof(*th) + sizeof(struct tcp_ao_hdr),
544 		       0, hash_len);
545 		((struct tcphdr *)hdr)->check = 0;
546 	} else {
547 		len = th->doff << 2;
548 		memcpy(hdr, th, len);
549 		/* zero out tcp-ao hash */
550 		((struct tcphdr *)hdr)->check = 0;
551 		memset(hdr + hash_offset, 0, hash_len);
552 	}
553 
554 	sg_init_one(&sg, hdr, len);
555 	ahash_request_set_crypt(hp->req, &sg, NULL, len);
556 	err = crypto_ahash_update(hp->req);
557 	WARN_ON_ONCE(err != 0);
558 	return err;
559 }
560 
561 int tcp_ao_hash_hdr(unsigned short int family, char *ao_hash,
562 		    struct tcp_ao_key *key, const u8 *tkey,
563 		    const union tcp_ao_addr *daddr,
564 		    const union tcp_ao_addr *saddr,
565 		    const struct tcphdr *th, u32 sne)
566 {
567 	int tkey_len = tcp_ao_digest_size(key);
568 	int hash_offset = ao_hash - (char *)th;
569 	struct tcp_sigpool hp;
570 	void *hash_buf = NULL;
571 
572 	hash_buf = kmalloc(tkey_len, GFP_ATOMIC);
573 	if (!hash_buf)
574 		goto clear_hash_noput;
575 
576 	if (tcp_sigpool_start(key->tcp_sigpool_id, &hp))
577 		goto clear_hash_noput;
578 
579 	if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp.req), tkey, tkey_len))
580 		goto clear_hash;
581 
582 	if (crypto_ahash_init(hp.req))
583 		goto clear_hash;
584 
585 	if (tcp_ao_hash_sne(&hp, sne))
586 		goto clear_hash;
587 	if (family == AF_INET) {
588 		if (tcp_v4_ao_hash_pseudoheader(&hp, daddr->a4.s_addr,
589 						saddr->a4.s_addr, th->doff * 4))
590 			goto clear_hash;
591 #if IS_ENABLED(CONFIG_IPV6)
592 	} else if (family == AF_INET6) {
593 		if (tcp_v6_ao_hash_pseudoheader(&hp, &daddr->a6,
594 						&saddr->a6, th->doff * 4))
595 			goto clear_hash;
596 #endif
597 	} else {
598 		WARN_ON_ONCE(1);
599 		goto clear_hash;
600 	}
601 	if (tcp_ao_hash_header(&hp, th,
602 			       !!(key->keyflags & TCP_AO_KEYF_EXCLUDE_OPT),
603 			       ao_hash, hash_offset, tcp_ao_maclen(key)))
604 		goto clear_hash;
605 	ahash_request_set_crypt(hp.req, NULL, hash_buf, 0);
606 	if (crypto_ahash_final(hp.req))
607 		goto clear_hash;
608 
609 	memcpy(ao_hash, hash_buf, tcp_ao_maclen(key));
610 	tcp_sigpool_end(&hp);
611 	kfree(hash_buf);
612 	return 0;
613 
614 clear_hash:
615 	tcp_sigpool_end(&hp);
616 clear_hash_noput:
617 	memset(ao_hash, 0, tcp_ao_maclen(key));
618 	kfree(hash_buf);
619 	return 1;
620 }
621 
622 int tcp_ao_hash_skb(unsigned short int family,
623 		    char *ao_hash, struct tcp_ao_key *key,
624 		    const struct sock *sk, const struct sk_buff *skb,
625 		    const u8 *tkey, int hash_offset, u32 sne)
626 {
627 	const struct tcphdr *th = tcp_hdr(skb);
628 	int tkey_len = tcp_ao_digest_size(key);
629 	struct tcp_sigpool hp;
630 	void *hash_buf = NULL;
631 
632 	hash_buf = kmalloc(tkey_len, GFP_ATOMIC);
633 	if (!hash_buf)
634 		goto clear_hash_noput;
635 
636 	if (tcp_sigpool_start(key->tcp_sigpool_id, &hp))
637 		goto clear_hash_noput;
638 
639 	if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp.req), tkey, tkey_len))
640 		goto clear_hash;
641 
642 	/* For now use sha1 by default. Depends on alg in tcp_ao_key */
643 	if (crypto_ahash_init(hp.req))
644 		goto clear_hash;
645 
646 	if (tcp_ao_hash_sne(&hp, sne))
647 		goto clear_hash;
648 	if (tcp_ao_hash_pseudoheader(family, sk, skb, &hp, skb->len))
649 		goto clear_hash;
650 	if (tcp_ao_hash_header(&hp, th,
651 			       !!(key->keyflags & TCP_AO_KEYF_EXCLUDE_OPT),
652 			       ao_hash, hash_offset, tcp_ao_maclen(key)))
653 		goto clear_hash;
654 	if (tcp_sigpool_hash_skb_data(&hp, skb, th->doff << 2))
655 		goto clear_hash;
656 	ahash_request_set_crypt(hp.req, NULL, hash_buf, 0);
657 	if (crypto_ahash_final(hp.req))
658 		goto clear_hash;
659 
660 	memcpy(ao_hash, hash_buf, tcp_ao_maclen(key));
661 	tcp_sigpool_end(&hp);
662 	kfree(hash_buf);
663 	return 0;
664 
665 clear_hash:
666 	tcp_sigpool_end(&hp);
667 clear_hash_noput:
668 	memset(ao_hash, 0, tcp_ao_maclen(key));
669 	kfree(hash_buf);
670 	return 1;
671 }
672 
673 int tcp_v4_ao_hash_skb(char *ao_hash, struct tcp_ao_key *key,
674 		       const struct sock *sk, const struct sk_buff *skb,
675 		       const u8 *tkey, int hash_offset, u32 sne)
676 {
677 	return tcp_ao_hash_skb(AF_INET, ao_hash, key, sk, skb,
678 			       tkey, hash_offset, sne);
679 }
680 
681 int tcp_v4_ao_synack_hash(char *ao_hash, struct tcp_ao_key *ao_key,
682 			  struct request_sock *req, const struct sk_buff *skb,
683 			  int hash_offset, u32 sne)
684 {
685 	void *hash_buf = NULL;
686 	int err;
687 
688 	hash_buf = kmalloc(tcp_ao_digest_size(ao_key), GFP_ATOMIC);
689 	if (!hash_buf)
690 		return -ENOMEM;
691 
692 	err = tcp_v4_ao_calc_key_rsk(ao_key, hash_buf, req);
693 	if (err)
694 		goto out;
695 
696 	err = tcp_ao_hash_skb(AF_INET, ao_hash, ao_key, req_to_sk(req), skb,
697 			      hash_buf, hash_offset, sne);
698 out:
699 	kfree(hash_buf);
700 	return err;
701 }
702 
703 struct tcp_ao_key *tcp_v4_ao_lookup_rsk(const struct sock *sk,
704 					struct request_sock *req,
705 					int sndid, int rcvid)
706 {
707 	struct inet_request_sock *ireq = inet_rsk(req);
708 	union tcp_ao_addr *addr = (union tcp_ao_addr *)&ireq->ir_rmt_addr;
709 	int l3index;
710 
711 	l3index = l3mdev_master_ifindex_by_index(sock_net(sk), ireq->ir_iif);
712 	return tcp_ao_do_lookup(sk, l3index, addr, AF_INET, sndid, rcvid);
713 }
714 
715 struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
716 				    int sndid, int rcvid)
717 {
718 	int l3index = l3mdev_master_ifindex_by_index(sock_net(sk),
719 						     addr_sk->sk_bound_dev_if);
720 	union tcp_ao_addr *addr = (union tcp_ao_addr *)&addr_sk->sk_daddr;
721 
722 	return tcp_ao_do_lookup(sk, l3index, addr, AF_INET, sndid, rcvid);
723 }
724 
725 int tcp_ao_prepare_reset(const struct sock *sk, struct sk_buff *skb,
726 			 const struct tcp_ao_hdr *aoh, int l3index, u32 seq,
727 			 struct tcp_ao_key **key, char **traffic_key,
728 			 bool *allocated_traffic_key, u8 *keyid, u32 *sne)
729 {
730 	const struct tcphdr *th = tcp_hdr(skb);
731 	struct tcp_ao_info *ao_info;
732 
733 	*allocated_traffic_key = false;
734 	/* If there's no socket - than initial sisn/disn are unknown.
735 	 * Drop the segment. RFC5925 (7.7) advises to require graceful
736 	 * restart [RFC4724]. Alternatively, the RFC5925 advises to
737 	 * save/restore traffic keys before/after reboot.
738 	 * Linux TCP-AO support provides TCP_AO_ADD_KEY and TCP_AO_REPAIR
739 	 * options to restore a socket post-reboot.
740 	 */
741 	if (!sk)
742 		return -ENOTCONN;
743 
744 	if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_NEW_SYN_RECV)) {
745 		unsigned int family = READ_ONCE(sk->sk_family);
746 		union tcp_ao_addr *addr;
747 		__be32 disn, sisn;
748 
749 		if (sk->sk_state == TCP_NEW_SYN_RECV) {
750 			struct request_sock *req = inet_reqsk(sk);
751 
752 			sisn = htonl(tcp_rsk(req)->rcv_isn);
753 			disn = htonl(tcp_rsk(req)->snt_isn);
754 			*sne = tcp_ao_compute_sne(0, tcp_rsk(req)->snt_isn, seq);
755 		} else {
756 			sisn = th->seq;
757 			disn = 0;
758 		}
759 		if (IS_ENABLED(CONFIG_IPV6) && family == AF_INET6)
760 			addr = (union tcp_md5_addr *)&ipv6_hdr(skb)->saddr;
761 		else
762 			addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr;
763 #if IS_ENABLED(CONFIG_IPV6)
764 		if (family == AF_INET6 && ipv6_addr_v4mapped(&sk->sk_v6_daddr))
765 			family = AF_INET;
766 #endif
767 
768 		sk = sk_const_to_full_sk(sk);
769 		ao_info = rcu_dereference(tcp_sk(sk)->ao_info);
770 		if (!ao_info)
771 			return -ENOENT;
772 		*key = tcp_ao_do_lookup(sk, l3index, addr, family,
773 					-1, aoh->rnext_keyid);
774 		if (!*key)
775 			return -ENOENT;
776 		*traffic_key = kmalloc(tcp_ao_digest_size(*key), GFP_ATOMIC);
777 		if (!*traffic_key)
778 			return -ENOMEM;
779 		*allocated_traffic_key = true;
780 		if (tcp_ao_calc_key_skb(*key, *traffic_key, skb,
781 					sisn, disn, family))
782 			return -1;
783 		*keyid = (*key)->rcvid;
784 	} else {
785 		struct tcp_ao_key *rnext_key;
786 		u32 snd_basis;
787 
788 		if (sk->sk_state == TCP_TIME_WAIT) {
789 			ao_info = rcu_dereference(tcp_twsk(sk)->ao_info);
790 			snd_basis = tcp_twsk(sk)->tw_snd_nxt;
791 		} else {
792 			ao_info = rcu_dereference(tcp_sk(sk)->ao_info);
793 			snd_basis = tcp_sk(sk)->snd_una;
794 		}
795 		if (!ao_info)
796 			return -ENOENT;
797 
798 		*key = tcp_ao_established_key(sk, ao_info, aoh->rnext_keyid, -1);
799 		if (!*key)
800 			return -ENOENT;
801 		*traffic_key = snd_other_key(*key);
802 		rnext_key = READ_ONCE(ao_info->rnext_key);
803 		*keyid = rnext_key->rcvid;
804 		*sne = tcp_ao_compute_sne(READ_ONCE(ao_info->snd_sne),
805 					  snd_basis, seq);
806 	}
807 	return 0;
808 }
809 
810 int tcp_ao_transmit_skb(struct sock *sk, struct sk_buff *skb,
811 			struct tcp_ao_key *key, struct tcphdr *th,
812 			__u8 *hash_location)
813 {
814 	struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
815 	struct tcp_sock *tp = tcp_sk(sk);
816 	struct tcp_ao_info *ao;
817 	void *tkey_buf = NULL;
818 	u8 *traffic_key;
819 	u32 sne;
820 
821 	ao = rcu_dereference_protected(tcp_sk(sk)->ao_info,
822 				       lockdep_sock_is_held(sk));
823 	traffic_key = snd_other_key(key);
824 	if (unlikely(tcb->tcp_flags & TCPHDR_SYN)) {
825 		__be32 disn;
826 
827 		if (!(tcb->tcp_flags & TCPHDR_ACK)) {
828 			disn = 0;
829 			tkey_buf = kmalloc(tcp_ao_digest_size(key), GFP_ATOMIC);
830 			if (!tkey_buf)
831 				return -ENOMEM;
832 			traffic_key = tkey_buf;
833 		} else {
834 			disn = ao->risn;
835 		}
836 		tp->af_specific->ao_calc_key_sk(key, traffic_key,
837 						sk, ao->lisn, disn, true);
838 	}
839 	sne = tcp_ao_compute_sne(READ_ONCE(ao->snd_sne), READ_ONCE(tp->snd_una),
840 				 ntohl(th->seq));
841 	tp->af_specific->calc_ao_hash(hash_location, key, sk, skb, traffic_key,
842 				      hash_location - (u8 *)th, sne);
843 	kfree(tkey_buf);
844 	return 0;
845 }
846 
847 static struct tcp_ao_key *tcp_ao_inbound_lookup(unsigned short int family,
848 		const struct sock *sk, const struct sk_buff *skb,
849 		int sndid, int rcvid, int l3index)
850 {
851 	if (family == AF_INET) {
852 		const struct iphdr *iph = ip_hdr(skb);
853 
854 		return tcp_ao_do_lookup(sk, l3index,
855 					(union tcp_ao_addr *)&iph->saddr,
856 					AF_INET, sndid, rcvid);
857 	} else {
858 		const struct ipv6hdr *iph = ipv6_hdr(skb);
859 
860 		return tcp_ao_do_lookup(sk, l3index,
861 					(union tcp_ao_addr *)&iph->saddr,
862 					AF_INET6, sndid, rcvid);
863 	}
864 }
865 
866 void tcp_ao_syncookie(struct sock *sk, const struct sk_buff *skb,
867 		      struct request_sock *req, unsigned short int family)
868 {
869 	struct tcp_request_sock *treq = tcp_rsk(req);
870 	const struct tcphdr *th = tcp_hdr(skb);
871 	const struct tcp_ao_hdr *aoh;
872 	struct tcp_ao_key *key;
873 	int l3index;
874 
875 	/* treq->af_specific is used to perform TCP_AO lookup
876 	 * in tcp_create_openreq_child().
877 	 */
878 #if IS_ENABLED(CONFIG_IPV6)
879 	if (family == AF_INET6)
880 		treq->af_specific = &tcp_request_sock_ipv6_ops;
881 	else
882 #endif
883 		treq->af_specific = &tcp_request_sock_ipv4_ops;
884 
885 	treq->used_tcp_ao = false;
886 
887 	if (tcp_parse_auth_options(th, NULL, &aoh) || !aoh)
888 		return;
889 
890 	l3index = l3mdev_master_ifindex_by_index(sock_net(sk), inet_rsk(req)->ir_iif);
891 	key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid, l3index);
892 	if (!key)
893 		/* Key not found, continue without TCP-AO */
894 		return;
895 
896 	treq->ao_rcv_next = aoh->keyid;
897 	treq->ao_keyid = aoh->rnext_keyid;
898 	treq->used_tcp_ao = true;
899 }
900 
901 static enum skb_drop_reason
902 tcp_ao_verify_hash(const struct sock *sk, const struct sk_buff *skb,
903 		   unsigned short int family, struct tcp_ao_info *info,
904 		   const struct tcp_ao_hdr *aoh, struct tcp_ao_key *key,
905 		   u8 *traffic_key, u8 *phash, u32 sne, int l3index)
906 {
907 	const struct tcphdr *th = tcp_hdr(skb);
908 	u8 maclen = tcp_ao_hdr_maclen(aoh);
909 	void *hash_buf = NULL;
910 
911 	if (maclen != tcp_ao_maclen(key)) {
912 		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOBAD);
913 		atomic64_inc(&info->counters.pkt_bad);
914 		atomic64_inc(&key->pkt_bad);
915 		trace_tcp_ao_wrong_maclen(sk, skb, aoh->keyid,
916 					  aoh->rnext_keyid, maclen);
917 		return SKB_DROP_REASON_TCP_AOFAILURE;
918 	}
919 
920 	hash_buf = kmalloc(tcp_ao_digest_size(key), GFP_ATOMIC);
921 	if (!hash_buf)
922 		return SKB_DROP_REASON_NOT_SPECIFIED;
923 
924 	/* XXX: make it per-AF callback? */
925 	tcp_ao_hash_skb(family, hash_buf, key, sk, skb, traffic_key,
926 			(phash - (u8 *)th), sne);
927 	if (crypto_memneq(phash, hash_buf, maclen)) {
928 		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOBAD);
929 		atomic64_inc(&info->counters.pkt_bad);
930 		atomic64_inc(&key->pkt_bad);
931 		trace_tcp_ao_mismatch(sk, skb, aoh->keyid,
932 				      aoh->rnext_keyid, maclen);
933 		kfree(hash_buf);
934 		return SKB_DROP_REASON_TCP_AOFAILURE;
935 	}
936 	NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOGOOD);
937 	atomic64_inc(&info->counters.pkt_good);
938 	atomic64_inc(&key->pkt_good);
939 	kfree(hash_buf);
940 	return SKB_NOT_DROPPED_YET;
941 }
942 
943 enum skb_drop_reason
944 tcp_inbound_ao_hash(struct sock *sk, const struct sk_buff *skb,
945 		    unsigned short int family, const struct request_sock *req,
946 		    int l3index, const struct tcp_ao_hdr *aoh)
947 {
948 	const struct tcphdr *th = tcp_hdr(skb);
949 	u8 maclen = tcp_ao_hdr_maclen(aoh);
950 	u8 *phash = (u8 *)(aoh + 1); /* hash goes just after the header */
951 	struct tcp_ao_info *info;
952 	enum skb_drop_reason ret;
953 	struct tcp_ao_key *key;
954 	__be32 sisn, disn;
955 	u8 *traffic_key;
956 	int state;
957 	u32 sne = 0;
958 
959 	info = rcu_dereference(tcp_sk(sk)->ao_info);
960 	if (!info) {
961 		NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOKEYNOTFOUND);
962 		trace_tcp_ao_key_not_found(sk, skb, aoh->keyid,
963 					   aoh->rnext_keyid, maclen);
964 		return SKB_DROP_REASON_TCP_AOUNEXPECTED;
965 	}
966 
967 	if (unlikely(th->syn)) {
968 		sisn = th->seq;
969 		disn = 0;
970 	}
971 
972 	state = READ_ONCE(sk->sk_state);
973 	/* Fast-path */
974 	if (likely((1 << state) & TCP_AO_ESTABLISHED)) {
975 		enum skb_drop_reason err;
976 		struct tcp_ao_key *current_key;
977 
978 		/* Check if this socket's rnext_key matches the keyid in the
979 		 * packet. If not we lookup the key based on the keyid
980 		 * matching the rcvid in the mkt.
981 		 */
982 		key = READ_ONCE(info->rnext_key);
983 		if (key->rcvid != aoh->keyid) {
984 			key = tcp_ao_established_key(sk, info, -1, aoh->keyid);
985 			if (!key)
986 				goto key_not_found;
987 		}
988 
989 		/* Delayed retransmitted SYN */
990 		if (unlikely(th->syn && !th->ack))
991 			goto verify_hash;
992 
993 		sne = tcp_ao_compute_sne(info->rcv_sne, tcp_sk(sk)->rcv_nxt,
994 					 ntohl(th->seq));
995 		/* Established socket, traffic key are cached */
996 		traffic_key = rcv_other_key(key);
997 		err = tcp_ao_verify_hash(sk, skb, family, info, aoh, key,
998 					 traffic_key, phash, sne, l3index);
999 		if (err)
1000 			return err;
1001 		current_key = READ_ONCE(info->current_key);
1002 		/* Key rotation: the peer asks us to use new key (RNext) */
1003 		if (unlikely(aoh->rnext_keyid != current_key->sndid)) {
1004 			trace_tcp_ao_rnext_request(sk, skb, current_key->sndid,
1005 						   aoh->rnext_keyid,
1006 						   tcp_ao_hdr_maclen(aoh));
1007 			/* If the key is not found we do nothing. */
1008 			key = tcp_ao_established_key(sk, info, aoh->rnext_keyid, -1);
1009 			if (key)
1010 				/* pairs with tcp_ao_del_cmd */
1011 				WRITE_ONCE(info->current_key, key);
1012 		}
1013 		return SKB_NOT_DROPPED_YET;
1014 	}
1015 
1016 	if (unlikely(state == TCP_CLOSE))
1017 		return SKB_DROP_REASON_TCP_CLOSE;
1018 
1019 	/* Lookup key based on peer address and keyid.
1020 	 * current_key and rnext_key must not be used on tcp listen
1021 	 * sockets as otherwise:
1022 	 * - request sockets would race on those key pointers
1023 	 * - tcp_ao_del_cmd() allows async key removal
1024 	 */
1025 	key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid, l3index);
1026 	if (!key)
1027 		goto key_not_found;
1028 
1029 	if (th->syn && !th->ack)
1030 		goto verify_hash;
1031 
1032 	if ((1 << state) & (TCPF_LISTEN | TCPF_NEW_SYN_RECV)) {
1033 		/* Make the initial syn the likely case here */
1034 		if (unlikely(req)) {
1035 			sne = tcp_ao_compute_sne(0, tcp_rsk(req)->rcv_isn,
1036 						 ntohl(th->seq));
1037 			sisn = htonl(tcp_rsk(req)->rcv_isn);
1038 			disn = htonl(tcp_rsk(req)->snt_isn);
1039 		} else if (unlikely(th->ack && !th->syn)) {
1040 			/* Possible syncookie packet */
1041 			sisn = htonl(ntohl(th->seq) - 1);
1042 			disn = htonl(ntohl(th->ack_seq) - 1);
1043 			sne = tcp_ao_compute_sne(0, ntohl(sisn),
1044 						 ntohl(th->seq));
1045 		} else if (unlikely(!th->syn)) {
1046 			/* no way to figure out initial sisn/disn - drop */
1047 			return SKB_DROP_REASON_TCP_FLAGS;
1048 		}
1049 	} else if ((1 << state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) {
1050 		disn = info->lisn;
1051 		if (th->syn || th->rst)
1052 			sisn = th->seq;
1053 		else
1054 			sisn = info->risn;
1055 	} else {
1056 		WARN_ONCE(1, "TCP-AO: Unexpected sk_state %d", state);
1057 		return SKB_DROP_REASON_TCP_AOFAILURE;
1058 	}
1059 verify_hash:
1060 	traffic_key = kmalloc(tcp_ao_digest_size(key), GFP_ATOMIC);
1061 	if (!traffic_key)
1062 		return SKB_DROP_REASON_NOT_SPECIFIED;
1063 	tcp_ao_calc_key_skb(key, traffic_key, skb, sisn, disn, family);
1064 	ret = tcp_ao_verify_hash(sk, skb, family, info, aoh, key,
1065 				 traffic_key, phash, sne, l3index);
1066 	kfree(traffic_key);
1067 	return ret;
1068 
1069 key_not_found:
1070 	NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOKEYNOTFOUND);
1071 	atomic64_inc(&info->counters.key_not_found);
1072 	trace_tcp_ao_key_not_found(sk, skb, aoh->keyid,
1073 				   aoh->rnext_keyid, maclen);
1074 	return SKB_DROP_REASON_TCP_AOKEYNOTFOUND;
1075 }
1076 
1077 static int tcp_ao_cache_traffic_keys(const struct sock *sk,
1078 				     struct tcp_ao_info *ao,
1079 				     struct tcp_ao_key *ao_key)
1080 {
1081 	u8 *traffic_key = snd_other_key(ao_key);
1082 	int ret;
1083 
1084 	ret = tcp_ao_calc_key_sk(ao_key, traffic_key, sk,
1085 				 ao->lisn, ao->risn, true);
1086 	if (ret)
1087 		return ret;
1088 
1089 	traffic_key = rcv_other_key(ao_key);
1090 	ret = tcp_ao_calc_key_sk(ao_key, traffic_key, sk,
1091 				 ao->lisn, ao->risn, false);
1092 	return ret;
1093 }
1094 
1095 void tcp_ao_connect_init(struct sock *sk)
1096 {
1097 	struct tcp_sock *tp = tcp_sk(sk);
1098 	struct tcp_ao_info *ao_info;
1099 	struct hlist_node *next;
1100 	union tcp_ao_addr *addr;
1101 	struct tcp_ao_key *key;
1102 	int family, l3index;
1103 
1104 	ao_info = rcu_dereference_protected(tp->ao_info,
1105 					    lockdep_sock_is_held(sk));
1106 	if (!ao_info)
1107 		return;
1108 
1109 	/* Remove all keys that don't match the peer */
1110 	family = sk->sk_family;
1111 	if (family == AF_INET)
1112 		addr = (union tcp_ao_addr *)&sk->sk_daddr;
1113 #if IS_ENABLED(CONFIG_IPV6)
1114 	else if (family == AF_INET6)
1115 		addr = (union tcp_ao_addr *)&sk->sk_v6_daddr;
1116 #endif
1117 	else
1118 		return;
1119 	l3index = l3mdev_master_ifindex_by_index(sock_net(sk),
1120 						 sk->sk_bound_dev_if);
1121 
1122 	hlist_for_each_entry_safe(key, next, &ao_info->head, node) {
1123 		if (!tcp_ao_key_cmp(key, l3index, addr, key->prefixlen, family, -1, -1))
1124 			continue;
1125 
1126 		if (key == ao_info->current_key)
1127 			ao_info->current_key = NULL;
1128 		if (key == ao_info->rnext_key)
1129 			ao_info->rnext_key = NULL;
1130 		hlist_del_rcu(&key->node);
1131 		atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc);
1132 		call_rcu(&key->rcu, tcp_ao_key_free_rcu);
1133 	}
1134 
1135 	key = tp->af_specific->ao_lookup(sk, sk, -1, -1);
1136 	if (key) {
1137 		/* if current_key or rnext_key were not provided,
1138 		 * use the first key matching the peer
1139 		 */
1140 		if (!ao_info->current_key)
1141 			ao_info->current_key = key;
1142 		if (!ao_info->rnext_key)
1143 			ao_info->rnext_key = key;
1144 		tp->tcp_header_len += tcp_ao_len_aligned(key);
1145 
1146 		ao_info->lisn = htonl(tp->write_seq);
1147 		ao_info->snd_sne = 0;
1148 	} else {
1149 		/* Can't happen: tcp_connect() verifies that there's
1150 		 * at least one tcp-ao key that matches the remote peer.
1151 		 */
1152 		WARN_ON_ONCE(1);
1153 		rcu_assign_pointer(tp->ao_info, NULL);
1154 		kfree(ao_info);
1155 	}
1156 }
1157 
1158 void tcp_ao_established(struct sock *sk)
1159 {
1160 	struct tcp_ao_info *ao;
1161 	struct tcp_ao_key *key;
1162 
1163 	ao = rcu_dereference_protected(tcp_sk(sk)->ao_info,
1164 				       lockdep_sock_is_held(sk));
1165 	if (!ao)
1166 		return;
1167 
1168 	hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk))
1169 		tcp_ao_cache_traffic_keys(sk, ao, key);
1170 }
1171 
1172 void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb)
1173 {
1174 	struct tcp_ao_info *ao;
1175 	struct tcp_ao_key *key;
1176 
1177 	ao = rcu_dereference_protected(tcp_sk(sk)->ao_info,
1178 				       lockdep_sock_is_held(sk));
1179 	if (!ao)
1180 		return;
1181 
1182 	/* sk with TCP_REPAIR_ON does not have skb in tcp_finish_connect */
1183 	if (skb)
1184 		WRITE_ONCE(ao->risn, tcp_hdr(skb)->seq);
1185 	ao->rcv_sne = 0;
1186 
1187 	hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk))
1188 		tcp_ao_cache_traffic_keys(sk, ao, key);
1189 }
1190 
1191 int tcp_ao_copy_all_matching(const struct sock *sk, struct sock *newsk,
1192 			     struct request_sock *req, struct sk_buff *skb,
1193 			     int family)
1194 {
1195 	struct tcp_ao_key *key, *new_key, *first_key;
1196 	struct tcp_ao_info *new_ao, *ao;
1197 	struct hlist_node *key_head;
1198 	int l3index, ret = -ENOMEM;
1199 	union tcp_ao_addr *addr;
1200 	bool match = false;
1201 
1202 	ao = rcu_dereference(tcp_sk(sk)->ao_info);
1203 	if (!ao)
1204 		return 0;
1205 
1206 	/* New socket without TCP-AO on it */
1207 	if (!tcp_rsk_used_ao(req))
1208 		return 0;
1209 
1210 	new_ao = tcp_ao_alloc_info(GFP_ATOMIC);
1211 	if (!new_ao)
1212 		return -ENOMEM;
1213 	new_ao->lisn = htonl(tcp_rsk(req)->snt_isn);
1214 	new_ao->risn = htonl(tcp_rsk(req)->rcv_isn);
1215 	new_ao->ao_required = ao->ao_required;
1216 	new_ao->accept_icmps = ao->accept_icmps;
1217 
1218 	if (family == AF_INET) {
1219 		addr = (union tcp_ao_addr *)&newsk->sk_daddr;
1220 #if IS_ENABLED(CONFIG_IPV6)
1221 	} else if (family == AF_INET6) {
1222 		addr = (union tcp_ao_addr *)&newsk->sk_v6_daddr;
1223 #endif
1224 	} else {
1225 		ret = -EAFNOSUPPORT;
1226 		goto free_ao;
1227 	}
1228 	l3index = l3mdev_master_ifindex_by_index(sock_net(newsk),
1229 						 newsk->sk_bound_dev_if);
1230 
1231 	hlist_for_each_entry_rcu(key, &ao->head, node) {
1232 		if (tcp_ao_key_cmp(key, l3index, addr, key->prefixlen, family, -1, -1))
1233 			continue;
1234 
1235 		new_key = tcp_ao_copy_key(newsk, key);
1236 		if (!new_key)
1237 			goto free_and_exit;
1238 
1239 		tcp_ao_cache_traffic_keys(newsk, new_ao, new_key);
1240 		tcp_ao_link_mkt(new_ao, new_key);
1241 		match = true;
1242 	}
1243 
1244 	if (!match) {
1245 		/* RFC5925 (7.4.1) specifies that the TCP-AO status
1246 		 * of a connection is determined on the initial SYN.
1247 		 * At this point the connection was TCP-AO enabled, so
1248 		 * it can't switch to being unsigned if peer's key
1249 		 * disappears on the listening socket.
1250 		 */
1251 		ret = -EKEYREJECTED;
1252 		goto free_and_exit;
1253 	}
1254 
1255 	if (!static_key_fast_inc_not_disabled(&tcp_ao_needed.key.key)) {
1256 		ret = -EUSERS;
1257 		goto free_and_exit;
1258 	}
1259 
1260 	key_head = rcu_dereference(hlist_first_rcu(&new_ao->head));
1261 	first_key = hlist_entry_safe(key_head, struct tcp_ao_key, node);
1262 
1263 	key = tcp_ao_established_key(req_to_sk(req), new_ao, tcp_rsk(req)->ao_keyid, -1);
1264 	if (key)
1265 		new_ao->current_key = key;
1266 	else
1267 		new_ao->current_key = first_key;
1268 
1269 	/* set rnext_key */
1270 	key = tcp_ao_established_key(req_to_sk(req), new_ao, -1, tcp_rsk(req)->ao_rcv_next);
1271 	if (key)
1272 		new_ao->rnext_key = key;
1273 	else
1274 		new_ao->rnext_key = first_key;
1275 
1276 	sk_gso_disable(newsk);
1277 	rcu_assign_pointer(tcp_sk(newsk)->ao_info, new_ao);
1278 
1279 	return 0;
1280 
1281 free_and_exit:
1282 	hlist_for_each_entry_safe(key, key_head, &new_ao->head, node) {
1283 		hlist_del(&key->node);
1284 		tcp_sigpool_release(key->tcp_sigpool_id);
1285 		atomic_sub(tcp_ao_sizeof_key(key), &newsk->sk_omem_alloc);
1286 		kfree_sensitive(key);
1287 	}
1288 free_ao:
1289 	kfree(new_ao);
1290 	return ret;
1291 }
1292 
1293 static bool tcp_ao_can_set_current_rnext(struct sock *sk)
1294 {
1295 	/* There aren't current/rnext keys on TCP_LISTEN sockets */
1296 	if (sk->sk_state == TCP_LISTEN)
1297 		return false;
1298 	return true;
1299 }
1300 
1301 static int tcp_ao_verify_ipv4(struct sock *sk, struct tcp_ao_add *cmd,
1302 			      union tcp_ao_addr **addr)
1303 {
1304 	struct sockaddr_in *sin = (struct sockaddr_in *)&cmd->addr;
1305 	struct inet_sock *inet = inet_sk(sk);
1306 
1307 	if (sin->sin_family != AF_INET)
1308 		return -EINVAL;
1309 
1310 	/* Currently matching is not performed on port (or port ranges) */
1311 	if (sin->sin_port != 0)
1312 		return -EINVAL;
1313 
1314 	/* Check prefix and trailing 0's in addr */
1315 	if (cmd->prefix != 0) {
1316 		__be32 mask;
1317 
1318 		if (ntohl(sin->sin_addr.s_addr) == INADDR_ANY)
1319 			return -EINVAL;
1320 		if (cmd->prefix > 32)
1321 			return -EINVAL;
1322 
1323 		mask = inet_make_mask(cmd->prefix);
1324 		if (sin->sin_addr.s_addr & ~mask)
1325 			return -EINVAL;
1326 
1327 		/* Check that MKT address is consistent with socket */
1328 		if (ntohl(inet->inet_daddr) != INADDR_ANY &&
1329 		    (inet->inet_daddr & mask) != sin->sin_addr.s_addr)
1330 			return -EINVAL;
1331 	} else {
1332 		if (ntohl(sin->sin_addr.s_addr) != INADDR_ANY)
1333 			return -EINVAL;
1334 	}
1335 
1336 	*addr = (union tcp_ao_addr *)&sin->sin_addr;
1337 	return 0;
1338 }
1339 
1340 static int tcp_ao_parse_crypto(struct tcp_ao_add *cmd, struct tcp_ao_key *key)
1341 {
1342 	unsigned int syn_tcp_option_space;
1343 	bool is_kdf_aes_128_cmac = false;
1344 	struct crypto_ahash *tfm;
1345 	struct tcp_sigpool hp;
1346 	void *tmp_key = NULL;
1347 	int err;
1348 
1349 	/* RFC5926, 3.1.1.2. KDF_AES_128_CMAC */
1350 	if (!strcmp("cmac(aes128)", cmd->alg_name)) {
1351 		strscpy(cmd->alg_name, "cmac(aes)", sizeof(cmd->alg_name));
1352 		is_kdf_aes_128_cmac = (cmd->keylen != 16);
1353 		tmp_key = kmalloc(cmd->keylen, GFP_KERNEL);
1354 		if (!tmp_key)
1355 			return -ENOMEM;
1356 	}
1357 
1358 	key->maclen = cmd->maclen ?: 12; /* 12 is the default in RFC5925 */
1359 
1360 	/* Check: maclen + tcp-ao header <= (MAX_TCP_OPTION_SPACE - mss
1361 	 *					- tstamp (including sackperm)
1362 	 *					- wscale),
1363 	 * see tcp_syn_options(), tcp_synack_options(), commit 33ad798c924b.
1364 	 *
1365 	 * In order to allow D-SACK with TCP-AO, the header size should be:
1366 	 * (MAX_TCP_OPTION_SPACE - TCPOLEN_TSTAMP_ALIGNED
1367 	 *			- TCPOLEN_SACK_BASE_ALIGNED
1368 	 *			- 2 * TCPOLEN_SACK_PERBLOCK) = 8 (maclen = 4),
1369 	 * see tcp_established_options().
1370 	 *
1371 	 * RFC5925, 2.2:
1372 	 * Typical MACs are 96-128 bits (12-16 bytes), but any length
1373 	 * that fits in the header of the segment being authenticated
1374 	 * is allowed.
1375 	 *
1376 	 * RFC5925, 7.6:
1377 	 * TCP-AO continues to consume 16 bytes in non-SYN segments,
1378 	 * leaving a total of 24 bytes for other options, of which
1379 	 * the timestamp consumes 10.  This leaves 14 bytes, of which 10
1380 	 * are used for a single SACK block. When two SACK blocks are used,
1381 	 * such as to handle D-SACK, a smaller TCP-AO MAC would be required
1382 	 * to make room for the additional SACK block (i.e., to leave 18
1383 	 * bytes for the D-SACK variant of the SACK option) [RFC2883].
1384 	 * Note that D-SACK is not supportable in TCP MD5 in the presence
1385 	 * of timestamps, because TCP MD5’s MAC length is fixed and too
1386 	 * large to leave sufficient option space.
1387 	 */
1388 	syn_tcp_option_space = MAX_TCP_OPTION_SPACE;
1389 	syn_tcp_option_space -= TCPOLEN_MSS_ALIGNED;
1390 	syn_tcp_option_space -= TCPOLEN_TSTAMP_ALIGNED;
1391 	syn_tcp_option_space -= TCPOLEN_WSCALE_ALIGNED;
1392 	if (tcp_ao_len_aligned(key) > syn_tcp_option_space) {
1393 		err = -EMSGSIZE;
1394 		goto err_kfree;
1395 	}
1396 
1397 	key->keylen = cmd->keylen;
1398 	memcpy(key->key, cmd->key, cmd->keylen);
1399 
1400 	err = tcp_sigpool_start(key->tcp_sigpool_id, &hp);
1401 	if (err)
1402 		goto err_kfree;
1403 
1404 	tfm = crypto_ahash_reqtfm(hp.req);
1405 	if (is_kdf_aes_128_cmac) {
1406 		void *scratch = hp.scratch;
1407 		struct scatterlist sg;
1408 
1409 		memcpy(tmp_key, cmd->key, cmd->keylen);
1410 		sg_init_one(&sg, tmp_key, cmd->keylen);
1411 
1412 		/* Using zero-key of 16 bytes as described in RFC5926 */
1413 		memset(scratch, 0, 16);
1414 		err = crypto_ahash_setkey(tfm, scratch, 16);
1415 		if (err)
1416 			goto err_pool_end;
1417 
1418 		err = crypto_ahash_init(hp.req);
1419 		if (err)
1420 			goto err_pool_end;
1421 
1422 		ahash_request_set_crypt(hp.req, &sg, key->key, cmd->keylen);
1423 		err = crypto_ahash_update(hp.req);
1424 		if (err)
1425 			goto err_pool_end;
1426 
1427 		err |= crypto_ahash_final(hp.req);
1428 		if (err)
1429 			goto err_pool_end;
1430 		key->keylen = 16;
1431 	}
1432 
1433 	err = crypto_ahash_setkey(tfm, key->key, key->keylen);
1434 	if (err)
1435 		goto err_pool_end;
1436 
1437 	tcp_sigpool_end(&hp);
1438 	kfree_sensitive(tmp_key);
1439 
1440 	if (tcp_ao_maclen(key) > key->digest_size)
1441 		return -EINVAL;
1442 
1443 	return 0;
1444 
1445 err_pool_end:
1446 	tcp_sigpool_end(&hp);
1447 err_kfree:
1448 	kfree_sensitive(tmp_key);
1449 	return err;
1450 }
1451 
1452 #if IS_ENABLED(CONFIG_IPV6)
1453 static int tcp_ao_verify_ipv6(struct sock *sk, struct tcp_ao_add *cmd,
1454 			      union tcp_ao_addr **paddr,
1455 			      unsigned short int *family)
1456 {
1457 	struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd->addr;
1458 	struct in6_addr *addr = &sin6->sin6_addr;
1459 	u8 prefix = cmd->prefix;
1460 
1461 	if (sin6->sin6_family != AF_INET6)
1462 		return -EINVAL;
1463 
1464 	/* Currently matching is not performed on port (or port ranges) */
1465 	if (sin6->sin6_port != 0)
1466 		return -EINVAL;
1467 
1468 	/* Check prefix and trailing 0's in addr */
1469 	if (cmd->prefix != 0 && ipv6_addr_v4mapped(addr)) {
1470 		__be32 addr4 = addr->s6_addr32[3];
1471 		__be32 mask;
1472 
1473 		if (prefix > 32 || ntohl(addr4) == INADDR_ANY)
1474 			return -EINVAL;
1475 
1476 		mask = inet_make_mask(prefix);
1477 		if (addr4 & ~mask)
1478 			return -EINVAL;
1479 
1480 		/* Check that MKT address is consistent with socket */
1481 		if (!ipv6_addr_any(&sk->sk_v6_daddr)) {
1482 			__be32 daddr4 = sk->sk_v6_daddr.s6_addr32[3];
1483 
1484 			if (!ipv6_addr_v4mapped(&sk->sk_v6_daddr))
1485 				return -EINVAL;
1486 			if ((daddr4 & mask) != addr4)
1487 				return -EINVAL;
1488 		}
1489 
1490 		*paddr = (union tcp_ao_addr *)&addr->s6_addr32[3];
1491 		*family = AF_INET;
1492 		return 0;
1493 	} else if (cmd->prefix != 0) {
1494 		struct in6_addr pfx;
1495 
1496 		if (ipv6_addr_any(addr) || prefix > 128)
1497 			return -EINVAL;
1498 
1499 		ipv6_addr_prefix(&pfx, addr, prefix);
1500 		if (ipv6_addr_cmp(&pfx, addr))
1501 			return -EINVAL;
1502 
1503 		/* Check that MKT address is consistent with socket */
1504 		if (!ipv6_addr_any(&sk->sk_v6_daddr) &&
1505 		    !ipv6_prefix_equal(&sk->sk_v6_daddr, addr, prefix))
1506 
1507 			return -EINVAL;
1508 	} else {
1509 		if (!ipv6_addr_any(addr))
1510 			return -EINVAL;
1511 	}
1512 
1513 	*paddr = (union tcp_ao_addr *)addr;
1514 	return 0;
1515 }
1516 #else
1517 static int tcp_ao_verify_ipv6(struct sock *sk, struct tcp_ao_add *cmd,
1518 			      union tcp_ao_addr **paddr,
1519 			      unsigned short int *family)
1520 {
1521 	return -EOPNOTSUPP;
1522 }
1523 #endif
1524 
1525 static struct tcp_ao_info *setsockopt_ao_info(struct sock *sk)
1526 {
1527 	if (sk_fullsock(sk)) {
1528 		return rcu_dereference_protected(tcp_sk(sk)->ao_info,
1529 						 lockdep_sock_is_held(sk));
1530 	} else if (sk->sk_state == TCP_TIME_WAIT) {
1531 		return rcu_dereference_protected(tcp_twsk(sk)->ao_info,
1532 						 lockdep_sock_is_held(sk));
1533 	}
1534 	return ERR_PTR(-ESOCKTNOSUPPORT);
1535 }
1536 
1537 static struct tcp_ao_info *getsockopt_ao_info(struct sock *sk)
1538 {
1539 	if (sk_fullsock(sk))
1540 		return rcu_dereference(tcp_sk(sk)->ao_info);
1541 	else if (sk->sk_state == TCP_TIME_WAIT)
1542 		return rcu_dereference(tcp_twsk(sk)->ao_info);
1543 
1544 	return ERR_PTR(-ESOCKTNOSUPPORT);
1545 }
1546 
1547 #define TCP_AO_KEYF_ALL (TCP_AO_KEYF_IFINDEX | TCP_AO_KEYF_EXCLUDE_OPT)
1548 #define TCP_AO_GET_KEYF_VALID	(TCP_AO_KEYF_IFINDEX)
1549 
1550 static struct tcp_ao_key *tcp_ao_key_alloc(struct sock *sk,
1551 					   struct tcp_ao_add *cmd)
1552 {
1553 	const char *algo = cmd->alg_name;
1554 	unsigned int digest_size;
1555 	struct crypto_ahash *tfm;
1556 	struct tcp_ao_key *key;
1557 	struct tcp_sigpool hp;
1558 	int err, pool_id;
1559 	size_t size;
1560 
1561 	/* Force null-termination of alg_name */
1562 	cmd->alg_name[ARRAY_SIZE(cmd->alg_name) - 1] = '\0';
1563 
1564 	/* RFC5926, 3.1.1.2. KDF_AES_128_CMAC */
1565 	if (!strcmp("cmac(aes128)", algo))
1566 		algo = "cmac(aes)";
1567 
1568 	/* Full TCP header (th->doff << 2) should fit into scratch area,
1569 	 * see tcp_ao_hash_header().
1570 	 */
1571 	pool_id = tcp_sigpool_alloc_ahash(algo, 60);
1572 	if (pool_id < 0)
1573 		return ERR_PTR(pool_id);
1574 
1575 	err = tcp_sigpool_start(pool_id, &hp);
1576 	if (err)
1577 		goto err_free_pool;
1578 
1579 	tfm = crypto_ahash_reqtfm(hp.req);
1580 	digest_size = crypto_ahash_digestsize(tfm);
1581 	tcp_sigpool_end(&hp);
1582 
1583 	size = sizeof(struct tcp_ao_key) + (digest_size << 1);
1584 	key = sock_kmalloc(sk, size, GFP_KERNEL);
1585 	if (!key) {
1586 		err = -ENOMEM;
1587 		goto err_free_pool;
1588 	}
1589 
1590 	key->tcp_sigpool_id = pool_id;
1591 	key->digest_size = digest_size;
1592 	return key;
1593 
1594 err_free_pool:
1595 	tcp_sigpool_release(pool_id);
1596 	return ERR_PTR(err);
1597 }
1598 
1599 static int tcp_ao_add_cmd(struct sock *sk, unsigned short int family,
1600 			  sockptr_t optval, int optlen)
1601 {
1602 	struct tcp_ao_info *ao_info;
1603 	union tcp_ao_addr *addr;
1604 	struct tcp_ao_key *key;
1605 	struct tcp_ao_add cmd;
1606 	int ret, l3index = 0;
1607 	bool first = false;
1608 
1609 	if (optlen < sizeof(cmd))
1610 		return -EINVAL;
1611 
1612 	ret = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen);
1613 	if (ret)
1614 		return ret;
1615 
1616 	if (cmd.keylen > TCP_AO_MAXKEYLEN)
1617 		return -EINVAL;
1618 
1619 	if (cmd.reserved != 0 || cmd.reserved2 != 0)
1620 		return -EINVAL;
1621 
1622 	if (family == AF_INET)
1623 		ret = tcp_ao_verify_ipv4(sk, &cmd, &addr);
1624 	else
1625 		ret = tcp_ao_verify_ipv6(sk, &cmd, &addr, &family);
1626 	if (ret)
1627 		return ret;
1628 
1629 	if (cmd.keyflags & ~TCP_AO_KEYF_ALL)
1630 		return -EINVAL;
1631 
1632 	if (cmd.set_current || cmd.set_rnext) {
1633 		if (!tcp_ao_can_set_current_rnext(sk))
1634 			return -EINVAL;
1635 	}
1636 
1637 	if (cmd.ifindex && !(cmd.keyflags & TCP_AO_KEYF_IFINDEX))
1638 		return -EINVAL;
1639 
1640 	/* For cmd.tcp_ifindex = 0 the key will apply to the default VRF */
1641 	if (cmd.keyflags & TCP_AO_KEYF_IFINDEX && cmd.ifindex) {
1642 		int bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
1643 		struct net_device *dev;
1644 
1645 		rcu_read_lock();
1646 		dev = dev_get_by_index_rcu(sock_net(sk), cmd.ifindex);
1647 		if (dev && netif_is_l3_master(dev))
1648 			l3index = dev->ifindex;
1649 		rcu_read_unlock();
1650 
1651 		if (!dev || !l3index)
1652 			return -EINVAL;
1653 
1654 		if (!bound_dev_if || bound_dev_if != cmd.ifindex) {
1655 			/* tcp_ao_established_key() doesn't expect having
1656 			 * non peer-matching key on an established TCP-AO
1657 			 * connection.
1658 			 */
1659 			if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)))
1660 				return -EINVAL;
1661 		}
1662 
1663 		/* It's still possible to bind after adding keys or even
1664 		 * re-bind to a different dev (with CAP_NET_RAW).
1665 		 * So, no reason to return error here, rather try to be
1666 		 * nice and warn the user.
1667 		 */
1668 		if (bound_dev_if && bound_dev_if != cmd.ifindex)
1669 			net_warn_ratelimited("AO key ifindex %d != sk bound ifindex %d\n",
1670 					     cmd.ifindex, bound_dev_if);
1671 	}
1672 
1673 	/* Don't allow keys for peers that have a matching TCP-MD5 key */
1674 	if (cmd.keyflags & TCP_AO_KEYF_IFINDEX) {
1675 		/* Non-_exact version of tcp_md5_do_lookup() will
1676 		 * as well match keys that aren't bound to a specific VRF
1677 		 * (that will make them match AO key with
1678 		 * sysctl_tcp_l3dev_accept = 1
1679 		 */
1680 		if (tcp_md5_do_lookup(sk, l3index, addr, family))
1681 			return -EKEYREJECTED;
1682 	} else {
1683 		if (tcp_md5_do_lookup_any_l3index(sk, addr, family))
1684 			return -EKEYREJECTED;
1685 	}
1686 
1687 	ao_info = setsockopt_ao_info(sk);
1688 	if (IS_ERR(ao_info))
1689 		return PTR_ERR(ao_info);
1690 
1691 	if (!ao_info) {
1692 		ao_info = tcp_ao_alloc_info(GFP_KERNEL);
1693 		if (!ao_info)
1694 			return -ENOMEM;
1695 		first = true;
1696 	} else {
1697 		/* Check that neither RecvID nor SendID match any
1698 		 * existing key for the peer, RFC5925 3.1:
1699 		 * > The IDs of MKTs MUST NOT overlap where their
1700 		 * > TCP connection identifiers overlap.
1701 		 */
1702 		if (__tcp_ao_do_lookup(sk, l3index, addr, family, cmd.prefix, -1, cmd.rcvid))
1703 			return -EEXIST;
1704 		if (__tcp_ao_do_lookup(sk, l3index, addr, family,
1705 				       cmd.prefix, cmd.sndid, -1))
1706 			return -EEXIST;
1707 	}
1708 
1709 	key = tcp_ao_key_alloc(sk, &cmd);
1710 	if (IS_ERR(key)) {
1711 		ret = PTR_ERR(key);
1712 		goto err_free_ao;
1713 	}
1714 
1715 	INIT_HLIST_NODE(&key->node);
1716 	memcpy(&key->addr, addr, (family == AF_INET) ? sizeof(struct in_addr) :
1717 						       sizeof(struct in6_addr));
1718 	key->prefixlen	= cmd.prefix;
1719 	key->family	= family;
1720 	key->keyflags	= cmd.keyflags;
1721 	key->sndid	= cmd.sndid;
1722 	key->rcvid	= cmd.rcvid;
1723 	key->l3index	= l3index;
1724 	atomic64_set(&key->pkt_good, 0);
1725 	atomic64_set(&key->pkt_bad, 0);
1726 
1727 	ret = tcp_ao_parse_crypto(&cmd, key);
1728 	if (ret < 0)
1729 		goto err_free_sock;
1730 
1731 	if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE))) {
1732 		tcp_ao_cache_traffic_keys(sk, ao_info, key);
1733 		if (first) {
1734 			ao_info->current_key = key;
1735 			ao_info->rnext_key = key;
1736 		}
1737 	}
1738 
1739 	tcp_ao_link_mkt(ao_info, key);
1740 	if (first) {
1741 		if (!static_branch_inc(&tcp_ao_needed.key)) {
1742 			ret = -EUSERS;
1743 			goto err_free_sock;
1744 		}
1745 		sk_gso_disable(sk);
1746 		rcu_assign_pointer(tcp_sk(sk)->ao_info, ao_info);
1747 	}
1748 
1749 	if (cmd.set_current)
1750 		WRITE_ONCE(ao_info->current_key, key);
1751 	if (cmd.set_rnext)
1752 		WRITE_ONCE(ao_info->rnext_key, key);
1753 	return 0;
1754 
1755 err_free_sock:
1756 	atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc);
1757 	tcp_sigpool_release(key->tcp_sigpool_id);
1758 	kfree_sensitive(key);
1759 err_free_ao:
1760 	if (first)
1761 		kfree(ao_info);
1762 	return ret;
1763 }
1764 
1765 static int tcp_ao_delete_key(struct sock *sk, struct tcp_ao_info *ao_info,
1766 			     bool del_async, struct tcp_ao_key *key,
1767 			     struct tcp_ao_key *new_current,
1768 			     struct tcp_ao_key *new_rnext)
1769 {
1770 	int err;
1771 
1772 	hlist_del_rcu(&key->node);
1773 
1774 	/* Support for async delete on listening sockets: as they don't
1775 	 * need current_key/rnext_key maintaining, we don't need to check
1776 	 * them and we can just free all resources in RCU fashion.
1777 	 */
1778 	if (del_async) {
1779 		atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc);
1780 		call_rcu(&key->rcu, tcp_ao_key_free_rcu);
1781 		return 0;
1782 	}
1783 
1784 	/* At this moment another CPU could have looked this key up
1785 	 * while it was unlinked from the list. Wait for RCU grace period,
1786 	 * after which the key is off-list and can't be looked up again;
1787 	 * the rx path [just before RCU came] might have used it and set it
1788 	 * as current_key (very unlikely).
1789 	 * Free the key with next RCU grace period (in case it was
1790 	 * current_key before tcp_ao_current_rnext() might have
1791 	 * changed it in forced-delete).
1792 	 */
1793 	synchronize_rcu();
1794 	if (new_current)
1795 		WRITE_ONCE(ao_info->current_key, new_current);
1796 	if (new_rnext)
1797 		WRITE_ONCE(ao_info->rnext_key, new_rnext);
1798 
1799 	if (unlikely(READ_ONCE(ao_info->current_key) == key ||
1800 		     READ_ONCE(ao_info->rnext_key) == key)) {
1801 		err = -EBUSY;
1802 		goto add_key;
1803 	}
1804 
1805 	atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc);
1806 	call_rcu(&key->rcu, tcp_ao_key_free_rcu);
1807 
1808 	return 0;
1809 add_key:
1810 	hlist_add_head_rcu(&key->node, &ao_info->head);
1811 	return err;
1812 }
1813 
1814 #define TCP_AO_DEL_KEYF_ALL (TCP_AO_KEYF_IFINDEX)
1815 static int tcp_ao_del_cmd(struct sock *sk, unsigned short int family,
1816 			  sockptr_t optval, int optlen)
1817 {
1818 	struct tcp_ao_key *key, *new_current = NULL, *new_rnext = NULL;
1819 	int err, addr_len, l3index = 0;
1820 	struct tcp_ao_info *ao_info;
1821 	union tcp_ao_addr *addr;
1822 	struct tcp_ao_del cmd;
1823 	__u8 prefix;
1824 	u16 port;
1825 
1826 	if (optlen < sizeof(cmd))
1827 		return -EINVAL;
1828 
1829 	err = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen);
1830 	if (err)
1831 		return err;
1832 
1833 	if (cmd.reserved != 0 || cmd.reserved2 != 0)
1834 		return -EINVAL;
1835 
1836 	if (cmd.set_current || cmd.set_rnext) {
1837 		if (!tcp_ao_can_set_current_rnext(sk))
1838 			return -EINVAL;
1839 	}
1840 
1841 	if (cmd.keyflags & ~TCP_AO_DEL_KEYF_ALL)
1842 		return -EINVAL;
1843 
1844 	/* No sanity check for TCP_AO_KEYF_IFINDEX as if a VRF
1845 	 * was destroyed, there still should be a way to delete keys,
1846 	 * that were bound to that l3intf. So, fail late at lookup stage
1847 	 * if there is no key for that ifindex.
1848 	 */
1849 	if (cmd.ifindex && !(cmd.keyflags & TCP_AO_KEYF_IFINDEX))
1850 		return -EINVAL;
1851 
1852 	ao_info = setsockopt_ao_info(sk);
1853 	if (IS_ERR(ao_info))
1854 		return PTR_ERR(ao_info);
1855 	if (!ao_info)
1856 		return -ENOENT;
1857 
1858 	/* For sockets in TCP_CLOSED it's possible set keys that aren't
1859 	 * matching the future peer (address/VRF/etc),
1860 	 * tcp_ao_connect_init() will choose a correct matching MKT
1861 	 * if there's any.
1862 	 */
1863 	if (cmd.set_current) {
1864 		new_current = tcp_ao_established_key(sk, ao_info, cmd.current_key, -1);
1865 		if (!new_current)
1866 			return -ENOENT;
1867 	}
1868 	if (cmd.set_rnext) {
1869 		new_rnext = tcp_ao_established_key(sk, ao_info, -1, cmd.rnext);
1870 		if (!new_rnext)
1871 			return -ENOENT;
1872 	}
1873 	if (cmd.del_async && sk->sk_state != TCP_LISTEN)
1874 		return -EINVAL;
1875 
1876 	if (family == AF_INET) {
1877 		struct sockaddr_in *sin = (struct sockaddr_in *)&cmd.addr;
1878 
1879 		addr = (union tcp_ao_addr *)&sin->sin_addr;
1880 		addr_len = sizeof(struct in_addr);
1881 		port = ntohs(sin->sin_port);
1882 	} else {
1883 		struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd.addr;
1884 		struct in6_addr *addr6 = &sin6->sin6_addr;
1885 
1886 		if (ipv6_addr_v4mapped(addr6)) {
1887 			addr = (union tcp_ao_addr *)&addr6->s6_addr32[3];
1888 			addr_len = sizeof(struct in_addr);
1889 			family = AF_INET;
1890 		} else {
1891 			addr = (union tcp_ao_addr *)addr6;
1892 			addr_len = sizeof(struct in6_addr);
1893 		}
1894 		port = ntohs(sin6->sin6_port);
1895 	}
1896 	prefix = cmd.prefix;
1897 
1898 	/* Currently matching is not performed on port (or port ranges) */
1899 	if (port != 0)
1900 		return -EINVAL;
1901 
1902 	/* We could choose random present key here for current/rnext
1903 	 * but that's less predictable. Let's be strict and don't
1904 	 * allow removing a key that's in use. RFC5925 doesn't
1905 	 * specify how-to coordinate key removal, but says:
1906 	 * "It is presumed that an MKT affecting a particular
1907 	 * connection cannot be destroyed during an active connection"
1908 	 */
1909 	hlist_for_each_entry_rcu(key, &ao_info->head, node,
1910 				 lockdep_sock_is_held(sk)) {
1911 		if (cmd.sndid != key->sndid ||
1912 		    cmd.rcvid != key->rcvid)
1913 			continue;
1914 
1915 		if (family != key->family ||
1916 		    prefix != key->prefixlen ||
1917 		    memcmp(addr, &key->addr, addr_len))
1918 			continue;
1919 
1920 		if ((cmd.keyflags & TCP_AO_KEYF_IFINDEX) !=
1921 		    (key->keyflags & TCP_AO_KEYF_IFINDEX))
1922 			continue;
1923 
1924 		if (key->l3index != l3index)
1925 			continue;
1926 
1927 		if (key == new_current || key == new_rnext)
1928 			continue;
1929 
1930 		return tcp_ao_delete_key(sk, ao_info, cmd.del_async, key,
1931 					 new_current, new_rnext);
1932 	}
1933 	return -ENOENT;
1934 }
1935 
1936 /* cmd.ao_required makes a socket TCP-AO only.
1937  * Don't allow any md5 keys for any l3intf on the socket together with it.
1938  * Restricting it early in setsockopt() removes a check for
1939  * ao_info->ao_required on inbound tcp segment fast-path.
1940  */
1941 static int tcp_ao_required_verify(struct sock *sk)
1942 {
1943 #ifdef CONFIG_TCP_MD5SIG
1944 	const struct tcp_md5sig_info *md5sig;
1945 
1946 	if (!static_branch_unlikely(&tcp_md5_needed.key))
1947 		return 0;
1948 
1949 	md5sig = rcu_dereference_check(tcp_sk(sk)->md5sig_info,
1950 				       lockdep_sock_is_held(sk));
1951 	if (!md5sig)
1952 		return 0;
1953 
1954 	if (rcu_dereference_check(hlist_first_rcu(&md5sig->head),
1955 				  lockdep_sock_is_held(sk)))
1956 		return 1;
1957 #endif
1958 	return 0;
1959 }
1960 
1961 static int tcp_ao_info_cmd(struct sock *sk, unsigned short int family,
1962 			   sockptr_t optval, int optlen)
1963 {
1964 	struct tcp_ao_key *new_current = NULL, *new_rnext = NULL;
1965 	struct tcp_ao_info *ao_info;
1966 	struct tcp_ao_info_opt cmd;
1967 	bool first = false;
1968 	int err;
1969 
1970 	if (optlen < sizeof(cmd))
1971 		return -EINVAL;
1972 
1973 	err = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen);
1974 	if (err)
1975 		return err;
1976 
1977 	if (cmd.set_current || cmd.set_rnext) {
1978 		if (!tcp_ao_can_set_current_rnext(sk))
1979 			return -EINVAL;
1980 	}
1981 
1982 	if (cmd.reserved != 0 || cmd.reserved2 != 0)
1983 		return -EINVAL;
1984 
1985 	ao_info = setsockopt_ao_info(sk);
1986 	if (IS_ERR(ao_info))
1987 		return PTR_ERR(ao_info);
1988 	if (!ao_info) {
1989 		if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)))
1990 			return -EINVAL;
1991 		ao_info = tcp_ao_alloc_info(GFP_KERNEL);
1992 		if (!ao_info)
1993 			return -ENOMEM;
1994 		first = true;
1995 	}
1996 
1997 	if (cmd.ao_required && tcp_ao_required_verify(sk)) {
1998 		err = -EKEYREJECTED;
1999 		goto out;
2000 	}
2001 
2002 	/* For sockets in TCP_CLOSED it's possible set keys that aren't
2003 	 * matching the future peer (address/port/VRF/etc),
2004 	 * tcp_ao_connect_init() will choose a correct matching MKT
2005 	 * if there's any.
2006 	 */
2007 	if (cmd.set_current) {
2008 		new_current = tcp_ao_established_key(sk, ao_info, cmd.current_key, -1);
2009 		if (!new_current) {
2010 			err = -ENOENT;
2011 			goto out;
2012 		}
2013 	}
2014 	if (cmd.set_rnext) {
2015 		new_rnext = tcp_ao_established_key(sk, ao_info, -1, cmd.rnext);
2016 		if (!new_rnext) {
2017 			err = -ENOENT;
2018 			goto out;
2019 		}
2020 	}
2021 	if (cmd.set_counters) {
2022 		atomic64_set(&ao_info->counters.pkt_good, cmd.pkt_good);
2023 		atomic64_set(&ao_info->counters.pkt_bad, cmd.pkt_bad);
2024 		atomic64_set(&ao_info->counters.key_not_found, cmd.pkt_key_not_found);
2025 		atomic64_set(&ao_info->counters.ao_required, cmd.pkt_ao_required);
2026 		atomic64_set(&ao_info->counters.dropped_icmp, cmd.pkt_dropped_icmp);
2027 	}
2028 
2029 	ao_info->ao_required = cmd.ao_required;
2030 	ao_info->accept_icmps = cmd.accept_icmps;
2031 	if (new_current)
2032 		WRITE_ONCE(ao_info->current_key, new_current);
2033 	if (new_rnext)
2034 		WRITE_ONCE(ao_info->rnext_key, new_rnext);
2035 	if (first) {
2036 		if (!static_branch_inc(&tcp_ao_needed.key)) {
2037 			err = -EUSERS;
2038 			goto out;
2039 		}
2040 		sk_gso_disable(sk);
2041 		rcu_assign_pointer(tcp_sk(sk)->ao_info, ao_info);
2042 	}
2043 	return 0;
2044 out:
2045 	if (first)
2046 		kfree(ao_info);
2047 	return err;
2048 }
2049 
2050 int tcp_parse_ao(struct sock *sk, int cmd, unsigned short int family,
2051 		 sockptr_t optval, int optlen)
2052 {
2053 	if (WARN_ON_ONCE(family != AF_INET && family != AF_INET6))
2054 		return -EAFNOSUPPORT;
2055 
2056 	switch (cmd) {
2057 	case TCP_AO_ADD_KEY:
2058 		return tcp_ao_add_cmd(sk, family, optval, optlen);
2059 	case TCP_AO_DEL_KEY:
2060 		return tcp_ao_del_cmd(sk, family, optval, optlen);
2061 	case TCP_AO_INFO:
2062 		return tcp_ao_info_cmd(sk, family, optval, optlen);
2063 	default:
2064 		WARN_ON_ONCE(1);
2065 		return -EINVAL;
2066 	}
2067 }
2068 
2069 int tcp_v4_parse_ao(struct sock *sk, int cmd, sockptr_t optval, int optlen)
2070 {
2071 	return tcp_parse_ao(sk, cmd, AF_INET, optval, optlen);
2072 }
2073 
2074 /* tcp_ao_copy_mkts_to_user(ao_info, optval, optlen)
2075  *
2076  * @ao_info:	struct tcp_ao_info on the socket that
2077  *		socket getsockopt(TCP_AO_GET_KEYS) is executed on
2078  * @optval:	pointer to array of tcp_ao_getsockopt structures in user space.
2079  *		Must be != NULL.
2080  * @optlen:	pointer to size of tcp_ao_getsockopt structure.
2081  *		Must be != NULL.
2082  *
2083  * Return value: 0 on success, a negative error number otherwise.
2084  *
2085  * optval points to an array of tcp_ao_getsockopt structures in user space.
2086  * optval[0] is used as both input and output to getsockopt. It determines
2087  * which keys are returned by the kernel.
2088  * optval[0].nkeys is the size of the array in user space. On return it contains
2089  * the number of keys matching the search criteria.
2090  * If tcp_ao_getsockopt::get_all is set, then all keys in the socket are
2091  * returned, otherwise only keys matching <addr, prefix, sndid, rcvid>
2092  * in optval[0] are returned.
2093  * optlen is also used as both input and output. The user provides the size
2094  * of struct tcp_ao_getsockopt in user space, and the kernel returns the size
2095  * of the structure in kernel space.
2096  * The size of struct tcp_ao_getsockopt may differ between user and kernel.
2097  * There are three cases to consider:
2098  *  * If usize == ksize, then keys are copied verbatim.
2099  *  * If usize < ksize, then the userspace has passed an old struct to a
2100  *    newer kernel. The rest of the trailing bytes in optval[0]
2101  *    (ksize - usize) are interpreted as 0 by the kernel.
2102  *  * If usize > ksize, then the userspace has passed a new struct to an
2103  *    older kernel. The trailing bytes unknown to the kernel (usize - ksize)
2104  *    are checked to ensure they are zeroed, otherwise -E2BIG is returned.
2105  * On return the kernel fills in min(usize, ksize) in each entry of the array.
2106  * The layout of the fields in the user and kernel structures is expected to
2107  * be the same (including in the 32bit vs 64bit case).
2108  */
2109 static int tcp_ao_copy_mkts_to_user(const struct sock *sk,
2110 				    struct tcp_ao_info *ao_info,
2111 				    sockptr_t optval, sockptr_t optlen)
2112 {
2113 	struct tcp_ao_getsockopt opt_in, opt_out;
2114 	struct tcp_ao_key *key, *current_key;
2115 	bool do_address_matching = true;
2116 	union tcp_ao_addr *addr = NULL;
2117 	int err, l3index, user_len;
2118 	unsigned int max_keys;	/* maximum number of keys to copy to user */
2119 	size_t out_offset = 0;
2120 	size_t bytes_to_write;	/* number of bytes to write to user level */
2121 	u32 matched_keys;	/* keys from ao_info matched so far */
2122 	int optlen_out;
2123 	__be16 port = 0;
2124 
2125 	if (copy_from_sockptr(&user_len, optlen, sizeof(int)))
2126 		return -EFAULT;
2127 
2128 	if (user_len <= 0)
2129 		return -EINVAL;
2130 
2131 	memset(&opt_in, 0, sizeof(struct tcp_ao_getsockopt));
2132 	err = copy_struct_from_sockptr(&opt_in, sizeof(opt_in),
2133 				       optval, user_len);
2134 	if (err < 0)
2135 		return err;
2136 
2137 	if (opt_in.pkt_good || opt_in.pkt_bad)
2138 		return -EINVAL;
2139 	if (opt_in.keyflags & ~TCP_AO_GET_KEYF_VALID)
2140 		return -EINVAL;
2141 	if (opt_in.ifindex && !(opt_in.keyflags & TCP_AO_KEYF_IFINDEX))
2142 		return -EINVAL;
2143 
2144 	if (opt_in.reserved != 0)
2145 		return -EINVAL;
2146 
2147 	max_keys = opt_in.nkeys;
2148 	l3index = (opt_in.keyflags & TCP_AO_KEYF_IFINDEX) ? opt_in.ifindex : -1;
2149 
2150 	if (opt_in.get_all || opt_in.is_current || opt_in.is_rnext) {
2151 		if (opt_in.get_all && (opt_in.is_current || opt_in.is_rnext))
2152 			return -EINVAL;
2153 		do_address_matching = false;
2154 	}
2155 
2156 	switch (opt_in.addr.ss_family) {
2157 	case AF_INET: {
2158 		struct sockaddr_in *sin;
2159 		__be32 mask;
2160 
2161 		sin = (struct sockaddr_in *)&opt_in.addr;
2162 		port = sin->sin_port;
2163 		addr = (union tcp_ao_addr *)&sin->sin_addr;
2164 
2165 		if (opt_in.prefix > 32)
2166 			return -EINVAL;
2167 
2168 		if (ntohl(sin->sin_addr.s_addr) == INADDR_ANY &&
2169 		    opt_in.prefix != 0)
2170 			return -EINVAL;
2171 
2172 		mask = inet_make_mask(opt_in.prefix);
2173 		if (sin->sin_addr.s_addr & ~mask)
2174 			return -EINVAL;
2175 
2176 		break;
2177 	}
2178 	case AF_INET6: {
2179 		struct sockaddr_in6 *sin6;
2180 		struct in6_addr *addr6;
2181 
2182 		sin6 = (struct sockaddr_in6 *)&opt_in.addr;
2183 		addr = (union tcp_ao_addr *)&sin6->sin6_addr;
2184 		addr6 = &sin6->sin6_addr;
2185 		port = sin6->sin6_port;
2186 
2187 		/* We don't have to change family and @addr here if
2188 		 * ipv6_addr_v4mapped() like in key adding:
2189 		 * tcp_ao_key_cmp() does it. Do the sanity checks though.
2190 		 */
2191 		if (opt_in.prefix != 0) {
2192 			if (ipv6_addr_v4mapped(addr6)) {
2193 				__be32 mask, addr4 = addr6->s6_addr32[3];
2194 
2195 				if (opt_in.prefix > 32 ||
2196 				    ntohl(addr4) == INADDR_ANY)
2197 					return -EINVAL;
2198 				mask = inet_make_mask(opt_in.prefix);
2199 				if (addr4 & ~mask)
2200 					return -EINVAL;
2201 			} else {
2202 				struct in6_addr pfx;
2203 
2204 				if (ipv6_addr_any(addr6) ||
2205 				    opt_in.prefix > 128)
2206 					return -EINVAL;
2207 
2208 				ipv6_addr_prefix(&pfx, addr6, opt_in.prefix);
2209 				if (ipv6_addr_cmp(&pfx, addr6))
2210 					return -EINVAL;
2211 			}
2212 		} else if (!ipv6_addr_any(addr6)) {
2213 			return -EINVAL;
2214 		}
2215 		break;
2216 	}
2217 	case 0:
2218 		if (!do_address_matching)
2219 			break;
2220 		fallthrough;
2221 	default:
2222 		return -EAFNOSUPPORT;
2223 	}
2224 
2225 	if (!do_address_matching) {
2226 		/* We could just ignore those, but let's do stricter checks */
2227 		if (addr || port)
2228 			return -EINVAL;
2229 		if (opt_in.prefix || opt_in.sndid || opt_in.rcvid)
2230 			return -EINVAL;
2231 	}
2232 
2233 	bytes_to_write = min_t(int, user_len, sizeof(struct tcp_ao_getsockopt));
2234 	matched_keys = 0;
2235 	/* May change in RX, while we're dumping, pre-fetch it */
2236 	current_key = READ_ONCE(ao_info->current_key);
2237 
2238 	hlist_for_each_entry_rcu(key, &ao_info->head, node,
2239 				 lockdep_sock_is_held(sk)) {
2240 		if (opt_in.get_all)
2241 			goto match;
2242 
2243 		if (opt_in.is_current || opt_in.is_rnext) {
2244 			if (opt_in.is_current && key == current_key)
2245 				goto match;
2246 			if (opt_in.is_rnext && key == ao_info->rnext_key)
2247 				goto match;
2248 			continue;
2249 		}
2250 
2251 		if (tcp_ao_key_cmp(key, l3index, addr, opt_in.prefix,
2252 				   opt_in.addr.ss_family,
2253 				   opt_in.sndid, opt_in.rcvid) != 0)
2254 			continue;
2255 match:
2256 		matched_keys++;
2257 		if (matched_keys > max_keys)
2258 			continue;
2259 
2260 		memset(&opt_out, 0, sizeof(struct tcp_ao_getsockopt));
2261 
2262 		if (key->family == AF_INET) {
2263 			struct sockaddr_in *sin_out = (struct sockaddr_in *)&opt_out.addr;
2264 
2265 			sin_out->sin_family = key->family;
2266 			sin_out->sin_port = 0;
2267 			memcpy(&sin_out->sin_addr, &key->addr, sizeof(struct in_addr));
2268 		} else {
2269 			struct sockaddr_in6 *sin6_out = (struct sockaddr_in6 *)&opt_out.addr;
2270 
2271 			sin6_out->sin6_family = key->family;
2272 			sin6_out->sin6_port = 0;
2273 			memcpy(&sin6_out->sin6_addr, &key->addr, sizeof(struct in6_addr));
2274 		}
2275 		opt_out.sndid = key->sndid;
2276 		opt_out.rcvid = key->rcvid;
2277 		opt_out.prefix = key->prefixlen;
2278 		opt_out.keyflags = key->keyflags;
2279 		opt_out.is_current = (key == current_key);
2280 		opt_out.is_rnext = (key == ao_info->rnext_key);
2281 		opt_out.nkeys = 0;
2282 		opt_out.maclen = key->maclen;
2283 		opt_out.keylen = key->keylen;
2284 		opt_out.ifindex = key->l3index;
2285 		opt_out.pkt_good = atomic64_read(&key->pkt_good);
2286 		opt_out.pkt_bad = atomic64_read(&key->pkt_bad);
2287 		memcpy(&opt_out.key, key->key, key->keylen);
2288 		tcp_sigpool_algo(key->tcp_sigpool_id, opt_out.alg_name, 64);
2289 
2290 		/* Copy key to user */
2291 		if (copy_to_sockptr_offset(optval, out_offset,
2292 					   &opt_out, bytes_to_write))
2293 			return -EFAULT;
2294 		out_offset += user_len;
2295 	}
2296 
2297 	optlen_out = (int)sizeof(struct tcp_ao_getsockopt);
2298 	if (copy_to_sockptr(optlen, &optlen_out, sizeof(int)))
2299 		return -EFAULT;
2300 
2301 	out_offset = offsetof(struct tcp_ao_getsockopt, nkeys);
2302 	if (copy_to_sockptr_offset(optval, out_offset,
2303 				   &matched_keys, sizeof(u32)))
2304 		return -EFAULT;
2305 
2306 	return 0;
2307 }
2308 
2309 int tcp_ao_get_mkts(struct sock *sk, sockptr_t optval, sockptr_t optlen)
2310 {
2311 	struct tcp_ao_info *ao_info;
2312 
2313 	ao_info = setsockopt_ao_info(sk);
2314 	if (IS_ERR(ao_info))
2315 		return PTR_ERR(ao_info);
2316 	if (!ao_info)
2317 		return -ENOENT;
2318 
2319 	return tcp_ao_copy_mkts_to_user(sk, ao_info, optval, optlen);
2320 }
2321 
2322 int tcp_ao_get_sock_info(struct sock *sk, sockptr_t optval, sockptr_t optlen)
2323 {
2324 	struct tcp_ao_info_opt out, in = {};
2325 	struct tcp_ao_key *current_key;
2326 	struct tcp_ao_info *ao;
2327 	int err, len;
2328 
2329 	if (copy_from_sockptr(&len, optlen, sizeof(int)))
2330 		return -EFAULT;
2331 
2332 	if (len <= 0)
2333 		return -EINVAL;
2334 
2335 	/* Copying this "in" only to check ::reserved, ::reserved2,
2336 	 * that may be needed to extend (struct tcp_ao_info_opt) and
2337 	 * what getsockopt() provides in future.
2338 	 */
2339 	err = copy_struct_from_sockptr(&in, sizeof(in), optval, len);
2340 	if (err)
2341 		return err;
2342 
2343 	if (in.reserved != 0 || in.reserved2 != 0)
2344 		return -EINVAL;
2345 
2346 	ao = setsockopt_ao_info(sk);
2347 	if (IS_ERR(ao))
2348 		return PTR_ERR(ao);
2349 	if (!ao)
2350 		return -ENOENT;
2351 
2352 	memset(&out, 0, sizeof(out));
2353 	out.ao_required		= ao->ao_required;
2354 	out.accept_icmps	= ao->accept_icmps;
2355 	out.pkt_good		= atomic64_read(&ao->counters.pkt_good);
2356 	out.pkt_bad		= atomic64_read(&ao->counters.pkt_bad);
2357 	out.pkt_key_not_found	= atomic64_read(&ao->counters.key_not_found);
2358 	out.pkt_ao_required	= atomic64_read(&ao->counters.ao_required);
2359 	out.pkt_dropped_icmp	= atomic64_read(&ao->counters.dropped_icmp);
2360 
2361 	current_key = READ_ONCE(ao->current_key);
2362 	if (current_key) {
2363 		out.set_current = 1;
2364 		out.current_key = current_key->sndid;
2365 	}
2366 	if (ao->rnext_key) {
2367 		out.set_rnext = 1;
2368 		out.rnext = ao->rnext_key->rcvid;
2369 	}
2370 
2371 	if (copy_to_sockptr(optval, &out, min_t(int, len, sizeof(out))))
2372 		return -EFAULT;
2373 
2374 	return 0;
2375 }
2376 
2377 int tcp_ao_set_repair(struct sock *sk, sockptr_t optval, unsigned int optlen)
2378 {
2379 	struct tcp_sock *tp = tcp_sk(sk);
2380 	struct tcp_ao_repair cmd;
2381 	struct tcp_ao_key *key;
2382 	struct tcp_ao_info *ao;
2383 	int err;
2384 
2385 	if (optlen < sizeof(cmd))
2386 		return -EINVAL;
2387 
2388 	err = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen);
2389 	if (err)
2390 		return err;
2391 
2392 	if (!tp->repair)
2393 		return -EPERM;
2394 
2395 	ao = setsockopt_ao_info(sk);
2396 	if (IS_ERR(ao))
2397 		return PTR_ERR(ao);
2398 	if (!ao)
2399 		return -ENOENT;
2400 
2401 	WRITE_ONCE(ao->lisn, cmd.snt_isn);
2402 	WRITE_ONCE(ao->risn, cmd.rcv_isn);
2403 	WRITE_ONCE(ao->snd_sne, cmd.snd_sne);
2404 	WRITE_ONCE(ao->rcv_sne, cmd.rcv_sne);
2405 
2406 	hlist_for_each_entry_rcu(key, &ao->head, node, lockdep_sock_is_held(sk))
2407 		tcp_ao_cache_traffic_keys(sk, ao, key);
2408 
2409 	return 0;
2410 }
2411 
2412 int tcp_ao_get_repair(struct sock *sk, sockptr_t optval, sockptr_t optlen)
2413 {
2414 	struct tcp_sock *tp = tcp_sk(sk);
2415 	struct tcp_ao_repair opt;
2416 	struct tcp_ao_info *ao;
2417 	int len;
2418 
2419 	if (copy_from_sockptr(&len, optlen, sizeof(int)))
2420 		return -EFAULT;
2421 
2422 	if (len <= 0)
2423 		return -EINVAL;
2424 
2425 	if (!tp->repair)
2426 		return -EPERM;
2427 
2428 	rcu_read_lock();
2429 	ao = getsockopt_ao_info(sk);
2430 	if (IS_ERR_OR_NULL(ao)) {
2431 		rcu_read_unlock();
2432 		return ao ? PTR_ERR(ao) : -ENOENT;
2433 	}
2434 
2435 	opt.snt_isn	= ao->lisn;
2436 	opt.rcv_isn	= ao->risn;
2437 	opt.snd_sne	= READ_ONCE(ao->snd_sne);
2438 	opt.rcv_sne	= READ_ONCE(ao->rcv_sne);
2439 	rcu_read_unlock();
2440 
2441 	if (copy_to_sockptr(optval, &opt, min_t(int, len, sizeof(opt))))
2442 		return -EFAULT;
2443 	return 0;
2444 }
2445