]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/import.c
kbdfs: simplfy
[plan9front.git] / sys / src / cmd / import.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <libsec.h>
5
6 enum {
7         Encnone,
8         Encssl,
9         Enctls,
10 };
11
12 static char *encprotos[] = {
13         [Encnone] =     "clear",
14         [Encssl] =      "ssl",
15         [Enctls] =      "tls",
16                         nil,
17 };
18
19 char            *keyspec = "";
20 char            *filterp;
21 char            *ealgs = "rc4_256 sha1";
22 int             encproto = Encnone;
23 char            *aan = "/bin/aan";
24 AuthInfo        *ai;
25 int             debug;
26 int             doauth = 1;
27 int             timedout;
28
29 int     connect(char*, char*, int);
30 int     passive(void);
31 int     old9p(int);
32 void    catcher(void*, char*);
33 void    sysfatal(char*, ...);
34 void    usage(void);
35 int     filter(int, char *, char *);
36
37 static void     mksecret(char *, uchar *);
38
39 /*
40  * based on libthread's threadsetname, but drags in less library code.
41  * actually just sets the arguments displayed.
42  */
43 void
44 procsetname(char *fmt, ...)
45 {
46         int fd;
47         char *cmdname;
48         char buf[128];
49         va_list arg;
50
51         va_start(arg, fmt);
52         cmdname = vsmprint(fmt, arg);
53         va_end(arg);
54         if (cmdname == nil)
55                 return;
56         snprint(buf, sizeof buf, "#p/%d/args", getpid());
57         if((fd = open(buf, OWRITE)) >= 0){
58                 write(fd, cmdname, strlen(cmdname)+1);
59                 close(fd);
60         }
61         free(cmdname);
62 }
63
64 void
65 post(char *name, char *envname, int srvfd)
66 {
67         int fd;
68         char buf[32];
69
70         fd = create(name, OWRITE, 0600);
71         if(fd < 0)
72                 return;
73         sprint(buf, "%d",srvfd);
74         if(write(fd, buf, strlen(buf)) != strlen(buf))
75                 sysfatal("srv write: %r");
76         close(fd);
77         putenv(envname, name);
78 }
79
80 static int
81 lookup(char *s, char *l[])
82 {
83         int i;
84
85         for (i = 0; l[i] != 0; i++)
86                 if (strcmp(l[i], s) == 0)
87                         return i;
88         return -1;
89 }
90
91 void
92 main(int argc, char **argv)
93 {
94         char *mntpt, *srvpost, srvfile[64];
95         int backwards = 0, fd, mntflags, oldserver;
96
97         quotefmtinstall();
98         srvpost = nil;
99         oldserver = 0;
100         mntflags = MREPL;
101         ARGBEGIN{
102         case 'A':
103                 doauth = 0;
104                 break;
105         case 'a':
106                 mntflags = MAFTER;
107                 break;
108         case 'b':
109                 mntflags = MBEFORE;
110                 break;
111         case 'c':
112                 mntflags |= MCREATE;
113                 break;
114         case 'C':
115                 mntflags |= MCACHE;
116                 break;
117         case 'd':
118                 debug++;
119                 break;
120         case 'f':
121                 /* ignored but allowed for compatibility */
122                 break;
123         case 'O':
124         case 'o':
125                 oldserver = 1;
126                 break;
127         case 'E':
128                 if ((encproto = lookup(EARGF(usage()), encprotos)) < 0)
129                         usage();
130                 break;
131         case 'e':
132                 ealgs = EARGF(usage());
133                 if(*ealgs == 0 || strcmp(ealgs, "clear") == 0)
134                         ealgs = nil;
135                 break;
136         case 'k':
137                 keyspec = EARGF(usage());
138                 break;
139         case 'p':
140                 filterp = aan;
141                 break;
142         case 's':
143                 srvpost = EARGF(usage());
144                 break;
145         case 'B':
146                 backwards = 1;
147                 break;
148         default:
149                 usage();
150         }ARGEND;
151
152         mntpt = 0;              /* to shut up compiler */
153         if(backwards){
154                 switch(argc) {
155                 default:
156                         mntpt = argv[0];
157                         break;
158                 case 0:
159                         usage();
160                 }
161         } else {
162                 switch(argc) {
163                 case 2:
164                         mntpt = argv[1];
165                         break;
166                 case 3:
167                         mntpt = argv[2];
168                         break;
169                 default:
170                         usage();
171                 }
172         }
173
174         if (encproto == Enctls)
175                 sysfatal("%s: tls has not yet been implemented", argv[0]);
176
177         notify(catcher);
178         alarm(60*1000);
179
180         if(backwards)
181                 fd = passive();
182         else
183                 fd = connect(argv[0], argv[1], oldserver);
184
185         if (!oldserver)
186                 fprint(fd, "impo %s %s\n", filterp? "aan": "nofilter",
187                         encprotos[encproto]);
188
189         if (encproto != Encnone && ealgs && ai) {
190                 uchar key[16];
191                 uchar digest[SHA1dlen];
192                 char fromclientsecret[21];
193                 char fromserversecret[21];
194                 int i;
195
196                 memmove(key+4, ai->secret, ai->nsecret);
197
198                 /* exchange random numbers */
199                 srand(truerand());
200                 for(i = 0; i < 4; i++)
201                         key[i] = rand();
202                 if(write(fd, key, 4) != 4)
203                         sysfatal("can't write key part: %r");
204                 if(readn(fd, key+12, 4) != 4)
205                         sysfatal("can't read key part: %r");
206
207                 /* scramble into two secrets */
208                 sha1(key, sizeof(key), digest, nil);
209                 mksecret(fromclientsecret, digest);
210                 mksecret(fromserversecret, digest+10);
211
212                 if (filterp)
213                         fd = filter(fd, filterp, argv[0]);
214
215                 /* set up encryption */
216                 procsetname("pushssl");
217                 fd = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
218                 if(fd < 0)
219                         sysfatal("can't establish ssl connection: %r");
220         }
221         else if (filterp)
222                 fd = filter(fd, filterp, argv[0]);
223
224         if(srvpost){
225                 sprint(srvfile, "/srv/%s", srvpost);
226                 remove(srvfile);
227                 post(srvfile, srvpost, fd);
228         }
229         procsetname("mount on %s", mntpt);
230         if(mount(fd, -1, mntpt, mntflags, "") < 0)
231                 sysfatal("can't mount %s: %r", argv[1]);
232         alarm(0);
233
234         if(backwards && argc > 1){
235                 exec(argv[1], &argv[1]);
236                 sysfatal("exec: %r");
237         }
238         exits(0);
239 }
240
241 void
242 catcher(void*, char *msg)
243 {
244         timedout = 1;
245         if(strcmp(msg, "alarm") == 0)
246                 noted(NCONT);
247         noted(NDFLT);
248 }
249
250 int
251 old9p(int fd)
252 {
253         int p[2];
254
255         procsetname("old9p");
256         if(pipe(p) < 0)
257                 sysfatal("pipe: %r");
258
259         switch(rfork(RFPROC|RFFDG|RFNAMEG)) {
260         case -1:
261                 sysfatal("rfork srvold9p: %r");
262         case 0:
263                 if(fd != 1){
264                         dup(fd, 1);
265                         close(fd);
266                 }
267                 if(p[0] != 0){
268                         dup(p[0], 0);
269                         close(p[0]);
270                 }
271                 close(p[1]);
272                 if(0){
273                         fd = open("/sys/log/cpu", OWRITE);
274                         if(fd != 2){
275                                 dup(fd, 2);
276                                 close(fd);
277                         }
278                         execl("/bin/srvold9p", "srvold9p", "-ds", nil);
279                 } else
280                         execl("/bin/srvold9p", "srvold9p", "-s", nil);
281                 sysfatal("exec srvold9p: %r");
282         default:
283                 close(fd);
284                 close(p[0]);
285         }
286         return p[1];
287 }
288
289 int
290 connect(char *system, char *tree, int oldserver)
291 {
292         char buf[ERRMAX], dir[128], *na;
293         int fd, n;
294         char *authp;
295
296         na = netmkaddr(system, 0, "exportfs");
297         procsetname("dial %s", na);
298         if((fd = dial(na, 0, dir, 0)) < 0)
299                 sysfatal("can't dial %s: %r", system);
300
301         if(doauth){
302                 if(oldserver)
303                         authp = "p9sk2";
304                 else
305                         authp = "p9any";
306
307                 procsetname("auth_proxy auth_getkey proto=%q role=client %s",
308                         authp, keyspec);
309                 ai = auth_proxy(fd, auth_getkey, "proto=%q role=client %s",
310                         authp, keyspec);
311                 if(ai == nil)
312                         sysfatal("%r: %s", system);
313         }
314
315         procsetname("writing tree name %s", tree);
316         n = write(fd, tree, strlen(tree));
317         if(n < 0)
318                 sysfatal("can't write tree: %r");
319
320         strcpy(buf, "can't read tree");
321
322         procsetname("awaiting OK for %s", tree);
323         n = read(fd, buf, sizeof buf - 1);
324         if(n!=2 || buf[0]!='O' || buf[1]!='K'){
325                 if (timedout)
326                         sysfatal("timed out connecting to %s", na);
327                 buf[sizeof buf - 1] = '\0';
328                 sysfatal("bad remote tree: %s", buf);
329         }
330
331         if(oldserver)
332                 return old9p(fd);
333         return fd;
334 }
335
336 int
337 passive(void)
338 {
339         int fd;
340
341         /*
342          * Ignore doauth==0 on purpose.  Is it useful here?
343          */
344
345         procsetname("auth_proxy auth_getkey proto=p9any role=server");
346         ai = auth_proxy(0, auth_getkey, "proto=p9any role=server");
347         if(ai == nil)
348                 sysfatal("auth_proxy: %r");
349         if(auth_chuid(ai, nil) < 0)
350                 sysfatal("auth_chuid: %r");
351         putenv("service", "import");
352
353         fd = dup(0, -1);
354         close(0);
355         open("/dev/null", ORDWR);
356         close(1);
357         open("/dev/null", ORDWR);
358
359         return fd;
360 }
361
362 void
363 usage(void)
364 {
365         fprint(2, "usage: import [-abcC] [-A] [-E clear|ssl|tls] "
366 "[-e 'crypt auth'|clear] [-k keypattern] [-p] host remotefs [mountpoint]\n");
367         exits("usage");
368 }
369
370 /* Network on fd1, mount driver on fd0 */
371 int
372 filter(int fd, char *cmd, char *host)
373 {
374         int p[2], len, argc;
375         char newport[256], buf[256], *s;
376         char *argv[16], *file, *pbuf;
377
378         if ((len = read(fd, newport, sizeof newport - 1)) < 0)
379                 sysfatal("filter: cannot write port; %r");
380         newport[len] = '\0';
381
382         if ((s = strchr(newport, '!')) == nil)
383                 sysfatal("filter: illegally formatted port %s", newport);
384
385         strecpy(buf, buf+sizeof buf, netmkaddr(host, "tcp", "0"));
386         pbuf = strrchr(buf, '!');
387         strecpy(pbuf, buf+sizeof buf, s);
388
389         if(debug)
390                 fprint(2, "filter: remote port %s\n", newport);
391
392         argc = tokenize(cmd, argv, nelem(argv)-2);
393         if (argc == 0)
394                 sysfatal("filter: empty command");
395         argv[argc++] = "-c";
396         argv[argc++] = buf;
397         argv[argc] = nil;
398         file = argv[0];
399         if (s = strrchr(argv[0], '/'))
400                 argv[0] = s+1;
401
402         if(pipe(p) < 0)
403                 sysfatal("pipe: %r");
404
405         switch(rfork(RFNOWAIT|RFPROC|RFFDG)) {
406         case -1:
407                 sysfatal("rfork record module: %r");
408         case 0:
409                 dup(p[0], 1);
410                 dup(p[0], 0);
411                 close(p[0]);
412                 close(p[1]);
413                 exec(file, argv);
414                 sysfatal("exec record module: %r");
415         default:
416                 close(fd);
417                 close(p[0]);
418         }
419         return p[1];
420 }
421
422 static void
423 mksecret(char *t, uchar *f)
424 {
425         sprint(t, "%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux",
426                 f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
427 }