]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/cmd/ip/socksd.c
etheriwl: don't break controller on command flush timeout
[plan9front.git] / sys / src / cmd / ip / socksd.c
index 2128a5ff68847713c8af23791ef01acceb47e04c..8e3ff19c69bbf1c1f660a6097826afdce6d9ebaf 100644 (file)
@@ -3,6 +3,8 @@
 #include <ip.h>
 
 int socksver;
+char inside[128];
+char outside[128];
 
 int
 str2addr(char *s, uchar *a)
@@ -76,6 +78,7 @@ addr2str(char *proto, uchar *a){
                case 0x03:
                        n = *a++;
                        port = nhgets(a+n);
+                       n = utfnlen((char*)a, n);
                        snprint(s, sizeof(s), "%s!%.*s!%d", proto, n, (char*)a, port);
                        return s;
                }
@@ -84,6 +87,104 @@ addr2str(char *proto, uchar *a){
        return s;
 }
 
+int
+udprelay(int fd, char *dir)
+{
+       struct {
+               Udphdr;
+               uchar data[8*1024];
+       } msg;
+       char addr[128], ldir[40];
+       int r, n, rfd, cfd;
+       uchar *p;
+
+       snprint(addr, sizeof(addr), "%s/udp!*!0", outside);
+       if((cfd = announce(addr, ldir)) < 0)
+               return -1;
+       if(write(cfd, "headers", 7) != 7)
+               return -1;
+       strcat(ldir, "/data");
+       if((rfd = open(ldir, ORDWR)) < 0)
+               return -1;
+       close(cfd);
+       
+       if((r = rfork(RFMEM|RFPROC|RFNOWAIT)) <= 0)
+               return r;
+
+       if((cfd = listen(dir, ldir)) < 0)
+               return -1;
+       close(fd);      /* close inside udp server */
+       if((fd = accept(cfd, ldir)) < 0)
+               return -1;
+
+       switch(rfork(RFMEM|RFPROC|RFNOWAIT)){
+       case -1:
+               return -1;
+       case 0:
+               while((r = read(fd, msg.data, sizeof(msg.data))) > 0){
+                       if(r < 4)
+                               continue;
+                       p = msg.data;
+                       if(p[0] | p[1] | p[2])
+                               continue;
+                       p += 3;
+                       switch(*p++){
+                       default:
+                               continue;
+                       case 0x01:
+                               r -= 2+1+1+4+2;
+                               if(r < 0)
+                                       continue;
+                               v4tov6(msg.raddr, p);
+                               p += 4;
+                               break;
+                       case 0x04:
+                               r -= 2+1+1+16+2;
+                               if(r < 0)
+                                       continue;
+                               memmove(msg.raddr, p, 16);
+                               p += 16;
+                               break;
+                       }
+                       memmove(msg.rport, p, 2);
+                       p += 2;
+                       memmove(msg.data, p, r);
+                       write(rfd, &msg, sizeof(Udphdr)+r);
+               }
+               break;
+       default:
+               while((r = read(rfd, &msg, sizeof(msg))) > 0){
+                       r -= sizeof(Udphdr);
+                       if(r < 0)
+                               continue;
+                       p = msg.data;
+                       if(isv4(msg.raddr))
+                               n = 2+1+1+4+2;
+                       else
+                               n = 2+1+1+16+2;
+                       if(r+n > sizeof(msg.data))
+                               r = sizeof(msg.data)-n;
+                       memmove(p+n, p, r);
+                       *p++ = 0;
+                       *p++ = 0;
+                       *p++ = 0;
+                       if(isv4(msg.raddr)){
+                               *p++ = 0x01;
+                               v6tov4(p, msg.raddr);
+                               p += 4;
+                       } else {
+                               *p++ = 0x04;
+                               memmove(p, msg.raddr, 16);
+                               p += 16;
+                       }
+                       memmove(p, msg.rport, 2);
+                       r += n;
+                       write(fd, msg.data, r);
+               }
+       }
+       return -1;
+}
+
 int
 sockerr(int err)
 {
@@ -98,13 +199,21 @@ void
 main(int argc, char *argv[])
 {
        uchar buf[8*1024], *p;
-       char dir[40], *s;
+       char addr[128], dir[40], ldir[40], *s;
+       int cmd, fd, cfd, n;
        NetConnInfo *nc;
-       int fd, cfd, n;
 
        fmtinstall('I', eipfmt);
 
+       setnetmtpt(inside, sizeof(inside), 0);
+       setnetmtpt(outside, sizeof(outside), 0);
        ARGBEGIN {
+       case 'x':
+               setnetmtpt(inside, sizeof(inside), ARGF());
+               break;
+       case 'o':
+               setnetmtpt(outside, sizeof(outside), ARGF());
+               break;
        } ARGEND;
 
        /* ver+cmd or ver+nmethod */
@@ -178,26 +287,33 @@ main(int argc, char *argv[])
                }
        }
 
-       nc = nil;
        dir[0] = 0;
        fd = cfd = -1;
-       switch(buf[1]){
+       cmd = buf[1];
+       switch(cmd){
        case 0x01:      /* CONNECT */
-               if((s = addr2str("tcp", buf)) == nil)
-                       return;
+               snprint(addr, sizeof(addr), "%s/tcp", outside);
+               if((s = addr2str(addr, buf)) == nil)
+                       break;
+               alarm(30000);
                fd = dial(s, 0, dir, &cfd);
+               alarm(0);
+               break;
+       case 0x02:      /* BIND */
+               if(myipaddr(buf, outside) < 0)
+                       break;
+               snprint(addr, sizeof(addr), "%s/tcp!%I!0", outside, buf);
+               fd = announce(addr, dir);
+               break;
+       case 0x03:      /* UDP */
+               if(myipaddr(buf, inside) < 0)
+                       break;
+               snprint(addr, sizeof(addr), "%s/udp!%I!0", inside, buf);
+               fd = announce(addr, dir);
                break;
        }
 
-       if(fd >= 0){
-               if((nc = getnetconninfo(dir, -1)) == nil){
-                       if(cfd >= 0)
-                               close(cfd);
-                       close(fd);
-                       fd = cfd = -1;
-               }
-       }
-
+Reply:
        /* reply */
        buf[1] = sockerr(fd < 0);                       /* status */
        if(socksver == 4){
@@ -217,12 +333,38 @@ main(int argc, char *argv[])
                        return;
                }
        }
-       if((n = str2addr(nc->laddr, buf)) <= 0)
+       if((nc = getnetconninfo(dir, cfd)) == nil)
+               return;
+       if((n = str2addr((cmd & 0x100) ? nc->raddr : nc->laddr, buf)) <= 0)
                return;
        if(write(1, buf, n) != n)
                return;
 
-       /* reley data */
+       switch(cmd){
+       default:
+               return;
+       case 0x01:      /* CONNECT */
+               break;
+       case 0x02:      /* BIND */
+               cfd = listen(dir, ldir);
+               close(fd);
+               fd = -1;
+               if(cfd >= 0){
+                       strcpy(dir, ldir);
+                       fd = accept(cfd, dir);
+               }
+               cmd |= 0x100;
+               goto Reply;
+       case 0x102:
+               break;          
+       case 0x03:      /* UDP */
+               if(udprelay(fd, dir) == 0)
+                       while(read(0, buf, sizeof(buf)) > 0)
+                               ;
+               goto Hangup;
+       }
+       
+       /* relay data */
        switch(rfork(RFMEM|RFPROC|RFFDG|RFNOWAIT)){
        case -1:
                return;
@@ -232,12 +374,12 @@ main(int argc, char *argv[])
        default:
                dup(fd, 1);
        }
-       close(fd);
        while((n = read(0, buf, sizeof(buf))) > 0)
                if(write(1, buf, n) != n)
                        break;
+Hangup:
        if(cfd >= 0)
                hangup(cfd);
-       exits(0);
+       postnote(PNGROUP, getpid(), "kill");
 }