]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/import.c
merge
[plan9front.git] / sys / src / cmd / import.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <libsec.h>
5
6 enum {
7         Encnone,
8         Encssl,
9         Enctls,
10 };
11
12 static char *encprotos[] = {
13         [Encnone] =     "clear",
14         [Encssl] =      "ssl",
15         [Enctls] =      "tls",
16                         nil,
17 };
18
19 char            *keyspec = "";
20 char            *filterp;
21 char            *ealgs = "rc4_256 sha1";
22 int             encproto = Encnone;
23 char            *aan = "/bin/aan";
24 char            *anstring  = "tcp!*!0";
25 AuthInfo        *ai;
26 int             debug;
27 int             doauth = 1;
28 int             timedout;
29 int             skiptree;
30
31 int     connect(char*, char*, int);
32 int     passive(void);
33 int     old9p(int);
34 void    catcher(void*, char*);
35 void    sysfatal(char*, ...);
36 void    usage(void);
37 int     filter(int, char *, char *);
38
39 static void     mksecret(char *, uchar *);
40
41 /*
42  * based on libthread's threadsetname, but drags in less library code.
43  * actually just sets the arguments displayed.
44  */
45 void
46 procsetname(char *fmt, ...)
47 {
48         int fd;
49         char *cmdname;
50         char buf[128];
51         va_list arg;
52
53         va_start(arg, fmt);
54         cmdname = vsmprint(fmt, arg);
55         va_end(arg);
56         if (cmdname == nil)
57                 return;
58         snprint(buf, sizeof buf, "#p/%d/args", getpid());
59         if((fd = open(buf, OWRITE)) >= 0){
60                 write(fd, cmdname, strlen(cmdname)+1);
61                 close(fd);
62         }
63         free(cmdname);
64 }
65
66 void
67 post(char *name, char *envname, int srvfd)
68 {
69         int fd;
70         char buf[32];
71
72         fd = create(name, OWRITE, 0600);
73         if(fd < 0)
74                 return;
75         snprint(buf, sizeof(buf), "%d", srvfd);
76         if(write(fd, buf, strlen(buf)) != strlen(buf))
77                 sysfatal("srv write: %r");
78         close(fd);
79         putenv(envname, name);
80 }
81
82 static int
83 lookup(char *s, char *l[])
84 {
85         int i;
86
87         for (i = 0; l[i] != 0; i++)
88                 if (strcmp(l[i], s) == 0)
89                         return i;
90         return -1;
91 }
92
93 void
94 main(int argc, char **argv)
95 {
96         char *mntpt, *srvpost, srvfile[64];
97         int backwards = 0, fd, mntflags, oldserver;
98
99         quotefmtinstall();
100         srvpost = nil;
101         oldserver = 0;
102         mntflags = MREPL;
103         ARGBEGIN{
104         case 'A':
105                 doauth = 0;
106                 break;
107         case 'a':
108                 mntflags = MAFTER;
109                 break;
110         case 'b':
111                 mntflags = MBEFORE;
112                 break;
113         case 'c':
114                 mntflags |= MCREATE;
115                 break;
116         case 'C':
117                 mntflags |= MCACHE;
118                 break;
119         case 'd':
120                 debug++;
121                 break;
122         case 'f':
123                 /* ignored but allowed for compatibility */
124                 break;
125         case 'O':
126         case 'o':
127                 oldserver = 1;
128                 break;
129         case 'E':
130                 if ((encproto = lookup(EARGF(usage()), encprotos)) < 0)
131                         usage();
132                 break;
133         case 'e':
134                 ealgs = EARGF(usage());
135                 if(*ealgs == 0 || strcmp(ealgs, "clear") == 0)
136                         ealgs = nil;
137                 break;
138         case 'k':
139                 keyspec = EARGF(usage());
140                 break;
141         case 'p':
142                 filterp = aan;
143                 break;
144         case 'n':
145                 anstring = EARGF(usage());
146                 break;
147         case 's':
148                 srvpost = EARGF(usage());
149                 break;
150         case 'B':
151                 backwards = 1;
152                 break;
153         case 'z':
154                 skiptree = 1;
155                 break;
156         default:
157                 usage();
158         }ARGEND;
159
160         mntpt = 0;              /* to shut up compiler */
161         if(backwards){
162                 switch(argc) {
163                 default:
164                         mntpt = argv[0];
165                         break;
166                 case 0:
167                         usage();
168                 }
169         } else {
170                 switch(argc) {
171                 case 2:
172                         mntpt = argv[1];
173                         break;
174                 case 3:
175                         mntpt = argv[2];
176                         break;
177                 default:
178                         usage();
179                 }
180         }
181
182         if (encproto == Enctls)
183                 sysfatal("%s: tls has not yet been implemented", argv[0]);
184
185         notify(catcher);
186         alarm(60*1000);
187
188         if (backwards)
189                 fd = passive();
190         else
191                 fd = connect(argv[0], argv[1], oldserver);
192
193         if (!oldserver)
194                 fprint(fd, "impo %s %s\n", filterp? "aan": "nofilter",
195                         encprotos[encproto]);
196
197         if (encproto != Encnone && ealgs && ai) {
198                 uchar key[16], digest[SHA1dlen];
199                 char fromclientsecret[21];
200                 char fromserversecret[21];
201                 int i;
202
203                 assert(ai->nsecret <= sizeof(key)-4);
204                 memmove(key+4, ai->secret, ai->nsecret);
205
206                 /* exchange random numbers */
207                 srand(truerand());
208                 for(i = 0; i < 4; i++)
209                         key[i] = rand();
210                 if(write(fd, key, 4) != 4)
211                         sysfatal("can't write key part: %r");
212                 if(readn(fd, key+12, 4) != 4)
213                         sysfatal("can't read key part: %r");
214
215                 /* scramble into two secrets */
216                 sha1(key, sizeof(key), digest, nil);
217                 mksecret(fromclientsecret, digest);
218                 mksecret(fromserversecret, digest+10);
219
220                 if (filterp)
221                         fd = filter(fd, filterp, backwards ? nil : argv[0]);
222
223                 /* set up encryption */
224                 procsetname("pushssl");
225                 fd = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
226                 if(fd < 0)
227                         sysfatal("can't establish ssl connection: %r");
228         }
229         else if (filterp)
230                 fd = filter(fd, filterp, backwards ? nil : argv[0]);
231
232         if(ai)
233                 auth_freeAI(ai);
234
235         if(srvpost){
236                 snprint(srvfile, sizeof(srvfile), "/srv/%s", srvpost);
237                 remove(srvfile);
238                 post(srvfile, srvpost, fd);
239         }
240         procsetname("mount on %s", mntpt);
241         if(mount(fd, -1, mntpt, mntflags, "") < 0)
242                 sysfatal("can't mount %s: %r", argv[1]);
243         alarm(0);
244
245         if(backwards && argc > 1){
246                 exec(argv[1], &argv[1]);
247                 sysfatal("exec: %r");
248         }
249         exits(0);
250 }
251
252 void
253 catcher(void*, char *msg)
254 {
255         timedout = 1;
256         if(strcmp(msg, "alarm") == 0)
257                 noted(NCONT);
258         noted(NDFLT);
259 }
260
261 int
262 old9p(int fd)
263 {
264         int p[2];
265
266         procsetname("old9p");
267         if(pipe(p) < 0)
268                 sysfatal("pipe: %r");
269
270         switch(rfork(RFPROC|RFMEM|RFFDG|RFNAMEG)) {
271         case -1:
272                 sysfatal("rfork srvold9p: %r");
273         case 0:
274                 if(fd != 1){
275                         dup(fd, 1);
276                         close(fd);
277                 }
278                 if(p[0] != 0){
279                         dup(p[0], 0);
280                         close(p[0]);
281                 }
282                 close(p[1]);
283                 if(0){
284                         fd = open("/sys/log/cpu", OWRITE);
285                         if(fd != 2){
286                                 dup(fd, 2);
287                                 close(fd);
288                         }
289                         execl("/bin/srvold9p", "srvold9p", "-ds", nil);
290                 } else
291                         execl("/bin/srvold9p", "srvold9p", "-s", nil);
292                 sysfatal("exec srvold9p: %r");
293         default:
294                 close(fd);
295                 close(p[0]);
296         }
297         return p[1];
298 }
299
300 int
301 connect(char *system, char *tree, int oldserver)
302 {
303         char buf[ERRMAX], dir[128], *na;
304         int fd, n;
305         char *authp;
306
307         na = netmkaddr(system, 0, "exportfs");
308         procsetname("dial %s", na);
309         if((fd = dial(na, 0, dir, 0)) < 0)
310                 sysfatal("can't dial %s: %r", system);
311
312         if(doauth){
313                 if(oldserver)
314                         authp = "p9sk2";
315                 else
316                         authp = "p9any";
317
318                 procsetname("auth_proxy auth_getkey proto=%q role=client %s",
319                         authp, keyspec);
320                 ai = auth_proxy(fd, auth_getkey, "proto=%q role=client %s",
321                         authp, keyspec);
322                 if(ai == nil)
323                         sysfatal("%r: %s", system);
324         }
325
326         if(!skiptree){
327                 procsetname("writing tree name %s", tree);
328                 n = write(fd, tree, strlen(tree));
329                 if(n < 0)
330                         sysfatal("can't write tree: %r");
331
332                 strcpy(buf, "can't read tree");
333
334                 procsetname("awaiting OK for %s", tree);
335                 n = read(fd, buf, sizeof buf - 1);
336                 if(n!=2 || buf[0]!='O' || buf[1]!='K'){
337                         if (timedout)
338                                 sysfatal("timed out connecting to %s", na);
339                         buf[sizeof buf - 1] = '\0';
340                         sysfatal("bad remote tree: %s", buf);
341                 }
342         }
343
344         if(oldserver)
345                 return old9p(fd);
346         return fd;
347 }
348
349 int
350 passive(void)
351 {
352         int fd;
353
354         /*
355          * Ignore doauth==0 on purpose.  Is it useful here?
356          */
357
358         procsetname("auth_proxy auth_getkey proto=p9any role=server");
359         ai = auth_proxy(0, auth_getkey, "proto=p9any role=server");
360         if(ai == nil)
361                 sysfatal("auth_proxy: %r");
362         if(auth_chuid(ai, nil) < 0)
363                 sysfatal("auth_chuid: %r");
364         putenv("service", "import");
365
366         fd = dup(0, -1);
367         close(0);
368         open("/dev/null", ORDWR);
369         close(1);
370         open("/dev/null", ORDWR);
371
372         return fd;
373 }
374
375 void
376 usage(void)
377 {
378         fprint(2, "usage: import [-abcC] [-A] [-E clear|ssl|tls] "
379 "[-e 'crypt auth'|clear] [-k keypattern] [-p] [-n address ] [-z] host remotefs [mountpoint]\n");
380         exits("usage");
381 }
382
383 int
384 filter(int fd, char *cmd, char *host)
385 {
386         char addr[128], buf[256], *s, *file, *argv[16];
387         int lfd, p[2], len, argc;
388
389         if(host == nil){
390                 /* Get a free port and post it to the client. */
391                 if (announce(anstring, addr) < 0)
392                         sysfatal("filter: Cannot announce %s: %r", anstring);
393
394                 snprint(buf, sizeof(buf), "%s/local", addr);
395                 if ((lfd = open(buf, OREAD)) < 0)
396                         sysfatal("filter: Cannot open %s: %r", buf);
397                 if ((len = read(lfd, buf, sizeof buf - 1)) < 0)
398                         sysfatal("filter: Cannot read %s: %r", buf);
399                 close(lfd);
400                 buf[len] = '\0';
401                 if ((s = strchr(buf, '\n')) != nil)
402                         len = s - buf;
403                 if (write(fd, buf, len) != len) 
404                         sysfatal("filter: cannot write port; %r");
405         } else {
406                 /* Read address string from connection */
407                 if ((len = read(fd, buf, sizeof buf - 1)) < 0)
408                         sysfatal("filter: cannot write port; %r");
409                 buf[len] = '\0';
410
411                 if ((s = strrchr(buf, '!')) == nil)
412                         sysfatal("filter: illegally formatted port %s", buf);
413                 strecpy(addr, addr+sizeof(addr), netmkaddr(host, "tcp", s+1));
414                 strecpy(strrchr(addr, '!'), addr+sizeof(addr), s);
415         }
416
417         if(debug)
418                 fprint(2, "filter: %s\n", addr);
419
420         snprint(buf, sizeof(buf), "%s", cmd);
421         argc = tokenize(buf, argv, nelem(argv)-3);
422         if (argc == 0)
423                 sysfatal("filter: empty command");
424
425         if(host != nil)
426                 argv[argc++] = "-c";
427         argv[argc++] = addr;
428         argv[argc] = nil;
429
430         file = argv[0];
431         if((s = strrchr(argv[0], '/')) != nil)
432                 argv[0] = s+1;
433
434         if(pipe(p) < 0)
435                 sysfatal("pipe: %r");
436
437         switch(rfork(RFNOWAIT|RFPROC|RFMEM|RFFDG|RFREND)) {
438         case -1:
439                 sysfatal("filter: rfork; %r\n");
440         case 0:
441                 close(fd);
442                 if (dup(p[0], 1) < 0)
443                         sysfatal("filter: Cannot dup to 1; %r");
444                 if (dup(p[0], 0) < 0)
445                         sysfatal("filter: Cannot dup to 0; %r");
446                 close(p[0]);
447                 close(p[1]);
448                 exec(file, argv);
449                 sysfatal("filter: exec; %r");
450         default:
451                 dup(p[1], fd);
452                 close(p[0]);
453                 close(p[1]);
454         }
455         return fd;
456 }
457
458 static void
459 mksecret(char *t, uchar *f)
460 {
461         sprint(t, "%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux",
462                 f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
463 }