xref: /freebsd/sys/netinet/libalias/alias_proxy.c (revision daf1cffce2e07931f27c6c6998652e90df6ba87e)
1 /* file: alias_proxy.c
2 
3     This file encapsulates special operations related to transparent
4     proxy redirection.  This is where packets with a particular destination,
5     usually tcp port 80, are redirected to a proxy server.
6 
7     When packets are proxied, the destination address and port are
8     modified.  In certain cases, it is necessary to somehow encode
9     the original address/port info into the packet.  Two methods are
10     presently supported: addition of a [DEST addr port] string at the
11     beginning a of tcp stream, or inclusion of an optional field
12     in the IP header.
13 
14     There is one public API function:
15 
16         PacketAliasProxyRule()    -- Adds and deletes proxy
17                                      rules.
18 
19     Rules are stored in a linear linked list, so lookup efficiency
20     won't be too good for large lists.
21 
22 
23     Initial development: April, 1998 (cjm)
24 
25     $FreeBSD$
26 */
27 
28 
29 /* System includes */
30 #include <ctype.h>
31 #include <stdio.h>
32 #include <stdlib.h>
33 #include <string.h>
34 #include <netdb.h>
35 
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 
39 /* BSD IPV4 includes */
40 #include <netinet/in_systm.h>
41 #include <netinet/in.h>
42 #include <netinet/ip.h>
43 #include <netinet/tcp.h>
44 
45 #include <arpa/inet.h>
46 
47 #include "alias_local.h"  /* Functions used by alias*.c */
48 #include "alias.h"        /* Public API functions for libalias */
49 
50 
51 
52 /*
53     Data structures
54  */
55 
56 /*
57  * A linked list of arbitrary length, based on struct proxy_entry is
58  * used to store proxy rules.
59  */
60 struct proxy_entry
61 {
62 #define PROXY_TYPE_ENCODE_NONE      1
63 #define PROXY_TYPE_ENCODE_TCPSTREAM 2
64 #define PROXY_TYPE_ENCODE_IPHDR     3
65     int rule_index;
66     int proxy_type;
67     u_char proto;
68     u_short proxy_port;
69     u_short server_port;
70 
71     struct in_addr server_addr;
72 
73     struct in_addr src_addr;
74     struct in_addr src_mask;
75 
76     struct in_addr dst_addr;
77     struct in_addr dst_mask;
78 
79     struct proxy_entry *next;
80     struct proxy_entry *last;
81 };
82 
83 
84 
85 /*
86     File scope variables
87 */
88 
89 static struct proxy_entry *proxyList;
90 
91 
92 
93 /* Local (static) functions:
94 
95     IpMask()                 -- Utility function for creating IP
96                                 masks from integer (1-32) specification.
97     IpAddr()                 -- Utility function for converting string
98                                 to IP address
99     IpPort()                 -- Utility function for converting string
100                                 to port number
101     RuleAdd()                -- Adds an element to the rule list.
102     RuleDelete()             -- Removes an element from the rule list.
103     RuleNumberDelete()       -- Removes all elements from the rule list
104                                 having a certain rule number.
105     ProxyEncodeTcpStream()   -- Adds [DEST x.x.x.x xxxx] to the beginning
106                                 of a TCP stream.
107     ProxyEncodeIpHeader()    -- Adds an IP option indicating the true
108                                 destination of a proxied IP packet
109 */
110 
111 static int IpMask(int, struct in_addr *);
112 static int IpAddr(char *, struct in_addr *);
113 static int IpPort(char *, int, int *);
114 static void RuleAdd(struct proxy_entry *);
115 static void RuleDelete(struct proxy_entry *);
116 static int RuleNumberDelete(int);
117 static void ProxyEncodeTcpStream(struct alias_link *, struct ip *, int);
118 static void ProxyEncodeIpHeader(struct ip *, int);
119 
120 static int
121 IpMask(int nbits, struct in_addr *mask)
122 {
123     int i;
124     u_int imask;
125 
126     if (nbits < 0 || nbits > 32)
127         return -1;
128 
129     imask = 0;
130     for (i=0; i<nbits; i++)
131         imask = (imask >> 1) + 0x80000000;
132     mask->s_addr = htonl(imask);
133 
134     return 0;
135 }
136 
137 static int
138 IpAddr(char *s, struct in_addr *addr)
139 {
140     if (inet_aton(s, addr) == 0)
141         return -1;
142     else
143         return 0;
144 }
145 
146 static int
147 IpPort(char *s, int proto, int *port)
148 {
149     int n;
150 
151     n = sscanf(s, "%d", port);
152     if (n != 1)
153     {
154         struct servent *se;
155 
156         if (proto == IPPROTO_TCP)
157             se = getservbyname(s, "tcp");
158         else if (proto == IPPROTO_UDP)
159             se = getservbyname(s, "udp");
160         else
161             return -1;
162 
163         if (se == NULL)
164                 return -1;
165 
166         *port = (u_int) ntohs(se->s_port);
167     }
168 
169     return 0;
170 }
171 
172 void
173 RuleAdd(struct proxy_entry *entry)
174 {
175     int rule_index;
176     struct proxy_entry *ptr;
177     struct proxy_entry *ptr_last;
178 
179     if (proxyList == NULL)
180     {
181         proxyList = entry;
182         entry->last = NULL;
183         entry->next = NULL;
184         return;
185     }
186 
187     rule_index = entry->rule_index;
188     ptr = proxyList;
189     ptr_last = NULL;
190     while (ptr != NULL)
191     {
192         if (ptr->rule_index >= rule_index)
193         {
194             if (ptr_last == NULL)
195             {
196                 entry->next = proxyList;
197                 entry->last = NULL;
198                 proxyList->last = entry;
199                 proxyList = entry;
200                 return;
201             }
202 
203             ptr_last->next = entry;
204             ptr->last = entry;
205             entry->last = ptr->last;
206             entry->next = ptr;
207             return;
208         }
209         ptr_last = ptr;
210         ptr = ptr->next;
211     }
212 
213     ptr_last->next = entry;
214     entry->last = ptr_last;
215     entry->next = NULL;
216 }
217 
218 static void
219 RuleDelete(struct proxy_entry *entry)
220 {
221     if (entry->last != NULL)
222         entry->last->next = entry->next;
223     else
224         proxyList = entry->next;
225 
226     if (entry->next != NULL)
227         entry->next->last = entry->last;
228 
229     free(entry);
230 }
231 
232 static int
233 RuleNumberDelete(int rule_index)
234 {
235     int err;
236     struct proxy_entry *ptr;
237 
238     err = -1;
239     ptr = proxyList;
240     while (ptr != NULL)
241     {
242         struct proxy_entry *ptr_next;
243 
244         ptr_next = ptr->next;
245         if (ptr->rule_index == rule_index)
246         {
247             err = 0;
248             RuleDelete(ptr);
249         }
250 
251         ptr = ptr_next;
252     }
253 
254     return err;
255 }
256 
257 static void
258 ProxyEncodeTcpStream(struct alias_link *link,
259                      struct ip *pip,
260                      int maxpacketsize)
261 {
262     int slen;
263     char buffer[40];
264     struct tcphdr *tc;
265 
266 /* Compute pointer to tcp header */
267     tc = (struct tcphdr *) ((char *) pip + (pip->ip_hl << 2));
268 
269 /* Don't modify if once already modified */
270 
271     if (GetAckModified (link))
272 	return;
273 
274 /* Translate destination address and port to string form */
275     snprintf(buffer, sizeof(buffer) - 2, "[DEST %s %d]",
276         inet_ntoa(GetProxyAddress (link)), (u_int) ntohs(GetProxyPort (link)));
277 
278 /* Pad string out to a multiple of two in length */
279     slen = strlen(buffer);
280     switch (slen % 2)
281     {
282     case 0:
283         strcat(buffer, " \n");
284 	slen += 2;
285         break;
286     case 1:
287         strcat(buffer, "\n");
288 	slen += 1;
289     }
290 
291 /* Check for packet overflow */
292     if ((ntohs(pip->ip_len) + strlen(buffer)) > maxpacketsize)
293         return;
294 
295 /* Shift existing TCP data and insert destination string */
296     {
297         int dlen;
298         int hlen;
299         u_char *p;
300 
301         hlen = (pip->ip_hl + tc->th_off) << 2;
302         dlen = ntohs (pip->ip_len) - hlen;
303 
304 /* Modify first packet that has data in it */
305 
306 	if (dlen == 0)
307 		return;
308 
309         p = (char *) pip;
310         p += hlen;
311 
312         memmove(p + slen, p, dlen);
313         memcpy(p, buffer, slen);
314     }
315 
316 /* Save information about modfied sequence number */
317     {
318         int delta;
319 
320         SetAckModified(link);
321         delta = GetDeltaSeqOut(pip, link);
322         AddSeq(pip, link, delta+slen);
323     }
324 
325 /* Update IP header packet length and checksum */
326     {
327         int accumulate;
328 
329         accumulate  = pip->ip_len;
330         pip->ip_len = htons(ntohs(pip->ip_len) + slen);
331         accumulate -= pip->ip_len;
332 
333         ADJUST_CHECKSUM(accumulate, pip->ip_sum);
334     }
335 
336 /* Update TCP checksum, Use TcpChecksum since so many things have
337    already changed. */
338 
339     tc->th_sum = 0;
340     tc->th_sum = TcpChecksum (pip);
341 }
342 
343 static void
344 ProxyEncodeIpHeader(struct ip *pip,
345                     int maxpacketsize)
346 {
347 #define OPTION_LEN_BYTES  8
348 #define OPTION_LEN_INT16  4
349 #define OPTION_LEN_INT32  2
350     u_char option[OPTION_LEN_BYTES];
351 
352 #ifdef DEBUG
353     fprintf(stdout, " ip cksum 1 = %x\n", (u_int) IpChecksum(pip));
354     fprintf(stdout, "tcp cksum 1 = %x\n", (u_int) TcpChecksum(pip));
355 #endif
356 
357 /* Check to see that there is room to add an IP option */
358     if (pip->ip_hl > (0x0f - OPTION_LEN_INT32))
359         return;
360 
361 /* Build option and copy into packet */
362     {
363         u_char *ptr;
364         struct tcphdr *tc;
365 
366         ptr = (u_char *) pip;
367         ptr += 20;
368         memcpy(ptr + OPTION_LEN_BYTES, ptr, ntohs(pip->ip_len) - 20);
369 
370         option[0] = 0x64; /* class: 3 (reserved), option 4 */
371         option[1] = OPTION_LEN_BYTES;
372 
373         memcpy(&option[2], (u_char *) &pip->ip_dst, 4);
374 
375         tc = (struct tcphdr *) ((char *) pip + (pip->ip_hl << 2));
376         memcpy(&option[6], (u_char *) &tc->th_sport, 2);
377 
378         memcpy(ptr, option, 8);
379     }
380 
381 /* Update checksum, header length and packet length */
382     {
383         int i;
384         int accumulate;
385         u_short *sptr;
386 
387         sptr = (u_short *) option;
388         accumulate = 0;
389         for (i=0; i<OPTION_LEN_INT16; i++)
390             accumulate -= *(sptr++);
391 
392         sptr = (u_short *) pip;
393         accumulate += *sptr;
394         pip->ip_hl += OPTION_LEN_INT32;
395         accumulate -= *sptr;
396 
397         accumulate += pip->ip_len;
398         pip->ip_len = htons(ntohs(pip->ip_len) + OPTION_LEN_BYTES);
399         accumulate -= pip->ip_len;
400 
401         ADJUST_CHECKSUM(accumulate, pip->ip_sum);
402     }
403 #undef OPTION_LEN_BYTES
404 #undef OPTION_LEN_INT16
405 #undef OPTION_LEN_INT32
406 #ifdef DEBUG
407     fprintf(stdout, " ip cksum 2 = %x\n", (u_int) IpChecksum(pip));
408     fprintf(stdout, "tcp cksum 2 = %x\n", (u_int) TcpChecksum(pip));
409 #endif
410 }
411 
412 
413 /* Functions by other packet alias source files
414 
415     ProxyCheck()         -- Checks whether an outgoing packet should
416                             be proxied.
417     ProxyModify()        -- Encodes the original destination address/port
418                             for a packet which is to be redirected to
419                             a proxy server.
420 */
421 
422 int
423 ProxyCheck(struct ip *pip,
424            struct in_addr *proxy_server_addr,
425            u_short *proxy_server_port)
426 {
427     u_short dst_port;
428     struct in_addr src_addr;
429     struct in_addr dst_addr;
430     struct proxy_entry *ptr;
431 
432     src_addr = pip->ip_src;
433     dst_addr = pip->ip_dst;
434     dst_port = ((struct tcphdr *) ((char *) pip + (pip->ip_hl << 2)))
435         ->th_dport;
436 
437     ptr = proxyList;
438     while (ptr != NULL)
439     {
440         u_short proxy_port;
441 
442         proxy_port = ptr->proxy_port;
443         if ((dst_port == proxy_port || proxy_port == 0)
444          && pip->ip_p == ptr->proto
445          && src_addr.s_addr != ptr->server_addr.s_addr)
446         {
447             struct in_addr src_addr_masked;
448             struct in_addr dst_addr_masked;
449 
450             src_addr_masked.s_addr = src_addr.s_addr & ptr->src_mask.s_addr;
451             dst_addr_masked.s_addr = dst_addr.s_addr & ptr->dst_mask.s_addr;
452 
453             if ((src_addr_masked.s_addr == ptr->src_addr.s_addr)
454              && (dst_addr_masked.s_addr == ptr->dst_addr.s_addr))
455             {
456                 if ((*proxy_server_port = ptr->server_port) == 0)
457                     *proxy_server_port = dst_port;
458                 *proxy_server_addr = ptr->server_addr;
459                 return ptr->proxy_type;
460             }
461         }
462         ptr = ptr->next;
463     }
464 
465     return 0;
466 }
467 
468 void
469 ProxyModify(struct alias_link *link,
470             struct ip *pip,
471             int maxpacketsize,
472             int proxy_type)
473 {
474     switch (proxy_type)
475     {
476     case PROXY_TYPE_ENCODE_IPHDR:
477         ProxyEncodeIpHeader(pip, maxpacketsize);
478         break;
479 
480     case PROXY_TYPE_ENCODE_TCPSTREAM:
481         ProxyEncodeTcpStream(link, pip, maxpacketsize);
482         break;
483     }
484 }
485 
486 
487 /*
488     Public API functions
489 */
490 
491 int
492 PacketAliasProxyRule(const char *cmd)
493 {
494 /*
495  * This function takes command strings of the form:
496  *
497  *   server <addr>[:<port>]
498  *   [port <port>]
499  *   [rule n]
500  *   [proto tcp|udp]
501  *   [src <addr>[/n]]
502  *   [dst <addr>[/n]]
503  *   [type encode_tcp_stream|encode_ip_hdr|no_encode]
504  *
505  *   delete <rule number>
506  *
507  * Subfields can be in arbitrary order.  Port numbers and addresses
508  * must be in either numeric or symbolic form. An optional rule number
509  * is used to control the order in which rules are searched.  If two
510  * rules have the same number, then search order cannot be guaranteed,
511  * and the rules should be disjoint.  If no rule number is specified,
512  * then 0 is used, and group 0 rules are always checked before any
513  * others.
514  */
515     int i, n, len;
516     int cmd_len;
517     int token_count;
518     int state;
519     char *token;
520     char buffer[256];
521     char str_port[sizeof(buffer)];
522     char str_server_port[sizeof(buffer)];
523 
524     int rule_index;
525     int proto;
526     int proxy_type;
527     int proxy_port;
528     int server_port;
529     struct in_addr server_addr;
530     struct in_addr src_addr, src_mask;
531     struct in_addr dst_addr, dst_mask;
532     struct proxy_entry *proxy_entry;
533 
534 /* Copy command line into a buffer */
535     cmd_len = strlen(cmd);
536     if (cmd_len > (sizeof(buffer) - 1))
537         return -1;
538     strcpy(buffer, cmd);
539 
540 /* Convert to lower case */
541     len = strlen(buffer);
542     for (i=0; i<len; i++)
543         buffer[i] = tolower(buffer[i]);
544 
545 /* Set default proxy type */
546 
547 /* Set up default values */
548     rule_index = 0;
549     proxy_type = PROXY_TYPE_ENCODE_NONE;
550     proto = IPPROTO_TCP;
551     proxy_port = 0;
552     server_addr.s_addr = 0;
553     server_port = 0;
554     src_addr.s_addr = 0;
555     IpMask(0, &src_mask);
556     dst_addr.s_addr = 0;
557     IpMask(0, &dst_mask);
558 
559     str_port[0] = 0;
560     str_server_port[0] = 0;
561 
562 /* Parse command string with state machine */
563 #define STATE_READ_KEYWORD    0
564 #define STATE_READ_TYPE       1
565 #define STATE_READ_PORT       2
566 #define STATE_READ_SERVER     3
567 #define STATE_READ_RULE       4
568 #define STATE_READ_DELETE     5
569 #define STATE_READ_PROTO      6
570 #define STATE_READ_SRC        7
571 #define STATE_READ_DST        8
572     state = STATE_READ_KEYWORD;
573     token = strtok(buffer, " \t");
574     token_count = 0;
575     while (token != NULL)
576     {
577         token_count++;
578         switch (state)
579         {
580         case STATE_READ_KEYWORD:
581             if (strcmp(token, "type") == 0)
582                 state = STATE_READ_TYPE;
583             else if (strcmp(token, "port") == 0)
584                 state = STATE_READ_PORT;
585             else if (strcmp(token, "server") == 0)
586                 state = STATE_READ_SERVER;
587             else if (strcmp(token, "rule") == 0)
588                 state = STATE_READ_RULE;
589             else if (strcmp(token, "delete") == 0)
590                 state = STATE_READ_DELETE;
591             else if (strcmp(token, "proto") == 0)
592                 state = STATE_READ_PROTO;
593             else if (strcmp(token, "src") == 0)
594                 state = STATE_READ_SRC;
595             else if (strcmp(token, "dst") == 0)
596                 state = STATE_READ_DST;
597             else
598                 return -1;
599             break;
600 
601         case STATE_READ_TYPE:
602             if (strcmp(token, "encode_ip_hdr") == 0)
603                 proxy_type = PROXY_TYPE_ENCODE_IPHDR;
604             else if (strcmp(token, "encode_tcp_stream") == 0)
605                 proxy_type = PROXY_TYPE_ENCODE_TCPSTREAM;
606             else if (strcmp(token, "no_encode") == 0)
607                 proxy_type = PROXY_TYPE_ENCODE_NONE;
608             else
609                 return -1;
610             state = STATE_READ_KEYWORD;
611             break;
612 
613         case STATE_READ_PORT:
614             strcpy(str_port, token);
615             state = STATE_READ_KEYWORD;
616             break;
617 
618         case STATE_READ_SERVER:
619             {
620                 int err;
621                 char *p;
622                 char s[sizeof(buffer)];
623 
624                 p = token;
625                 while (*p != ':' && *p != 0)
626                     p++;
627 
628                 if (*p != ':')
629                 {
630                     err = IpAddr(token, &server_addr);
631                     if (err)
632                         return -1;
633                 }
634                 else
635                 {
636                     *p = ' ';
637 
638                     n = sscanf(token, "%s %s", s, str_server_port);
639                     if (n != 2)
640                         return -1;
641 
642                     err = IpAddr(s, &server_addr);
643                     if (err)
644                         return -1;
645                 }
646             }
647             state = STATE_READ_KEYWORD;
648             break;
649 
650         case STATE_READ_RULE:
651             n = sscanf(token, "%d", &rule_index);
652             if (n != 1 || rule_index < 0)
653                 return -1;
654             state = STATE_READ_KEYWORD;
655             break;
656 
657         case STATE_READ_DELETE:
658             {
659                 int err;
660                 int rule_to_delete;
661 
662                 if (token_count != 2)
663                     return -1;
664 
665                 n = sscanf(token, "%d", &rule_to_delete);
666                 if (n != 1)
667                     return -1;
668                 err = RuleNumberDelete(rule_to_delete);
669                 if (err)
670                     return -1;
671                 return 0;
672             }
673 
674         case STATE_READ_PROTO:
675             if (strcmp(token, "tcp") == 0)
676                 proto = IPPROTO_TCP;
677             else if (strcmp(token, "udp") == 0)
678                 proto = IPPROTO_UDP;
679             else
680                 return -1;
681             state = STATE_READ_KEYWORD;
682             break;
683 
684         case STATE_READ_SRC:
685         case STATE_READ_DST:
686             {
687                 int err;
688                 char *p;
689                 struct in_addr mask;
690                 struct in_addr addr;
691 
692                 p = token;
693                 while (*p != '/' && *p != 0)
694                     p++;
695 
696                 if (*p != '/')
697                 {
698                      IpMask(32, &mask);
699                      err = IpAddr(token, &addr);
700                      if (err)
701                          return -1;
702                 }
703                 else
704                 {
705                     int n;
706                     int nbits;
707                     char s[sizeof(buffer)];
708 
709                     *p = ' ';
710                     n = sscanf(token, "%s %d", s, &nbits);
711                     if (n != 2)
712                         return -1;
713 
714                     err = IpAddr(s, &addr);
715                     if (err)
716                         return -1;
717 
718                     err = IpMask(nbits, &mask);
719                     if (err)
720                         return -1;
721                 }
722 
723                 if (state == STATE_READ_SRC)
724                 {
725                     src_addr = addr;
726                     src_mask = mask;
727                 }
728                 else
729                 {
730                     dst_addr = addr;
731                     dst_mask = mask;
732                 }
733             }
734             state = STATE_READ_KEYWORD;
735             break;
736 
737         default:
738             return -1;
739             break;
740         }
741 
742         token = strtok(NULL, " \t");
743     }
744 #undef STATE_READ_KEYWORD
745 #undef STATE_READ_TYPE
746 #undef STATE_READ_PORT
747 #undef STATE_READ_SERVER
748 #undef STATE_READ_RULE
749 #undef STATE_READ_DELETE
750 #undef STATE_READ_PROTO
751 #undef STATE_READ_SRC
752 #undef STATE_READ_DST
753 
754 /* Convert port strings to numbers.  This needs to be done after
755    the string is parsed, because the prototype might not be designated
756    before the ports (which might be symbolic entries in /etc/services) */
757 
758     if (strlen(str_port) != 0)
759     {
760         int err;
761 
762         err = IpPort(str_port, proto, &proxy_port);
763         if (err)
764             return -1;
765     }
766     else
767     {
768         proxy_port = 0;
769     }
770 
771     if (strlen(str_server_port) != 0)
772     {
773         int err;
774 
775         err = IpPort(str_server_port, proto, &server_port);
776         if (err)
777             return -1;
778     }
779     else
780     {
781         server_port = 0;
782     }
783 
784 /* Check that at least the server address has been defined */
785     if (server_addr.s_addr == 0)
786         return -1;
787 
788 /* Add to linked list */
789     proxy_entry = malloc(sizeof(struct proxy_entry));
790     if (proxy_entry == NULL)
791         return -1;
792 
793     proxy_entry->proxy_type = proxy_type;
794     proxy_entry->rule_index = rule_index;
795     proxy_entry->proto = proto;
796     proxy_entry->proxy_port = htons(proxy_port);
797     proxy_entry->server_port = htons(server_port);
798     proxy_entry->server_addr = server_addr;
799     proxy_entry->src_addr.s_addr = src_addr.s_addr & src_mask.s_addr;
800     proxy_entry->dst_addr.s_addr = dst_addr.s_addr & dst_mask.s_addr;
801     proxy_entry->src_mask = src_mask;
802     proxy_entry->dst_mask = dst_mask;
803 
804     RuleAdd(proxy_entry);
805 
806     return 0;
807 }
808