]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ip/tftpfs.c
tftp: prevent it from hanging if ack packets get lost
[plan9front.git] / sys / src / cmd / ip / tftpfs.c
1 #include <u.h>
2 #include <libc.h>
3 #include <thread.h>
4 #include <auth.h>
5 #include <fcall.h>
6 #include <9p.h>
7 #include <ip.h>
8
9 enum {
10         Qdata = 1,
11
12         Tftp_READ       = 1,
13         Tftp_WRITE      = 2,
14         Tftp_DATA       = 3,
15         Tftp_ACK        = 4,
16         Tftp_ERROR      = 5,
17         Tftp_OACK       = 6,
18
19         TftpPort                = 69,
20
21         Segsize         = 512,
22         Maxpath         = 2+2+Segsize-8,
23 };
24
25 typedef struct Tfile Tfile;
26 struct Tfile
27 {
28         int             id;
29         uchar   addr[IPaddrlen];
30         char            path[Maxpath];
31         Channel *c;
32         Tfile           *next;
33         Ref;
34 };
35
36 uchar ipaddr[IPaddrlen];
37 static ulong time0;
38 Tfile *files;
39
40
41 static Tfile*
42 tfileget(uchar *addr, char *path)
43 {
44         Tfile *f;
45         static int id;
46
47         for(f = files; f; f = f->next){
48                 if(memcmp(addr, f->addr, IPaddrlen) == 0 && strcmp(path, f->path) == 0){
49                         incref(f);
50                         return f;
51                 }
52         }
53         f = emalloc9p(sizeof *f);
54         memset(f, 0, sizeof(*f));
55         ipmove(f->addr, addr);
56         strncpy(f->path, path, sizeof(f->path));
57         f->ref = 1;
58         f->id = id++;
59         f->next = files;
60         files = f;
61
62         return f;
63 }
64
65 static void
66 tfileput(Tfile *f)
67 {
68         Channel *c;
69         Tfile **pp;
70
71         if(f==nil || decref(f))
72                 return;
73         if(c = f->c){
74                 f->c = nil;
75                 sendp(c, nil);
76         }
77         for(pp = &files; *pp; pp = &(*pp)->next){
78                 if(*pp == f){
79                         *pp = f->next;
80                         break;
81                 }
82         }
83         free(f);
84 }
85
86 static char*
87 basename(char *p)
88 {
89         char *b;
90
91         for(b = p; *p; p++)
92                 if(*p == '/')
93                         b = p+1;
94         return b;
95 }
96
97 static void
98 tfilestat(Req *r, char *path, vlong length)
99 {
100         memset(&r->d, 0, sizeof(r->d));
101         r->d.uid = estrdup9p("tftp");
102         r->d.gid = estrdup9p("tftp");
103         r->d.name = estrdup9p(basename(path));
104         r->d.atime = r->d.mtime = time0;
105         r->d.length = length;
106         r->d.qid.path = r->fid->qid.path;
107         if(r->fid->qid.path & Qdata){
108                 r->d.qid.type = 0;
109                 r->d.mode = 0555;
110         } else {
111                 r->d.qid.type = QTDIR;
112                 r->d.mode = DMDIR|0555;
113         }
114         respond(r, nil);
115 }
116
117 static void
118 catch(void *, char *msg)
119 {
120         if(strstr(msg, "alarm"))
121                 noted(NCONT);
122         noted(NDFLT);
123 }
124
125 static int
126 filereq(uchar *buf, char *path)
127 {
128         uchar *p;
129         int n;
130
131         hnputs(buf, Tftp_READ);
132         p = buf+2;
133         n = strlen(path);
134
135         /* hack: remove the trailing dot */
136         if(path[n-1] == '.')
137                 n--;
138
139         memcpy(p, path, n);
140         p += n;
141         *p++ = 0;
142         memcpy(p, "octet", 6);
143         p += 6;
144         return p - buf;
145 }
146
147 static void
148 download(void *aux)
149 {
150         int fd, cfd, last, block, seq, n, ndata;
151         char *err, adir[40];
152         uchar *data;
153         Channel *c;
154         Tfile *f;
155         Req *r;
156
157         struct {
158                 Udphdr;
159                 uchar buf[2+2+Segsize+1];
160         } msg;
161
162         c = nil;
163         r = nil;
164         fd = cfd = -1;
165         err = nil;
166         data = nil;
167         ndata = 0;
168
169         if((f = aux) == nil)
170                 goto out;
171         if((c = f->c) == nil)
172                 goto out;
173
174         threadsetname(f->path);
175
176         if((cfd = announce("udp!*!0", adir)) < 0){
177                 err = "announce: %r";
178                 goto out;
179         }
180         if(write(cfd, "headers", 7) < 0){
181                 err = "write ctl: %r";
182                 goto out;
183         }
184         strcat(adir, "/data");
185         if((fd = open(adir, ORDWR)) < 0){
186                 err = "open: %r";
187                 goto out;
188         }
189
190         n = filereq(msg.buf, f->path);
191         ipmove(msg.raddr, f->addr);
192         hnputs(msg.rport, TftpPort);
193         if(write(fd, &msg, sizeof(Udphdr) + n) < 0){
194                 err = "send read request: %r";
195                 goto out;
196         }
197
198         notify(catch);
199
200         seq = 1;
201         last = 0;
202         while(!last){
203                 alarm(5000);
204                 if((n = read(fd, &msg, sizeof(Udphdr) + sizeof(msg.buf)-1)) < 0){
205                         err = "receive response: %r";
206                         goto out;
207                 }
208                 alarm(0);
209
210                 n -= sizeof(Udphdr);
211                 msg.buf[n] = 0;
212                 switch(nhgets(msg.buf)){
213                 case Tftp_ERROR:
214                         werrstr((char*)msg.buf+4);
215                         err = "%r";
216                         goto out;
217
218                 case Tftp_DATA:
219                         if(n < 4)
220                                 continue;
221                         block = nhgets(msg.buf+2);
222                         if(block > seq)
223                                 continue;
224                         hnputs(msg.buf, Tftp_ACK);
225                         if(write(fd, &msg, sizeof(Udphdr) + 4) < 0){
226                                 err = "send acknowledge: %r";
227                                 goto out;
228                         }
229                         if(block < seq)
230                                 continue;
231                         seq = block+1;
232                         n -= 4;
233                         if(n < Segsize)
234                                 last = 1;
235                         data = erealloc9p(data, ndata + n);
236                         memcpy(data + ndata, msg.buf+4, n);
237                         ndata += n;
238
239                 rloop:  /* hanlde read request while downloading */
240                         if((r != nil) && (r->ifcall.type == Tread) && (r->ifcall.offset < ndata)){
241                                 readbuf(r, data, ndata);
242                                 respond(r, nil);
243                                 r = nil;
244                         }
245                         if((r == nil) && (nbrecv(c, &r) == 1)){
246                                 if(r == nil){
247                                         chanfree(c);
248                                         c = nil;
249                                         goto out;
250                                 }
251                                 goto rloop;
252                         }
253                         break;
254                 }
255         }
256
257 out:
258         alarm(0);
259         if(cfd >= 0)
260                 close(cfd);
261         if(fd >= 0)
262                 close(fd);
263
264         if(c){
265                 while((r != nil) || (r = recvp(c))){
266                         if(err){
267                                 char buf[ERRMAX];
268
269                                 snprint(buf, sizeof(buf), err);
270                                 respond(r, buf);
271                         } else {
272                                 switch(r->ifcall.type){
273                                 case Tread:
274                                         readbuf(r, data, ndata);
275                                         respond(r, nil);
276                                         break;
277                                 case Tstat:
278                                         tfilestat(r, f->path, ndata);
279                                         break;
280                                 default:
281                                         respond(r, "bug in fs");
282                                 }
283                         }
284                         r = nil;
285                 }
286                 chanfree(c);
287         }
288         free(data);
289 }
290
291 static void
292 fsattach(Req *r)
293 {
294         Tfile *f;
295
296         if(r->ifcall.aname && r->ifcall.aname[0]){
297                 uchar addr[IPaddrlen];
298
299                 if(parseip(addr, r->ifcall.aname) < 0){
300                         respond(r, "bad ip specified");
301                         return;
302                 }
303                 f = tfileget(addr, "/");
304         } else {
305                 if(ipcmp(ipaddr, IPnoaddr) == 0){
306                         respond(r, "no ipaddr specified");
307                         return;
308                 }
309                 f = tfileget(ipaddr, "/");
310         }
311         r->fid->aux = f;
312         r->fid->qid.type = QTDIR;
313         r->fid->qid.path = f->id<<1;
314         r->fid->qid.vers = 0;
315         r->ofcall.qid = r->fid->qid;
316         respond(r, nil);
317 }
318
319 static char*
320 fswalk1(Fid *fid, char *name, Qid *qid)
321 {
322         Tfile *f;
323         char *t;
324
325         f = fid->aux;
326         t = smprint("%s/%s", f->path, name);
327         f = tfileget(f->addr, cleanname(t));
328         free(t);
329         tfileput(fid->aux); fid->aux = f;
330         fid->qid.type = QTDIR;
331         fid->qid.path = f->id<<1;
332
333         /* hack:
334          * a dot in the path means the path element is not
335          * a directory. to force download of files containing
336          * no dot, a trailing dot can be appended that will
337          * be stripped out in the tftp read request.
338          */
339         if(strchr(f->path, '.') != nil){
340                 fid->qid.type = 0;
341                 fid->qid.path |= Qdata;
342         }
343
344         if(qid)
345                 *qid = fid->qid;
346         return nil;
347 }
348
349 static char*
350 fsclone(Fid *oldfid, Fid *newfid)
351 {
352         Tfile *f;
353
354         f = oldfid->aux;
355         incref(f);
356         newfid->aux = f;
357         return nil;
358 }
359
360 static void
361 fsdestroyfid(Fid *fid)
362 {
363         tfileput(fid->aux);
364         fid->aux = nil;
365 }
366
367 static void
368 fsopen(Req *r)
369 {
370         int m;
371
372         m = r->ifcall.mode & 3;
373         if(m != OREAD && m != OEXEC){
374                 respond(r, "permission denied");
375                 return;
376         }
377         respond(r, nil);
378 }
379
380 static void
381 dispatch(Req *r)
382 {
383         Tfile *f;
384
385         f = r->fid->aux;
386         if(f->c == nil){
387                 f->c = chancreate(sizeof(r), 0);
388                 proccreate(download, f, 16*1024);
389         }
390         sendp(f->c, r);
391 }
392
393 static void
394 fsread(Req *r)
395 {
396         if(r->fid->qid.path & Qdata){
397                 dispatch(r);
398         } else {
399                 respond(r, nil);
400         }
401 }
402
403 static void
404 fsstat(Req *r)
405 {
406         if(r->fid->qid.path & Qdata){
407                 dispatch(r);
408         } else {
409                 tfilestat(r, ((Tfile*)r->fid->aux)->path, 0);
410         }
411 }
412
413 Srv fs = 
414 {
415 .attach=                fsattach,
416 .destroyfid=    fsdestroyfid,
417 .walk1=         fswalk1,
418 .clone=         fsclone,
419 .open=          fsopen,
420 .read=          fsread,
421 .stat=          fsstat,
422 };
423
424 void
425 usage(void)
426 {
427         fprint(2, "usage: tftpfs [-D] [-s srvname] [-m mtpt] [ipaddr]\n");
428         threadexitsall("usage");
429 }
430
431 void
432 threadmain(int argc, char **argv)
433 {
434         char *srvname = nil;
435         char *mtpt = "/n/tftp";
436
437         time0 = time(0);
438         ipmove(ipaddr, IPnoaddr);
439
440         ARGBEGIN{
441         case 'D':
442                 chatty9p++;
443                 break;
444         case 's':
445                 srvname = EARGF(usage());
446                 mtpt = nil;
447                 break;
448         case 'm':
449                 mtpt = EARGF(usage());
450                 break;
451         default:
452                 usage();
453         }ARGEND;
454
455         switch(argc){
456         case 0:
457                 break;
458         case 1:
459                 if(parseip(ipaddr, *argv) < 0)
460                         usage();
461                 break;
462         default:
463                 usage();
464         }
465
466         if(srvname==nil && mtpt==nil)
467                 usage();
468
469         threadpostmountsrv(&fs, srvname, mtpt, MREPL|MCREATE);
470 }