]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/cmd/cpu.c
cpu: cleanup ssl code, make sure -p works for any auth method
[plan9front.git] / sys / src / cmd / cpu.c
index b797ff142524cb33aaf82596f6d9cfedbdc3f0c4..44792fc9025bca6ae28942f66f3ea9ed50624942 100644 (file)
@@ -24,7 +24,7 @@ void  writestr(int, char*, char*, int);
 int    readstr(int, char*, int);
 char   *rexcall(int*, char*, char*);
 int    setamalg(char*);
-char *keyspec = "";
+char   *keyspec = "";
 
 int    notechan;
 int    exportpid;
@@ -43,6 +43,11 @@ char *ealgs = "rc4_256 sha1";
 /* message size for exportfs; may be larger so we can do big graphics in CPU window */
 int    msgsize = Maxfdata+IOHDRSZ;
 
+/* encryption mechanisms */
+static int     clear(int);
+
+int (*encryption)(int) = clear;
+
 /* authentication mechanisms */
 static int     netkeyauth(int);
 static int     netkeysrvauth(int, char*);
@@ -56,8 +61,7 @@ struct AuthMethod {
        char    *name;                  /* name of method */
        int     (*cf)(int);             /* client side authentication */
        int     (*sf)(int, char*);      /* server side authentication */
-} authmethod[] =
-{
+} authmethod[] = {
        { "p9",         p9auth,         srvp9auth,},
        { "netkey",     netkeyauth,     netkeysrvauth,},
        { "none",       noauth,         srvnoauth,},
@@ -73,7 +77,7 @@ char  *aan = "/bin/aan";
 char   *anstring = "tcp!*!0";
 char   *filterp = nil;
 
-int    filter(int fd, char *host);
+int filter(int fd, char *host);
 
 void
 usage(void)
@@ -335,7 +339,7 @@ old9p(int fd)
 void
 remoteside(int old)
 {
-       char user[MaxStr], home[MaxStr], buf[MaxStr], xdir[MaxStr], cmd[MaxStr];
+       char user[MaxStr], buf[MaxStr], xdir[MaxStr], cmd[MaxStr];
        int i, n, fd, badchdir, gotcmd;
 
        rfork(RFENVG);
@@ -360,14 +364,12 @@ remoteside(int old)
        } else
                writestr(fd, "", "", 1);
 
-       fd = (*am->sf)(fd, user);
-       if(fd < 0)
+       if((fd = (*am->sf)(fd, user)) < 0)
                fatal("srvauth: %r");
-
-       /* Set environment values for the user */
-       putenv("user", user);
-       snprint(home, sizeof(home), "/usr/%s", user);
-       putenv("home", home);
+       if((fd = filter(fd, nil)) < 0)
+               fatal("filter: %r");
+       if((fd = encryption(fd)) < 0)
+               fatal("encrypt: %r");
 
        /* Now collect invoking cpu's current directory or possibly a command */
        gotcmd = 0;
@@ -380,15 +382,11 @@ remoteside(int old)
                        fatal("dir: %r");
        }
 
-       /* Establish the new process at the current working directory of the
-        * gnot */
+       /* Establish the new process at the current working directory of the gnot */
        badchdir = 0;
-       if(strcmp(xdir, "NO") == 0)
-               chdir(home);
-       else if(chdir(xdir) < 0) {
-               badchdir = 1;
-               chdir(home);
-       }
+       if(strcmp(xdir, "NO") != 0)
+               if(chdir(xdir) < 0)
+                       badchdir = 1;
 
        /* Start the gnot serving its namespace */
        writestr(fd, "FS", "FS", 0);
@@ -436,14 +434,13 @@ char*
 rexcall(int *fd, char *host, char *service)
 {
        char *na;
-       char dir[MaxStr];
        char err[ERRMAX];
        char msg[MaxStr];
        int n;
 
        na = netmkaddr(host, 0, service);
        procsetname("dialing %s", na);
-       if((*fd = dial(na, 0, dir, 0)) < 0)
+       if((*fd = dial(na, 0, 0, 0)) < 0)
                return "can't dial";
 
        /* negotiate aan filter extension */
@@ -453,7 +450,7 @@ rexcall(int *fd, char *host, char *service)
                if(n < 0)
                        return "negotiating aan";
                if(*err){
-                       werrstr(err);
+                       errstr(err, sizeof err);
                        return negstr;
                }
        }
@@ -470,16 +467,19 @@ rexcall(int *fd, char *host, char *service)
        if(n < 0)
                return negstr;
        if(*err){
-               werrstr(err);
+               errstr(err, sizeof err);
                return negstr;
        }
 
        /* authenticate */
        procsetname("%s: auth via %s", origargs, am->name);
-       *fd = (*am->cf)(*fd);
-       if(*fd < 0)
+       if((*fd = (*am->cf)(*fd)) < 0)
                return "can't authenticate";
-       return 0;
+       if((*fd = filter(*fd, system)) < 0)
+               return "can't filter";
+       if((*fd = encryption(*fd)) < 0)
+               return "can't encrypt";
+       return nil;
 }
 
 void
@@ -550,7 +550,7 @@ netkeyauth(int fd)
                if(readstr(fd, chall, sizeof chall) < 0)
                        break;
                if(*chall == 0)
-                       return filter(fd, system);
+                       return fd;
                print("challenge: %s\nresponse: ", chall);
                if(readln(resp, sizeof(resp)) < 0)
                        break;
@@ -590,7 +590,21 @@ netkeysrvauth(int fd, char *user)
        if(auth_chuid(ai, 0) < 0)
                fatal("newns: %r");
        auth_freeAI(ai);
-       return filter(fd, nil);
+       return fd;
+}
+
+static int
+clear(int fd)
+{
+       return fd;
+}
+
+static char sslsecret[2][21];
+
+static int
+sslencrypt(int fd)
+{
+       return pushssl(fd, ealgs, sslsecret[0], sslsecret[1], nil);
 }
 
 static void
@@ -600,55 +614,63 @@ mksecret(char *t, uchar *f)
                f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
 }
 
-/*
- *  plan9 authentication followed by rc4 encryption
- */
 static int
-p9auth(int fd)
+sslsetup(int fd, uchar *secret, int nsecret, int isclient)
 {
        uchar key[16], digest[SHA1dlen];
-       char fromclientsecret[21];
-       char fromserversecret[21];
-       AuthInfo *ai;
        int i;
 
-       procsetname("%s: auth_proxy proto=%q role=client %s",
-               origargs, p9authproto, keyspec);
-       ai = auth_proxy(fd, auth_getkey, "proto=%q role=client %s", p9authproto, keyspec);
-       if(ai == nil)
-               return -1;
-       if(ealgs == nil){
-               auth_freeAI(ai);
+       if(ealgs == nil)
                return fd;
+
+       if(nsecret < 8){
+               werrstr("secret too small to ssl");
+               return -1;
        }
-       assert(ai->nsecret <= sizeof(key)-4);
-       memmove(key+4, ai->secret, ai->nsecret);
-       auth_freeAI(ai);
+       memmove(key+4, secret, 8);
 
        /* exchange random numbers */
        srand(truerand());
-       for(i = 0; i < 4; i++)
-               key[i] = rand();
-       procsetname("writing p9 key");
-       if(write(fd, key, 4) != 4)
-               return -1;
-       procsetname("reading p9 key");
-       if(readn(fd, key+12, 4) != 4)
-               return -1;
+
+       if(isclient){
+               for(i = 0; i < 4; i++)
+                       key[i] = rand();
+               if(write(fd, key, 4) != 4)
+                       return -1;
+               if(readn(fd, key+12, 4) != 4)
+                       return -1;
+       } else {
+               for(i = 0; i < 4; i++)
+                       key[i+12] = rand();
+               if(readn(fd, key, 4) != 4)
+                       return -1;
+               if(write(fd, key+12, 4) != 4)
+                       return -1;
+       }
 
        /* scramble into two secrets */
        sha1(key, sizeof(key), digest, nil);
-       mksecret(fromclientsecret, digest);
-       mksecret(fromserversecret, digest+10);
+       mksecret(sslsecret[isclient == 0], digest);
+       mksecret(sslsecret[isclient != 0], digest+10);
 
-       if((fd = filter(fd, system)) < 0)
-               return -1;
+       encryption = sslencrypt;
+
+       return fd;
+}
+
+/*
+ *  plan9 authentication followed by rc4 encryption
+ */
+static int
+p9auth(int fd)
+{
+       AuthInfo *ai;
 
-       /* set up encryption */
-       procsetname("pushssl");
-       fd = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
-       if(fd < 0)
-               werrstr("can't establish ssl connection: %r");
+       ai = auth_proxy(fd, auth_getkey, "proto=%q role=client %s", p9authproto, keyspec);
+       if(ai == nil)
+               return -1;
+       fd = sslsetup(fd, ai->secret, ai->nsecret, 1);
+       auth_freeAI(ai);
        return fd;
 }
 
@@ -668,25 +690,10 @@ srvnoauth(int fd, char *user)
        return fd;
 }
 
-void
-loghex(uchar *p, int n)
-{
-       char buf[100];
-       int i;
-
-       for(i = 0; i < n; i++)
-               sprint(buf+2*i, "%2.2ux", p[i]);
-       syslog(0, "cpu", "%s", buf);
-}
-
 static int
 srvp9auth(int fd, char *user)
 {
-       uchar key[16], digest[SHA1dlen];
-       char fromclientsecret[21];
-       char fromserversecret[21];
        AuthInfo *ai;
-       int i;
 
        ai = auth_proxy(fd, nil, "proto=%q role=server %s", p9authproto, keyspec);
        if(ai == nil)
@@ -694,35 +701,8 @@ srvp9auth(int fd, char *user)
        if(auth_chuid(ai, nil) < 0)
                fatal("newns: %r");
        snprint(user, MaxStr, "%s", ai->cuid);
-       if(ealgs == nil){
-               auth_freeAI(ai);
-               return fd;
-       }
-       assert(ai->nsecret <= sizeof(key)-4);
-       memmove(key+4, ai->secret, ai->nsecret);
+       fd = sslsetup(fd, ai->secret, ai->nsecret, 0);
        auth_freeAI(ai);
-
-       /* exchange random numbers */
-       srand(truerand());
-       for(i = 0; i < 4; i++)
-               key[i+12] = rand();
-       if(readn(fd, key, 4) != 4)
-               return -1;
-       if(write(fd, key+12, 4) != 4)
-               return -1;
-
-       /* scramble into two secrets */
-       sha1(key, sizeof(key), digest, nil);
-       mksecret(fromclientsecret, digest);
-       mksecret(fromserversecret, digest+10);
-
-       if((fd = filter(fd, nil)) < 0)
-               return -1;
-
-       /* set up encryption */
-       fd = pushssl(fd, ealgs, fromserversecret, fromclientsecret, nil);
-       if(fd < 0)
-               werrstr("can't establish ssl connection: %r");
        return fd;
 }
 
@@ -784,7 +764,8 @@ filter(int fd, char *host)
                buf[len] = '\0';
                if((s = strrchr(buf, '!')) == nil)
                        fatal("filter: malformed remote port: %s", buf);
-               snprint(addr, sizeof(addr), "%s", netmkaddr(host, "tcp", s+1));
+               strecpy(addr, addr+sizeof(addr), netmkaddr(host, "tcp", s+1));
+               strecpy(strrchr(addr, '!'), addr+sizeof(addr), s);
        }
 
        snprint(buf, sizeof(buf), "%s", filterp);