]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ip/tftpfs.c
merge
[plan9front.git] / sys / src / cmd / ip / tftpfs.c
1 #include <u.h>
2 #include <libc.h>
3 #include <thread.h>
4 #include <auth.h>
5 #include <fcall.h>
6 #include <9p.h>
7 #include <ip.h>
8
9 enum {
10         Qdata = 1,
11
12         Tftp_READ       = 1,
13         Tftp_WRITE      = 2,
14         Tftp_DATA       = 3,
15         Tftp_ACK        = 4,
16         Tftp_ERROR      = 5,
17         Tftp_OACK       = 6,
18
19         TftpPort        = 69,
20
21         Segsize         = 512,
22         Maxpath         = 2+2+Segsize-8,
23 };
24
25 typedef struct Tfile Tfile;
26 struct Tfile
27 {
28         int     id;
29         uchar   addr[IPaddrlen];
30         char    path[Maxpath];
31         Channel *c;
32         Tfile   *next;
33         Ref;
34 };
35
36 char net[Maxpath];
37 uchar ipaddr[IPaddrlen];
38 static ulong time0;
39 Tfile *files;
40
41 static Tfile*
42 tfileget(uchar *addr, char *path)
43 {
44         Tfile *f;
45         static int id;
46
47         for(f = files; f; f = f->next){
48                 if(memcmp(addr, f->addr, IPaddrlen) == 0 && strcmp(path, f->path) == 0){
49                         incref(f);
50                         return f;
51                 }
52         }
53         f = emalloc9p(sizeof *f);
54         memset(f, 0, sizeof(*f));
55         ipmove(f->addr, addr);
56         strncpy(f->path, path, Maxpath-1);
57         f->ref = 1;
58         f->id = id++;
59         f->next = files;
60         files = f;
61
62         return f;
63 }
64
65 static void
66 tfileput(Tfile *f)
67 {
68         Channel *c;
69         Tfile **pp;
70
71         if(f==nil || decref(f))
72                 return;
73         if(c = f->c){
74                 f->c = nil;
75                 sendp(c, nil);
76         }
77         for(pp = &files; *pp; pp = &(*pp)->next){
78                 if(*pp == f){
79                         *pp = f->next;
80                         break;
81                 }
82         }
83         free(f);
84 }
85
86 static char*
87 basename(char *p)
88 {
89         char *b;
90
91         for(b = p; *p; p++)
92                 if(*p == '/')
93                         b = p+1;
94         return b;
95 }
96
97 static void
98 tfilestat(Req *r, char *path, vlong length)
99 {
100         memset(&r->d, 0, sizeof(r->d));
101         r->d.uid = estrdup9p("tftp");
102         r->d.gid = estrdup9p("tftp");
103         r->d.name = estrdup9p(basename(path));
104         r->d.atime = r->d.mtime = time0;
105         r->d.length = length;
106         r->d.qid.path = r->fid->qid.path;
107         if(r->fid->qid.path & Qdata){
108                 r->d.qid.type = 0;
109                 r->d.mode = 0555;
110         } else {
111                 r->d.qid.type = QTDIR;
112                 r->d.mode = DMDIR|0555;
113         }
114         respond(r, nil);
115 }
116
117 static void
118 catch(void *, char *msg)
119 {
120         if(strstr(msg, "alarm"))
121                 noted(NCONT);
122         noted(NDFLT);
123 }
124
125 static int
126 filereq(uchar *buf, char *path)
127 {
128         uchar *p;
129         int n;
130
131         hnputs(buf, Tftp_READ);
132         p = buf+2;
133         n = strlen(path);
134
135         /* hack: remove the trailing dot */
136         if(path[n-1] == '.')
137                 n--;
138
139         memcpy(p, path, n);
140         p += n;
141         *p++ = 0;
142         memcpy(p, "octet", 6);
143         p += 6;
144         return p - buf;
145 }
146
147 static void
148 download(void *aux)
149 {
150         int fd, cfd, last, block, seq, n, ndata;
151         char *err, adir[40], buf[256];
152         uchar *data;
153         Channel *c;
154         Tfile *f;
155         Req *r;
156
157         struct {
158                 Udphdr;
159                 uchar buf[2+2+Segsize+1];
160         } msg;
161
162         c = nil;
163         r = nil;
164         fd = cfd = -1;
165         err = nil;
166         data = nil;
167         ndata = 0;
168
169         if((f = aux) == nil)
170                 goto out;
171         if((c = f->c) == nil)
172                 goto out;
173
174         threadsetname("%s", f->path);
175
176         snprint(buf, sizeof(buf), "%s/udp!*!0", net);
177         if((cfd = announce(buf, adir)) < 0){
178                 err = "announce: %r";
179                 goto out;
180         }
181         if(write(cfd, "headers", 7) < 0){
182                 err = "write ctl: %r";
183                 goto out;
184         }
185         strcat(adir, "/data");
186         if((fd = open(adir, ORDWR)) < 0){
187                 err = "open: %r";
188                 goto out;
189         }
190
191         n = filereq(msg.buf, f->path);
192         ipmove(msg.raddr, f->addr);
193         hnputs(msg.rport, TftpPort);
194         if(write(fd, &msg, sizeof(Udphdr) + n) < 0){
195                 err = "send read request: %r";
196                 goto out;
197         }
198
199         notify(catch);
200
201         seq = 1;
202         last = 0;
203         while(!last){
204                 alarm(5000);
205                 if((n = read(fd, &msg, sizeof(Udphdr) + sizeof(msg.buf)-1)) < 0){
206                         err = "receive response: %r";
207                         goto out;
208                 }
209                 alarm(0);
210
211                 n -= sizeof(Udphdr);
212                 msg.buf[n] = 0;
213                 switch(nhgets(msg.buf)){
214                 case Tftp_ERROR:
215                         werrstr("%s", (char*)msg.buf+4);
216                         err = "%r";
217                         goto out;
218
219                 case Tftp_DATA:
220                         if(n < 4)
221                                 continue;
222                         block = nhgets(msg.buf+2);
223                         if(block > seq)
224                                 continue;
225                         hnputs(msg.buf, Tftp_ACK);
226                         if(write(fd, &msg, sizeof(Udphdr) + 4) < 0){
227                                 err = "send acknowledge: %r";
228                                 goto out;
229                         }
230                         if(block < seq)
231                                 continue;
232                         seq = block+1;
233                         n -= 4;
234                         if(n < Segsize)
235                                 last = 1;
236                         data = erealloc9p(data, ndata + n);
237                         memcpy(data + ndata, msg.buf+4, n);
238                         ndata += n;
239
240                 rloop:  /* hanlde read request while downloading */
241                         if((r != nil) && (r->ifcall.type == Tread) && (r->ifcall.offset < ndata)){
242                                 readbuf(r, data, ndata);
243                                 respond(r, nil);
244                                 r = nil;
245                         }
246                         if((r == nil) && (nbrecv(c, &r) == 1)){
247                                 if(r == nil){
248                                         chanfree(c);
249                                         c = nil;
250                                         goto out;
251                                 }
252                                 goto rloop;
253                         }
254                         break;
255                 }
256         }
257
258 out:
259         alarm(0);
260         if(cfd >= 0)
261                 close(cfd);
262         if(fd >= 0)
263                 close(fd);
264
265         if(c){
266                 while((r != nil) || (r = recvp(c))){
267                         if(err){
268                                 snprint(buf, sizeof(buf), err);
269                                 respond(r, buf);
270                         } else {
271                                 switch(r->ifcall.type){
272                                 case Tread:
273                                         readbuf(r, data, ndata);
274                                         respond(r, nil);
275                                         break;
276                                 case Tstat:
277                                         tfilestat(r, f->path, ndata);
278                                         break;
279                                 default:
280                                         respond(r, "bug in fs");
281                                 }
282                         }
283                         r = nil;
284                 }
285                 chanfree(c);
286         }
287         free(data);
288 }
289
290 static void
291 fsattach(Req *r)
292 {
293         Tfile *f;
294
295         if(r->ifcall.aname && r->ifcall.aname[0]){
296                 uchar addr[IPaddrlen];
297
298                 if(parseip(addr, r->ifcall.aname) == -1){
299                         respond(r, "bad ip specified");
300                         return;
301                 }
302                 f = tfileget(addr, "/");
303         } else {
304                 if(ipcmp(ipaddr, IPnoaddr) == 0){
305                         respond(r, "no ipaddr specified");
306                         return;
307                 }
308                 f = tfileget(ipaddr, "/");
309         }
310         r->fid->aux = f;
311         r->fid->qid.type = QTDIR;
312         r->fid->qid.path = f->id<<1;
313         r->fid->qid.vers = 0;
314         r->ofcall.qid = r->fid->qid;
315         respond(r, nil);
316 }
317
318 static char*
319 fswalk1(Fid *fid, char *name, Qid *qid)
320 {
321         Tfile *f;
322         char *t;
323
324         f = fid->aux;
325         t = smprint("%s/%s", f->path, name);
326         f = tfileget(f->addr, cleanname(t));
327         free(t);
328         tfileput(fid->aux); fid->aux = f;
329         fid->qid.type = QTDIR;
330         fid->qid.path = f->id<<1;
331
332         /* hack:
333          * a dot in the path means the path element is not
334          * a directory. to force download of files containing
335          * no dot, a trailing dot can be appended that will
336          * be stripped out in the tftp read request.
337          */
338         if(strchr(f->path, '.') != nil){
339                 fid->qid.type = 0;
340                 fid->qid.path |= Qdata;
341         }
342
343         if(qid)
344                 *qid = fid->qid;
345         return nil;
346 }
347
348 static char*
349 fsclone(Fid *oldfid, Fid *newfid)
350 {
351         Tfile *f;
352
353         f = oldfid->aux;
354         incref(f);
355         newfid->aux = f;
356         return nil;
357 }
358
359 static void
360 fsdestroyfid(Fid *fid)
361 {
362         tfileput(fid->aux);
363         fid->aux = nil;
364 }
365
366 static void
367 fsopen(Req *r)
368 {
369         int m;
370
371         m = r->ifcall.mode & 3;
372         if(m != OREAD && m != OEXEC){
373                 respond(r, "permission denied");
374                 return;
375         }
376         respond(r, nil);
377 }
378
379 static void
380 dispatch(Req *r)
381 {
382         Tfile *f;
383
384         f = r->fid->aux;
385         if(f->c == nil){
386                 f->c = chancreate(sizeof(r), 0);
387                 proccreate(download, f, 16*1024);
388         }
389         sendp(f->c, r);
390 }
391
392 static void
393 fsread(Req *r)
394 {
395         if(r->fid->qid.path & Qdata){
396                 dispatch(r);
397         } else {
398                 respond(r, nil);
399         }
400 }
401
402 static void
403 fsstat(Req *r)
404 {
405         if(r->fid->qid.path & Qdata){
406                 dispatch(r);
407         } else {
408                 tfilestat(r, ((Tfile*)r->fid->aux)->path, 0);
409         }
410 }
411
412 Srv fs = 
413 {
414 .attach=        fsattach,
415 .destroyfid=    fsdestroyfid,
416 .walk1=         fswalk1,
417 .clone=         fsclone,
418 .open=          fsopen,
419 .read=          fsread,
420 .stat=          fsstat,
421 };
422
423 void
424 usage(void)
425 {
426         fprint(2, "usage: tftpfs [-D] [-s srvname] [-m mtpt] [-x net] [ipaddr]\n");
427         threadexitsall("usage");
428 }
429
430 void
431 threadmain(int argc, char **argv)
432 {
433         char *srvname = nil;
434         char *mtpt = "/n/tftp";
435
436         time0 = time(0);
437         strcpy(net, "/net");
438         ipmove(ipaddr, IPnoaddr);
439
440         ARGBEGIN{
441         case 'D':
442                 chatty9p++;
443                 break;
444         case 's':
445                 srvname = EARGF(usage());
446                 mtpt = nil;
447                 break;
448         case 'm':
449                 mtpt = EARGF(usage());
450                 break;
451         case 'x':
452                 setnetmtpt(net, sizeof net, EARGF(usage()));
453                 break;
454         default:
455                 usage();
456         }ARGEND;
457
458         switch(argc){
459         case 0:
460                 break;
461         case 1:
462                 if(parseip(ipaddr, *argv) == -1)
463                         usage();
464                 break;
465         default:
466                 usage();
467         }
468
469         if(srvname==nil && mtpt==nil)
470                 usage();
471
472         threadpostmountsrv(&fs, srvname, mtpt, MREPL|MCREATE);
473 }