xref: /illumos-gate/usr/src/uts/common/io/ppp/spppcomp/vjcompress.c (revision 8119dad84d6416f13557b0ba8e2aaf9064cbcfd3)
1 /*
2  * Copyright (c) 2000 by Sun Microsystems, Inc.
3  * All rights reserved.
4  *
5  * Routines to compress and uncompess tcp packets (for transmission
6  * over low speed serial lines.
7  *
8  * Copyright (c) 1989 Regents of the University of California.
9  * All rights reserved.
10  *
11  * Redistribution and use in source and binary forms are permitted
12  * provided that the above copyright notice and this paragraph are
13  * duplicated in all such forms and that any documentation,
14  * advertising materials, and other materials related to such
15  * distribution and use acknowledge that the software was developed
16  * by the University of California, Berkeley.  The name of the
17  * University may not be used to endorse or promote products derived
18  * from this software without specific prior written permission.
19  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED
21  * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
22  *
23  *	Van Jacobson (van@helios.ee.lbl.gov), Dec 31, 1989:
24  *	- Initial distribution.
25  *
26  * Modified June 1993 by Paul Mackerras, paulus@cs.anu.edu.au,
27  * so that the entire packet being decompressed doesn't have
28  * to be in contiguous memory (just the compressed header).
29  */
30 
31 /*
32  * This version is used under STREAMS in Solaris 2
33  *
34  * $Id: vjcompress.c,v 1.10 1999/09/15 23:49:06 masputra Exp $
35  */
36 
37 #include <sys/types.h>
38 #include <sys/param.h>
39 #include <sys/byteorder.h>	/* for ntohl, etc. */
40 #include <sys/systm.h>
41 #include <sys/sysmacros.h>
42 
43 #include <netinet/in.h>
44 #include <netinet/in_systm.h>
45 #include <netinet/ip.h>
46 #include <netinet/tcp.h>
47 
48 #include <net/ppp_defs.h>
49 #include <net/vjcompress.h>
50 
51 #ifndef VJ_NO_STATS
52 #define	INCR(counter) ++comp->stats.counter
53 #else
54 #define	INCR(counter)
55 #endif
56 
57 #define	BCMP(p1, p2, n) bcmp((char *)(p1), (char *)(p2), (unsigned int)(n))
58 
59 #undef  BCOPY
60 #define	BCOPY(p1, p2, n) bcopy((char *)(p1), (char *)(p2), (unsigned int)(n))
61 
62 /*
63  * I'd like to use offsetof(struct ip,ip_hl) and offsetof(struct
64  * tcp,th_off), but these are bitfields.
65  */
66 #define	getip_hl(bp)	(((uchar_t *)bp)[0] & 0x0F)
67 #define	getth_off(bp)	(((uchar_t *)bp)[12] >> 4)
68 #define	getip_p(bp)	(((uchar_t *)bp)[offsetof(struct ip, ip_p)])
69 #define	setip_p(bp, v)	(((uchar_t *)bp)[offsetof(struct ip, ip_p)] = (v))
70 
71 /*
72  * vj_compress_init()
73  */
74 void
75 vj_compress_init(struct vjcompress *comp, int max_state)
76 {
77 	register uint_t		i;
78 	register struct cstate	*tstate = comp->tstate;
79 
80 	if (max_state == -1) {
81 		max_state = MAX_STATES - 1;
82 	}
83 
84 	bzero((char *)comp, sizeof (*comp));
85 
86 	for (i = max_state; i > 0; --i) {
87 		tstate[i].cs_id = i & 0xff;
88 		tstate[i].cs_next = &tstate[i - 1];
89 	}
90 
91 	tstate[0].cs_next = &tstate[max_state];
92 	tstate[0].cs_id = 0;
93 
94 	comp->last_cs = &tstate[0];
95 	comp->last_recv = 255;
96 	comp->last_xmit = 255;
97 	comp->flags = VJF_TOSS;
98 }
99 
100 /*
101  * ENCODE encodes a number that is known to be non-zero.  ENCODEZ
102  * checks for zero (since zero has to be encoded in the long, 3 byte
103  * form).
104  */
105 #define	ENCODE(n) {						\
106 	if ((ushort_t)(n) >= 256) {				\
107 		*cp++ = 0;					\
108 		cp[1] = (n) & 0xff;				\
109 		cp[0] = ((n) >> 8) & 0xff;			\
110 		cp += 2;					\
111 	} else {						\
112 		*cp++ = (n) & 0xff;				\
113 	}							\
114 }
115 #define	ENCODEZ(n) {						\
116 	if ((ushort_t)(n) >= 256 || (ushort_t)(n) == 0) {	\
117 		*cp++ = 0;					\
118 		cp[1] = (n) & 0xff;				\
119 		cp[0] = ((n) >> 8) & 0xff;			\
120 		cp += 2;					\
121 	} else {						\
122 		*cp++ = (n) & 0xff;				\
123 	}							\
124 }
125 
126 #define	DECODEL(f) {							\
127 	if (*cp == 0) {							\
128 		uint32_t tmp = ntohl(f) + ((cp[1] << 8) | cp[2]);	\
129 		(f) = htonl(tmp);					\
130 		cp += 3;						\
131 	} else {							\
132 		uint32_t tmp = ntohl(f) + (uint32_t)*cp++;		\
133 		(f) = htonl(tmp);					\
134 	}								\
135 }
136 
137 #define	DECODES(f) {							\
138 	if (*cp == 0) {							\
139 		ushort_t tmp = ntohs(f) + ((cp[1] << 8) | cp[2]);	\
140 		(f) = htons(tmp);					\
141 		cp += 3;						\
142 	} else {							\
143 		ushort_t tmp = ntohs(f) + (uint32_t)*cp++;		\
144 		(f) = htons(tmp);					\
145 	}								\
146 }
147 
148 #define	DECODEU(f) {							\
149 	if (*cp == 0) {							\
150 		(f) = htons((cp[1] << 8) | cp[2]);			\
151 		cp += 3;						\
152 	} else {							\
153 		(f) = htons((uint32_t)*cp++);				\
154 	}								\
155 }
156 
157 uint_t
158 vj_compress_tcp(register struct ip *ip, uint_t mlen, struct vjcompress *comp,
159 	int compress_cid, uchar_t **vjhdrp)
160 {
161 	register struct cstate	*cs = comp->last_cs->cs_next;
162 	register uint_t		hlen = getip_hl(ip);
163 	register struct tcphdr	*oth;
164 	register struct tcphdr	*th;
165 	register uint_t		deltaS;
166 	register uint_t		deltaA;
167 	register uint_t		changes = 0;
168 	uchar_t			new_seq[16];
169 	register uchar_t	*cp = new_seq;
170 	register uint_t		thlen;
171 
172 	/*
173 	 * Bail if this is an IP fragment or if the TCP packet isn't
174 	 * `compressible' (i.e., ACK isn't set or some other control bit is
175 	 * set).  (We assume that the caller has already made sure the
176 	 * packet is IP proto TCP)
177 	 */
178 	if ((ip->ip_off & htons(0x3fff)) || mlen < 40) {
179 		return (TYPE_IP);
180 	}
181 
182 	th = (struct tcphdr *)&((int *)ip)[hlen];
183 
184 	if ((th->th_flags & (TH_SYN|TH_FIN|TH_RST|TH_ACK)) != TH_ACK) {
185 		return (TYPE_IP);
186 	}
187 
188 	thlen = (hlen + getth_off(th)) << 2;
189 	if (thlen > mlen) {
190 		return (TYPE_IP);
191 	}
192 
193 	/*
194 	 * Packet is compressible -- we're going to send either a
195 	 * COMPRESSED_TCP or UNCOMPRESSED_TCP packet.  Either way we need
196 	 * to locate (or create) the connection state.  Special case the
197 	 * most recently used connection since it's most likely to be used
198 	 * again & we don't have to do any reordering if it's used.
199 	 */
200 	INCR(vjs_packets);
201 
202 	if (ip->ip_src.s_addr != cs->cs_ip.ip_src.s_addr ||
203 		ip->ip_dst.s_addr != cs->cs_ip.ip_dst.s_addr ||
204 		*(int *)th != ((int *)&cs->cs_ip)[getip_hl(&cs->cs_ip)]) {
205 
206 		/*
207 		 * Wasn't the first -- search for it.
208 		 *
209 		 * States are kept in a circularly linked list with
210 		 * last_cs pointing to the end of the list.  The
211 		 * list is kept in lru order by moving a state to the
212 		 * head of the list whenever it is referenced.  Since
213 		 * the list is short and, empirically, the connection
214 		 * we want is almost always near the front, we locate
215 		 * states via linear search.  If we don't find a state
216 		 * for the datagram, the oldest state is (re-)used.
217 		 */
218 		register struct cstate	*lcs;
219 		register struct cstate	*lastcs = comp->last_cs;
220 
221 		do {
222 			lcs = cs; cs = cs->cs_next;
223 
224 			INCR(vjs_searches);
225 
226 			if (ip->ip_src.s_addr == cs->cs_ip.ip_src.s_addr &&
227 				ip->ip_dst.s_addr == cs->cs_ip.ip_dst.s_addr &&
228 				*(int *)th == ((int *)
229 					&cs->cs_ip)[getip_hl(&cs->cs_ip)]) {
230 
231 				goto found;
232 			}
233 
234 		} while (cs != lastcs);
235 
236 		/*
237 		 * Didn't find it -- re-use oldest cstate.  Send an
238 		 * uncompressed packet that tells the other side what
239 		 * connection number we're using for this conversation.
240 		 * Note that since the state list is circular, the oldest
241 		 * state points to the newest and we only need to set
242 		 * last_cs to update the lru linkage.
243 		 */
244 		INCR(vjs_misses);
245 
246 		comp->last_cs = lcs;
247 
248 		goto uncompressed;
249 
250 found:
251 		/*
252 		 * Found it -- move to the front on the connection list.
253 		 */
254 		if (cs == lastcs) {
255 			comp->last_cs = lcs;
256 		} else {
257 			lcs->cs_next = cs->cs_next;
258 			cs->cs_next = lastcs->cs_next;
259 			lastcs->cs_next = cs;
260 		}
261 	}
262 
263 	/*
264 	 * Make sure that only what we expect to change changed. The first
265 	 * line of the `if' checks the IP protocol version, header length &
266 	 * type of service.  The 2nd line checks the "Don't fragment" bit.
267 	 * The 3rd line checks the time-to-live and protocol (the protocol
268 	 * check is unnecessary but costless).  The 4th line checks the TCP
269 	 * header length.  The 5th line checks IP options, if any.  The 6th
270 	 * line checks TCP options, if any.  If any of these things are
271 	 * different between the previous & current datagram, we send the
272 	 * current datagram `uncompressed'.
273 	 */
274 	oth = (struct tcphdr *)&((int *)&cs->cs_ip)[hlen];
275 
276 	/* Used to check for IP options. */
277 	deltaS = hlen;
278 
279 	if (((ushort_t *)ip)[0] != ((ushort_t *)&cs->cs_ip)[0] ||
280 		((ushort_t *)ip)[3] != ((ushort_t *)&cs->cs_ip)[3] ||
281 		((ushort_t *)ip)[4] != ((ushort_t *)&cs->cs_ip)[4] ||
282 		getth_off(th) != getth_off(oth) ||
283 		(deltaS > 5 &&
284 			BCMP(ip + 1, &cs->cs_ip + 1, (deltaS - 5) << 2)) ||
285 		(getth_off(th) > 5 &&
286 			BCMP(th + 1, oth + 1, (getth_off(th) - 5) << 2))) {
287 
288 		goto uncompressed;
289 	}
290 
291 	/*
292 	 * Figure out which of the changing fields changed.  The
293 	 * receiver expects changes in the order: urgent, window,
294 	 * ack, seq (the order minimizes the number of temporaries
295 	 * needed in this section of code).
296 	 */
297 	if (th->th_flags & TH_URG) {
298 
299 		deltaS = ntohs(th->th_urp);
300 
301 		ENCODEZ(deltaS);
302 
303 		changes |= NEW_U;
304 
305 	} else if (th->th_urp != oth->th_urp) {
306 
307 		/*
308 		 * argh! URG not set but urp changed -- a sensible
309 		 * implementation should never do this but RFC793
310 		 * doesn't prohibit the change so we have to deal
311 		 * with it
312 		 */
313 		goto uncompressed;
314 	}
315 
316 	if ((deltaS = (ushort_t)(ntohs(th->th_win) - ntohs(oth->th_win))) > 0) {
317 		ENCODE(deltaS);
318 
319 		changes |= NEW_W;
320 	}
321 
322 	if ((deltaA = ntohl(th->th_ack) - ntohl(oth->th_ack)) > 0) {
323 		if (deltaA > 0xffff) {
324 			goto uncompressed;
325 		}
326 
327 		ENCODE(deltaA);
328 
329 		changes |= NEW_A;
330 	}
331 
332 	if ((deltaS = ntohl(th->th_seq) - ntohl(oth->th_seq)) > 0) {
333 		if (deltaS > 0xffff) {
334 			goto uncompressed;
335 		}
336 
337 		ENCODE(deltaS);
338 
339 		changes |= NEW_S;
340 	}
341 
342 	switch (changes) {
343 
344 	case 0:
345 		/*
346 		 * Nothing changed. If this packet contains data and the
347 		 * last one didn't, this is probably a data packet following
348 		 * an ack (normal on an interactive connection) and we send
349 		 * it compressed.  Otherwise it's probably a retransmit,
350 		 * retransmitted ack or window probe.  Send it uncompressed
351 		 * in case the other side missed the compressed version.
352 		 */
353 		if (ip->ip_len != cs->cs_ip.ip_len &&
354 					ntohs(cs->cs_ip.ip_len) == thlen) {
355 			break;
356 		}
357 
358 		/* (otherwise fall through) */
359 		/* FALLTHRU */
360 
361 	case SPECIAL_I:
362 	case SPECIAL_D:
363 
364 		/*
365 		 * actual changes match one of our special case encodings --
366 		 * send packet uncompressed.
367 		 */
368 		goto uncompressed;
369 
370 	case NEW_S|NEW_A:
371 
372 		if (deltaS == deltaA &&
373 				deltaS == ntohs(cs->cs_ip.ip_len) - thlen) {
374 
375 			/*
376 			 * special case for echoed terminal traffic
377 			 */
378 			changes = SPECIAL_I;
379 			cp = new_seq;
380 		}
381 
382 		break;
383 
384 	case NEW_S:
385 
386 		if (deltaS == ntohs(cs->cs_ip.ip_len) - thlen) {
387 
388 			/*
389 			 * special case for data xfer
390 			 */
391 			changes = SPECIAL_D;
392 			cp = new_seq;
393 		}
394 
395 		break;
396 	}
397 
398 	deltaS = ntohs(ip->ip_id) - ntohs(cs->cs_ip.ip_id);
399 	if (deltaS != 1) {
400 		ENCODEZ(deltaS);
401 
402 		changes |= NEW_I;
403 	}
404 
405 	if (th->th_flags & TH_PUSH) {
406 		changes |= TCP_PUSH_BIT;
407 	}
408 
409 	/*
410 	 * Grab the cksum before we overwrite it below.  Then update our
411 	 * state with this packet's header.
412 	 */
413 	deltaA = ntohs(th->th_sum);
414 
415 	BCOPY(ip, &cs->cs_ip, thlen);
416 
417 	/*
418 	 * We want to use the original packet as our compressed packet.
419 	 * (cp - new_seq) is the number of bytes we need for compressed
420 	 * sequence numbers.  In addition we need one byte for the change
421 	 * mask, one for the connection id and two for the tcp checksum.
422 	 * So, (cp - new_seq) + 4 bytes of header are needed.  thlen is how
423 	 * many bytes of the original packet to toss so subtract the two to
424 	 * get the new packet size.
425 	 */
426 	deltaS = cp - new_seq;
427 
428 	cp = (uchar_t *)ip;
429 
430 	if (compress_cid == 0 || comp->last_xmit != cs->cs_id) {
431 		comp->last_xmit = cs->cs_id;
432 
433 		thlen -= deltaS + 4;
434 
435 		*vjhdrp = (cp += thlen);
436 
437 		*cp++ = changes | NEW_C;
438 		*cp++ = cs->cs_id;
439 	} else {
440 		thlen -= deltaS + 3;
441 
442 		*vjhdrp = (cp += thlen);
443 
444 		*cp++ = changes & 0xff;
445 	}
446 
447 	*cp++ = (deltaA >> 8) & 0xff;
448 	*cp++ = deltaA & 0xff;
449 
450 	BCOPY(new_seq, cp, deltaS);
451 
452 	INCR(vjs_compressed);
453 
454 	return (TYPE_COMPRESSED_TCP);
455 
456 	/*
457 	 * Update connection state cs & send uncompressed packet (that is,
458 	 * a regular ip/tcp packet but with the 'conversation id' we hope
459 	 * to use on future compressed packets in the protocol field).
460 	 */
461 uncompressed:
462 
463 	BCOPY(ip, &cs->cs_ip, thlen);
464 
465 	ip->ip_p = cs->cs_id;
466 	comp->last_xmit = cs->cs_id;
467 
468 	return (TYPE_UNCOMPRESSED_TCP);
469 }
470 
471 /*
472  * vj_uncompress_err()
473  *
474  * Called when we may have missed a packet.
475  */
476 void
477 vj_uncompress_err(struct vjcompress *comp)
478 {
479 	comp->flags |= VJF_TOSS;
480 
481 	INCR(vjs_errorin);
482 }
483 
484 /*
485  * vj_uncompress_uncomp()
486  *
487  * "Uncompress" a packet of type TYPE_UNCOMPRESSED_TCP.
488  */
489 int
490 vj_uncompress_uncomp(uchar_t *buf, int buflen, struct vjcompress *comp)
491 {
492 	register uint_t		hlen;
493 	register struct cstate	*cs;
494 
495 	hlen = getip_hl(buf) << 2;
496 
497 	if (getip_p(buf) >= MAX_STATES ||
498 	    hlen + sizeof (struct tcphdr) > buflen ||
499 	    (hlen += getth_off(buf+hlen) << 2) > buflen || hlen > MAX_HDR) {
500 
501 		comp->flags |= VJF_TOSS;
502 
503 		INCR(vjs_errorin);
504 
505 		return (0);
506 	}
507 
508 	cs = &comp->rstate[comp->last_recv = getip_p(buf)];
509 	comp->flags &= ~VJF_TOSS;
510 	setip_p(buf, IPPROTO_TCP);
511 
512 	BCOPY(buf, &cs->cs_ip, hlen);
513 
514 	cs->cs_hlen = hlen & 0xff;
515 
516 	INCR(vjs_uncompressedin);
517 
518 	return (1);
519 }
520 
521 /*
522  * vj_uncompress_tcp()
523  *
524  * Uncompress a packet of type TYPE_COMPRESSED_TCP.
525  * The packet starts at buf and is of total length total_len.
526  * The first buflen bytes are at buf; this must include the entire
527  * compressed TCP/IP header.  This procedure returns the length
528  * of the VJ header, with a pointer to the uncompressed IP header
529  * in *hdrp and its length in *hlenp.
530  */
531 int
532 vj_uncompress_tcp(uchar_t *buf, int buflen, int total_len,
533 	struct vjcompress *comp, uchar_t **hdrp, uint_t *hlenp)
534 {
535 	register uchar_t	*cp;
536 	register uint_t		hlen;
537 	register uint_t		changes;
538 	register struct tcphdr	*th;
539 	register struct cstate	*cs;
540 	register ushort_t	*bp;
541 	register uint_t		vjlen;
542 	register uint32_t	tmp;
543 
544 	INCR(vjs_compressedin);
545 
546 	cp = buf;
547 	changes = *cp++;
548 
549 	if (changes & NEW_C) {
550 		/*
551 		 * Make sure the state index is in range, then grab the state.
552 		 * If we have a good state index, clear the 'discard' flag.
553 		 */
554 		if (*cp >= MAX_STATES) {
555 			goto bad;
556 		}
557 
558 		comp->flags &= ~VJF_TOSS;
559 		comp->last_recv = *cp++;
560 	} else {
561 		/*
562 		 * this packet has an implicit state index.  If we've
563 		 * had a line error since the last time we got an
564 		 * explicit state index, we have to toss the packet
565 		 */
566 		if (comp->flags & VJF_TOSS) {
567 			INCR(vjs_tossed);
568 			return (-1);
569 		}
570 	}
571 
572 	cs = &comp->rstate[comp->last_recv];
573 	hlen = getip_hl(&cs->cs_ip) << 2;
574 
575 	th = (struct tcphdr *)((uint32_t *)&cs->cs_ip+hlen/sizeof (uint32_t));
576 	th->th_sum = htons((*cp << 8) | cp[1]);
577 
578 	cp += 2;
579 
580 	if (changes & TCP_PUSH_BIT) {
581 		th->th_flags |= TH_PUSH;
582 	} else {
583 		th->th_flags &= ~TH_PUSH;
584 	}
585 
586 	switch (changes & SPECIALS_MASK) {
587 
588 	case SPECIAL_I:
589 
590 		{
591 
592 		register uint32_t	i;
593 
594 		i = ntohs(cs->cs_ip.ip_len) - cs->cs_hlen;
595 
596 		tmp = ntohl(th->th_ack) + i;
597 		th->th_ack = htonl(tmp);
598 
599 		tmp = ntohl(th->th_seq) + i;
600 		th->th_seq = htonl(tmp);
601 
602 		}
603 
604 		break;
605 
606 	case SPECIAL_D:
607 
608 		tmp = ntohl(th->th_seq) + ntohs(cs->cs_ip.ip_len) - cs->cs_hlen;
609 		th->th_seq = htonl(tmp);
610 
611 		break;
612 
613 	default:
614 
615 		if (changes & NEW_U) {
616 			th->th_flags |= TH_URG;
617 			DECODEU(th->th_urp);
618 		} else {
619 			th->th_flags &= ~TH_URG;
620 		}
621 
622 		if (changes & NEW_W) {
623 			DECODES(th->th_win);
624 		}
625 
626 		if (changes & NEW_A) {
627 			DECODEL(th->th_ack);
628 		}
629 
630 		if (changes & NEW_S) {
631 			DECODEL(th->th_seq);
632 		}
633 
634 		break;
635 	}
636 
637 	if (changes & NEW_I) {
638 		DECODES(cs->cs_ip.ip_id);
639 	} else {
640 		cs->cs_ip.ip_id = ntohs(cs->cs_ip.ip_id) + 1;
641 		cs->cs_ip.ip_id = htons(cs->cs_ip.ip_id);
642 	}
643 
644 	/*
645 	 * At this point, cp points to the first byte of data in the
646 	 * packet.  Fill in the IP total length and update the IP
647 	 * header checksum.
648 	 */
649 	vjlen = cp - buf;
650 	buflen -= vjlen;
651 	if (buflen < 0) {
652 		/*
653 		 * we must have dropped some characters (crc should detect
654 		 * this but the old slip framing won't)
655 		 */
656 		goto bad;
657 	}
658 
659 	total_len += cs->cs_hlen - vjlen;
660 	cs->cs_ip.ip_len = htons(total_len);
661 
662 	/*
663 	 * recompute the ip header checksum
664 	 */
665 	bp = (ushort_t *)&cs->cs_ip;
666 	cs->cs_ip.ip_sum = 0;
667 
668 	for (changes = 0; hlen > 0; hlen -= 2) {
669 		changes += *bp++;
670 	}
671 
672 	changes = (changes & 0xffff) + (changes >> 16);
673 	changes = (changes & 0xffff) + (changes >> 16);
674 	cs->cs_ip.ip_sum = ~ changes;
675 
676 	*hdrp = (uchar_t *)&cs->cs_ip;
677 	*hlenp = cs->cs_hlen;
678 
679 	return (vjlen);
680 
681 bad:
682 
683 	comp->flags |= VJF_TOSS;
684 
685 	INCR(vjs_errorin);
686 
687 	return (-1);
688 }
689