xref: /freebsd/sys/netinet/libalias/alias_proxy.c (revision 77a0943ded95b9e6438f7db70c4a28e4d93946d4)
1 /* file: alias_proxy.c
2 
3     This file encapsulates special operations related to transparent
4     proxy redirection.  This is where packets with a particular destination,
5     usually tcp port 80, are redirected to a proxy server.
6 
7     When packets are proxied, the destination address and port are
8     modified.  In certain cases, it is necessary to somehow encode
9     the original address/port info into the packet.  Two methods are
10     presently supported: addition of a [DEST addr port] string at the
11     beginning a of tcp stream, or inclusion of an optional field
12     in the IP header.
13 
14     There is one public API function:
15 
16         PacketAliasProxyRule()    -- Adds and deletes proxy
17                                      rules.
18 
19     Rules are stored in a linear linked list, so lookup efficiency
20     won't be too good for large lists.
21 
22 
23     Initial development: April, 1998 (cjm)
24 
25     $FreeBSD$
26 */
27 
28 
29 /* System includes */
30 #include <ctype.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <netdb.h>
35 
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 
39 /* BSD IPV4 includes */
40 #include <netinet/in_systm.h>
41 #include <netinet/in.h>
42 #include <netinet/ip.h>
43 #include <netinet/tcp.h>
44 
45 #include <arpa/inet.h>
46 
47 #include "alias_local.h"  /* Functions used by alias*.c */
48 #include "alias.h"        /* Public API functions for libalias */
49 
50 
51 
52 /*
53     Data structures
54  */
55 
56 /*
57  * A linked list of arbitrary length, based on struct proxy_entry is
58  * used to store proxy rules.
59  */
60 struct proxy_entry
61 {
62 #define PROXY_TYPE_ENCODE_NONE      1
63 #define PROXY_TYPE_ENCODE_TCPSTREAM 2
64 #define PROXY_TYPE_ENCODE_IPHDR     3
65     int rule_index;
66     int proxy_type;
67     u_char proto;
68     u_short proxy_port;
69     u_short server_port;
70 
71     struct in_addr server_addr;
72 
73     struct in_addr src_addr;
74     struct in_addr src_mask;
75 
76     struct in_addr dst_addr;
77     struct in_addr dst_mask;
78 
79     struct proxy_entry *next;
80     struct proxy_entry *last;
81 };
82 
83 
84 
85 /*
86     File scope variables
87 */
88 
89 static struct proxy_entry *proxyList;
90 
91 
92 
93 /* Local (static) functions:
94 
95     IpMask()                 -- Utility function for creating IP
96                                 masks from integer (1-32) specification.
97     IpAddr()                 -- Utility function for converting string
98                                 to IP address
99     IpPort()                 -- Utility function for converting string
100                                 to port number
101     RuleAdd()                -- Adds an element to the rule list.
102     RuleDelete()             -- Removes an element from the rule list.
103     RuleNumberDelete()       -- Removes all elements from the rule list
104                                 having a certain rule number.
105     ProxyEncodeTcpStream()   -- Adds [DEST x.x.x.x xxxx] to the beginning
106                                 of a TCP stream.
107     ProxyEncodeIpHeader()    -- Adds an IP option indicating the true
108                                 destination of a proxied IP packet
109 */
110 
111 static int IpMask(int, struct in_addr *);
112 static int IpAddr(char *, struct in_addr *);
113 static int IpPort(char *, int, int *);
114 static void RuleAdd(struct proxy_entry *);
115 static void RuleDelete(struct proxy_entry *);
116 static int RuleNumberDelete(int);
117 static void ProxyEncodeTcpStream(struct alias_link *, struct ip *, int);
118 static void ProxyEncodeIpHeader(struct ip *, int);
119 
120 static int
121 IpMask(int nbits, struct in_addr *mask)
122 {
123     int i;
124     u_int imask;
125 
126     if (nbits < 0 || nbits > 32)
127         return -1;
128 
129     imask = 0;
130     for (i=0; i<nbits; i++)
131         imask = (imask >> 1) + 0x80000000;
132     mask->s_addr = htonl(imask);
133 
134     return 0;
135 }
136 
137 static int
138 IpAddr(char *s, struct in_addr *addr)
139 {
140     if (inet_aton(s, addr) == 0)
141         return -1;
142     else
143         return 0;
144 }
145 
146 static int
147 IpPort(char *s, int proto, int *port)
148 {
149     int n;
150 
151     n = sscanf(s, "%d", port);
152     if (n != 1)
153     {
154         struct servent *se;
155 
156         if (proto == IPPROTO_TCP)
157             se = getservbyname(s, "tcp");
158         else if (proto == IPPROTO_UDP)
159             se = getservbyname(s, "udp");
160         else
161             return -1;
162 
163         if (se == NULL)
164                 return -1;
165 
166         *port = (u_int) ntohs(se->s_port);
167     }
168 
169     return 0;
170 }
171 
172 void
173 RuleAdd(struct proxy_entry *entry)
174 {
175     int rule_index;
176     struct proxy_entry *ptr;
177     struct proxy_entry *ptr_last;
178 
179     if (proxyList == NULL)
180     {
181         proxyList = entry;
182         entry->last = NULL;
183         entry->next = NULL;
184         return;
185     }
186 
187     rule_index = entry->rule_index;
188     ptr = proxyList;
189     ptr_last = NULL;
190     while (ptr != NULL)
191     {
192         if (ptr->rule_index >= rule_index)
193         {
194             if (ptr_last == NULL)
195             {
196                 entry->next = proxyList;
197                 entry->last = NULL;
198                 proxyList->last = entry;
199                 proxyList = entry;
200                 return;
201             }
202 
203             ptr_last->next = entry;
204             ptr->last = entry;
205             entry->last = ptr->last;
206             entry->next = ptr;
207             return;
208         }
209         ptr_last = ptr;
210         ptr = ptr->next;
211     }
212 
213     ptr_last->next = entry;
214     entry->last = ptr_last;
215     entry->next = NULL;
216 }
217 
218 static void
219 RuleDelete(struct proxy_entry *entry)
220 {
221     if (entry->last != NULL)
222         entry->last->next = entry->next;
223     else
224         proxyList = entry->next;
225 
226     if (entry->next != NULL)
227         entry->next->last = entry->last;
228 
229     free(entry);
230 }
231 
232 static int
233 RuleNumberDelete(int rule_index)
234 {
235     int err;
236     struct proxy_entry *ptr;
237 
238     err = -1;
239     ptr = proxyList;
240     while (ptr != NULL)
241     {
242         struct proxy_entry *ptr_next;
243 
244         ptr_next = ptr->next;
245         if (ptr->rule_index == rule_index)
246         {
247             err = 0;
248             RuleDelete(ptr);
249         }
250 
251         ptr = ptr_next;
252     }
253 
254     return err;
255 }
256 
257 static void
258 ProxyEncodeTcpStream(struct alias_link *link,
259                      struct ip *pip,
260                      int maxpacketsize)
261 {
262     int slen;
263     char buffer[40];
264     struct tcphdr *tc;
265 
266 /* Compute pointer to tcp header */
267     tc = (struct tcphdr *) ((char *) pip + (pip->ip_hl << 2));
268 
269 /* Don't modify if once already modified */
270 
271     if (GetAckModified (link))
272 	return;
273 
274 /* Translate destination address and port to string form */
275     snprintf(buffer, sizeof(buffer) - 2, "[DEST %s %d]",
276         inet_ntoa(GetProxyAddress (link)), (u_int) ntohs(GetProxyPort (link)));
277 
278 /* Pad string out to a multiple of two in length */
279     slen = strlen(buffer);
280     switch (slen % 2)
281     {
282     case 0:
283         strcat(buffer, " \n");
284 	slen += 2;
285         break;
286     case 1:
287         strcat(buffer, "\n");
288 	slen += 1;
289     }
290 
291 /* Check for packet overflow */
292     if ((ntohs(pip->ip_len) + strlen(buffer)) > maxpacketsize)
293         return;
294 
295 /* Shift existing TCP data and insert destination string */
296     {
297         int dlen;
298         int hlen;
299         u_char *p;
300 
301         hlen = (pip->ip_hl + tc->th_off) << 2;
302         dlen = ntohs (pip->ip_len) - hlen;
303 
304 /* Modify first packet that has data in it */
305 
306 	if (dlen == 0)
307 		return;
308 
309         p = (char *) pip;
310         p += hlen;
311 
312         memmove(p + slen, p, dlen);
313         memcpy(p, buffer, slen);
314     }
315 
316 /* Save information about modfied sequence number */
317     {
318         int delta;
319 
320         SetAckModified(link);
321         delta = GetDeltaSeqOut(pip, link);
322         AddSeq(pip, link, delta+slen);
323     }
324 
325 /* Update IP header packet length and checksum */
326     {
327         int accumulate;
328 
329         accumulate  = pip->ip_len;
330         pip->ip_len = htons(ntohs(pip->ip_len) + slen);
331         accumulate -= pip->ip_len;
332 
333         ADJUST_CHECKSUM(accumulate, pip->ip_sum);
334     }
335 
336 /* Update TCP checksum, Use TcpChecksum since so many things have
337    already changed. */
338 
339     tc->th_sum = 0;
340     tc->th_sum = TcpChecksum (pip);
341 }
342 
343 static void
344 ProxyEncodeIpHeader(struct ip *pip,
345                     int maxpacketsize)
346 {
347 #define OPTION_LEN_BYTES  8
348 #define OPTION_LEN_INT16  4
349 #define OPTION_LEN_INT32  2
350     u_char option[OPTION_LEN_BYTES];
351 
352 #ifdef DEBUG
353     fprintf(stdout, " ip cksum 1 = %x\n", (u_int) IpChecksum(pip));
354     fprintf(stdout, "tcp cksum 1 = %x\n", (u_int) TcpChecksum(pip));
355 #endif
356 
357 /* Check to see that there is room to add an IP option */
358     if (pip->ip_hl > (0x0f - OPTION_LEN_INT32))
359         return;
360 
361 /* Build option and copy into packet */
362     {
363         u_char *ptr;
364         struct tcphdr *tc;
365 
366         ptr = (u_char *) pip;
367         ptr += 20;
368         memcpy(ptr + OPTION_LEN_BYTES, ptr, ntohs(pip->ip_len) - 20);
369 
370         option[0] = 0x64; /* class: 3 (reserved), option 4 */
371         option[1] = OPTION_LEN_BYTES;
372 
373         memcpy(&option[2], (u_char *) &pip->ip_dst, 4);
374 
375         tc = (struct tcphdr *) ((char *) pip + (pip->ip_hl << 2));
376         memcpy(&option[6], (u_char *) &tc->th_sport, 2);
377 
378         memcpy(ptr, option, 8);
379     }
380 
381 /* Update checksum, header length and packet length */
382     {
383         int i;
384         int accumulate;
385         u_short *sptr;
386 
387         sptr = (u_short *) option;
388         accumulate = 0;
389         for (i=0; i<OPTION_LEN_INT16; i++)
390             accumulate -= *(sptr++);
391 
392         sptr = (u_short *) pip;
393         accumulate += *sptr;
394         pip->ip_hl += OPTION_LEN_INT32;
395         accumulate -= *sptr;
396 
397         accumulate += pip->ip_len;
398         pip->ip_len = htons(ntohs(pip->ip_len) + OPTION_LEN_BYTES);
399         accumulate -= pip->ip_len;
400 
401         ADJUST_CHECKSUM(accumulate, pip->ip_sum);
402     }
403 #undef OPTION_LEN_BYTES
404 #undef OPTION_LEN_INT16
405 #undef OPTION_LEN_INT32
406 #ifdef DEBUG
407     fprintf(stdout, " ip cksum 2 = %x\n", (u_int) IpChecksum(pip));
408     fprintf(stdout, "tcp cksum 2 = %x\n", (u_int) TcpChecksum(pip));
409 #endif
410 }
411 
412 
413 /* Functions by other packet alias source files
414 
415     ProxyCheck()         -- Checks whether an outgoing packet should
416                             be proxied.
417     ProxyModify()        -- Encodes the original destination address/port
418                             for a packet which is to be redirected to
419                             a proxy server.
420 */
421 
422 int
423 ProxyCheck(struct ip *pip,
424            struct in_addr *proxy_server_addr,
425            u_short *proxy_server_port)
426 {
427     u_short dst_port;
428     struct in_addr src_addr;
429     struct in_addr dst_addr;
430     struct proxy_entry *ptr;
431 
432     src_addr = pip->ip_src;
433     dst_addr = pip->ip_dst;
434     dst_port = ((struct tcphdr *) ((char *) pip + (pip->ip_hl << 2)))
435         ->th_dport;
436 
437     ptr = proxyList;
438     while (ptr != NULL)
439     {
440         u_short proxy_port;
441 
442         proxy_port = ptr->proxy_port;
443         if ((dst_port == proxy_port || proxy_port == 0)
444          && pip->ip_p == ptr->proto
445          && src_addr.s_addr != ptr->server_addr.s_addr)
446         {
447             struct in_addr src_addr_masked;
448             struct in_addr dst_addr_masked;
449 
450             src_addr_masked.s_addr = src_addr.s_addr & ptr->src_mask.s_addr;
451             dst_addr_masked.s_addr = dst_addr.s_addr & ptr->dst_mask.s_addr;
452 
453             if ((src_addr_masked.s_addr == ptr->src_addr.s_addr)
454              && (dst_addr_masked.s_addr == ptr->dst_addr.s_addr))
455             {
456                 if ((*proxy_server_port = ptr->server_port) == 0)
457                     *proxy_server_port = dst_port;
458                 *proxy_server_addr = ptr->server_addr;
459                 return ptr->proxy_type;
460             }
461         }
462         ptr = ptr->next;
463     }
464 
465     return 0;
466 }
467 
468 void
469 ProxyModify(struct alias_link *link,
470             struct ip *pip,
471             int maxpacketsize,
472             int proxy_type)
473 {
474     switch (proxy_type)
475     {
476     case PROXY_TYPE_ENCODE_IPHDR:
477         ProxyEncodeIpHeader(pip, maxpacketsize);
478         break;
479 
480     case PROXY_TYPE_ENCODE_TCPSTREAM:
481         ProxyEncodeTcpStream(link, pip, maxpacketsize);
482         break;
483     }
484 }
485 
486 
487 /*
488     Public API functions
489 */
490 
491 int
492 PacketAliasProxyRule(const char *cmd)
493 {
494 /*
495  * This function takes command strings of the form:
496  *
497  *   server <addr>[:<port>]
498  *   [port <port>]
499  *   [rule n]
500  *   [proto tcp|udp]
501  *   [src <addr>[/n]]
502  *   [dst <addr>[/n]]
503  *   [type encode_tcp_stream|encode_ip_hdr|no_encode]
504  *
505  *   delete <rule number>
506  *
507  * Subfields can be in arbitrary order.  Port numbers and addresses
508  * must be in either numeric or symbolic form. An optional rule number
509  * is used to control the order in which rules are searched.  If two
510  * rules have the same number, then search order cannot be guaranteed,
511  * and the rules should be disjoint.  If no rule number is specified,
512  * then 0 is used, and group 0 rules are always checked before any
513  * others.
514  */
515     int i, n, len;
516     int cmd_len;
517     int token_count;
518     int state;
519     char *token;
520     char buffer[256];
521     char str_port[sizeof(buffer)];
522     char str_server_port[sizeof(buffer)];
523     char *res = buffer;
524 
525     int rule_index;
526     int proto;
527     int proxy_type;
528     int proxy_port;
529     int server_port;
530     struct in_addr server_addr;
531     struct in_addr src_addr, src_mask;
532     struct in_addr dst_addr, dst_mask;
533     struct proxy_entry *proxy_entry;
534 
535 /* Copy command line into a buffer */
536     cmd_len = strlen(cmd);
537     if (cmd_len > (sizeof(buffer) - 1))
538         return -1;
539     strcpy(buffer, cmd);
540 
541 /* Convert to lower case */
542     len = strlen(buffer);
543     for (i=0; i<len; i++)
544 	buffer[i] = tolower((unsigned char)buffer[i]);
545 
546 /* Set default proxy type */
547 
548 /* Set up default values */
549     rule_index = 0;
550     proxy_type = PROXY_TYPE_ENCODE_NONE;
551     proto = IPPROTO_TCP;
552     proxy_port = 0;
553     server_addr.s_addr = 0;
554     server_port = 0;
555     src_addr.s_addr = 0;
556     IpMask(0, &src_mask);
557     dst_addr.s_addr = 0;
558     IpMask(0, &dst_mask);
559 
560     str_port[0] = 0;
561     str_server_port[0] = 0;
562 
563 /* Parse command string with state machine */
564 #define STATE_READ_KEYWORD    0
565 #define STATE_READ_TYPE       1
566 #define STATE_READ_PORT       2
567 #define STATE_READ_SERVER     3
568 #define STATE_READ_RULE       4
569 #define STATE_READ_DELETE     5
570 #define STATE_READ_PROTO      6
571 #define STATE_READ_SRC        7
572 #define STATE_READ_DST        8
573     state = STATE_READ_KEYWORD;
574     token = strsep(&res, " \t");
575     token_count = 0;
576     while (token != NULL)
577     {
578         token_count++;
579         switch (state)
580         {
581         case STATE_READ_KEYWORD:
582             if (strcmp(token, "type") == 0)
583                 state = STATE_READ_TYPE;
584             else if (strcmp(token, "port") == 0)
585                 state = STATE_READ_PORT;
586             else if (strcmp(token, "server") == 0)
587                 state = STATE_READ_SERVER;
588             else if (strcmp(token, "rule") == 0)
589                 state = STATE_READ_RULE;
590             else if (strcmp(token, "delete") == 0)
591                 state = STATE_READ_DELETE;
592             else if (strcmp(token, "proto") == 0)
593                 state = STATE_READ_PROTO;
594             else if (strcmp(token, "src") == 0)
595                 state = STATE_READ_SRC;
596             else if (strcmp(token, "dst") == 0)
597                 state = STATE_READ_DST;
598             else
599                 return -1;
600             break;
601 
602         case STATE_READ_TYPE:
603             if (strcmp(token, "encode_ip_hdr") == 0)
604                 proxy_type = PROXY_TYPE_ENCODE_IPHDR;
605             else if (strcmp(token, "encode_tcp_stream") == 0)
606                 proxy_type = PROXY_TYPE_ENCODE_TCPSTREAM;
607             else if (strcmp(token, "no_encode") == 0)
608                 proxy_type = PROXY_TYPE_ENCODE_NONE;
609             else
610                 return -1;
611             state = STATE_READ_KEYWORD;
612             break;
613 
614         case STATE_READ_PORT:
615             strcpy(str_port, token);
616             state = STATE_READ_KEYWORD;
617             break;
618 
619         case STATE_READ_SERVER:
620             {
621                 int err;
622                 char *p;
623                 char s[sizeof(buffer)];
624 
625                 p = token;
626                 while (*p != ':' && *p != 0)
627                     p++;
628 
629                 if (*p != ':')
630                 {
631                     err = IpAddr(token, &server_addr);
632                     if (err)
633                         return -1;
634                 }
635                 else
636                 {
637                     *p = ' ';
638 
639                     n = sscanf(token, "%s %s", s, str_server_port);
640                     if (n != 2)
641                         return -1;
642 
643                     err = IpAddr(s, &server_addr);
644                     if (err)
645                         return -1;
646                 }
647             }
648             state = STATE_READ_KEYWORD;
649             break;
650 
651         case STATE_READ_RULE:
652             n = sscanf(token, "%d", &rule_index);
653             if (n != 1 || rule_index < 0)
654                 return -1;
655             state = STATE_READ_KEYWORD;
656             break;
657 
658         case STATE_READ_DELETE:
659             {
660                 int err;
661                 int rule_to_delete;
662 
663                 if (token_count != 2)
664                     return -1;
665 
666                 n = sscanf(token, "%d", &rule_to_delete);
667                 if (n != 1)
668                     return -1;
669                 err = RuleNumberDelete(rule_to_delete);
670                 if (err)
671                     return -1;
672                 return 0;
673             }
674 
675         case STATE_READ_PROTO:
676             if (strcmp(token, "tcp") == 0)
677                 proto = IPPROTO_TCP;
678             else if (strcmp(token, "udp") == 0)
679                 proto = IPPROTO_UDP;
680             else
681                 return -1;
682             state = STATE_READ_KEYWORD;
683             break;
684 
685         case STATE_READ_SRC:
686         case STATE_READ_DST:
687             {
688                 int err;
689                 char *p;
690                 struct in_addr mask;
691                 struct in_addr addr;
692 
693                 p = token;
694                 while (*p != '/' && *p != 0)
695                     p++;
696 
697                 if (*p != '/')
698                 {
699                      IpMask(32, &mask);
700                      err = IpAddr(token, &addr);
701                      if (err)
702                          return -1;
703                 }
704                 else
705                 {
706                     int n;
707                     int nbits;
708                     char s[sizeof(buffer)];
709 
710                     *p = ' ';
711                     n = sscanf(token, "%s %d", s, &nbits);
712                     if (n != 2)
713                         return -1;
714 
715                     err = IpAddr(s, &addr);
716                     if (err)
717                         return -1;
718 
719                     err = IpMask(nbits, &mask);
720                     if (err)
721                         return -1;
722                 }
723 
724                 if (state == STATE_READ_SRC)
725                 {
726                     src_addr = addr;
727                     src_mask = mask;
728                 }
729                 else
730                 {
731                     dst_addr = addr;
732                     dst_mask = mask;
733                 }
734             }
735             state = STATE_READ_KEYWORD;
736             break;
737 
738         default:
739             return -1;
740             break;
741         }
742 
743 	do {
744 		token = strsep(&res, " \t");
745 	} while (token != NULL && !*token);
746     }
747 #undef STATE_READ_KEYWORD
748 #undef STATE_READ_TYPE
749 #undef STATE_READ_PORT
750 #undef STATE_READ_SERVER
751 #undef STATE_READ_RULE
752 #undef STATE_READ_DELETE
753 #undef STATE_READ_PROTO
754 #undef STATE_READ_SRC
755 #undef STATE_READ_DST
756 
757 /* Convert port strings to numbers.  This needs to be done after
758    the string is parsed, because the prototype might not be designated
759    before the ports (which might be symbolic entries in /etc/services) */
760 
761     if (strlen(str_port) != 0)
762     {
763         int err;
764 
765         err = IpPort(str_port, proto, &proxy_port);
766         if (err)
767             return -1;
768     }
769     else
770     {
771         proxy_port = 0;
772     }
773 
774     if (strlen(str_server_port) != 0)
775     {
776         int err;
777 
778         err = IpPort(str_server_port, proto, &server_port);
779         if (err)
780             return -1;
781     }
782     else
783     {
784         server_port = 0;
785     }
786 
787 /* Check that at least the server address has been defined */
788     if (server_addr.s_addr == 0)
789         return -1;
790 
791 /* Add to linked list */
792     proxy_entry = malloc(sizeof(struct proxy_entry));
793     if (proxy_entry == NULL)
794         return -1;
795 
796     proxy_entry->proxy_type = proxy_type;
797     proxy_entry->rule_index = rule_index;
798     proxy_entry->proto = proto;
799     proxy_entry->proxy_port = htons(proxy_port);
800     proxy_entry->server_port = htons(server_port);
801     proxy_entry->server_addr = server_addr;
802     proxy_entry->src_addr.s_addr = src_addr.s_addr & src_mask.s_addr;
803     proxy_entry->dst_addr.s_addr = dst_addr.s_addr & dst_mask.s_addr;
804     proxy_entry->src_mask = src_mask;
805     proxy_entry->dst_mask = dst_mask;
806 
807     RuleAdd(proxy_entry);
808 
809     return 0;
810 }
811