xref: /illumos-gate/usr/src/uts/common/inet/tcp/tcp_sig.c (revision c2cbc6b847acc8ffa2560958fbd72b70f21e5afe)
1 /*
2  * This file and its contents are supplied under the terms of the
3  * Common Development and Distribution License ("CDDL"), version 1.0.
4  * You may only use this file in accordance with the terms of version
5  * 1.0 of the CDDL.
6  *
7  * A full copy of the text of the CDDL should have accompanied this
8  * source.  A copy of the CDDL is also available via the Internet at
9  * http://www.illumos.org/license/CDDL.
10  */
11 
12 /*
13  * Copyright 2024 Oxide Computer Company
14  */
15 
16 /*
17  * RFC 2385 TCP MD5 Signature Option
18  *
19  * A security option commonly used to enhance security for BGP sessions. When a
20  * TCP socket has its TCP_MD5SIG option enabled, an additional TCP option is
21  * added to the header containing an MD5 digest calculated across the pseudo IP
22  * header, part of the TCP header, the data in the segment and a shared secret.
23  * The option is large (18 bytes plus 2 more for padding to a word boundary),
24  * and often /just/ fits in the TCP header -- particularly with SYN packets due
25  * to their additional options such as MSS.
26  *
27  * The socket option is boolean, and it is also necessary to have configured a
28  * security association (SA) to match the traffic that should be signed, and to
29  * provide the signing key. These SAs are configured from userland via
30  * tcpkey(8), use source and destination addresses and ports as criteria, and
31  * are maintained in a per-netstack linked list. The SAs pertaining to a
32  * particular TCP connection, one for each direction, are cached in the
33  * connection's TCP state after the first packet has been processed, and so
34  * using a single list is not a significant overhead, particularly as it is
35  * expected to be short.
36  *
37  * Enabling the socket option has a number of side effects:
38  *
39  *  - TCP fast path is disabled;
40  *  - TCP Fusion is disabled;
41  *  - Outbound packets for which a matching SA cannot be found are silently
42  *    discarded.
43  *  - Inbound packets that DO NOT contain an MD5 option in their TCP header are
44  *    silently discarded.
45  *  - Inbound packets that DO contain an MD5 option but for which the digest
46  *    does not match the locally calculated one are silently discarded.
47  *
48  * An SA is bound to a TCP stream once the first packet is sent or received
49  * following the TCP_MD5SIG socket option being enabled. Typically an
50  * application will enable the socket option immediately after creating the
51  * socket, and before moving on to calling connect() or bind() but it is
52  * necessary to wait for the first packet as that is the point at which the
53  * source and destination addresses and ports are all known, and we need these
54  * to find the SA. Note that if no matching SA is present in the database when
55  * the first packet is sent or received, it will be silently dropped. Due to
56  * the reference counting and tombstone logic, an SA that has been bound to one
57  * or more streams will persist until all of those streams have been torn down.
58  * It is not possible to change the SA for an active connection.
59  *
60  * -------------
61  * Lock Ordering
62  * -------------
63  *
64  * In order to ensure that we don't deadlock, if both are required, the RW lock
65  * across the SADB must be taken before acquiring an individual SA's lock. That
66  * is, locks must be taken in the following order (and released in the opposite
67  * order):
68  *
69  * 0) <tcpstack>->tcps_sigdb->td_lock
70  * 1) <tcpstack>->tcps_sigdb->td_sa.list-><entry>->ts_lock
71  *
72  * The lock at <tcpstack>->tcps_sigdb_lock is independent and used to
73  * synchronize lazy initialization of the database.
74  */
75 
76 #include <sys/atomic.h>
77 #include <sys/cmn_err.h>
78 #include <sys/cpuvar.h>
79 #include <sys/debug.h>
80 #include <sys/errno.h>
81 #include <sys/kmem.h>
82 #include <sys/list.h>
83 #include <sys/md5.h>
84 #include <sys/stdbool.h>
85 #include <sys/stream.h>
86 #include <sys/stropts.h>
87 #include <sys/strsubr.h>
88 #include <sys/strsun.h>
89 #include <sys/sysmacros.h>
90 #include <sys/types.h>
91 #include <netinet/in.h>
92 #include <netinet/ip6.h>
93 #include <net/pfkeyv2.h>
94 #include <net/pfpolicy.h>
95 #include <inet/common.h>
96 #include <inet/mi.h>
97 #include <inet/ip.h>
98 #include <inet/ip6.h>
99 #include <inet/ip_if.h>
100 #include <inet/tcp_stats.h>
101 #include <inet/keysock.h>
102 #include <inet/sadb.h>
103 #include <inet/tcp_sig.h>
104 
105 static void tcpsig_sa_free(tcpsig_sa_t *);
106 
107 void
tcpsig_init(tcp_stack_t * tcps)108 tcpsig_init(tcp_stack_t *tcps)
109 {
110 	mutex_init(&tcps->tcps_sigdb_lock, NULL, MUTEX_DEFAULT, NULL);
111 }
112 
113 void
tcpsig_fini(tcp_stack_t * tcps)114 tcpsig_fini(tcp_stack_t *tcps)
115 {
116 	tcpsig_db_t *db;
117 
118 	if ((db = tcps->tcps_sigdb) != NULL) {
119 		tcpsig_sa_t *sa;
120 
121 		rw_destroy(&db->td_lock);
122 		while ((sa = list_remove_head(&db->td_salist)) != NULL)
123 			tcpsig_sa_free(sa);
124 		list_destroy(&db->td_salist);
125 		kmem_free(tcps->tcps_sigdb, sizeof (tcpsig_db_t));
126 		tcps->tcps_sigdb = NULL;
127 	}
128 	mutex_destroy(&tcps->tcps_sigdb_lock);
129 }
130 
131 static tcpsig_db_t *
tcpsig_db(tcp_stack_t * tcps)132 tcpsig_db(tcp_stack_t *tcps)
133 {
134 	mutex_enter(&tcps->tcps_sigdb_lock);
135 	if (tcps->tcps_sigdb == NULL) {
136 		tcpsig_db_t *db = kmem_alloc(sizeof (tcpsig_db_t), KM_SLEEP);
137 
138 		rw_init(&db->td_lock, NULL, RW_DEFAULT, 0);
139 		list_create(&db->td_salist, sizeof (tcpsig_sa_t),
140 		    offsetof(tcpsig_sa_t, ts_link));
141 
142 		tcps->tcps_sigdb = db;
143 	}
144 	mutex_exit(&tcps->tcps_sigdb_lock);
145 
146 	return ((tcpsig_db_t *)tcps->tcps_sigdb);
147 }
148 
149 static uint8_t *
tcpsig_make_sa_ext(uint8_t * start,const uint8_t * const end,const tcpsig_sa_t * sa)150 tcpsig_make_sa_ext(uint8_t *start, const uint8_t * const end,
151     const tcpsig_sa_t *sa)
152 {
153 	sadb_sa_t *assoc;
154 
155 	ASSERT3P(end, >, start);
156 
157 	if (start == NULL || end - start < sizeof (*assoc))
158 		return (NULL);
159 
160 	assoc = (sadb_sa_t *)start;
161 	assoc->sadb_sa_exttype = SADB_EXT_SA;
162 	assoc->sadb_sa_len = SADB_8TO64(sizeof (*assoc));
163 	assoc->sadb_sa_auth = sa->ts_key.sak_algid;
164 	assoc->sadb_sa_flags = SADB_X_SAFLAGS_TCPSIG;
165 	assoc->sadb_sa_state = sa->ts_state;
166 
167 	return ((uint8_t *)(assoc + 1));
168 }
169 
170 static size_t
tcpsig_addr_extsize(const tcpsig_sa_t * sa)171 tcpsig_addr_extsize(const tcpsig_sa_t *sa)
172 {
173 	size_t addrsize = 0;
174 
175 	switch (sa->ts_family) {
176 	case AF_INET:
177 		addrsize = roundup(sizeof (sin_t) +
178 		    sizeof (sadb_address_t), sizeof (uint64_t));
179 		break;
180 	case AF_INET6:
181 		addrsize = roundup(sizeof (sin6_t) +
182 		    sizeof (sadb_address_t), sizeof (uint64_t));
183 		break;
184 	}
185 	return (addrsize);
186 }
187 
188 static uint8_t *
tcpsig_make_addr_ext(uint8_t * start,const uint8_t * const end,uint16_t exttype,sa_family_t af,const struct sockaddr_storage * addr)189 tcpsig_make_addr_ext(uint8_t *start, const uint8_t * const end,
190     uint16_t exttype, sa_family_t af, const struct sockaddr_storage *addr)
191 {
192 	uint8_t *cur = start;
193 	unsigned int addrext_len;
194 	sadb_address_t *addrext;
195 
196 	ASSERT(af == AF_INET || af == AF_INET6);
197 	ASSERT3P(end, >, start);
198 
199 	if (cur == NULL)
200 		return (NULL);
201 
202 	if (end - cur < sizeof (*addrext))
203 		return (NULL);
204 
205 	addrext = (sadb_address_t *)cur;
206 	addrext->sadb_address_proto = IPPROTO_TCP;
207 	addrext->sadb_address_reserved = 0;
208 	addrext->sadb_address_prefixlen = 0;
209 	addrext->sadb_address_exttype = exttype;
210 	cur = (uint8_t *)(addrext + 1);
211 
212 	if (af == AF_INET) {
213 		sin_t *sin;
214 
215 		if (end - cur < sizeof (*sin))
216 			return (NULL);
217 		sin = (sin_t *)cur;
218 
219 		*sin = sin_null;
220 		bcopy(addr, sin, sizeof (*sin));
221 		cur = (uint8_t *)(sin + 1);
222 	} else {
223 		sin6_t *sin6;
224 
225 		if (end - cur < sizeof (*sin6))
226 			return (NULL);
227 		sin6 = (sin6_t *)cur;
228 
229 		*sin6 = sin6_null;
230 		bcopy(addr, sin6, sizeof (*sin6));
231 		cur = (uint8_t *)(sin6 + 1);
232 	}
233 
234 	addrext_len = roundup(cur - start, sizeof (uint64_t));
235 	addrext->sadb_address_len = SADB_8TO64(addrext_len);
236 
237 	if (end - start < addrext_len)
238 		return (NULL);
239 	return (start + addrext_len);
240 }
241 
242 #define	SET_EXPIRE(sa, delta, exp) do {					\
243 	if (((sa)->ts_ ## delta) != 0) {				\
244 		(sa)->ts_ ## exp = tcpsig_add_time((sa)->ts_addtime,	\
245 			(sa)->ts_ ## delta);				\
246 	}								\
247 } while (0)
248 
249 #define	UPDATE_EXPIRE(sa, delta, exp) do {				\
250 	if (((sa)->ts_ ## delta) != 0) {				\
251 		time_t tmp = tcpsig_add_time((sa)->ts_usetime,		\
252 		    (sa)->ts_ ## delta);				\
253 		if (((sa)->ts_ ## exp) == 0)				\
254 			(sa)->ts_ ## exp = tmp;				\
255 		else							\
256 			(sa)->ts_ ## exp = MIN((sa)->ts_ ## exp, tmp);	\
257 	}								\
258 } while (0)
259 
260 #define	EXPIRED(sa, exp, now)						\
261 	((sa)->ts_ ## exp != 0 && sa->ts_ ## exp < (now))
262 
263 /*
264  * PF_KEY gives us lifetimes in uint64_t seconds. In order to avoid odd
265  * behaviour (either negative lifetimes or loss of high order bits) when
266  * someone asks for bizarrely long SA lifetimes, we do a saturating add for
267  * expire times.
268  */
269 #define	TIME_MAX	INT64_MAX
270 static time_t
tcpsig_add_time(time_t base,uint64_t delta)271 tcpsig_add_time(time_t base, uint64_t delta)
272 {
273 	if (delta > TIME_MAX)
274 		delta = TIME_MAX;
275 
276 	if (base > 0) {
277 		if (TIME_MAX - base < delta)
278 			return (TIME_MAX);
279 	}
280 
281 	return (base + delta);
282 }
283 
284 /*
285  * Check hard/soft liftimes and return an appropriate error.
286  */
287 static int
tcpsig_check_lifetimes(sadb_lifetime_t * hard,sadb_lifetime_t * soft)288 tcpsig_check_lifetimes(sadb_lifetime_t *hard, sadb_lifetime_t *soft)
289 {
290 	if (hard == NULL || soft == NULL)
291 		return (SADB_X_DIAGNOSTIC_NONE);
292 
293 	if (hard->sadb_lifetime_addtime != 0 &&
294 	    soft->sadb_lifetime_addtime != 0 &&
295 	    hard->sadb_lifetime_addtime < soft->sadb_lifetime_addtime) {
296 		return (SADB_X_DIAGNOSTIC_ADDTIME_HSERR);
297 	}
298 
299 	if (hard->sadb_lifetime_usetime != 0 &&
300 	    soft->sadb_lifetime_usetime != 0 &&
301 	    hard->sadb_lifetime_usetime < soft->sadb_lifetime_usetime) {
302 		return (SADB_X_DIAGNOSTIC_USETIME_HSERR);
303 	}
304 
305 	return (SADB_X_DIAGNOSTIC_NONE);
306 }
307 
308 /*
309  * Update the lifetime values of an SA.
310  * If the updated lifetimes mean that a previously dying or dead SA should be
311  * promoted back to mature, then do that too. However, if they would mean that
312  * the SA is immediately expired, then that will be handled on the next
313  * aging run.
314  */
315 static void
tcpsig_update_lifetimes(tcpsig_sa_t * sa,sadb_lifetime_t * hard,sadb_lifetime_t * soft)316 tcpsig_update_lifetimes(tcpsig_sa_t *sa, sadb_lifetime_t *hard,
317     sadb_lifetime_t *soft)
318 {
319 	const time_t now = gethrestime_sec();
320 
321 	mutex_enter(&sa->ts_lock);
322 
323 	if (hard != NULL) {
324 		if (hard->sadb_lifetime_usetime != 0)
325 			sa->ts_harduselt = hard->sadb_lifetime_usetime;
326 		if (hard->sadb_lifetime_addtime != 0)
327 			sa->ts_hardaddlt = hard->sadb_lifetime_addtime;
328 		if (sa->ts_hardaddlt != 0)
329 			SET_EXPIRE(sa, hardaddlt, hardexpiretime);
330 		if (sa->ts_harduselt != 0 && sa->ts_usetime != 0)
331 			UPDATE_EXPIRE(sa, harduselt, hardexpiretime);
332 		if (sa->ts_state == SADB_SASTATE_DEAD &&
333 		    !EXPIRED(sa, hardexpiretime, now)) {
334 			sa->ts_state = SADB_SASTATE_MATURE;
335 		}
336 	}
337 
338 	if (soft != NULL) {
339 		if (soft->sadb_lifetime_usetime != 0) {
340 			sa->ts_softuselt = MIN(sa->ts_harduselt,
341 			    soft->sadb_lifetime_usetime);
342 		}
343 		if (soft->sadb_lifetime_addtime != 0) {
344 			sa->ts_softaddlt = MIN(sa->ts_hardaddlt,
345 			    soft->sadb_lifetime_addtime);
346 		}
347 		if (sa->ts_softaddlt != 0)
348 			SET_EXPIRE(sa, softaddlt, softexpiretime);
349 		if (sa->ts_softuselt != 0 && sa->ts_usetime != 0)
350 			UPDATE_EXPIRE(sa, softuselt, softexpiretime);
351 		if (sa->ts_state == SADB_SASTATE_DYING &&
352 		    !EXPIRED(sa, softexpiretime, now)) {
353 			sa->ts_state = SADB_SASTATE_MATURE;
354 		}
355 	}
356 
357 	mutex_exit(&sa->ts_lock);
358 }
359 
360 static void
tcpsig_sa_touch(tcpsig_sa_t * sa)361 tcpsig_sa_touch(tcpsig_sa_t *sa)
362 {
363 	const time_t now = gethrestime_sec();
364 
365 	mutex_enter(&sa->ts_lock);
366 	sa->ts_lastuse = now;
367 
368 	if (sa->ts_usetime == 0) {
369 		sa->ts_usetime = now;
370 		/* Update expiry times following the first use */
371 		UPDATE_EXPIRE(sa, softuselt, softexpiretime);
372 		UPDATE_EXPIRE(sa, harduselt, hardexpiretime);
373 	}
374 	mutex_exit(&sa->ts_lock);
375 }
376 
377 static void
tcpsig_sa_expiremsg(keysock_t * ks,const tcpsig_sa_t * sa,int ltt)378 tcpsig_sa_expiremsg(keysock_t *ks, const tcpsig_sa_t *sa, int ltt)
379 {
380 	size_t alloclen;
381 	sadb_sa_t *assoc;
382 	sadb_msg_t *samsg;
383 	sadb_lifetime_t *lt;
384 	uint8_t *cur, *end;
385 	mblk_t *mp;
386 
387 	alloclen = sizeof (sadb_msg_t) + sizeof (sadb_sa_t) +
388 	    2 * sizeof (sadb_lifetime_t) + 2 * tcpsig_addr_extsize(sa);
389 
390 	mp = allocb(alloclen, BPRI_HI);
391 	if (mp == NULL)
392 		return;
393 
394 	bzero(mp->b_rptr, alloclen);
395 	mp->b_wptr += alloclen;
396 	end = mp->b_wptr;
397 
398 	samsg = (sadb_msg_t *)mp->b_rptr;
399 	samsg->sadb_msg_version = PF_KEY_V2;
400 	samsg->sadb_msg_type = SADB_EXPIRE;
401 	samsg->sadb_msg_errno = 0;
402 	samsg->sadb_msg_satype = SADB_X_SATYPE_TCPSIG;
403 	samsg->sadb_msg_reserved = 0;
404 	samsg->sadb_msg_seq = 0;
405 	samsg->sadb_msg_pid = 0;
406 	samsg->sadb_msg_len = (uint16_t)SADB_8TO64(alloclen);
407 
408 	cur = (uint8_t *)(samsg + 1);
409 	cur = tcpsig_make_sa_ext(cur, end, sa);
410 	cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_SRC,
411 	    sa->ts_family, &sa->ts_src);
412 	cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_DST,
413 	    sa->ts_family, &sa->ts_dst);
414 
415 	if (cur == NULL) {
416 		freeb(mp);
417 		return;
418 	}
419 
420 	lt = (sadb_lifetime_t *)cur;
421 	lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
422 	lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_CURRENT;
423 	lt->sadb_lifetime_allocations = 0;
424 	lt->sadb_lifetime_bytes = 0;
425 	lt->sadb_lifetime_addtime = sa->ts_addtime;
426 	lt->sadb_lifetime_usetime = sa->ts_usetime;
427 
428 	lt++;
429 	lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
430 	lt->sadb_lifetime_exttype = ltt;
431 	lt->sadb_lifetime_allocations = 0;
432 	lt->sadb_lifetime_bytes = 0;
433 	lt->sadb_lifetime_addtime = sa->ts_hardaddlt;
434 	lt->sadb_lifetime_usetime = sa->ts_harduselt;
435 
436 	keysock_passup(mp, (sadb_msg_t *)mp->b_rptr,
437 	    0, NULL, B_TRUE, ks->keysock_keystack);
438 }
439 
440 static void
tcpsig_sa_age(keysock_t * ks,tcp_stack_t * tcps)441 tcpsig_sa_age(keysock_t *ks, tcp_stack_t *tcps)
442 {
443 	tcpsig_db_t *db = tcpsig_db(tcps);
444 	tcpsig_sa_t *nextsa;
445 	const time_t now = gethrestime_sec();
446 
447 	rw_enter(&db->td_lock, RW_WRITER);
448 	nextsa = list_head(&db->td_salist);
449 	while (nextsa != NULL) {
450 		tcpsig_sa_t *sa = nextsa;
451 
452 		nextsa = list_next(&db->td_salist, sa);
453 
454 		mutex_enter(&sa->ts_lock);
455 
456 		if (sa->ts_tombstoned) {
457 			mutex_exit(&sa->ts_lock);
458 			continue;
459 		}
460 
461 		if (EXPIRED(sa, hardexpiretime, now)) {
462 			sa->ts_state = IPSA_STATE_DEAD;
463 			tcpsig_sa_expiremsg(ks, sa, SADB_EXT_LIFETIME_HARD);
464 			if (sa->ts_refcnt > 0) {
465 				sa->ts_tombstoned = true;
466 				mutex_exit(&sa->ts_lock);
467 			} else {
468 				list_remove(&db->td_salist, sa);
469 				mutex_exit(&sa->ts_lock);
470 				tcpsig_sa_free(sa);
471 			}
472 			continue;
473 		}
474 
475 		if (EXPIRED(sa, softexpiretime, now) &&
476 		    sa->ts_state == IPSA_STATE_MATURE) {
477 			sa->ts_state = IPSA_STATE_DYING;
478 			tcpsig_sa_expiremsg(ks, sa, SADB_EXT_LIFETIME_SOFT);
479 		}
480 
481 		mutex_exit(&sa->ts_lock);
482 	}
483 
484 	rw_exit(&db->td_lock);
485 }
486 
487 static void
tcpsig_sa_free(tcpsig_sa_t * sa)488 tcpsig_sa_free(tcpsig_sa_t *sa)
489 {
490 	ASSERT0(sa->ts_refcnt);
491 	mutex_destroy(&sa->ts_lock);
492 	kmem_free(sa->ts_key.sak_key, sa->ts_key.sak_keylen);
493 	kmem_free(sa, sizeof (*sa));
494 }
495 
496 void
tcpsig_sa_rele(tcpsig_sa_t * sa)497 tcpsig_sa_rele(tcpsig_sa_t *sa)
498 {
499 	mutex_enter(&sa->ts_lock);
500 	VERIFY3U(sa->ts_refcnt, >, 0);
501 	sa->ts_refcnt--;
502 	/*
503 	 * If we are tombstoned (have been marked as deleted) and the reference
504 	 * count has now dropped to zero, then we can go ahead and finally
505 	 * remove this SA from the database.
506 	 */
507 	if (sa->ts_tombstoned && sa->ts_refcnt == 0) {
508 		tcpsig_db_t *db = tcpsig_db(sa->ts_stack);
509 
510 		/*
511 		 * To maintain the required lock ordering, we need to drop the
512 		 * lock on the SA while acquiring the RW lock on the list. Take
513 		 * an additional hold before doing this dance and drop it once
514 		 * we have re-gained the lock.
515 		 */
516 		sa->ts_refcnt++;
517 		mutex_exit(&sa->ts_lock);
518 		rw_enter(&db->td_lock, RW_WRITER);
519 		mutex_enter(&sa->ts_lock);
520 		sa->ts_refcnt--;
521 		mutex_exit(&sa->ts_lock);
522 
523 		list_remove(&db->td_salist, sa);
524 
525 		rw_exit(&db->td_lock);
526 		tcpsig_sa_free(sa);
527 	} else {
528 		mutex_exit(&sa->ts_lock);
529 	}
530 }
531 
532 static bool
tcpsig_sa_match4(tcpsig_sa_t * sa,struct sockaddr_storage * src_s,struct sockaddr_storage * dst_s)533 tcpsig_sa_match4(tcpsig_sa_t *sa, struct sockaddr_storage *src_s,
534     struct sockaddr_storage *dst_s)
535 {
536 	sin_t msrc, mdst, *src, *dst, *sasrc, *sadst;
537 
538 	if (src_s->ss_family != AF_INET)
539 		return (false);
540 
541 	src = (sin_t *)src_s;
542 	dst = (sin_t *)dst_s;
543 
544 	if (sa->ts_family == AF_INET6) {
545 		sin6_t *sasrc6 = (sin6_t *)&sa->ts_src;
546 		sin6_t *sadst6 = (sin6_t *)&sa->ts_dst;
547 
548 		if (!IN6_IS_ADDR_V4MAPPED(&sasrc6->sin6_addr) ||
549 		    !IN6_IS_ADDR_V4MAPPED(&sadst6->sin6_addr)) {
550 			return (false);
551 		}
552 
553 		msrc = sin_null;
554 		msrc.sin_family = AF_INET;
555 		msrc.sin_port = sasrc6->sin6_port;
556 		IN6_V4MAPPED_TO_INADDR(&sasrc6->sin6_addr, &msrc.sin_addr);
557 		sasrc = &msrc;
558 
559 		mdst = sin_null;
560 		mdst.sin_family = AF_INET;
561 		mdst.sin_port = sadst6->sin6_port;
562 		IN6_V4MAPPED_TO_INADDR(&sadst6->sin6_addr, &mdst.sin_addr);
563 		sadst = &mdst;
564 	} else {
565 		sasrc = (sin_t *)&sa->ts_src;
566 		sadst = (sin_t *)&sa->ts_dst;
567 	}
568 
569 	if (sasrc->sin_port != 0 && sasrc->sin_port != src->sin_port)
570 		return (false);
571 	if (sadst->sin_port != 0 && sadst->sin_port != dst->sin_port)
572 		return (false);
573 
574 	if (sasrc->sin_addr.s_addr != src->sin_addr.s_addr)
575 		return (false);
576 	if (sadst->sin_addr.s_addr != dst->sin_addr.s_addr)
577 		return (false);
578 
579 	return (true);
580 }
581 
582 static bool
tcpsig_sa_match6(tcpsig_sa_t * sa,struct sockaddr_storage * src_s,struct sockaddr_storage * dst_s)583 tcpsig_sa_match6(tcpsig_sa_t *sa, struct sockaddr_storage *src_s,
584     struct sockaddr_storage *dst_s)
585 {
586 	sin6_t *src, *dst, *sasrc, *sadst;
587 
588 	if (src_s->ss_family != AF_INET6 || sa->ts_src.ss_family != AF_INET6)
589 		return (false);
590 
591 	src = (sin6_t *)src_s;
592 	dst = (sin6_t *)dst_s;
593 
594 	sasrc = (sin6_t *)&sa->ts_src;
595 	sadst = (sin6_t *)&sa->ts_dst;
596 
597 	if (sasrc->sin6_port != 0 && sasrc->sin6_port != src->sin6_port)
598 		return (false);
599 	if (sadst->sin6_port != 0 && sadst->sin6_port != dst->sin6_port)
600 		return (false);
601 
602 	if (!IN6_ARE_ADDR_EQUAL(&sasrc->sin6_addr, &src->sin6_addr))
603 		return (false);
604 	if (!IN6_ARE_ADDR_EQUAL(&sadst->sin6_addr, &dst->sin6_addr))
605 		return (false);
606 
607 	return (true);
608 }
609 
610 static tcpsig_sa_t *
tcpsig_sa_find_held(struct sockaddr_storage * src,struct sockaddr_storage * dst,tcp_stack_t * tcps)611 tcpsig_sa_find_held(struct sockaddr_storage *src, struct sockaddr_storage *dst,
612     tcp_stack_t *tcps)
613 {
614 	tcpsig_db_t *db = tcpsig_db(tcps);
615 	tcpsig_sa_t *sa = NULL;
616 	const time_t now = gethrestime_sec();
617 
618 	ASSERT(RW_LOCK_HELD(&db->td_lock));
619 
620 	if (src->ss_family != dst->ss_family)
621 		return (NULL);
622 
623 	for (sa = list_head(&db->td_salist); sa != NULL;
624 	    sa = list_next(&db->td_salist, sa)) {
625 		mutex_enter(&sa->ts_lock);
626 		/*
627 		 * We don't consider tombstoned or hard expired entries as a
628 		 * possible match.
629 		 */
630 		if (sa->ts_tombstoned || EXPIRED(sa, hardexpiretime, now)) {
631 			mutex_exit(&sa->ts_lock);
632 			continue;
633 		}
634 		if (tcpsig_sa_match4(sa, src, dst) ||
635 		    tcpsig_sa_match6(sa, src, dst)) {
636 			sa->ts_refcnt++;
637 			mutex_exit(&sa->ts_lock);
638 			break;
639 		}
640 		mutex_exit(&sa->ts_lock);
641 	}
642 
643 	return (sa);
644 }
645 
646 static tcpsig_sa_t *
tcpsig_sa_find(struct sockaddr_storage * src,struct sockaddr_storage * dst,tcp_stack_t * tcps)647 tcpsig_sa_find(struct sockaddr_storage *src, struct sockaddr_storage *dst,
648     tcp_stack_t *tcps)
649 {
650 	tcpsig_db_t *db = tcpsig_db(tcps);
651 	tcpsig_sa_t *sa;
652 
653 	rw_enter(&db->td_lock, RW_READER);
654 	sa = tcpsig_sa_find_held(src, dst, tcps);
655 	rw_exit(&db->td_lock);
656 
657 	return (sa);
658 }
659 
660 static int
tcpsig_sa_flush(keysock_t * ks,tcp_stack_t * tcps,int * diagp)661 tcpsig_sa_flush(keysock_t *ks, tcp_stack_t *tcps, int *diagp)
662 {
663 	tcpsig_db_t *db = tcpsig_db(tcps);
664 	tcpsig_sa_t *nextsa;
665 
666 	rw_enter(&db->td_lock, RW_WRITER);
667 	nextsa = list_head(&db->td_salist);
668 	while (nextsa != NULL) {
669 		tcpsig_sa_t *sa = nextsa;
670 
671 		nextsa = list_next(&db->td_salist, sa);
672 
673 		mutex_enter(&sa->ts_lock);
674 		if (sa->ts_refcnt > 0) {
675 			sa->ts_tombstoned = true;
676 			mutex_exit(&sa->ts_lock);
677 			continue;
678 		}
679 
680 		list_remove(&db->td_salist, sa);
681 
682 		mutex_exit(&sa->ts_lock);
683 		tcpsig_sa_free(sa);
684 	}
685 
686 	rw_exit(&db->td_lock);
687 
688 	return (0);
689 }
690 
691 static int
tcpsig_sa_add(keysock_t * ks,tcp_stack_t * tcps,keysock_in_t * ksi,sadb_ext_t ** extv,int * diagp)692 tcpsig_sa_add(keysock_t *ks, tcp_stack_t *tcps, keysock_in_t *ksi,
693     sadb_ext_t **extv, int *diagp)
694 {
695 	tcpsig_db_t *db;
696 	sadb_address_t *srcext, *dstext;
697 	sadb_lifetime_t *soft, *hard;
698 	sadb_sa_t *assoc;
699 	struct sockaddr_storage *src, *dst;
700 	sadb_key_t *key;
701 	tcpsig_sa_t *sa, *dupsa;
702 	int ret = 0;
703 
704 	assoc = (sadb_sa_t *)extv[SADB_EXT_SA];
705 	srcext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_SRC];
706 	dstext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_DST];
707 	key = (sadb_key_t *)extv[SADB_X_EXT_STR_AUTH];
708 	soft = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_SOFT];
709 	hard = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_HARD];
710 
711 	if (assoc == NULL) {
712 		*diagp = SADB_X_DIAGNOSTIC_MISSING_SA;
713 		return (EINVAL);
714 	}
715 
716 	if (srcext == NULL) {
717 		*diagp = SADB_X_DIAGNOSTIC_MISSING_SRC;
718 		return (EINVAL);
719 	}
720 
721 	if (dstext == NULL) {
722 		*diagp = SADB_X_DIAGNOSTIC_MISSING_DST;
723 		return (EINVAL);
724 	}
725 
726 	if (key == NULL) {
727 		*diagp = SADB_X_DIAGNOSTIC_MISSING_ASTR;
728 		return (EINVAL);
729 	}
730 
731 	if ((*diagp = tcpsig_check_lifetimes(hard, soft)) !=
732 	    SADB_X_DIAGNOSTIC_NONE) {
733 		return (EINVAL);
734 	}
735 
736 	src = (struct sockaddr_storage *)(srcext + 1);
737 	dst = (struct sockaddr_storage *)(dstext + 1);
738 
739 	if (src->ss_family != dst->ss_family) {
740 		*diagp = SADB_X_DIAGNOSTIC_AF_MISMATCH;
741 		return (EINVAL);
742 	}
743 
744 	if (src->ss_family != AF_INET && src->ss_family != AF_INET6) {
745 		*diagp = SADB_X_DIAGNOSTIC_BAD_SRC_AF;
746 		return (EINVAL);
747 	}
748 
749 	/* We only support MD5 */
750 	if (assoc->sadb_sa_auth != SADB_AALG_MD5) {
751 		*diagp = SADB_X_DIAGNOSTIC_BAD_AALG;
752 		return (EINVAL);
753 	}
754 
755 	/* The authentication key length must be a multiple of whole bytes */
756 	if ((key->sadb_key_bits & 0x7) != 0) {
757 		*diagp = SADB_X_DIAGNOSTIC_MALFORMED_AKEY;
758 		return (EINVAL);
759 	}
760 
761 	db = tcpsig_db(tcps);
762 
763 	sa = kmem_zalloc(sizeof (*sa), KM_NOSLEEP_LAZY);
764 	if (sa == NULL)
765 		return (ENOMEM);
766 
767 	sa->ts_stack = tcps;
768 	sa->ts_family = src->ss_family;
769 	if (sa->ts_family == AF_INET6) {
770 		bcopy(src, (sin6_t *)&sa->ts_src, sizeof (sin6_t));
771 		bcopy(dst, (sin6_t *)&sa->ts_dst, sizeof (sin6_t));
772 	} else {
773 		bcopy(src, (sin_t *)&sa->ts_src, sizeof (sin_t));
774 		bcopy(dst, (sin_t *)&sa->ts_dst, sizeof (sin_t));
775 	}
776 
777 	sa->ts_key.sak_algid = assoc->sadb_sa_auth;
778 	sa->ts_key.sak_keylen = SADB_1TO8(key->sadb_key_bits);
779 	sa->ts_key.sak_keybits = key->sadb_key_bits;
780 
781 	sa->ts_key.sak_key = kmem_alloc(sa->ts_key.sak_keylen,
782 	    KM_NOSLEEP_LAZY);
783 	if (sa->ts_key.sak_key == NULL) {
784 		kmem_free(sa, sizeof (*sa));
785 		return (ENOMEM);
786 	}
787 	bcopy(key + 1, sa->ts_key.sak_key, sa->ts_key.sak_keylen);
788 	bzero(key + 1, sa->ts_key.sak_keylen);
789 
790 	mutex_init(&sa->ts_lock, NULL, MUTEX_DEFAULT, NULL);
791 
792 	sa->ts_state = SADB_SASTATE_MATURE;
793 	sa->ts_addtime = gethrestime_sec();
794 	sa->ts_usetime = 0;
795 	if (soft != NULL) {
796 		sa->ts_softaddlt = soft->sadb_lifetime_addtime;
797 		sa->ts_softuselt = soft->sadb_lifetime_usetime;
798 		SET_EXPIRE(sa, softaddlt, softexpiretime);
799 	}
800 
801 	if (hard != NULL) {
802 		sa->ts_hardaddlt = hard->sadb_lifetime_addtime;
803 		sa->ts_harduselt = hard->sadb_lifetime_usetime;
804 		SET_EXPIRE(sa, hardaddlt, hardexpiretime);
805 	}
806 
807 	sa->ts_refcnt = 0;
808 	sa->ts_tombstoned = false;
809 
810 	rw_enter(&db->td_lock, RW_WRITER);
811 	if ((dupsa = tcpsig_sa_find_held(src, dst, tcps)) != NULL) {
812 		rw_exit(&db->td_lock);
813 		tcpsig_sa_rele(dupsa);
814 		tcpsig_sa_free(sa);
815 		*diagp = SADB_X_DIAGNOSTIC_DUPLICATE_SA;
816 		ret = EEXIST;
817 	} else {
818 		list_insert_tail(&db->td_salist, sa);
819 		rw_exit(&db->td_lock);
820 	}
821 
822 	return (ret);
823 }
824 
825 /*
826  * Handle an UPDATE message. We only support updating lifetimes.
827  */
828 static int
tcpsig_sa_update(keysock_t * ks,tcp_stack_t * tcps,keysock_in_t * ksi,sadb_ext_t ** extv,int * diagp)829 tcpsig_sa_update(keysock_t *ks, tcp_stack_t *tcps, keysock_in_t *ksi,
830     sadb_ext_t **extv, int *diagp)
831 {
832 	tcpsig_db_t *db;
833 	sadb_address_t *srcext, *dstext;
834 	sadb_lifetime_t *soft, *hard;
835 	struct sockaddr_storage *src, *dst;
836 	tcpsig_sa_t *sa;
837 
838 	srcext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_SRC];
839 	dstext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_DST];
840 	soft = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_SOFT];
841 	hard = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_HARD];
842 
843 	if (srcext == NULL) {
844 		*diagp = SADB_X_DIAGNOSTIC_MISSING_SRC;
845 		return (EINVAL);
846 	}
847 
848 	if (dstext == NULL) {
849 		*diagp = SADB_X_DIAGNOSTIC_MISSING_DST;
850 		return (EINVAL);
851 	}
852 
853 
854 	if ((*diagp = tcpsig_check_lifetimes(hard, soft)) !=
855 	    SADB_X_DIAGNOSTIC_NONE) {
856 		return (EINVAL);
857 	}
858 
859 	src = (struct sockaddr_storage *)(srcext + 1);
860 	dst = (struct sockaddr_storage *)(dstext + 1);
861 
862 	sa = tcpsig_sa_find(src, dst, tcps);
863 
864 	if (sa == NULL) {
865 		*diagp = SADB_X_DIAGNOSTIC_PAIR_SA_NOTFOUND;
866 		return (ESRCH);
867 	}
868 
869 	tcpsig_update_lifetimes(sa, hard, soft);
870 	tcpsig_sa_rele(sa);
871 
872 	/*
873 	 * Run an aging pass in case updating the SA lifetimes has resulted in
874 	 * the SA now being aged out.
875 	 */
876 	tcpsig_sa_age(ks, tcps);
877 
878 	return (0);
879 }
880 
881 static mblk_t *
tcpsig_dump_one(const tcpsig_sa_t * sa,sadb_msg_t * samsg)882 tcpsig_dump_one(const tcpsig_sa_t *sa, sadb_msg_t *samsg)
883 {
884 	size_t alloclen, keysize;
885 	sadb_sa_t *assoc;
886 	sadb_msg_t *newsamsg;
887 	uint8_t *cur, *end;
888 	sadb_key_t *key;
889 	mblk_t *mp;
890 	bool soft = false, hard = false;
891 
892 	ASSERT(MUTEX_HELD(&sa->ts_lock));
893 
894 	alloclen = sizeof (sadb_msg_t) + sizeof (sadb_sa_t) +
895 	    2 * tcpsig_addr_extsize(sa);
896 
897 	if (sa->ts_softaddlt != 0 || sa->ts_softuselt != 0) {
898 		alloclen += sizeof (sadb_lifetime_t);
899 		soft = true;
900 	}
901 
902 	if (sa->ts_hardaddlt != 0 || sa->ts_harduselt != 0) {
903 		alloclen += sizeof (sadb_lifetime_t);
904 		hard = true;
905 	}
906 
907 	/* Add space for LIFETIME_CURRENT */
908 	if (soft || hard)
909 		alloclen += sizeof (sadb_lifetime_t);
910 
911 	keysize = roundup(sizeof (sadb_key_t) + sa->ts_key.sak_keylen,
912 	    sizeof (uint64_t));
913 
914 	alloclen += keysize;
915 
916 	mp = allocb(alloclen, BPRI_HI);
917 	if (mp == NULL)
918 		return (NULL);
919 
920 	bzero(mp->b_rptr, alloclen);
921 	mp->b_wptr += alloclen;
922 	end = mp->b_wptr;
923 
924 	newsamsg = (sadb_msg_t *)mp->b_rptr;
925 	*newsamsg = *samsg;
926 	newsamsg->sadb_msg_len = (uint16_t)SADB_8TO64(alloclen);
927 
928 	cur = (uint8_t *)(newsamsg + 1);
929 	cur = tcpsig_make_sa_ext(cur, end, sa);
930 	cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_SRC,
931 	    sa->ts_family, &sa->ts_src);
932 	cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_DST,
933 	    sa->ts_family, &sa->ts_dst);
934 
935 	if (cur == NULL) {
936 		freeb(mp);
937 		return (NULL);
938 	}
939 
940 	if (soft || hard) {
941 		sadb_lifetime_t *lt = (sadb_lifetime_t *)cur;
942 
943 		lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
944 		lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_CURRENT;
945 		lt->sadb_lifetime_allocations = 0;
946 		lt->sadb_lifetime_bytes = 0;
947 		lt->sadb_lifetime_addtime = sa->ts_addtime;
948 		lt->sadb_lifetime_usetime = sa->ts_usetime;
949 		lt++;
950 
951 		if (soft) {
952 			lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
953 			lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_SOFT;
954 			lt->sadb_lifetime_allocations = 0;
955 			lt->sadb_lifetime_bytes = 0;
956 			lt->sadb_lifetime_addtime = sa->ts_softaddlt;
957 			lt->sadb_lifetime_usetime = sa->ts_softuselt;
958 			lt++;
959 		}
960 		if (hard) {
961 			lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
962 			lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_HARD;
963 			lt->sadb_lifetime_allocations = 0;
964 			lt->sadb_lifetime_bytes = 0;
965 			lt->sadb_lifetime_addtime = sa->ts_hardaddlt;
966 			lt->sadb_lifetime_usetime = sa->ts_harduselt;
967 			lt++;
968 		}
969 
970 		cur = (uint8_t *)lt;
971 	}
972 
973 	key = (sadb_key_t *)cur;
974 	key->sadb_key_exttype = SADB_X_EXT_STR_AUTH;
975 	key->sadb_key_len = SADB_8TO64(keysize);
976 	key->sadb_key_bits = sa->ts_key.sak_keybits;
977 	key->sadb_key_reserved = 0;
978 	bcopy(sa->ts_key.sak_key, (uint8_t *)(key + 1), sa->ts_key.sak_keylen);
979 
980 	return (mp);
981 }
982 
983 static int
tcpsig_sa_dump(keysock_t * ks,tcp_stack_t * tcps,sadb_msg_t * samsg,int * diag)984 tcpsig_sa_dump(keysock_t *ks, tcp_stack_t *tcps, sadb_msg_t *samsg, int *diag)
985 {
986 	tcpsig_db_t *db;
987 	tcpsig_sa_t *sa;
988 
989 	db = tcpsig_db(tcps);
990 	rw_enter(&db->td_lock, RW_READER);
991 
992 	for (sa = list_head(&db->td_salist); sa != NULL;
993 	    sa = list_next(&db->td_salist, sa)) {
994 		mblk_t *mp;
995 
996 		mutex_enter(&sa->ts_lock);
997 		if (sa->ts_tombstoned) {
998 			mutex_exit(&sa->ts_lock);
999 			continue;
1000 		}
1001 		mp = tcpsig_dump_one(sa, samsg);
1002 		mutex_exit(&sa->ts_lock);
1003 
1004 		if (mp == NULL) {
1005 			rw_exit(&db->td_lock);
1006 			return (ENOMEM);
1007 		}
1008 		keysock_passup(mp, (sadb_msg_t *)mp->b_rptr,
1009 		    ks->keysock_serial, NULL, B_TRUE, ks->keysock_keystack);
1010 	}
1011 
1012 	rw_exit(&db->td_lock);
1013 
1014 	/* A sequence number of 0 indicates the end of the list */
1015 	samsg->sadb_msg_seq = 0;
1016 
1017 	return (0);
1018 }
1019 
1020 static int
tcpsig_sa_delget(keysock_t * ks,tcp_stack_t * tcps,sadb_msg_t * samsg,sadb_ext_t ** extv,int * diagp)1021 tcpsig_sa_delget(keysock_t *ks, tcp_stack_t *tcps, sadb_msg_t *samsg,
1022     sadb_ext_t **extv, int *diagp)
1023 {
1024 	sadb_address_t *srcext, *dstext;
1025 	struct sockaddr_storage *src, *dst;
1026 	tcpsig_sa_t *sa;
1027 	mblk_t *mp;
1028 
1029 	srcext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_SRC];
1030 	dstext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_DST];
1031 
1032 	if (srcext == NULL) {
1033 		*diagp = SADB_X_DIAGNOSTIC_MISSING_SRC;
1034 		return (EINVAL);
1035 	}
1036 
1037 	if (dstext == NULL) {
1038 		*diagp = SADB_X_DIAGNOSTIC_MISSING_DST;
1039 		return (EINVAL);
1040 	}
1041 
1042 	src = (struct sockaddr_storage *)(srcext + 1);
1043 	dst = (struct sockaddr_storage *)(dstext + 1);
1044 
1045 	sa = tcpsig_sa_find(src, dst, tcps);
1046 
1047 	if (sa == NULL) {
1048 		*diagp = SADB_X_DIAGNOSTIC_PAIR_SA_NOTFOUND;
1049 		return (ESRCH);
1050 	}
1051 
1052 	if (samsg->sadb_msg_type == SADB_GET) {
1053 		mutex_enter(&sa->ts_lock);
1054 		mp = tcpsig_dump_one(sa, samsg);
1055 		mutex_exit(&sa->ts_lock);
1056 
1057 		if (mp == NULL) {
1058 			tcpsig_sa_rele(sa);
1059 			return (ENOMEM);
1060 		}
1061 		keysock_passup(mp, (sadb_msg_t *)mp->b_rptr,
1062 		    ks->keysock_serial, NULL, B_TRUE, ks->keysock_keystack);
1063 		tcpsig_sa_rele(sa);
1064 
1065 		return (0);
1066 	}
1067 
1068 	/*
1069 	 * Delete the entry.
1070 	 * At this point we still have a hold on the entry from the find call
1071 	 * above, so mark it as tombstoned and then release the hold. If
1072 	 * that causes the reference count to become 0, the entry will be
1073 	 * removed from the database.
1074 	 */
1075 
1076 	mutex_enter(&sa->ts_lock);
1077 	sa->ts_tombstoned = true;
1078 	mutex_exit(&sa->ts_lock);
1079 	tcpsig_sa_rele(sa);
1080 
1081 	return (0);
1082 }
1083 
1084 void
tcpsig_sa_handler(keysock_t * ks,mblk_t * mp,sadb_msg_t * samsg,sadb_ext_t ** extv)1085 tcpsig_sa_handler(keysock_t *ks, mblk_t *mp, sadb_msg_t *samsg,
1086     sadb_ext_t **extv)
1087 {
1088 	keysock_stack_t *keystack = ks->keysock_keystack;
1089 	netstack_t *nst = keystack->keystack_netstack;
1090 	tcp_stack_t *tcps = nst->netstack_tcp;
1091 	keysock_in_t *ksi = (keysock_in_t *)mp->b_rptr;
1092 	int diag = SADB_X_DIAGNOSTIC_NONE;
1093 	int error;
1094 
1095 	tcpsig_sa_age(ks, tcps);
1096 
1097 	switch (samsg->sadb_msg_type) {
1098 	case SADB_ADD:
1099 		error = tcpsig_sa_add(ks, tcps, ksi, extv, &diag);
1100 		keysock_error(ks, mp, error, diag);
1101 		break;
1102 	case SADB_UPDATE:
1103 		error = tcpsig_sa_update(ks, tcps, ksi, extv, &diag);
1104 		keysock_error(ks, mp, error, diag);
1105 		break;
1106 	case SADB_GET:
1107 	case SADB_DELETE:
1108 		error = tcpsig_sa_delget(ks, tcps, samsg, extv, &diag);
1109 		keysock_error(ks, mp, error, diag);
1110 		break;
1111 	case SADB_FLUSH:
1112 		error = tcpsig_sa_flush(ks, tcps, &diag);
1113 		keysock_error(ks, mp, error, diag);
1114 		break;
1115 	case SADB_DUMP:
1116 		error = tcpsig_sa_dump(ks, tcps, samsg, &diag);
1117 		keysock_error(ks, mp, error, diag);
1118 		break;
1119 	default:
1120 		keysock_error(ks, mp, EOPNOTSUPP, diag);
1121 		break;
1122 	}
1123 }
1124 
1125 bool
tcpsig_sa_exists(tcp_t * tcp,bool inbound,tcpsig_sa_t ** sap)1126 tcpsig_sa_exists(tcp_t *tcp, bool inbound, tcpsig_sa_t **sap)
1127 {
1128 	tcp_stack_t *tcps = tcp->tcp_tcps;
1129 	conn_t *connp = tcp->tcp_connp;
1130 	struct sockaddr_storage src, dst;
1131 	tcpsig_sa_t *sa;
1132 
1133 	bzero(&src, sizeof (src));
1134 	bzero(&dst, sizeof (dst));
1135 
1136 	if (connp->conn_ipversion == IPV6_VERSION) {
1137 		sin6_t *sin6;
1138 
1139 		sin6 = (sin6_t *)&src;
1140 		sin6->sin6_family = AF_INET6;
1141 		if (inbound) {
1142 			sin6->sin6_addr = connp->conn_faddr_v6;
1143 			sin6->sin6_port = connp->conn_fport;
1144 		} else {
1145 			sin6->sin6_addr = connp->conn_saddr_v6;
1146 			sin6->sin6_port = connp->conn_lport;
1147 		}
1148 
1149 		sin6 = (sin6_t *)&dst;
1150 		sin6->sin6_family = AF_INET6;
1151 		if (inbound) {
1152 			sin6->sin6_addr = connp->conn_saddr_v6;
1153 			sin6->sin6_port = connp->conn_lport;
1154 		} else {
1155 			sin6->sin6_addr = connp->conn_faddr_v6;
1156 			sin6->sin6_port = connp->conn_fport;
1157 		}
1158 	} else {
1159 		sin_t *sin;
1160 
1161 		sin = (sin_t *)&src;
1162 		sin->sin_family = AF_INET;
1163 		if (inbound) {
1164 			sin->sin_addr.s_addr = connp->conn_faddr_v4;
1165 			sin->sin_port = connp->conn_fport;
1166 		} else {
1167 			sin->sin_addr.s_addr = connp->conn_saddr_v4;
1168 			sin->sin_port = connp->conn_lport;
1169 		}
1170 
1171 		sin = (sin_t *)&dst;
1172 		sin->sin_family = AF_INET;
1173 		if (inbound) {
1174 			sin->sin_addr.s_addr = connp->conn_saddr_v4;
1175 			sin->sin_port = connp->conn_lport;
1176 		} else {
1177 			sin->sin_addr.s_addr = connp->conn_faddr_v4;
1178 			sin->sin_port = connp->conn_fport;
1179 		}
1180 	}
1181 
1182 	sa = tcpsig_sa_find(&src, &dst, tcps);
1183 
1184 	if (sa == NULL)
1185 		return (false);
1186 
1187 	if (sap != NULL)
1188 		*sap = sa;
1189 	else
1190 		tcpsig_sa_rele(sa);
1191 
1192 	return (true);
1193 }
1194 
1195 static void
tcpsig_pseudo_compute4(tcp_t * tcp,int tcplen,MD5_CTX * ctx,bool inbound)1196 tcpsig_pseudo_compute4(tcp_t *tcp, int tcplen, MD5_CTX *ctx, bool inbound)
1197 {
1198 	struct ip_pseudo {
1199 		struct in_addr	ipp_src;
1200 		struct in_addr	ipp_dst;
1201 		uint8_t		ipp_pad;
1202 		uint8_t		ipp_proto;
1203 		uint16_t	ipp_len;
1204 	} ipp;
1205 	conn_t *connp = tcp->tcp_connp;
1206 
1207 	if (inbound) {
1208 		ipp.ipp_src.s_addr = connp->conn_faddr_v4;
1209 		ipp.ipp_dst.s_addr = connp->conn_saddr_v4;
1210 	} else {
1211 		ipp.ipp_src.s_addr = connp->conn_saddr_v4;
1212 		ipp.ipp_dst.s_addr = connp->conn_faddr_v4;
1213 	}
1214 	ipp.ipp_pad = 0;
1215 	ipp.ipp_proto = IPPROTO_TCP;
1216 	ipp.ipp_len = htons(tcplen);
1217 
1218 	DTRACE_PROBE1(ipp4, struct ip_pseudo *, &ipp);
1219 
1220 	MD5Update(ctx, (char *)&ipp, sizeof (ipp));
1221 }
1222 
1223 static void
tcpsig_pseudo_compute6(tcp_t * tcp,int tcplen,MD5_CTX * ctx,bool inbound)1224 tcpsig_pseudo_compute6(tcp_t *tcp, int tcplen, MD5_CTX *ctx, bool inbound)
1225 {
1226 	struct ip6_pseudo {
1227 		struct in6_addr	ipp_src;
1228 		struct in6_addr ipp_dst;
1229 		uint32_t	ipp_len;
1230 		uint32_t	ipp_nxt;
1231 	} ip6p;
1232 	conn_t *connp = tcp->tcp_connp;
1233 
1234 	if (inbound) {
1235 		ip6p.ipp_src = connp->conn_faddr_v6;
1236 		ip6p.ipp_dst = connp->conn_saddr_v6;
1237 	} else {
1238 		ip6p.ipp_src = connp->conn_saddr_v6;
1239 		ip6p.ipp_dst = connp->conn_faddr_v6;
1240 	}
1241 	ip6p.ipp_len = htonl(tcplen);
1242 	ip6p.ipp_nxt = htonl(IPPROTO_TCP);
1243 
1244 	DTRACE_PROBE1(ipp6, struct ip6_pseudo *, &ip6p);
1245 
1246 	MD5Update(ctx, (char *)&ip6p, sizeof (ip6p));
1247 }
1248 
1249 bool
tcpsig_signature(mblk_t * mp,tcp_t * tcp,tcpha_t * tcpha,int tcplen,uint8_t * digest,bool inbound)1250 tcpsig_signature(mblk_t *mp, tcp_t *tcp, tcpha_t *tcpha, int tcplen,
1251     uint8_t *digest, bool inbound)
1252 {
1253 	tcp_stack_t *tcps = tcp->tcp_tcps;
1254 	conn_t *connp = tcp->tcp_connp;
1255 	tcpsig_sa_t *sa;
1256 	MD5_CTX context;
1257 
1258 	/*
1259 	 * The TCP_MD5SIG option is 20 bytes, including padding, which adds 5
1260 	 * 32-bit words to the header's 4-bit field. Check that it can fit in
1261 	 * the current packet.
1262 	 */
1263 	if (!inbound && (tcpha->tha_offset_and_reserved >> 4) > 10) {
1264 		TCP_STAT(tcps, tcp_sig_no_space);
1265 		return (false);
1266 	}
1267 
1268 	sa = inbound ? tcp->tcp_sig_sa_in : tcp->tcp_sig_sa_out;
1269 	if (sa == NULL) {
1270 		if (!tcpsig_sa_exists(tcp, inbound, &sa)) {
1271 			TCP_STAT(tcps, tcp_sig_match_failed);
1272 			return (false);
1273 		}
1274 
1275 		/*
1276 		 * tcpsig_sa_exists() returns a held SA, so we don't need to
1277 		 * take another hold before adding it to tcp.
1278 		 */
1279 		if (inbound)
1280 			tcp->tcp_sig_sa_in = sa;
1281 		else
1282 			tcp->tcp_sig_sa_out = sa;
1283 	}
1284 
1285 	tcpsig_sa_touch(sa);
1286 
1287 	VERIFY3U(sa->ts_key.sak_algid, ==, SADB_AALG_MD5);
1288 
1289 	/* We have a key for this connection, generate the hash */
1290 	MD5Init(&context);
1291 
1292 	/* TCP pseudo-header */
1293 	if (connp->conn_ipversion == IPV6_VERSION)
1294 		tcpsig_pseudo_compute6(tcp, tcplen, &context, inbound);
1295 	else
1296 		tcpsig_pseudo_compute4(tcp, tcplen, &context, inbound);
1297 
1298 	/* TCP header, excluding options and with a zero checksum */
1299 	uint16_t offset = tcpha->tha_offset_and_reserved;
1300 	uint16_t sum = tcpha->tha_sum;
1301 
1302 	if (!inbound) {
1303 		/* Account for the MD5 option we are going to add */
1304 		tcpha->tha_offset_and_reserved += (5 << 4);
1305 	}
1306 	tcpha->tha_sum = 0;
1307 	MD5Update(&context, tcpha, sizeof (*tcpha));
1308 	tcpha->tha_offset_and_reserved = offset;
1309 	tcpha->tha_sum = sum;
1310 
1311 	/* TCP segment data */
1312 	for (; mp != NULL; mp = mp->b_cont)
1313 		MD5Update(&context, mp->b_rptr, mp->b_wptr - mp->b_rptr);
1314 
1315 	/* Connection-specific key */
1316 	MD5Update(&context, sa->ts_key.sak_key, sa->ts_key.sak_keylen);
1317 
1318 	MD5Final(digest, &context);
1319 
1320 	return (true);
1321 }
1322 
1323 bool
tcpsig_verify(mblk_t * mp,tcp_t * tcp,tcpha_t * tcpha,ip_recv_attr_t * ira,uint8_t * digest)1324 tcpsig_verify(mblk_t *mp, tcp_t *tcp, tcpha_t *tcpha, ip_recv_attr_t *ira,
1325     uint8_t *digest)
1326 {
1327 	uint8_t calc_digest[MD5_DIGEST_LENGTH];
1328 
1329 	if (!tcpsig_signature(mp, tcp, tcpha,
1330 	    ira->ira_pktlen - ira->ira_ip_hdr_length, calc_digest, true)) {
1331 		/* The appropriate stat will already have been bumped */
1332 		return (false);
1333 	}
1334 
1335 	if (bcmp(digest, calc_digest, sizeof (calc_digest)) != 0) {
1336 		TCP_STAT(tcp->tcp_tcps, tcp_sig_verify_failed);
1337 		return (false);
1338 	}
1339 
1340 	return (true);
1341 }
1342