]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/import.c
libc: add procsetname()
[plan9front.git] / sys / src / cmd / import.c
1 #include <u.h>
2 #include <libc.h>
3 #include <auth.h>
4 #include <libsec.h>
5
6 enum {
7         Encnone,
8         Encssl,
9         Enctls,
10 };
11
12 static char *encprotos[] = {
13         [Encnone] =     "clear",
14         [Encssl] =      "ssl",
15         [Enctls] =      "tls",
16                         nil,
17 };
18
19 char            *keyspec = "";
20 char            *filterp;
21 char            *ealgs = "rc4_256 sha1";
22 int             encproto = Encnone;
23 char            *aan = "/bin/aan";
24 char            *anstring  = "tcp!*!0";
25 AuthInfo        *ai;
26 int             debug;
27 int             doauth = 1;
28 int             timedout;
29 int             skiptree;
30
31 int     connect(char*, char*);
32 int     passive(void);
33 void    catcher(void*, char*);
34 void    sysfatal(char*, ...);
35 void    usage(void);
36 int     filter(int, char *, char *);
37
38 static void     mksecret(char *, uchar *);
39
40 void
41 post(char *name, char *envname, int srvfd)
42 {
43         int fd;
44         char buf[32];
45
46         fd = create(name, OWRITE, 0600);
47         if(fd < 0)
48                 return;
49         snprint(buf, sizeof(buf), "%d", srvfd);
50         if(write(fd, buf, strlen(buf)) != strlen(buf))
51                 sysfatal("srv write: %r");
52         close(fd);
53         putenv(envname, name);
54 }
55
56 static int
57 lookup(char *s, char *l[])
58 {
59         int i;
60
61         for (i = 0; l[i] != 0; i++)
62                 if (strcmp(l[i], s) == 0)
63                         return i;
64         return -1;
65 }
66
67 void
68 main(int argc, char **argv)
69 {
70         char *mntpt, *srvpost, srvfile[64];
71         int backwards = 0, fd, mntflags;
72
73         quotefmtinstall();
74         srvpost = nil;
75         mntflags = MREPL;
76         ARGBEGIN{
77         case 'A':
78                 doauth = 0;
79                 break;
80         case 'a':
81                 mntflags = MAFTER;
82                 break;
83         case 'b':
84                 mntflags = MBEFORE;
85                 break;
86         case 'c':
87                 mntflags |= MCREATE;
88                 break;
89         case 'C':
90                 mntflags |= MCACHE;
91                 break;
92         case 'd':
93                 debug++;
94                 break;
95         case 'f':
96                 /* ignored but allowed for compatibility */
97                 break;
98         case 'E':
99                 if ((encproto = lookup(EARGF(usage()), encprotos)) < 0)
100                         usage();
101                 break;
102         case 'e':
103                 ealgs = EARGF(usage());
104                 if(*ealgs == 0 || strcmp(ealgs, "clear") == 0)
105                         ealgs = nil;
106                 break;
107         case 'k':
108                 keyspec = EARGF(usage());
109                 break;
110         case 'p':
111                 filterp = aan;
112                 break;
113         case 'n':
114                 anstring = EARGF(usage());
115                 break;
116         case 's':
117                 srvpost = EARGF(usage());
118                 break;
119         case 'B':
120                 backwards = 1;
121                 break;
122         case 'z':
123                 skiptree = 1;
124                 break;
125         default:
126                 usage();
127         }ARGEND;
128
129         mntpt = 0;              /* to shut up compiler */
130         if(backwards){
131                 switch(argc) {
132                 default:
133                         mntpt = argv[0];
134                         break;
135                 case 0:
136                         usage();
137                 }
138         } else {
139                 switch(argc) {
140                 case 2:
141                         mntpt = argv[1];
142                         break;
143                 case 3:
144                         mntpt = argv[2];
145                         break;
146                 default:
147                         usage();
148                 }
149         }
150
151         if (encproto == Enctls)
152                 sysfatal("%s: tls has not yet been implemented", argv[0]);
153
154         notify(catcher);
155         alarm(60*1000);
156
157         if (backwards)
158                 fd = passive();
159         else
160                 fd = connect(argv[0], argv[1]);
161
162         fprint(fd, "impo %s %s\n", filterp? "aan": "nofilter", encprotos[encproto]);
163
164         if (encproto != Encnone && ealgs && ai) {
165                 uchar key[16], digest[SHA1dlen];
166                 char fromclientsecret[21];
167                 char fromserversecret[21];
168                 int i;
169
170                 if(ai->nsecret < 8)
171                         sysfatal("secret too small to ssl");
172                 memmove(key+4, ai->secret, 8);
173
174                 /* exchange random numbers */
175                 srand(truerand());
176                 for(i = 0; i < 4; i++)
177                         key[i] = rand();
178                 if(write(fd, key, 4) != 4)
179                         sysfatal("can't write key part: %r");
180                 if(readn(fd, key+12, 4) != 4)
181                         sysfatal("can't read key part: %r");
182
183                 /* scramble into two secrets */
184                 sha1(key, sizeof(key), digest, nil);
185                 mksecret(fromclientsecret, digest);
186                 mksecret(fromserversecret, digest+10);
187
188                 if (filterp)
189                         fd = filter(fd, filterp, backwards ? nil : argv[0]);
190
191                 /* set up encryption */
192                 procsetname("pushssl");
193                 fd = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
194                 if(fd < 0)
195                         sysfatal("can't establish ssl connection: %r");
196         }
197         else if (filterp)
198                 fd = filter(fd, filterp, backwards ? nil : argv[0]);
199
200         if(ai)
201                 auth_freeAI(ai);
202
203         if(srvpost){
204                 snprint(srvfile, sizeof(srvfile), "/srv/%s", srvpost);
205                 remove(srvfile);
206                 post(srvfile, srvpost, fd);
207         }
208         procsetname("mount on %s", mntpt);
209         if(mount(fd, -1, mntpt, mntflags, "") < 0)
210                 sysfatal("can't mount %s: %r", argv[1]);
211         alarm(0);
212
213         if(backwards && argc > 1){
214                 exec(argv[1], &argv[1]);
215                 sysfatal("exec: %r");
216         }
217         exits(0);
218 }
219
220 void
221 catcher(void*, char *msg)
222 {
223         timedout = 1;
224         if(strcmp(msg, "alarm") == 0)
225                 noted(NCONT);
226         noted(NDFLT);
227 }
228
229 int
230 connect(char *system, char *tree)
231 {
232         char buf[ERRMAX], dir[128], *na;
233         int fd, n;
234
235         na = netmkaddr(system, 0, "exportfs");
236         procsetname("dial %s", na);
237         if((fd = dial(na, 0, dir, 0)) < 0)
238                 sysfatal("can't dial %s: %r", system);
239
240         if(doauth){
241                 procsetname("auth_proxy auth_getkey proto=p9any role=client %s", keyspec);
242                 ai = auth_proxy(fd, auth_getkey, "proto=p9any role=client %s", keyspec);
243                 if(ai == nil)
244                         sysfatal("%r: %s", system);
245         }
246
247         if(!skiptree){
248                 procsetname("writing tree name %s", tree);
249                 n = write(fd, tree, strlen(tree));
250                 if(n < 0)
251                         sysfatal("can't write tree: %r");
252
253                 strcpy(buf, "can't read tree");
254
255                 procsetname("awaiting OK for %s", tree);
256                 n = read(fd, buf, sizeof buf - 1);
257                 if(n!=2 || buf[0]!='O' || buf[1]!='K'){
258                         if (timedout)
259                                 sysfatal("timed out connecting to %s", na);
260                         buf[sizeof buf - 1] = '\0';
261                         sysfatal("bad remote tree: %s", buf);
262                 }
263         }
264         return fd;
265 }
266
267 int
268 passive(void)
269 {
270         int fd;
271
272         /*
273          * Ignore doauth==0 on purpose.  Is it useful here?
274          */
275
276         procsetname("auth_proxy auth_getkey proto=p9any role=server");
277         ai = auth_proxy(0, auth_getkey, "proto=p9any role=server");
278         if(ai == nil)
279                 sysfatal("auth_proxy: %r");
280         if(auth_chuid(ai, nil) < 0)
281                 sysfatal("auth_chuid: %r");
282         putenv("service", "import");
283
284         fd = dup(0, -1);
285         close(0);
286         open("/dev/null", ORDWR);
287         close(1);
288         open("/dev/null", ORDWR);
289
290         return fd;
291 }
292
293 void
294 usage(void)
295 {
296         fprint(2, "usage: import [-abcC] [-A] [-E clear|ssl|tls] "
297 "[-e 'crypt auth'|clear] [-k keypattern] [-p] [-n address ] [-z] host remotefs [mountpoint]\n");
298         exits("usage");
299 }
300
301 int
302 filter(int fd, char *cmd, char *host)
303 {
304         char addr[128], buf[256], *s, *file, *argv[16];
305         int lfd, p[2], len, argc;
306
307         if(host == nil){
308                 /* Get a free port and post it to the client. */
309                 if (announce(anstring, addr) < 0)
310                         sysfatal("filter: Cannot announce %s: %r", anstring);
311
312                 snprint(buf, sizeof(buf), "%s/local", addr);
313                 if ((lfd = open(buf, OREAD)) < 0)
314                         sysfatal("filter: Cannot open %s: %r", buf);
315                 if ((len = read(lfd, buf, sizeof buf - 1)) < 0)
316                         sysfatal("filter: Cannot read %s: %r", buf);
317                 close(lfd);
318                 buf[len] = '\0';
319                 if ((s = strchr(buf, '\n')) != nil)
320                         len = s - buf;
321                 if (write(fd, buf, len) != len) 
322                         sysfatal("filter: cannot write port; %r");
323         } else {
324                 /* Read address string from connection */
325                 if ((len = read(fd, buf, sizeof buf - 1)) < 0)
326                         sysfatal("filter: cannot write port; %r");
327                 buf[len] = '\0';
328
329                 if ((s = strrchr(buf, '!')) == nil)
330                         sysfatal("filter: illegally formatted port %s", buf);
331                 strecpy(addr, addr+sizeof(addr), netmkaddr(host, "tcp", s+1));
332                 strecpy(strrchr(addr, '!'), addr+sizeof(addr), s);
333         }
334
335         if(debug)
336                 fprint(2, "filter: %s\n", addr);
337
338         snprint(buf, sizeof(buf), "%s", cmd);
339         argc = tokenize(buf, argv, nelem(argv)-3);
340         if (argc == 0)
341                 sysfatal("filter: empty command");
342
343         if(host != nil)
344                 argv[argc++] = "-c";
345         argv[argc++] = addr;
346         argv[argc] = nil;
347
348         file = argv[0];
349         if((s = strrchr(argv[0], '/')) != nil)
350                 argv[0] = s+1;
351
352         if(pipe(p) < 0)
353                 sysfatal("pipe: %r");
354
355         switch(rfork(RFNOWAIT|RFPROC|RFMEM|RFFDG|RFREND)) {
356         case -1:
357                 sysfatal("filter: rfork; %r\n");
358         case 0:
359                 close(fd);
360                 if (dup(p[0], 1) < 0)
361                         sysfatal("filter: Cannot dup to 1; %r");
362                 if (dup(p[0], 0) < 0)
363                         sysfatal("filter: Cannot dup to 0; %r");
364                 close(p[0]);
365                 close(p[1]);
366                 exec(file, argv);
367                 sysfatal("filter: exec; %r");
368         default:
369                 dup(p[1], fd);
370                 close(p[0]);
371                 close(p[1]);
372         }
373         return fd;
374 }
375
376 static void
377 mksecret(char *t, uchar *f)
378 {
379         sprint(t, "%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux",
380                 f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
381 }