]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/import.c
add import -z option to skip initial tree negotiation (from mycroftiv)
[plan9front.git] / sys / src / cmd / import.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <libsec.h>
5
6 enum {
7         Encnone,
8         Encssl,
9         Enctls,
10 };
11
12 static char *encprotos[] = {
13         [Encnone] =     "clear",
14         [Encssl] =      "ssl",
15         [Enctls] =      "tls",
16                         nil,
17 };
18
19 char            *keyspec = "";
20 char            *filterp;
21 char            *ealgs = "rc4_256 sha1";
22 int             encproto = Encnone;
23 char            *aan = "/bin/aan";
24 AuthInfo        *ai;
25 int             debug;
26 int             doauth = 1;
27 int             timedout;
28 int             skiptree;
29
30 int     connect(char*, char*, int);
31 int     passive(void);
32 int     old9p(int);
33 void    catcher(void*, char*);
34 void    sysfatal(char*, ...);
35 void    usage(void);
36 int     filter(int, char *, char *);
37
38 static void     mksecret(char *, uchar *);
39
40 /*
41  * based on libthread's threadsetname, but drags in less library code.
42  * actually just sets the arguments displayed.
43  */
44 void
45 procsetname(char *fmt, ...)
46 {
47         int fd;
48         char *cmdname;
49         char buf[128];
50         va_list arg;
51
52         va_start(arg, fmt);
53         cmdname = vsmprint(fmt, arg);
54         va_end(arg);
55         if (cmdname == nil)
56                 return;
57         snprint(buf, sizeof buf, "#p/%d/args", getpid());
58         if((fd = open(buf, OWRITE)) >= 0){
59                 write(fd, cmdname, strlen(cmdname)+1);
60                 close(fd);
61         }
62         free(cmdname);
63 }
64
65 void
66 post(char *name, char *envname, int srvfd)
67 {
68         int fd;
69         char buf[32];
70
71         fd = create(name, OWRITE, 0600);
72         if(fd < 0)
73                 return;
74         snprint(buf, sizeof(buf), "%d", srvfd);
75         if(write(fd, buf, strlen(buf)) != strlen(buf))
76                 sysfatal("srv write: %r");
77         close(fd);
78         putenv(envname, name);
79 }
80
81 static int
82 lookup(char *s, char *l[])
83 {
84         int i;
85
86         for (i = 0; l[i] != 0; i++)
87                 if (strcmp(l[i], s) == 0)
88                         return i;
89         return -1;
90 }
91
92 void
93 main(int argc, char **argv)
94 {
95         char *mntpt, *srvpost, srvfile[64];
96         int backwards = 0, fd, mntflags, oldserver;
97
98         quotefmtinstall();
99         srvpost = nil;
100         oldserver = 0;
101         mntflags = MREPL;
102         ARGBEGIN{
103         case 'A':
104                 doauth = 0;
105                 break;
106         case 'a':
107                 mntflags = MAFTER;
108                 break;
109         case 'b':
110                 mntflags = MBEFORE;
111                 break;
112         case 'c':
113                 mntflags |= MCREATE;
114                 break;
115         case 'C':
116                 mntflags |= MCACHE;
117                 break;
118         case 'd':
119                 debug++;
120                 break;
121         case 'f':
122                 /* ignored but allowed for compatibility */
123                 break;
124         case 'O':
125         case 'o':
126                 oldserver = 1;
127                 break;
128         case 'E':
129                 if ((encproto = lookup(EARGF(usage()), encprotos)) < 0)
130                         usage();
131                 break;
132         case 'e':
133                 ealgs = EARGF(usage());
134                 if(*ealgs == 0 || strcmp(ealgs, "clear") == 0)
135                         ealgs = nil;
136                 break;
137         case 'k':
138                 keyspec = EARGF(usage());
139                 break;
140         case 'p':
141                 filterp = aan;
142                 break;
143         case 's':
144                 srvpost = EARGF(usage());
145                 break;
146         case 'B':
147                 backwards = 1;
148                 break;
149         case 'z':
150                 skiptree = 1;
151                 break;
152         default:
153                 usage();
154         }ARGEND;
155
156         mntpt = 0;              /* to shut up compiler */
157         if(backwards){
158                 switch(argc) {
159                 default:
160                         mntpt = argv[0];
161                         break;
162                 case 0:
163                         usage();
164                 }
165         } else {
166                 switch(argc) {
167                 case 2:
168                         mntpt = argv[1];
169                         break;
170                 case 3:
171                         mntpt = argv[2];
172                         break;
173                 default:
174                         usage();
175                 }
176         }
177
178         if (encproto == Enctls)
179                 sysfatal("%s: tls has not yet been implemented", argv[0]);
180
181         notify(catcher);
182         alarm(60*1000);
183
184         if(backwards)
185                 fd = passive();
186         else
187                 fd = connect(argv[0], argv[1], oldserver);
188
189         if (!oldserver)
190                 fprint(fd, "impo %s %s\n", filterp? "aan": "nofilter",
191                         encprotos[encproto]);
192
193         if (encproto != Encnone && ealgs && ai) {
194                 uchar key[16], digest[SHA1dlen];
195                 char fromclientsecret[21];
196                 char fromserversecret[21];
197                 int i;
198
199                 assert(ai->nsecret <= sizeof(key)-4);
200                 memmove(key+4, ai->secret, ai->nsecret);
201
202                 /* exchange random numbers */
203                 srand(truerand());
204                 for(i = 0; i < 4; i++)
205                         key[i] = rand();
206                 if(write(fd, key, 4) != 4)
207                         sysfatal("can't write key part: %r");
208                 if(readn(fd, key+12, 4) != 4)
209                         sysfatal("can't read key part: %r");
210
211                 /* scramble into two secrets */
212                 sha1(key, sizeof(key), digest, nil);
213                 mksecret(fromclientsecret, digest);
214                 mksecret(fromserversecret, digest+10);
215
216                 if (filterp)
217                         fd = filter(fd, filterp, argv[0]);
218
219                 /* set up encryption */
220                 procsetname("pushssl");
221                 fd = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
222                 if(fd < 0)
223                         sysfatal("can't establish ssl connection: %r");
224         }
225         else if (filterp)
226                 fd = filter(fd, filterp, argv[0]);
227
228         if(ai)
229                 auth_freeAI(ai);
230
231         if(srvpost){
232                 snprint(srvfile, sizeof(srvfile), "/srv/%s", srvpost);
233                 remove(srvfile);
234                 post(srvfile, srvpost, fd);
235         }
236         procsetname("mount on %s", mntpt);
237         if(mount(fd, -1, mntpt, mntflags, "") < 0)
238                 sysfatal("can't mount %s: %r", argv[1]);
239         alarm(0);
240
241         if(backwards && argc > 1){
242                 exec(argv[1], &argv[1]);
243                 sysfatal("exec: %r");
244         }
245         exits(0);
246 }
247
248 void
249 catcher(void*, char *msg)
250 {
251         timedout = 1;
252         if(strcmp(msg, "alarm") == 0)
253                 noted(NCONT);
254         noted(NDFLT);
255 }
256
257 int
258 old9p(int fd)
259 {
260         int p[2];
261
262         procsetname("old9p");
263         if(pipe(p) < 0)
264                 sysfatal("pipe: %r");
265
266         switch(rfork(RFPROC|RFMEM|RFFDG|RFNAMEG)) {
267         case -1:
268                 sysfatal("rfork srvold9p: %r");
269         case 0:
270                 if(fd != 1){
271                         dup(fd, 1);
272                         close(fd);
273                 }
274                 if(p[0] != 0){
275                         dup(p[0], 0);
276                         close(p[0]);
277                 }
278                 close(p[1]);
279                 if(0){
280                         fd = open("/sys/log/cpu", OWRITE);
281                         if(fd != 2){
282                                 dup(fd, 2);
283                                 close(fd);
284                         }
285                         execl("/bin/srvold9p", "srvold9p", "-ds", nil);
286                 } else
287                         execl("/bin/srvold9p", "srvold9p", "-s", nil);
288                 sysfatal("exec srvold9p: %r");
289         default:
290                 close(fd);
291                 close(p[0]);
292         }
293         return p[1];
294 }
295
296 int
297 connect(char *system, char *tree, int oldserver)
298 {
299         char buf[ERRMAX], dir[128], *na;
300         int fd, n;
301         char *authp;
302
303         na = netmkaddr(system, 0, "exportfs");
304         procsetname("dial %s", na);
305         if((fd = dial(na, 0, dir, 0)) < 0)
306                 sysfatal("can't dial %s: %r", system);
307
308         if(doauth){
309                 if(oldserver)
310                         authp = "p9sk2";
311                 else
312                         authp = "p9any";
313
314                 procsetname("auth_proxy auth_getkey proto=%q role=client %s",
315                         authp, keyspec);
316                 ai = auth_proxy(fd, auth_getkey, "proto=%q role=client %s",
317                         authp, keyspec);
318                 if(ai == nil)
319                         sysfatal("%r: %s", system);
320         }
321
322         if(!skiptree){
323                 procsetname("writing tree name %s", tree);
324                 n = write(fd, tree, strlen(tree));
325                 if(n < 0)
326                         sysfatal("can't write tree: %r");
327
328                 strcpy(buf, "can't read tree");
329
330                 procsetname("awaiting OK for %s", tree);
331                 n = read(fd, buf, sizeof buf - 1);
332                 if(n!=2 || buf[0]!='O' || buf[1]!='K'){
333                         if (timedout)
334                                 sysfatal("timed out connecting to %s", na);
335                         buf[sizeof buf - 1] = '\0';
336                         sysfatal("bad remote tree: %s", buf);
337                 }
338         }
339
340         if(oldserver)
341                 return old9p(fd);
342         return fd;
343 }
344
345 int
346 passive(void)
347 {
348         int fd;
349
350         /*
351          * Ignore doauth==0 on purpose.  Is it useful here?
352          */
353
354         procsetname("auth_proxy auth_getkey proto=p9any role=server");
355         ai = auth_proxy(0, auth_getkey, "proto=p9any role=server");
356         if(ai == nil)
357                 sysfatal("auth_proxy: %r");
358         if(auth_chuid(ai, nil) < 0)
359                 sysfatal("auth_chuid: %r");
360         putenv("service", "import");
361
362         fd = dup(0, -1);
363         close(0);
364         open("/dev/null", ORDWR);
365         close(1);
366         open("/dev/null", ORDWR);
367
368         return fd;
369 }
370
371 void
372 usage(void)
373 {
374         fprint(2, "usage: import [-abcC] [-A] [-E clear|ssl|tls] "
375 "[-e 'crypt auth'|clear] [-k keypattern] [-p] [-z] host remotefs [mountpoint]\n");
376         exits("usage");
377 }
378
379 /* Network on fd1, mount driver on fd0 */
380 int
381 filter(int fd, char *cmd, char *host)
382 {
383         char addr[128], buf[256], *s, *file, *argv[16];
384         int p[2], len, argc;
385
386         if ((len = read(fd, buf, sizeof buf - 1)) < 0)
387                 sysfatal("filter: cannot write port; %r");
388         buf[len] = '\0';
389
390         if ((s = strrchr(buf, '!')) == nil)
391                 sysfatal("filter: illegally formatted port %s", buf);
392         snprint(addr, sizeof(addr), "%s", netmkaddr(host, "tcp", s+1));
393
394         if(debug)
395                 fprint(2, "filter: remote %s\n", addr);
396
397         snprint(buf, sizeof(buf), "%s", cmd);
398         argc = tokenize(buf, argv, nelem(argv)-2);
399         if (argc == 0)
400                 sysfatal("filter: empty command");
401         argv[argc++] = "-c";
402         argv[argc++] = addr;
403         argv[argc] = nil;
404         file = argv[0];
405         if (s = strrchr(argv[0], '/'))
406                 argv[0] = s+1;
407
408         if(pipe(p) < 0)
409                 sysfatal("pipe: %r");
410
411         switch(rfork(RFNOWAIT|RFPROC|RFMEM|RFFDG)) {
412         case -1:
413                 sysfatal("filter: rfork; %r");
414         case 0:
415                 dup(p[0], 1);
416                 dup(p[0], 0);
417                 close(p[0]);
418                 close(p[1]);
419                 exec(file, argv);
420                 sysfatal("filter: exec; %r");
421         default:
422                 close(fd);
423                 close(p[0]);
424         }
425         return p[1];
426 }
427
428 static void
429 mksecret(char *t, uchar *f)
430 {
431         sprint(t, "%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux",
432                 f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
433 }