]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/lib9p/srv.c
rc: avoid stat calls for directory globbing
[plan9front.git] / sys / src / lib9p / srv.c
index ca0b6a30b241636150f01037c87ef014fdab54ef..f3039e01231bfe8ddb2ffd6279d398f43dfe0a25 100644 (file)
@@ -59,7 +59,8 @@ getreq(Srv *s)
        Req *r;
 
        qlock(&s->rlock);
-       if((n = read9pmsg(s->infd, s->rbuf, s->msize)) <= 0){
+       n = read9pmsg(s->infd, s->rbuf, s->msize);
+       if(n <= 0){
                qunlock(&s->rlock);
                return nil;
        }
@@ -163,23 +164,35 @@ walkandclone(Req *r, char *(*walk1)(Fid*, char*, void*), char *(*clone)(Fid*, Fi
 }
 
 static void
-sversion(Srv*, Req *r)
+sversion(Srv *srv, Req *r)
 {
+       if(srv->rref.ref != 1){
+               respond(r, Ebotch);
+               return;
+       }
        if(strncmp(r->ifcall.version, "9P", 2) != 0){
                r->ofcall.version = "unknown";
+               r->ofcall.msize = 256;
                respond(r, nil);
                return;
        }
-
        r->ofcall.version = "9P2000";
-       r->ofcall.msize = r->ifcall.msize;
+       if(r->ifcall.msize < 256){
+               respond(r, "version: message size too small");
+               return;
+       }
+       if(r->ifcall.msize < 1024*1024)
+               r->ofcall.msize = r->ifcall.msize;
+       else
+               r->ofcall.msize = 1024*1024;
        respond(r, nil);
 }
+
 static void
 rversion(Req *r, char *error)
 {
-       assert(error == nil);
-       changemsize(r->srv, r->ofcall.msize);
+       if(error == nil)
+               changemsize(r->srv, r->ofcall.msize);
 }
 
 static void
@@ -201,8 +214,14 @@ sauth(Srv *srv, Req *r)
 static void
 rauth(Req *r, char *error)
 {
-       if(error && r->afid)
+       if(r->afid == nil)
+               return;
+       if(error){
                closefid(removefid(r->srv->fpool, r->afid->fid));
+               return;
+       }
+       if(r->afid->omode == -1)
+               r->afid->omode = ORDWR;
 }
 
 static void
@@ -352,6 +371,23 @@ rwalk(Req *r, char *error)
        }
 }
 
+static int
+dirwritable(Fid *fid)
+{
+       File *f;
+
+       f = fid->file;
+       if(f){
+               rlock(f);
+               if(f->parent && !hasperm(f->parent, fid->uid, AWRITE)){
+                       runlock(f);
+                       return 0;
+               }
+               runlock(f);
+       }
+       return 1;
+}
+
 static void
 sopen(Srv *srv, Req *r)
 {
@@ -372,7 +408,8 @@ sopen(Srv *srv, Req *r)
        r->ofcall.qid = r->fid->qid;
        switch(r->ifcall.mode&3){
        default:
-               assert(0);
+               respond(r, Ebotch);
+               return;
        case OREAD:
                p = AREAD;      
                break;
@@ -397,9 +434,7 @@ sopen(Srv *srv, Req *r)
                        respond(r, Eperm);
                        return;
                }
-       /* BUG RACE */
-               if((r->ifcall.mode&ORCLOSE)
-               && !hasperm(r->fid->file->parent, r->fid->uid, AWRITE)){
+               if((r->ifcall.mode&ORCLOSE) && !dirwritable(r->fid)){
                        respond(r, Eperm);
                        return;
                }
@@ -415,21 +450,6 @@ sopen(Srv *srv, Req *r)
        else
                respond(r, nil);
 }
-static void
-ropen(Req *r, char *error)
-{
-       char errbuf[ERRMAX];
-       if(error)
-               return;
-       if(chatty9p){
-               snprint(errbuf, sizeof errbuf, "fid mode is 0x%ux\n", r->ifcall.mode);
-               write(2, errbuf, strlen(errbuf));
-       }
-       r->fid->omode = r->ifcall.mode;
-       r->fid->qid = r->ofcall.qid;
-       if(r->ofcall.qid.type&QTDIR)
-               r->fid->diroffset = 0;
-}
 
 static void
 screate(Srv *srv, Req *r)
@@ -447,13 +467,18 @@ screate(Srv *srv, Req *r)
        else
                respond(r, Enocreate);
 }
+
 static void
-rcreate(Req *r, char *error)
+ropen(Req *r, char *error)
 {
        if(error)
                return;
-       r->fid->omode = r->ifcall.mode;
+       if(chatty9p)
+               fprint(2, "fid mode is %x\n", (int)r->ifcall.mode);
+       if(r->ofcall.qid.type&QTDIR)
+               r->fid->diroffset = 0;
        r->fid->qid = r->ofcall.qid;
+       r->fid->omode = r->ifcall.mode;
 }
 
 static void
@@ -465,6 +490,20 @@ sread(Srv *srv, Req *r)
                respond(r, Eunknownfid);
                return;
        }
+       o = r->fid->omode;
+       if(o == -1){
+               respond(r, Ebotch);
+               return;
+       }
+       switch(o & 3){
+       default:
+               respond(r, Ebotch);
+               return;
+       case OREAD:
+       case ORDWR:
+       case OEXEC:
+               break;
+       }
        if((int)r->ifcall.count < 0){
                respond(r, Ebotch);
                return;
@@ -474,18 +513,12 @@ sread(Srv *srv, Req *r)
                respond(r, Ebadoffset);
                return;
        }
-
        if(r->ifcall.count > srv->msize - IOHDRSZ)
                r->ifcall.count = srv->msize - IOHDRSZ;
        r->rbuf = emalloc9p(r->ifcall.count);
        r->ofcall.data = r->rbuf;
-       o = r->fid->omode & 3;
-       if(o != OREAD && o != ORDWR && o != OEXEC){
-               respond(r, Ebotch);
-               return;
-       }
        if((r->fid->qid.type&QTDIR) && r->fid->file){
-               r->ofcall.count = readdirfile(r->fid->rdir, r->rbuf, r->ifcall.count);
+               r->ofcall.count = readdirfile(r->fid->rdir, r->rbuf, r->ifcall.count, r->ifcall.offset);
                respond(r, nil);
                return;
        }
@@ -498,19 +531,35 @@ static void
 rread(Req *r, char *error)
 {
        if(error==nil && (r->fid->qid.type&QTDIR))
-               r->fid->diroffset += r->ofcall.count;
+               r->fid->diroffset = r->ifcall.offset + r->ofcall.count;
 }
 
 static void
 swrite(Srv *srv, Req *r)
 {
        int o;
-       char e[ERRMAX];
 
        if((r->fid = lookupfid(srv->fpool, r->ifcall.fid)) == nil){
                respond(r, Eunknownfid);
                return;
        }
+       o = r->fid->omode;
+       if(o == -1){
+               respond(r, Ebotch);
+               return;
+       }
+       switch(o & 3){
+       default:
+               respond(r, Ebotch);
+               return;
+       case OWRITE:
+       case ORDWR:
+               break;
+       }
+       if(r->fid->qid.type&QTDIR){
+               respond(r, Ebotch);
+               return;
+       }
        if((int)r->ifcall.count < 0){
                respond(r, Ebotch);
                return;
@@ -521,16 +570,10 @@ swrite(Srv *srv, Req *r)
        }
        if(r->ifcall.count > srv->msize - IOHDRSZ)
                r->ifcall.count = srv->msize - IOHDRSZ;
-       o = r->fid->omode & 3;
-       if(o != OWRITE && o != ORDWR){
-               snprint(e, sizeof e, "write on fid with open mode 0x%ux", r->fid->omode);
-               respond(r, e);
-               return;
-       }
        if(srv->write)
                srv->write(r);
        else
-               respond(r, "no srv->write");
+               respond(r, Enowrite);
 }
 static void
 rwrite(Req *r, char *error)
@@ -561,8 +604,7 @@ sremove(Srv *srv, Req *r)
                respond(r, Eunknownfid);
                return;
        }
-       /* BUG RACE */
-       if(r->fid->file && !hasperm(r->fid->file->parent, r->fid->uid, AWRITE)){
+       if(!dirwritable(r->fid)){
                respond(r, Eperm);
                return;
        }
@@ -655,25 +697,32 @@ swstat(Srv *srv, Req *r)
                respond(r, Ebaddir);
                return;
        }
-       if((ushort)~r->d.type){
-               respond(r, "wstat -- attempt to change type");
-               return;
-       }
-       if((uint)~r->d.dev){
-               respond(r, "wstat -- attempt to change dev");
-               return;
-       }
-       if((uchar)~r->d.qid.type || (ulong)~r->d.qid.vers || (uvlong)~r->d.qid.path){
-               respond(r, "wstat -- attempt to change qid");
+       if(r->d.qid.path != ~0 && r->d.qid.path != r->fid->qid.path){
+               respond(r, "wstat -- attempt to change qid.path");
                return;
        }
-       if(r->d.muid && r->d.muid[0]){
-               respond(r, "wstat -- attempt to change muid");
+       if(r->d.qid.vers != ~0 && r->d.qid.vers != r->fid->qid.vers){
+               respond(r, "wstat -- attempt to change qid.vers");
                return;
        }
-       if((ulong)~r->d.mode && ((r->d.mode&DMDIR)>>24) != (r->fid->qid.type&QTDIR)){
-               respond(r, "wstat -- attempt to change DMDIR bit");
-               return;
+       if(r->d.mode != ~0){
+               if(r->d.mode & ~(DMDIR|DMAPPEND|DMEXCL|DMTMP|0777)){
+                       respond(r, "wstat -- unknown bits in mode");
+                       return;
+               }
+               if(r->d.qid.type != (uchar)~0 && r->d.qid.type != ((r->d.mode>>24)&0xFF)){
+                       respond(r, "wstat -- qid.type/mode mismatch");
+                       return;
+               }
+               if(((r->d.mode>>24) ^ r->fid->qid.type) & ~(QTAPPEND|QTEXCL|QTTMP)){
+                       respond(r, "wstat -- attempt to change qid.type");
+                       return;
+               }
+       } else {
+               if(r->d.qid.type != (uchar)~0 && r->d.qid.type != r->fid->qid.type){
+                       respond(r, "wstat -- attempt to change qid.type");
+                       return;
+               }
        }
        srv->wstat(r);
 }
@@ -682,31 +731,21 @@ rwstat(Req*, char*)
 {
 }
 
-void
-srv(Srv *srv)
+static void srvclose(Srv *);
+
+static void
+srvwork(void *v)
 {
+       Srv *srv = v;
        Req *r;
 
-       fmtinstall('D', dirfmt);
-       fmtinstall('F', fcallfmt);
-
-       if(srv->fpool == nil)
-               srv->fpool = allocfidpool(srv->destroyfid);
-       if(srv->rpool == nil)
-               srv->rpool = allocreqpool(srv->destroyreq);
-       if(srv->msize == 0)
-               srv->msize = 8192+IOHDRSZ;
-
-       changemsize(srv, srv->msize);
-
-       srv->fpool->srv = srv;
-       srv->rpool->srv = srv;
-
        while(r = getreq(srv)){
+               incref(&srv->rref);
                if(r->error){
                        respond(r, r->error);
                        continue;       
                }
+               qlock(&srv->slock);
                switch(r->ifcall.type){
                default:
                        respond(r, "unknown message");
@@ -725,8 +764,29 @@ srv(Srv *srv)
                case Tstat:     sstat(srv, r);  break;
                case Twstat:    swstat(srv, r); break;
                }
+               if(srv->sref.ref > 8 && srv->spid != getpid()){
+                       decref(&srv->sref);
+                       qunlock(&srv->slock);
+                       return;
+               }
+               qunlock(&srv->slock);
        }
 
+       if(srv->end && srv->sref.ref == 1)
+               srv->end(srv);
+       if(decref(&srv->sref) == 0)
+               srvclose(srv);
+}
+
+static void
+srvclose(Srv *srv)
+{
+       if(srv->rref.ref || srv->sref.ref)
+               return;
+
+       if(chatty9p)
+               fprint(2, "srvclose\n");
+
        free(srv->rbuf);
        srv->rbuf = nil;
        free(srv->wbuf);
@@ -737,8 +797,54 @@ srv(Srv *srv)
        freereqpool(srv->rpool);
        srv->rpool = nil;
 
-       if(srv->end)
-               srv->end(srv);
+       if(srv->free)
+               srv->free(srv);
+}
+
+void
+srvacquire(Srv *srv)
+{
+       incref(&srv->sref);
+       qlock(&srv->slock);
+}
+
+void
+srvrelease(Srv *srv)
+{
+       if(decref(&srv->sref) == 0){
+               incref(&srv->sref);
+               _forker(srvwork, srv, 0);
+       }
+       qunlock(&srv->slock);
+}
+
+void
+srv(Srv *srv)
+{
+       fmtinstall('D', dirfmt);
+       fmtinstall('F', fcallfmt);
+
+       srv->spid = getpid();
+       memset(&srv->sref, 0, sizeof(srv->sref));
+       memset(&srv->rref, 0, sizeof(srv->rref));
+
+       if(srv->fpool == nil)
+               srv->fpool = allocfidpool(srv->destroyfid);
+       if(srv->rpool == nil)
+               srv->rpool = allocreqpool(srv->destroyreq);
+       if(srv->msize == 0)
+               srv->msize = 8192+IOHDRSZ;
+
+       changemsize(srv, srv->msize);
+
+       srv->fpool->srv = srv;
+       srv->rpool->srv = srv;
+
+       if(srv->start)
+               srv->start(srv);
+
+       incref(&srv->sref);
+       srvwork(srv);
 }
 
 void
@@ -755,8 +861,6 @@ respond(Req *r, char *error)
        r->error = error;
 
        switch(r->ifcall.type){
-       default:
-               assert(0);
        /*
         * Flush is special.  If the handler says so, we return
         * without further processing.  Respond will be called
@@ -771,7 +875,7 @@ respond(Req *r, char *error)
        case Tattach:   rattach(r, error);      break;
        case Twalk:     rwalk(r, error);        break;
        case Topen:     ropen(r, error);        break;
-       case Tcreate:   rcreate(r, error);      break;
+       case Tcreate:   ropen(r, error);        break;
        case Tread:     rread(r, error);        break;
        case Twrite:    rwrite(r, error);       break;
        case Tclunk:    rclunk(r, error);       break;
@@ -791,7 +895,7 @@ if(chatty9p)
        qlock(&srv->wlock);
        n = convS2M(&r->ofcall, srv->wbuf, srv->msize);
        if(n <= 0){
-               fprint(2, "n = %d %F\n", n, &r->ofcall);
+               fprint(2, "msize = %d n = %d %F\n", srv->msize, n, &r->ofcall);
                abort();
        }
        assert(n > 2);
@@ -799,7 +903,7 @@ if(chatty9p)
                closereq(removereq(r->pool, r->ifcall.tag));
        m = write(srv->outfd, srv->wbuf, n);
        if(m != n)
-               sysfatal("lib9p srv: write %d returned %d on fd %d: %r", n, m, srv->outfd);
+               fprint(2, "lib9p srv: write %d returned %d on fd %d: %r", n, m, srv->outfd);
        qunlock(&srv->wlock);
 
        qlock(&r->lk);  /* no one will add flushes now */
@@ -816,6 +920,9 @@ if(chatty9p)
                closereq(r);
        else
                free(r);
+
+       if(decref(&srv->rref) == 0)
+               srvclose(srv);
 }
 
 void
@@ -826,30 +933,3 @@ responderror(Req *r)
        rerrstr(errbuf, sizeof errbuf);
        respond(r, errbuf);
 }
-
-int
-postfd(char *name, int pfd)
-{
-       int fd;
-       char buf[80];
-
-       snprint(buf, sizeof buf, "/srv/%s", name);
-       if(chatty9p)
-               fprint(2, "postfd %s\n", buf);
-       fd = create(buf, OWRITE|ORCLOSE|OCEXEC, 0600);
-       if(fd < 0){
-               if(chatty9p)
-                       fprint(2, "create fails: %r\n");
-               return -1;
-       }
-       if(fprint(fd, "%d", pfd) < 0){
-               if(chatty9p)
-                       fprint(2, "write fails: %r\n");
-               close(fd);
-               return -1;
-       }
-       if(chatty9p)
-               fprint(2, "postfd successful\n");
-       return 0;
-}
-