]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ip/tftpfs.c
merge
[plan9front.git] / sys / src / cmd / ip / tftpfs.c
1 #include <u.h>
2 #include <libc.h>
3 #include <thread.h>
4 #include <auth.h>
5 #include <fcall.h>
6 #include <9p.h>
7 #include <ip.h>
8
9 enum {
10         Qdata = 1,
11
12         Tftp_READ       = 1,
13         Tftp_WRITE      = 2,
14         Tftp_DATA       = 3,
15         Tftp_ACK        = 4,
16         Tftp_ERROR      = 5,
17         Tftp_OACK       = 6,
18
19         TftpPort                = 69,
20
21         Segsize         = 512,
22         Maxpath         = 2+2+Segsize-8,
23 };
24
25 typedef struct Tfile Tfile;
26 struct Tfile
27 {
28         int             id;
29         uchar   addr[IPaddrlen];
30         char            path[Maxpath];
31         Channel *c;
32         Tfile           *next;
33         Ref;
34 };
35
36 uchar ipaddr[IPaddrlen];
37 static ulong time0;
38 Tfile *files;
39
40
41 static Tfile*
42 tfileget(uchar *addr, char *path)
43 {
44         Tfile *f;
45         static int id;
46
47         for(f = files; f; f = f->next){
48                 if(memcmp(addr, f->addr, IPaddrlen) == 0 && strcmp(path, f->path) == 0){
49                         incref(f);
50                         return f;
51                 }
52         }
53         f = emalloc9p(sizeof *f);
54         memset(f, 0, sizeof(*f));
55         ipmove(f->addr, addr);
56         strncpy(f->path, path, sizeof(f->path));
57         f->ref = 1;
58         f->id = id++;
59         f->next = files;
60         files = f;
61
62         return f;
63 }
64
65 static void
66 tfileput(Tfile *f)
67 {
68         Channel *c;
69         Tfile **pp;
70
71         if(f==nil || decref(f))
72                 return;
73         if(c = f->c){
74                 f->c = nil;
75                 sendp(c, nil);
76         }
77         for(pp = &files; *pp; pp = &(*pp)->next){
78                 if(*pp == f){
79                         *pp = f->next;
80                         break;
81                 }
82         }
83         free(f);
84 }
85
86 static char*
87 basename(char *p)
88 {
89         char *b;
90
91         for(b = p; *p; p++)
92                 if(*p == '/')
93                         b = p+1;
94         return b;
95 }
96
97 static void
98 tfilestat(Req *r, char *path, vlong length)
99 {
100         memset(&r->d, 0, sizeof(r->d));
101         r->d.uid = estrdup9p("tftp");
102         r->d.gid = estrdup9p("tftp");
103         r->d.name = estrdup9p(basename(path));
104         r->d.atime = r->d.mtime = time0;
105         r->d.length = length;
106         r->d.qid.path = r->fid->qid.path;
107         if(r->fid->qid.path & Qdata){
108                 r->d.qid.type = 0;
109                 r->d.mode = 0555;
110         } else {
111                 r->d.qid.type = QTDIR;
112                 r->d.mode = DMDIR|0555;
113         }
114         respond(r, nil);
115 }
116
117 static void
118 catch(void *, char *msg)
119 {
120         if(strstr(msg, "alarm"))
121                 noted(NCONT);
122         noted(NDFLT);
123 }
124
125 static int
126 newport(void)
127 {
128         static int port;
129         return 5000+(port++)%64;
130 }
131
132 static int
133 filereq(uchar *buf, char *path)
134 {
135         uchar *p;
136         int n;
137
138         hnputs(buf, Tftp_READ);
139         p = buf+2;
140         n = strlen(path);
141
142         /* hack: remove the trailing dot */
143         if(path[n-1] == '.')
144                 n--;
145
146         memcpy(p, path, n);
147         p += n;
148         *p++ = 0;
149         memcpy(p, "octet", 6);
150         p += 6;
151         return p - buf;
152 }
153
154 static void
155 download(void *aux)
156 {
157         int fd, cfd, last, block, n, ndata;
158         char *err, addr[40], adir[40];
159         uchar *data;
160         Channel *c;
161         Tfile *f;
162         Req *r;
163
164         struct {
165                 Udphdr;
166                 uchar buf[2+2+Segsize+1];
167         } msg;
168
169         c = nil;
170         r = nil;
171         fd = cfd = -1;
172         err = nil;
173         data = nil;
174         ndata = 0;
175
176         if((f = aux) == nil)
177                 goto out;
178         if((c = f->c) == nil)
179                 goto out;
180
181         threadsetname(f->path);
182
183         for(n=0; n<10; n++){
184                 snprint(addr, sizeof(addr), "udp!*!%d", newport());
185                 if((cfd = announce(addr, adir)) >= 0)
186                         break;
187         }
188         if(cfd < 0){
189                 err = "announce: %r";
190                 goto out;
191         }
192         if(write(cfd, "headers", 7) < 0){
193                 err = "write ctl: %r";
194                 goto out;
195         }
196         strcat(adir, "/data");
197         if((fd = open(adir, ORDWR)) < 0){
198                 err = "open: %r";
199                 goto out;
200         }
201
202         n = filereq(msg.buf, f->path);
203         ipmove(msg.raddr, f->addr);
204         hnputs(msg.rport, TftpPort);
205         if(write(fd, &msg, sizeof(Udphdr) + n) < 0){
206                 err = "send read request: %r";
207                 goto out;
208         }
209
210         notify(catch);
211
212         last = 0;
213         while(!last){
214                 alarm(5000);
215                 if((n = read(fd, &msg, sizeof(Udphdr) + sizeof(msg.buf)-1)) < 0){
216                         err = "receive response: %r";
217                         goto out;
218                 }
219                 alarm(0);
220
221                 n -= sizeof(Udphdr);
222                 msg.buf[n] = 0;
223                 switch(nhgets(msg.buf)){
224                 case Tftp_ERROR:
225                         werrstr((char*)msg.buf+4);
226                         err = "%r";
227                         goto out;
228
229                 case Tftp_DATA:
230                         if(n < 4)
231                                 continue;
232                         block = nhgets(msg.buf+2);
233                         if((n -= 4) > 0){
234                                 data = erealloc9p(data, ndata + n);
235                                 memcpy(data + ndata, msg.buf+4, n);
236                                 ndata += n;
237
238 rloop:                  /* hanlde read request while downloading */
239                                 if((r != nil) && (r->ifcall.type == Tread) && (r->ifcall.offset < ndata)){
240                                         readbuf(r, data, ndata);
241                                         respond(r, nil);
242                                         r = nil;
243                                 }
244                                 if((r == nil) && (nbrecv(c, &r) == 1)){
245                                         if(r == nil){
246                                                 chanfree(c);
247                                                 c = nil;
248                                                 goto out;
249                                         }
250                                         goto rloop;
251                                 }
252                         }
253                         if(n < Segsize)
254                                 last = 1;
255                         hnputs(msg.buf, Tftp_ACK);
256                         hnputs(msg.buf+2, block);
257                         if(write(fd, &msg, sizeof(Udphdr) + 4) < 0){
258                                 err = "send acknowledge: %r";
259                                 goto out;
260                         }
261                         break;
262                 }
263         }
264
265 out:
266         alarm(0);
267         if(cfd >= 0)
268                 close(cfd);
269         if(fd >= 0)
270                 close(fd);
271
272         if(c){
273                 while((r != nil) || (r = recvp(c))){
274                         if(err){
275                                 char buf[ERRMAX];
276
277                                 snprint(buf, sizeof(buf), err);
278                                 respond(r, buf);
279                         } else {
280                                 switch(r->ifcall.type){
281                                 case Tread:
282                                         readbuf(r, data, ndata);
283                                         respond(r, nil);
284                                         break;
285                                 case Tstat:
286                                         tfilestat(r, f->path, ndata);
287                                         break;
288                                 default:
289                                         respond(r, "bug in fs");
290                                 }
291                         }
292                         r = nil;
293                 }
294                 chanfree(c);
295         }
296         free(data);
297 }
298
299 static void
300 fsattach(Req *r)
301 {
302         Tfile *f;
303
304         if(r->ifcall.aname && r->ifcall.aname[0]){
305                 uchar addr[IPaddrlen];
306
307                 if(parseip(addr, r->ifcall.aname) < 0){
308                         respond(r, "bad ip specified");
309                         return;
310                 }
311                 f = tfileget(addr, "/");
312         } else {
313                 if(ipcmp(ipaddr, IPnoaddr) == 0){
314                         respond(r, "no ipaddr specified");
315                         return;
316                 }
317                 f = tfileget(ipaddr, "/");
318         }
319         r->fid->aux = f;
320         r->fid->qid.type = QTDIR;
321         r->fid->qid.path = f->id<<1;
322         r->fid->qid.vers = 0;
323         r->ofcall.qid = r->fid->qid;
324         respond(r, nil);
325 }
326
327 static char*
328 fswalk1(Fid *fid, char *name, Qid *qid)
329 {
330         Tfile *f;
331         char *t;
332
333         f = fid->aux;
334         t = smprint("%s/%s", f->path, name);
335         f = tfileget(f->addr, cleanname(t));
336         free(t);
337         tfileput(fid->aux); fid->aux = f;
338         fid->qid.type = QTDIR;
339         fid->qid.path = f->id<<1;
340
341         /* hack:
342          * a dot in the path means the path element is not
343          * a directory. to force download of files containing
344          * no dot, a trailing dot can be appended that will
345          * be stripped out in the tftp read request.
346          */
347         if(strchr(f->path, '.') != nil){
348                 fid->qid.type = 0;
349                 fid->qid.path |= Qdata;
350         }
351
352         if(qid)
353                 *qid = fid->qid;
354         return nil;
355 }
356
357 static char*
358 fsclone(Fid *oldfid, Fid *newfid)
359 {
360         Tfile *f;
361
362         f = oldfid->aux;
363         incref(f);
364         newfid->aux = f;
365         return nil;
366 }
367
368 static void
369 fsdestroyfid(Fid *fid)
370 {
371         tfileput(fid->aux);
372         fid->aux = nil;
373 }
374
375 static void
376 fsopen(Req *r)
377 {
378         int m;
379
380         m = r->ifcall.mode & 3;
381         if(m != OREAD && m != OEXEC){
382                 respond(r, "permission denied");
383                 return;
384         }
385         respond(r, nil);
386 }
387
388 static void
389 dispatch(Req *r)
390 {
391         Tfile *f;
392
393         f = r->fid->aux;
394         if(f->c == nil){
395                 f->c = chancreate(sizeof(r), 0);
396                 proccreate(download, f, 16*1024);
397         }
398         sendp(f->c, r);
399 }
400
401 static void
402 fsread(Req *r)
403 {
404         if(r->fid->qid.path & Qdata){
405                 dispatch(r);
406         } else {
407                 respond(r, nil);
408         }
409 }
410
411 static void
412 fsstat(Req *r)
413 {
414         if(r->fid->qid.path & Qdata){
415                 dispatch(r);
416         } else {
417                 tfilestat(r, ((Tfile*)r->fid->aux)->path, 0);
418         }
419 }
420
421 Srv fs = 
422 {
423 .attach=                fsattach,
424 .destroyfid=    fsdestroyfid,
425 .walk1=         fswalk1,
426 .clone=         fsclone,
427 .open=          fsopen,
428 .read=          fsread,
429 .stat=          fsstat,
430 };
431
432 void
433 usage(void)
434 {
435         fprint(2, "usage: tftpfs [-D] [-s srvname] [-m mtpt] [ipaddr]\n");
436         threadexitsall("usage");
437 }
438
439 void
440 threadmain(int argc, char **argv)
441 {
442         char *srvname = nil;
443         char *mtpt = nil;
444
445         time0 = time(0);
446         ipmove(ipaddr, IPnoaddr);
447
448         ARGBEGIN{
449         case 'D':
450                 chatty9p++;
451                 break;
452         case 's':
453                 srvname = EARGF(usage());
454                 break;
455         case 'm':
456                 mtpt = EARGF(usage());
457                 break;
458         default:
459                 usage();
460         }ARGEND;
461
462         switch(argc){
463         case 0:
464                 break;
465         case 1:
466                 if(parseip(ipaddr, *argv) < 0)
467                         usage();
468                 break;
469         default:
470                 usage();
471         }
472
473         if(srvname==nil && mtpt==nil)
474                 usage();
475
476         threadpostmountsrv(&fs, srvname, mtpt, MREPL|MCREATE);
477 }