xref: /freebsd/sys/netinet/libalias/alias_proxy.c (revision 601752d5a7bef087e755da5a2b158fa35cb51ccb)
1 /* file: alias_proxy.c
2 
3     This file encapsulates special operations related to transparent
4     proxy redirection.  This is where packets with a particular destination,
5     usually tcp port 80, are redirected to a proxy server.
6 
7     When packets are proxied, the destination address and port are
8     modified.  In certain cases, it is necessary to somehow encode
9     the original address/port info into the packet.  Two methods are
10     presently supported: addition of a [DEST addr port] string at the
11     beginning a of tcp stream, or inclusion of an optional field
12     in the IP header.
13 
14     There is one public API function:
15 
16         PacketAliasProxyRule()    -- Adds and deletes proxy
17                                      rules.
18 
19     Rules are stored in a linear linked list, so lookup efficiency
20     won't be too good for large lists.
21 
22 
23     Initial development: April, 1998 (cjm)
24 */
25 
26 
27 /* System includes */
28 #include <ctype.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <netdb.h>
33 
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 
37 /* BSD IPV4 includes */
38 #include <netinet/in_systm.h>
39 #include <netinet/in.h>
40 #include <netinet/ip.h>
41 #include <netinet/tcp.h>
42 
43 #include <arpa/inet.h>
44 
45 #include "alias_local.h"  /* Functions used by alias*.c */
46 #include "alias.h"        /* Public API functions for libalias */
47 
48 
49 
50 /*
51     Data structures
52  */
53 
54 /*
55  * A linked list of arbitrary length, based on struct proxy_entry is
56  * used to store proxy rules.
57  */
58 struct proxy_entry
59 {
60 #define PROXY_TYPE_ENCODE_NONE      1
61 #define PROXY_TYPE_ENCODE_TCPSTREAM 2
62 #define PROXY_TYPE_ENCODE_IPHDR     3
63     int rule_index;
64     int proxy_type;
65     u_char proto;
66     u_short proxy_port;
67     u_short server_port;
68 
69     struct in_addr server_addr;
70 
71     struct in_addr src_addr;
72     struct in_addr src_mask;
73 
74     struct in_addr dst_addr;
75     struct in_addr dst_mask;
76 
77     struct proxy_entry *next;
78     struct proxy_entry *last;
79 };
80 
81 
82 
83 /*
84     File scope variables
85 */
86 
87 static struct proxy_entry *proxyList;
88 
89 
90 
91 /* Local (static) functions:
92 
93     IpMask()                 -- Utility function for creating IP
94                                 masks from integer (1-32) specification.
95     IpAddr()                 -- Utility function for converting string
96                                 to IP address
97     IpPort()                 -- Utility function for converting string
98                                 to port number
99     RuleAdd()                -- Adds an element to the rule list.
100     RuleDelete()             -- Removes an element from the rule list.
101     RuleNumberDelete()       -- Removes all elements from the rule list
102                                 having a certain rule number.
103     ProxyEncodeTcpStream()   -- Adds [DEST x.x.x.x xxxx] to the beginning
104                                 of a TCP stream.
105     ProxyEncodeIpHeader()    -- Adds an IP option indicating the true
106                                 destination of a proxied IP packet
107 */
108 
109 static int IpMask(int, struct in_addr *);
110 static int IpAddr(char *, struct in_addr *);
111 static int IpPort(char *, int, int *);
112 static void RuleAdd(struct proxy_entry *);
113 static void RuleDelete(struct proxy_entry *);
114 static int RuleNumberDelete(int);
115 static void ProxyEncodeTcpStream(struct alias_link *, struct ip *, int);
116 static void ProxyEncodeIpHeader(struct ip *, int);
117 
118 static int
119 IpMask(int nbits, struct in_addr *mask)
120 {
121     int i;
122     u_int imask;
123 
124     if (nbits < 0 || nbits > 32)
125         return -1;
126 
127     imask = 0;
128     for (i=0; i<nbits; i++)
129         imask = (imask >> 1) + 0x80000000;
130     mask->s_addr = htonl(imask);
131 
132     return 0;
133 }
134 
135 static int
136 IpAddr(char *s, struct in_addr *addr)
137 {
138     if (inet_aton(s, addr) == 0)
139         return -1;
140     else
141         return 0;
142 }
143 
144 static int
145 IpPort(char *s, int proto, int *port)
146 {
147     int n;
148 
149     n = sscanf(s, "%d", port);
150     if (n != 1)
151     {
152         struct servent *se;
153 
154         if (proto == IPPROTO_TCP)
155             se = getservbyname(s, "tcp");
156         else if (proto == IPPROTO_UDP)
157             se = getservbyname(s, "udp");
158         else
159             return -1;
160 
161         if (se == NULL)
162                 return -1;
163 
164         *port = (u_int) ntohs(se->s_port);
165     }
166 
167     return 0;
168 }
169 
170 void
171 RuleAdd(struct proxy_entry *entry)
172 {
173     int rule_index;
174     struct proxy_entry *ptr;
175     struct proxy_entry *ptr_last;
176 
177     if (proxyList == NULL)
178     {
179         proxyList = entry;
180         entry->last = NULL;
181         entry->next = NULL;
182         return;
183     }
184 
185     rule_index = entry->rule_index;
186     ptr = proxyList;
187     ptr_last = NULL;
188     while (ptr != NULL)
189     {
190         if (ptr->rule_index >= rule_index)
191         {
192             if (ptr_last == NULL)
193             {
194                 entry->next = proxyList;
195                 entry->last = NULL;
196                 proxyList->last = entry;
197                 proxyList = entry;
198                 return;
199             }
200 
201             ptr_last->next = entry;
202             ptr->last = entry;
203             entry->last = ptr->last;
204             entry->next = ptr;
205             return;
206         }
207         ptr_last = ptr;
208         ptr = ptr->next;
209     }
210 
211     ptr_last->next = entry;
212     entry->last = ptr_last;
213     entry->next = NULL;
214 }
215 
216 static void
217 RuleDelete(struct proxy_entry *entry)
218 {
219     if (entry->last != NULL)
220         entry->last->next = entry->next;
221     else
222         proxyList = entry->next;
223 
224     if (entry->next != NULL)
225         entry->next->last = entry->last;
226 
227     free(entry);
228 }
229 
230 static int
231 RuleNumberDelete(int rule_index)
232 {
233     int err;
234     struct proxy_entry *ptr;
235 
236     err = -1;
237     ptr = proxyList;
238     while (ptr != NULL)
239     {
240         struct proxy_entry *ptr_next;
241 
242         ptr_next = ptr->next;
243         if (ptr->rule_index == rule_index)
244         {
245             err = 0;
246             RuleDelete(ptr);
247         }
248 
249         ptr = ptr_next;
250     }
251 
252     return err;
253 }
254 
255 static void
256 ProxyEncodeTcpStream(struct alias_link *link,
257                      struct ip *pip,
258                      int maxpacketsize)
259 {
260     int slen;
261     char buffer[40];
262     struct tcphdr *tc;
263 
264 /* Compute pointer to tcp header */
265     tc = (struct tcphdr *) ((char *) pip + (pip->ip_hl << 2));
266 
267 /* Don't modify if once already modified */
268 
269     if (GetAckModified (link))
270 	return;
271 
272 /* Translate destination address and port to string form */
273     snprintf(buffer, sizeof(buffer) - 2, "[DEST %s %d]",
274         inet_ntoa(GetProxyAddress (link)), (u_int) ntohs(GetProxyPort (link)));
275 
276 /* Pad string out to a multiple of two in length */
277     slen = strlen(buffer);
278     switch (slen % 2)
279     {
280     case 0:
281         strcat(buffer, " \n");
282 	slen += 2;
283         break;
284     case 1:
285         strcat(buffer, "\n");
286 	slen += 1;
287     }
288 
289 /* Check for packet overflow */
290     if ((ntohs(pip->ip_len) + strlen(buffer)) > maxpacketsize)
291         return;
292 
293 /* Shift existing TCP data and insert destination string */
294     {
295         int dlen;
296         int hlen;
297         u_char *p;
298 
299         hlen = (pip->ip_hl + tc->th_off) << 2;
300         dlen = ntohs (pip->ip_len) - hlen;
301 
302 /* Modify first packet that has data in it */
303 
304 	if (dlen == 0)
305 		return;
306 
307         p = (char *) pip;
308         p += hlen;
309 
310         memmove(p + slen, p, dlen);
311         memcpy(p, buffer, slen);
312     }
313 
314 /* Save information about modfied sequence number */
315     {
316         int delta;
317 
318         SetAckModified(link);
319         delta = GetDeltaSeqOut(pip, link);
320         AddSeq(pip, link, delta+slen);
321     }
322 
323 /* Update IP header packet length and checksum */
324     {
325         int accumulate;
326 
327         accumulate  = pip->ip_len;
328         pip->ip_len = htons(ntohs(pip->ip_len) + slen);
329         accumulate -= pip->ip_len;
330 
331         ADJUST_CHECKSUM(accumulate, pip->ip_sum);
332     }
333 
334 /* Update TCP checksum, Use TcpChecksum since so many things have
335    already changed. */
336 
337     tc->th_sum = 0;
338     tc->th_sum = TcpChecksum (pip);
339 }
340 
341 static void
342 ProxyEncodeIpHeader(struct ip *pip,
343                     int maxpacketsize)
344 {
345 #define OPTION_LEN_BYTES  8
346 #define OPTION_LEN_INT16  4
347 #define OPTION_LEN_INT32  2
348     u_char option[OPTION_LEN_BYTES];
349 
350 #ifdef DEBUG
351     fprintf(stdout, " ip cksum 1 = %x\n", (u_int) IpChecksum(pip));
352     fprintf(stdout, "tcp cksum 1 = %x\n", (u_int) TcpChecksum(pip));
353 #endif
354 
355 /* Check to see that there is room to add an IP option */
356     if (pip->ip_hl > (0x0f - OPTION_LEN_INT32))
357         return;
358 
359 /* Build option and copy into packet */
360     {
361         u_char *ptr;
362         struct tcphdr *tc;
363 
364         ptr = (u_char *) pip;
365         ptr += 20;
366         memcpy(ptr + OPTION_LEN_BYTES, ptr, ntohs(pip->ip_len) - 20);
367 
368         option[0] = 0x64; /* class: 3 (reserved), option 4 */
369         option[1] = OPTION_LEN_BYTES;
370 
371         memcpy(&option[2], (u_char *) &pip->ip_dst, 4);
372 
373         tc = (struct tcphdr *) ((char *) pip + (pip->ip_hl << 2));
374         memcpy(&option[6], (u_char *) &tc->th_sport, 2);
375 
376         memcpy(ptr, option, 8);
377     }
378 
379 /* Update checksum, header length and packet length */
380     {
381         int i;
382         int accumulate;
383         u_short *sptr;
384 
385         sptr = (u_short *) option;
386         accumulate = 0;
387         for (i=0; i<OPTION_LEN_INT16; i++)
388             accumulate -= *(sptr++);
389 
390         sptr = (u_short *) pip;
391         accumulate += *sptr;
392         pip->ip_hl += OPTION_LEN_INT32;
393         accumulate -= *sptr;
394 
395         accumulate += pip->ip_len;
396         pip->ip_len = htons(ntohs(pip->ip_len) + OPTION_LEN_BYTES);
397         accumulate -= pip->ip_len;
398 
399         ADJUST_CHECKSUM(accumulate, pip->ip_sum);
400     }
401 #undef OPTION_LEN_BYTES
402 #undef OPTION_LEN_INT16
403 #undef OPTION_LEN_INT32
404 #ifdef DEBUG
405     fprintf(stdout, " ip cksum 2 = %x\n", (u_int) IpChecksum(pip));
406     fprintf(stdout, "tcp cksum 2 = %x\n", (u_int) TcpChecksum(pip));
407 #endif
408 }
409 
410 
411 /* Functions by other packet alias source files
412 
413     ProxyCheck()         -- Checks whether an outgoing packet should
414                             be proxied.
415     ProxyModify()        -- Encodes the original destination address/port
416                             for a packet which is to be redirected to
417                             a proxy server.
418 */
419 
420 int
421 ProxyCheck(struct ip *pip,
422            struct in_addr *proxy_server_addr,
423            u_short *proxy_server_port)
424 {
425     u_short dst_port;
426     struct in_addr src_addr;
427     struct in_addr dst_addr;
428     struct proxy_entry *ptr;
429 
430     src_addr = pip->ip_src;
431     dst_addr = pip->ip_dst;
432     dst_port = ((struct tcphdr *) ((char *) pip + (pip->ip_hl << 2)))
433         ->th_dport;
434 
435     ptr = proxyList;
436     while (ptr != NULL)
437     {
438         u_short proxy_port;
439 
440         proxy_port = ptr->proxy_port;
441         if ((dst_port == proxy_port || proxy_port == 0)
442          && pip->ip_p == ptr->proto
443          && src_addr.s_addr != ptr->server_addr.s_addr)
444         {
445             struct in_addr src_addr_masked;
446             struct in_addr dst_addr_masked;
447 
448             src_addr_masked.s_addr = src_addr.s_addr & ptr->src_mask.s_addr;
449             dst_addr_masked.s_addr = dst_addr.s_addr & ptr->dst_mask.s_addr;
450 
451             if ((src_addr_masked.s_addr == ptr->src_addr.s_addr)
452              && (dst_addr_masked.s_addr == ptr->dst_addr.s_addr))
453             {
454                 if ((*proxy_server_port = ptr->server_port) == 0)
455                     *proxy_server_port = dst_port;
456                 *proxy_server_addr = ptr->server_addr;
457                 return ptr->proxy_type;
458             }
459         }
460         ptr = ptr->next;
461     }
462 
463     return 0;
464 }
465 
466 void
467 ProxyModify(struct alias_link *link,
468             struct ip *pip,
469             int maxpacketsize,
470             int proxy_type)
471 {
472     switch (proxy_type)
473     {
474     case PROXY_TYPE_ENCODE_IPHDR:
475         ProxyEncodeIpHeader(pip, maxpacketsize);
476         break;
477 
478     case PROXY_TYPE_ENCODE_TCPSTREAM:
479         ProxyEncodeTcpStream(link, pip, maxpacketsize);
480         break;
481     }
482 }
483 
484 
485 /*
486     Public API functions
487 */
488 
489 int
490 PacketAliasProxyRule(const char *cmd)
491 {
492 /*
493  * This function takes command strings of the form:
494  *
495  *   server <addr>[:<port>]
496  *   [port <port>]
497  *   [rule n]
498  *   [proto tcp|udp]
499  *   [src <addr>[/n]]
500  *   [dst <addr>[/n]]
501  *   [type encode_tcp_stream|encode_ip_hdr|no_encode]
502  *
503  *   delete <rule number>
504  *
505  * Subfields can be in arbitrary order.  Port numbers and addresses
506  * must be in either numeric or symbolic form. An optional rule number
507  * is used to control the order in which rules are searched.  If two
508  * rules have the same number, then search order cannot be guaranteed,
509  * and the rules should be disjoint.  If no rule number is specified,
510  * then 0 is used, and group 0 rules are always checked before any
511  * others.
512  */
513     int i, n, len;
514     int cmd_len;
515     int token_count;
516     int state;
517     char *token;
518     char buffer[256];
519     char str_port[sizeof(buffer)];
520     char str_server_port[sizeof(buffer)];
521 
522     int rule_index;
523     int proto;
524     int proxy_type;
525     int proxy_port;
526     int server_port;
527     struct in_addr server_addr;
528     struct in_addr src_addr, src_mask;
529     struct in_addr dst_addr, dst_mask;
530     struct proxy_entry *proxy_entry;
531 
532 /* Copy command line into a buffer */
533     cmd_len = strlen(cmd);
534     if (cmd_len > (sizeof(buffer) - 1))
535         return -1;
536     strcpy(buffer, cmd);
537 
538 /* Convert to lower case */
539     len = strlen(buffer);
540     for (i=0; i<len; i++)
541         buffer[i] = tolower(buffer[i]);
542 
543 /* Set default proxy type */
544 
545 /* Set up default values */
546     rule_index = 0;
547     proxy_type = PROXY_TYPE_ENCODE_NONE;
548     proto = IPPROTO_TCP;
549     proxy_port = 0;
550     server_addr.s_addr = 0;
551     server_port = 0;
552     src_addr.s_addr = 0;
553     IpMask(0, &src_mask);
554     dst_addr.s_addr = 0;
555     IpMask(0, &dst_mask);
556 
557     str_port[0] = 0;
558     str_server_port[0] = 0;
559 
560 /* Parse command string with state machine */
561 #define STATE_READ_KEYWORD    0
562 #define STATE_READ_TYPE       1
563 #define STATE_READ_PORT       2
564 #define STATE_READ_SERVER     3
565 #define STATE_READ_RULE       4
566 #define STATE_READ_DELETE     5
567 #define STATE_READ_PROTO      6
568 #define STATE_READ_SRC        7
569 #define STATE_READ_DST        8
570     state = STATE_READ_KEYWORD;
571     token = strtok(buffer, " \t");
572     token_count = 0;
573     while (token != NULL)
574     {
575         token_count++;
576         switch (state)
577         {
578         case STATE_READ_KEYWORD:
579             if (strcmp(token, "type") == 0)
580                 state = STATE_READ_TYPE;
581             else if (strcmp(token, "port") == 0)
582                 state = STATE_READ_PORT;
583             else if (strcmp(token, "server") == 0)
584                 state = STATE_READ_SERVER;
585             else if (strcmp(token, "rule") == 0)
586                 state = STATE_READ_RULE;
587             else if (strcmp(token, "delete") == 0)
588                 state = STATE_READ_DELETE;
589             else if (strcmp(token, "proto") == 0)
590                 state = STATE_READ_PROTO;
591             else if (strcmp(token, "src") == 0)
592                 state = STATE_READ_SRC;
593             else if (strcmp(token, "dst") == 0)
594                 state = STATE_READ_DST;
595             else
596                 return -1;
597             break;
598 
599         case STATE_READ_TYPE:
600             if (strcmp(token, "encode_ip_hdr") == 0)
601                 proxy_type = PROXY_TYPE_ENCODE_IPHDR;
602             else if (strcmp(token, "encode_tcp_stream") == 0)
603                 proxy_type = PROXY_TYPE_ENCODE_TCPSTREAM;
604             else if (strcmp(token, "no_encode") == 0)
605                 proxy_type = PROXY_TYPE_ENCODE_NONE;
606             else
607                 return -1;
608             state = STATE_READ_KEYWORD;
609             break;
610 
611         case STATE_READ_PORT:
612             strcpy(str_port, token);
613             state = STATE_READ_KEYWORD;
614             break;
615 
616         case STATE_READ_SERVER:
617             {
618                 int err;
619                 char *p;
620                 char s[sizeof(buffer)];
621 
622                 p = token;
623                 while (*p != ':' && *p != 0)
624                     p++;
625 
626                 if (*p != ':')
627                 {
628                     err = IpAddr(token, &server_addr);
629                     if (err)
630                         return -1;
631                 }
632                 else
633                 {
634                     *p = ' ';
635 
636                     n = sscanf(token, "%s %s", s, str_server_port);
637                     if (n != 2)
638                         return -1;
639 
640                     err = IpAddr(s, &server_addr);
641                     if (err)
642                         return -1;
643                 }
644             }
645             state = STATE_READ_KEYWORD;
646             break;
647 
648         case STATE_READ_RULE:
649             n = sscanf(token, "%d", &rule_index);
650             if (n != 1 || rule_index < 0)
651                 return -1;
652             state = STATE_READ_KEYWORD;
653             break;
654 
655         case STATE_READ_DELETE:
656             {
657                 int err;
658                 int rule_to_delete;
659 
660                 if (token_count != 2)
661                     return -1;
662 
663                 n = sscanf(token, "%d", &rule_to_delete);
664                 if (n != 1)
665                     return -1;
666                 err = RuleNumberDelete(rule_to_delete);
667                 if (err)
668                     return -1;
669                 return 0;
670             }
671 
672         case STATE_READ_PROTO:
673             if (strcmp(token, "tcp") == 0)
674                 proto = IPPROTO_TCP;
675             else if (strcmp(token, "udp") == 0)
676                 proto = IPPROTO_UDP;
677             else
678                 return -1;
679             state = STATE_READ_KEYWORD;
680             break;
681 
682         case STATE_READ_SRC:
683         case STATE_READ_DST:
684             {
685                 int err;
686                 char *p;
687                 struct in_addr mask;
688                 struct in_addr addr;
689 
690                 p = token;
691                 while (*p != '/' && *p != 0)
692                     p++;
693 
694                 if (*p != '/')
695                 {
696                      IpMask(32, &mask);
697                      err = IpAddr(token, &addr);
698                      if (err)
699                          return -1;
700                 }
701                 else
702                 {
703                     int n;
704                     int nbits;
705                     char s[sizeof(buffer)];
706 
707                     *p = ' ';
708                     n = sscanf(token, "%s %d", s, &nbits);
709                     if (n != 2)
710                         return -1;
711 
712                     err = IpAddr(s, &addr);
713                     if (err)
714                         return -1;
715 
716                     err = IpMask(nbits, &mask);
717                     if (err)
718                         return -1;
719                 }
720 
721                 if (state == STATE_READ_SRC)
722                 {
723                     src_addr = addr;
724                     src_mask = mask;
725                 }
726                 else
727                 {
728                     dst_addr = addr;
729                     dst_mask = mask;
730                 }
731             }
732             state = STATE_READ_KEYWORD;
733             break;
734 
735         default:
736             return -1;
737             break;
738         }
739 
740         token = strtok(NULL, " \t");
741     }
742 #undef STATE_READ_KEYWORD
743 #undef STATE_READ_TYPE
744 #undef STATE_READ_PORT
745 #undef STATE_READ_SERVER
746 #undef STATE_READ_RULE
747 #undef STATE_READ_DELETE
748 #undef STATE_READ_PROTO
749 #undef STATE_READ_SRC
750 #undef STATE_READ_DST
751 
752 /* Convert port strings to numbers.  This needs to be done after
753    the string is parsed, because the prototype might not be designated
754    before the ports (which might be symbolic entries in /etc/services) */
755 
756     if (strlen(str_port) != 0)
757     {
758         int err;
759 
760         err = IpPort(str_port, proto, &proxy_port);
761         if (err)
762             return -1;
763     }
764     else
765     {
766         proxy_port = 0;
767     }
768 
769     if (strlen(str_server_port) != 0)
770     {
771         int err;
772 
773         err = IpPort(str_server_port, proto, &server_port);
774         if (err)
775             return -1;
776     }
777     else
778     {
779         server_port = 0;
780     }
781 
782 /* Check that at least the server address has been defined */
783     if (server_addr.s_addr == 0)
784         return -1;
785 
786 /* Add to linked list */
787     proxy_entry = malloc(sizeof(struct proxy_entry));
788     if (proxy_entry == NULL)
789         return -1;
790 
791     proxy_entry->proxy_type = proxy_type;
792     proxy_entry->rule_index = rule_index;
793     proxy_entry->proto = proto;
794     proxy_entry->proxy_port = htons(proxy_port);
795     proxy_entry->server_port = htons(server_port);
796     proxy_entry->server_addr = server_addr;
797     proxy_entry->src_addr.s_addr = src_addr.s_addr & src_mask.s_addr;
798     proxy_entry->dst_addr.s_addr = dst_addr.s_addr & dst_mask.s_addr;
799     proxy_entry->src_mask = src_mask;
800     proxy_entry->dst_mask = dst_mask;
801 
802     RuleAdd(proxy_entry);
803 
804     return 0;
805 }
806