]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/ssh/sshserve.c
kernel: keep segment locked for data2txt
[plan9front.git] / sys / src / cmd / ssh / sshserve.c
1 #include "ssh.h"
2
3 char *cipherlist = "blowfish rc4 3des";
4 char *authlist = "tis";
5
6 void fromnet(Conn*);
7 void startcmd(Conn*, char*, int*, int*);
8 int maxmsg = 256*1024;
9
10 Cipher *allcipher[] = {
11         &cipherrc4,
12         &cipherblowfish,
13         &cipher3des,
14         &cipherdes,
15         &ciphernone,
16         &ciphertwiddle,
17 };
18
19 Authsrv *allauthsrv[] = {
20         &authsrvpassword,
21         &authsrvtis,
22 };
23
24 Cipher*
25 findcipher(char *name, Cipher **list, int nlist)
26 {
27         int i;
28
29         for(i=0; i<nlist; i++)
30                 if(strcmp(name, list[i]->name) == 0)
31                         return list[i];
32         error("unknown cipher %s", name);
33         return nil;
34 }
35
36 Authsrv*
37 findauthsrv(char *name, Authsrv **list, int nlist)
38 {
39         int i;
40
41         for(i=0; i<nlist; i++)
42                 if(strcmp(name, list[i]->name) == 0)
43                         return list[i];
44         error("unknown authsrv %s", name);
45         return nil;
46 }
47
48 void
49 usage(void)
50 {
51         fprint(2, "usage: sshserve [-A authlist] [-c cipherlist] client-ip-address\n");
52         exits("usage");
53 }
54
55 void
56 main(int argc, char **argv)
57 {
58         char *f[16];
59         int i;
60         Conn c;
61
62         fmtinstall('B', mpfmt);
63         fmtinstall('H', encodefmt);
64         atexit(atexitkiller);
65         atexitkill(getpid());
66
67         memset(&c, 0, sizeof c);
68
69         ARGBEGIN{
70         case 'D':
71                 debuglevel = atoi(EARGF(usage()));
72                 break;
73         case 'A':
74                 authlist = EARGF(usage());
75                 break;
76         case 'c':
77                 cipherlist = EARGF(usage());
78                 break;
79         default:
80                 usage();
81         }ARGEND
82
83         if(argc != 1)
84                 usage();
85         c.host = argv[0];
86
87         sshlog("connect from %s", c.host);
88
89         /* limit of 768 bits in remote host key? */
90         c.serverpriv = rsagen(768, 6, 0);
91         if(c.serverpriv == nil)
92                 sysfatal("rsagen failed: %r");
93         c.serverkey = &c.serverpriv->pub;
94
95         c.nokcipher = getfields(cipherlist, f, nelem(f), 1, ", ");
96         c.okcipher = emalloc(sizeof(Cipher*)*c.nokcipher);
97         for(i=0; i<c.nokcipher; i++)
98                 c.okcipher[i] = findcipher(f[i], allcipher, nelem(allcipher));
99
100         c.nokauthsrv = getfields(authlist, f, nelem(f), 1, ", ");
101         c.okauthsrv = emalloc(sizeof(Authsrv*)*c.nokauthsrv);
102         for(i=0; i<c.nokauthsrv; i++)
103                 c.okauthsrv[i] = findauthsrv(f[i], allauthsrv, nelem(allauthsrv));
104
105         sshserverhandshake(&c);
106
107         fromnet(&c);
108 }
109
110 void
111 fromnet(Conn *c)
112 {
113         int infd, kidpid, n;
114         char *cmd;
115         Msg *m;
116
117         infd = kidpid = -1;
118         for(;;){
119                 m = recvmsg(c, -1);
120                 if(m == nil)
121                         exits(nil);
122                 switch(m->type){
123                 default:
124                         //badmsg(m, 0);
125                         sendmsg(allocmsg(c, SSH_SMSG_FAILURE, 0));
126                         break;
127
128                 case SSH_MSG_DISCONNECT:
129                         sysfatal("client disconnected");
130
131                 case SSH_CMSG_REQUEST_PTY:
132                         sendmsg(allocmsg(c, SSH_SMSG_SUCCESS, 0));
133                         break;
134
135                 case SSH_CMSG_X11_REQUEST_FORWARDING:
136                         sendmsg(allocmsg(c, SSH_SMSG_FAILURE, 0));
137                         break;
138
139                 case SSH_CMSG_MAX_PACKET_SIZE:
140                         maxmsg = getlong(m);
141                         sendmsg(allocmsg(c, SSH_SMSG_SUCCESS, 0));
142                         break;
143
144                 case SSH_CMSG_REQUEST_COMPRESSION:
145                         sendmsg(allocmsg(c, SSH_SMSG_FAILURE, 0));
146                         break;
147
148                 case SSH_CMSG_EXEC_SHELL:
149                         startcmd(c, nil, &kidpid, &infd);
150                         goto InteractiveMode;
151
152                 case SSH_CMSG_EXEC_CMD:
153                         cmd = getstring(m);
154                         startcmd(c, cmd, &kidpid, &infd);
155                         goto InteractiveMode;
156                 }
157                 free(m);
158         }
159
160 InteractiveMode:
161         for(;;){
162                 free(m);
163                 m = recvmsg(c, -1);
164                 if(m == nil)
165                         exits(nil);
166                 switch(m->type){
167                 default:
168                         badmsg(m, 0);
169
170                 case SSH_MSG_DISCONNECT:
171                         postnote(PNGROUP, kidpid, "hangup");
172                         sysfatal("client disconnected");
173
174                 case SSH_CMSG_STDIN_DATA:
175                         if(infd != 0){
176                                 n = getlong(m);
177                                 write(infd, getbytes(m, n), n);
178                         }
179                         break;
180
181                 case SSH_CMSG_EOF:
182                         close(infd);
183                         infd = -1;
184                         break;
185
186                 case SSH_CMSG_EXIT_CONFIRMATION:
187                         /* sent by some clients as dying breath */
188                         exits(nil);
189         
190                 case SSH_CMSG_WINDOW_SIZE:
191                         /* we don't care */
192                         break;
193                 }
194         }
195 }
196
197 void
198 copyout(Conn *c, int fd, int mtype)
199 {
200         char buf[8192];
201         int n, max, pid;
202         Msg *m;
203
204         max = sizeof buf;
205         if(max > maxmsg - 32)   /* 32 is an overestimate of packet overhead */
206                 max = maxmsg - 32;
207         if(max <= 0)
208                 sysfatal("maximum message size too small");
209         
210         switch(pid = rfork(RFPROC|RFMEM|RFNOWAIT)){
211         case -1:
212                 sysfatal("fork: %r");
213         case 0:
214                 break;
215         default:
216                 atexitkill(pid);
217                 return;
218         }
219
220         while((n = read(fd, buf, max)) > 0){
221                 m = allocmsg(c, mtype, 4+n);
222                 putlong(m, n);
223                 putbytes(m, buf, n);
224                 sendmsg(m);
225         }
226         exits(nil);
227 }
228
229 void
230 startcmd(Conn *c, char *cmd, int *kidpid, int *kidin)
231 {
232         int i, pid, kpid;
233         int pfd[3][2];
234         char *dir;
235         char *sysname, *tz;
236         Msg *m;
237         Waitmsg *w;
238
239         for(i=0; i<3; i++)
240                 if(pipe(pfd[i]) < 0)
241                         sysfatal("pipe: %r");
242
243         sysname = getenv("sysname");
244         tz = getenv("timezone");
245
246         switch(pid = rfork(RFPROC|RFMEM|RFNOWAIT)){
247         case -1:
248                 sysfatal("fork: %r");
249         case 0:
250                 switch(kpid = rfork(RFPROC|RFNOTEG|RFENVG|RFFDG)){
251                 case -1:
252                         sysfatal("fork: %r");
253                 case 0:
254                         for(i=0; i<3; i++){
255                                 if(dup(pfd[i][1], i) < 0)
256                                         sysfatal("dup: %r");
257                                 close(pfd[i][0]);
258                                 close(pfd[i][1]);
259                         }
260                         putenv("user", c->user);
261                         if(sysname)
262                                 putenv("sysname", sysname);
263                         if(tz)
264                                 putenv("tz", tz);
265         
266                         dir = smprint("/usr/%s", c->user);
267                         if(dir == nil || chdir(dir) < 0)
268                                 chdir("/");
269                         if(cmd){
270                                 putenv("service", "rx");
271                                 execl("/bin/rc", "rc", "-lc", cmd, nil);
272                                 sysfatal("cannot exec /bin/rc: %r");
273                         }else{
274                                 putenv("service", "con");
275                                 execl("/bin/ip/telnetd", "telnetd", "-tn", nil);
276                                 sysfatal("cannot exec /bin/ip/telnetd: %r");
277                         }
278                 default:
279                         *kidpid = kpid;
280                         rendezvous(kidpid, 0);
281                         for(;;){
282                                 if((w = wait()) == nil)
283                                         sysfatal("wait: %r");
284                                 if(w->pid == kpid)
285                                         break;
286                                 free(w);
287                         }
288                         if(w->msg[0]){
289                                 m = allocmsg(c, SSH_MSG_DISCONNECT, 4+strlen(w->msg));
290                                 putstring(m, w->msg);
291                                 sendmsg(m);
292                         }else{
293                                 m = allocmsg(c, SSH_SMSG_EXITSTATUS, 4);
294                                 putlong(m, 0);
295                                 sendmsg(m);
296                         }
297                         for(i=0; i<3; i++)
298                                 close(pfd[i][0]);
299                         free(w);
300                         exits(nil);     
301                         break;
302                 }
303         default:
304                 atexitkill(pid);
305                 rendezvous(kidpid, 0);
306                 break;
307         }
308
309         for(i=0; i<3; i++)
310                 close(pfd[i][1]);
311
312         copyout(c, pfd[1][0], SSH_SMSG_STDOUT_DATA);
313         copyout(c, pfd[2][0], SSH_SMSG_STDERR_DATA);
314         *kidin = pfd[0][0];
315 }