]> git.lizzy.rs Git - plan9front.git/blob - sys/src/9/port/devssl.c
merge
[plan9front.git] / sys / src / 9 / port / devssl.c
1 /*
2  *  devssl - secure sockets layer
3  */
4 #include        "u.h"
5 #include        "../port/lib.h"
6 #include        "mem.h"
7 #include        "dat.h"
8 #include        "fns.h"
9 #include        "../port/error.h"
10
11 #include        <libsec.h>
12
13 #define NOSPOOKS 1
14
15 typedef struct OneWay OneWay;
16 struct OneWay
17 {
18         QLock   q;
19         QLock   ctlq;
20
21         void    *state;         /* encryption state */
22         int     slen;           /* hash data length */
23         uchar   *secret;        /* secret */
24         ulong   mid;            /* message id */
25 };
26
27 enum
28 {
29         /* connection states */
30         Sincomplete=    0,
31         Sclear=         1,
32         Sencrypting=    2,
33         Sdigesting=     4,
34         Sdigenc=        Sencrypting|Sdigesting,
35
36         /* encryption algorithms */
37         Noencryption=   0,
38         DESCBC=         1,
39         DESECB=         2,
40         RC4=            3
41 };
42
43 typedef struct Dstate Dstate;
44 struct Dstate
45 {
46         Chan    *c;             /* io channel */
47         uchar   state;          /* state of connection */
48         int     ref;            /* serialized by dslock for atomic destroy */
49
50         uchar   encryptalg;     /* encryption algorithm */
51         ushort  blocklen;       /* blocking length */
52
53         ushort  diglen;         /* length of digest */
54         DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);        /* hash func */
55
56         /* for SSL format */
57         int     max;            /* maximum unpadded data per msg */
58         int     maxpad;         /* maximum padded data per msg */
59
60         /* input side */
61         OneWay  in;
62         Block   *processed;
63         Block   *unprocessed;
64
65         /* output side */
66         OneWay  out;
67
68         /* protections */
69         char    *user;
70         int     perm;
71 };
72
73 enum
74 {
75         Maxdmsg=        1<<16,
76         Maxdstate=      512,    /* max. open ssl conn's; must be a power of 2 */
77 };
78
79 static  Lock    dslock;
80 static  int     dshiwat;
81 static  char    *dsname[Maxdstate];
82 static  Dstate  *dstate[Maxdstate];
83 static  char    *encalgs;
84 static  char    *hashalgs;
85
86 enum{
87         Qtopdir         = 1,    /* top level directory */
88         Qprotodir,
89         Qclonus,
90         Qconvdir,               /* directory for a conversation */
91         Qdata,
92         Qctl,
93         Qsecretin,
94         Qsecretout,
95         Qencalgs,
96         Qhashalgs,
97 };
98
99 #define TYPE(x)         ((x).path & 0xf)
100 #define CONV(x)         (((x).path >> 5)&(Maxdstate-1))
101 #define QID(c, y)       (((c)<<5) | (y))
102
103 static void     ensure(Dstate*, Block**, int);
104 static void     consume(Block**, uchar*, int);
105 static void     setsecret(OneWay*, uchar*, int);
106 static Block*   encryptb(Dstate*, Block*, int);
107 static Block*   decryptb(Dstate*, Block*);
108 static Block*   digestb(Dstate*, Block*, int);
109 static void     checkdigestb(Dstate*, Block*);
110 static Chan*    buftochan(char*);
111 static void     sslhangup(Dstate*);
112 static Dstate*  dsclone(Chan *c);
113 static void     dsnew(Chan *c, Dstate **);
114 static long     sslput(Dstate *s, Block * volatile b);
115
116 char *sslnames[] = {
117 [Qclonus]       "clone",
118 [Qdata]         "data",
119 [Qctl]          "ctl",
120 [Qsecretin]     "secretin",
121 [Qsecretout]    "secretout",
122 [Qencalgs]      "encalgs",
123 [Qhashalgs]     "hashalgs",
124 };
125
126 static int
127 sslgen(Chan *c, char*, Dirtab *d, int nd, int s, Dir *dp)
128 {
129         Qid q;
130         Dstate *ds;
131         char name[16], *p, *nm;
132         int ft;
133
134         USED(nd);
135         USED(d);
136
137         q.type = QTFILE;
138         q.vers = 0;
139
140         ft = TYPE(c->qid);
141         switch(ft) {
142         case Qtopdir:
143                 if(s == DEVDOTDOT){
144                         q.path = QID(0, Qtopdir);
145                         q.type = QTDIR;
146                         devdir(c, q, "#D", 0, eve, 0555, dp);
147                         return 1;
148                 }
149                 if(s > 0)
150                         return -1;
151                 q.path = QID(0, Qprotodir);
152                 q.type = QTDIR;
153                 devdir(c, q, "ssl", 0, eve, 0555, dp);
154                 return 1;
155         case Qprotodir:
156                 if(s == DEVDOTDOT){
157                         q.path = QID(0, Qtopdir);
158                         q.type = QTDIR;
159                         devdir(c, q, ".", 0, eve, 0555, dp);
160                         return 1;
161                 }
162                 if(s < dshiwat) {
163                         q.path = QID(s, Qconvdir);
164                         q.type = QTDIR;
165                         ds = dstate[s];
166                         if(ds != 0)
167                                 nm = ds->user;
168                         else
169                                 nm = eve;
170                         if(dsname[s] == nil){
171                                 sprint(name, "%d", s);
172                                 kstrdup(&dsname[s], name);
173                         }
174                         devdir(c, q, dsname[s], 0, nm, 0555, dp);
175                         return 1;
176                 }
177                 if(s > dshiwat)
178                         return -1;
179                 q.path = QID(0, Qclonus);
180                 devdir(c, q, "clone", 0, eve, 0555, dp);
181                 return 1;
182         case Qconvdir:
183                 if(s == DEVDOTDOT){
184                         q.path = QID(0, Qprotodir);
185                         q.type = QTDIR;
186                         devdir(c, q, "ssl", 0, eve, 0555, dp);
187                         return 1;
188                 }
189                 ds = dstate[CONV(c->qid)];
190                 if(ds != 0)
191                         nm = ds->user;
192                 else
193                         nm = eve;
194                 switch(s) {
195                 default:
196                         return -1;
197                 case 0:
198                         q.path = QID(CONV(c->qid), Qctl);
199                         p = "ctl";
200                         break;
201                 case 1:
202                         q.path = QID(CONV(c->qid), Qdata);
203                         p = "data";
204                         break;
205                 case 2:
206                         q.path = QID(CONV(c->qid), Qsecretin);
207                         p = "secretin";
208                         break;
209                 case 3:
210                         q.path = QID(CONV(c->qid), Qsecretout);
211                         p = "secretout";
212                         break;
213                 case 4:
214                         q.path = QID(CONV(c->qid), Qencalgs);
215                         p = "encalgs";
216                         break;
217                 case 5:
218                         q.path = QID(CONV(c->qid), Qhashalgs);
219                         p = "hashalgs";
220                         break;
221                 }
222                 devdir(c, q, p, 0, nm, 0660, dp);
223                 return 1;
224         case Qclonus:
225                 devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, eve, 0555, dp);
226                 return 1;
227         default:
228                 ds = dstate[CONV(c->qid)];
229                 if(ds != 0)
230                         nm = ds->user;
231                 else
232                         nm = eve;
233                 devdir(c, c->qid, sslnames[TYPE(c->qid)], 0, nm, 0660, dp);
234                 return 1;
235         }
236 }
237
238 static Chan*
239 sslattach(char *spec)
240 {
241         Chan *c;
242
243         c = devattach('D', spec);
244         c->qid.path = QID(0, Qtopdir);
245         c->qid.vers = 0;
246         c->qid.type = QTDIR;
247         return c;
248 }
249
250 static Walkqid*
251 sslwalk(Chan *c, Chan *nc, char **name, int nname)
252 {
253         return devwalk(c, nc, name, nname, nil, 0, sslgen);
254 }
255
256 static int
257 sslstat(Chan *c, uchar *db, int n)
258 {
259         return devstat(c, db, n, nil, 0, sslgen);
260 }
261
262 static Chan*
263 sslopen(Chan *c, int omode)
264 {
265         Dstate *s, **pp;
266         int perm;
267         int ft;
268
269         perm = 0;
270         omode &= 3;
271         switch(omode) {
272         case OREAD:
273                 perm = 4;
274                 break;
275         case OWRITE:
276                 perm = 2;
277                 break;
278         case ORDWR:
279                 perm = 6;
280                 break;
281         }
282
283         ft = TYPE(c->qid);
284         switch(ft) {
285         default:
286                 panic("sslopen");
287         case Qtopdir:
288         case Qprotodir:
289         case Qconvdir:
290                 if(omode != OREAD)
291                         error(Eperm);
292                 break;
293         case Qclonus:
294                 s = dsclone(c);
295                 if(s == 0)
296                         error(Enodev);
297                 break;
298         case Qctl:
299         case Qdata:
300         case Qsecretin:
301         case Qsecretout:
302                 if(waserror()) {
303                         unlock(&dslock);
304                         nexterror();
305                 }
306                 lock(&dslock);
307                 pp = &dstate[CONV(c->qid)];
308                 s = *pp;
309                 if(s == 0)
310                         dsnew(c, pp);
311                 else {
312                         if((perm & (s->perm>>6)) != perm
313                            && (strcmp(up->user, s->user) != 0
314                              || (perm & s->perm) != perm))
315                                 error(Eperm);
316
317                         s->ref++;
318                 }
319                 unlock(&dslock);
320                 poperror();
321                 break;
322         case Qencalgs:
323         case Qhashalgs:
324                 if(omode != OREAD)
325                         error(Eperm);
326                 break;
327         }
328         c->mode = openmode(omode);
329         c->flag |= COPEN;
330         c->offset = 0;
331         return c;
332 }
333
334 static int
335 sslwstat(Chan *c, uchar *db, int n)
336 {
337         Dir *dir;
338         Dstate *s;
339         int m;
340
341         s = dstate[CONV(c->qid)];
342         if(s == 0)
343                 error(Ebadusefd);
344         if(strcmp(s->user, up->user) != 0)
345                 error(Eperm);
346
347         dir = smalloc(sizeof(Dir)+n);
348         m = convM2D(db, n, &dir[0], (char*)&dir[1]);
349         if(m == 0){
350                 free(dir);
351                 error(Eshortstat);
352         }
353
354         if(!emptystr(dir->uid))
355                 kstrdup(&s->user, dir->uid);
356         if(dir->mode != ~0UL)
357                 s->perm = dir->mode;
358
359         free(dir);
360         return m;
361 }
362
363 static void
364 sslclose(Chan *c)
365 {
366         Dstate *s;
367         int ft;
368
369         ft = TYPE(c->qid);
370         switch(ft) {
371         case Qctl:
372         case Qdata:
373         case Qsecretin:
374         case Qsecretout:
375                 if((c->flag & COPEN) == 0)
376                         break;
377
378                 s = dstate[CONV(c->qid)];
379                 if(s == 0)
380                         break;
381
382                 lock(&dslock);
383                 if(--s->ref > 0) {
384                         unlock(&dslock);
385                         break;
386                 }
387                 dstate[CONV(c->qid)] = 0;
388                 unlock(&dslock);
389
390                 if(s->user != nil)
391                         free(s->user);
392                 sslhangup(s);
393                 if(s->c)
394                         cclose(s->c);
395                 if(s->in.secret)
396                         free(s->in.secret);
397                 if(s->out.secret)
398                         free(s->out.secret);
399                 if(s->in.state)
400                         free(s->in.state);
401                 if(s->out.state)
402                         free(s->out.state);
403                 free(s);
404
405         }
406 }
407
408 /*
409  *  make sure we have at least 'n' bytes in list 'l'
410  */
411 static void
412 ensure(Dstate *s, Block **l, int n)
413 {
414         int sofar, i;
415         Block *b, *bl;
416
417         sofar = 0;
418         for(b = *l; b; b = b->next){
419                 sofar += BLEN(b);
420                 if(sofar >= n)
421                         return;
422                 l = &b->next;
423         }
424
425         while(sofar < n){
426                 bl = devtab[s->c->type]->bread(s->c, Maxdmsg, 0);
427                 if(bl == 0)
428                         nexterror();
429                 *l = bl;
430                 i = 0;
431                 for(b = bl; b; b = b->next){
432                         i += BLEN(b);
433                         l = &b->next;
434                 }
435                 if(i == 0)
436                         error(Ehungup);
437                 sofar += i;
438         }
439 }
440
441 /*
442  *  copy 'n' bytes from 'l' into 'p' and free
443  *  the bytes in 'l'
444  */
445 static void
446 consume(Block **l, uchar *p, int n)
447 {
448         Block *b;
449         int i;
450
451         for(; *l && n > 0; n -= i){
452                 b = *l;
453                 i = BLEN(b);
454                 if(i > n)
455                         i = n;
456                 memmove(p, b->rp, i);
457                 b->rp += i;
458                 p += i;
459                 if(BLEN(b) < 0)
460                         panic("consume");
461                 if(BLEN(b))
462                         break;
463                 *l = b->next;
464                 freeb(b);
465         }
466 }
467
468 /*
469  *  give back n bytes
470 static void
471 regurgitate(Dstate *s, uchar *p, int n)
472 {
473         Block *b;
474
475         if(n <= 0)
476                 return;
477         b = s->unprocessed;
478         if(s->unprocessed == nil || b->rp - b->base < n) {
479                 b = allocb(n);
480                 memmove(b->wp, p, n);
481                 b->wp += n;
482                 b->next = s->unprocessed;
483                 s->unprocessed = b;
484         } else {
485                 b->rp -= n;
486                 memmove(b->rp, p, n);
487         }
488 }
489  */
490
491 /*
492  *  remove at most n bytes from the queue, if discard is set
493  *  dump the remainder
494  */
495 static Block*
496 qtake(Block **l, int n, int discard)
497 {
498         Block *nb, *b, *first;
499         int i;
500
501         first = *l;
502         for(b = first; b; b = b->next){
503                 i = BLEN(b);
504                 if(i == n){
505                         if(discard){
506                                 freeblist(b->next);
507                                 *l = 0;
508                         } else
509                                 *l = b->next;
510                         b->next = 0;
511                         return first;
512                 } else if(i > n){
513                         i -= n;
514                         if(discard){
515                                 freeblist(b->next);
516                                 b->wp -= i;
517                                 *l = 0;
518                         } else {
519                                 nb = allocb(i);
520                                 memmove(nb->wp, b->rp+n, i);
521                                 nb->wp += i;
522                                 b->wp -= i;
523                                 nb->next = b->next;
524                                 *l = nb;
525                         }
526                         b->next = 0;
527                         if(BLEN(b) < 0)
528                                 panic("qtake");
529                         return first;
530                 } else
531                         n -= i;
532                 if(BLEN(b) < 0)
533                         panic("qtake");
534         }
535         *l = 0;
536         return first;
537 }
538
539 /*
540  *  We can't let Eintr's lose data since the program
541  *  doing the read may be able to handle it.  The only
542  *  places Eintr is possible is during the read's in consume.
543  *  Therefore, we make sure we can always put back the bytes
544  *  consumed before the last ensure.
545  */
546 static Block*
547 sslbread(Chan *c, long n, ulong)
548 {
549         Dstate * volatile s;
550         Block *b;
551         uchar consumed[3], *p;
552         int toconsume;
553         int len, pad;
554
555         s = dstate[CONV(c->qid)];
556         if(s == 0)
557                 panic("sslbread");
558         if(s->state == Sincomplete)
559                 error(Ebadusefd);
560
561         qlock(&s->in.q);
562         if(waserror()){
563                 qunlock(&s->in.q);
564                 nexterror();
565         }
566
567         if(s->processed == 0){
568                 /*
569                  * Read in the whole message.  Until we've got it all,
570                  * it stays on s->unprocessed, so that if we get Eintr,
571                  * we'll pick up where we left off.
572                  */
573                 ensure(s, &s->unprocessed, 3);
574                 s->unprocessed = pullupblock(s->unprocessed, 2);
575                 p = s->unprocessed->rp;
576                 if(p[0] & 0x80){
577                         len = ((p[0] & 0x7f)<<8) | p[1];
578                         ensure(s, &s->unprocessed, len);
579                         pad = 0;
580                         toconsume = 2;
581                 } else {
582                         s->unprocessed = pullupblock(s->unprocessed, 3);
583                         len = ((p[0] & 0x3f)<<8) | p[1];
584                         pad = p[2];
585                         if(pad > len){
586                                 print("pad %d buf len %d\n", pad, len);
587                                 error("bad pad in ssl message");
588                         }
589                         toconsume = 3;
590                 }
591                 ensure(s, &s->unprocessed, toconsume+len);
592
593                 /* skip header */
594                 consume(&s->unprocessed, consumed, toconsume);
595
596                 /* grab the next message and decode/decrypt it */
597                 b = qtake(&s->unprocessed, len, 0);
598
599                 if(blocklen(b) != len)
600                         print("devssl: sslbread got wrong count %d != %d", blocklen(b), len);
601
602                 if(waserror()){
603                         qunlock(&s->in.ctlq);
604                         if(b != nil)
605                                 freeb(b);
606                         nexterror();
607                 }
608                 qlock(&s->in.ctlq);
609                 switch(s->state){
610                 case Sencrypting:
611                         if(b == nil)
612                                 error("ssl message too short (encrypting)");
613                         b = decryptb(s, b);
614                         break;
615                 case Sdigesting:
616                         b = pullupblock(b, s->diglen);
617                         if(b == nil)
618                                 error("ssl message too short (digesting)");
619                         checkdigestb(s, b);
620                         pullblock(&b, s->diglen);
621                         len -= s->diglen;
622                         break;
623                 case Sdigenc:
624                         b = decryptb(s, b);
625                         b = pullupblock(b, s->diglen);
626                         if(b == nil)
627                                 error("ssl message too short (dig+enc)");
628                         checkdigestb(s, b);
629                         pullblock(&b, s->diglen);
630                         len -= s->diglen;
631                         break;
632                 }
633
634                 /* remove pad */
635                 if(pad)
636                         s->processed = qtake(&b, len - pad, 1);
637                 else
638                         s->processed = b;
639                 b = nil;
640                 s->in.mid++;
641                 qunlock(&s->in.ctlq);
642                 poperror();
643         }
644
645         /* return at most what was asked for */
646         b = qtake(&s->processed, n, 0);
647
648         qunlock(&s->in.q);
649         poperror();
650
651         return b;
652 }
653
654 static long
655 sslread(Chan *c, void *a, long n, vlong off)
656 {
657         Block * volatile b;
658         Block *nb;
659         uchar *va;
660         int i;
661         char buf[128];
662         ulong offset = off;
663         int ft;
664
665         if(c->qid.type & QTDIR)
666                 return devdirread(c, a, n, 0, 0, sslgen);
667
668         ft = TYPE(c->qid);
669         switch(ft) {
670         default:
671                 error(Ebadusefd);
672         case Qctl:
673                 ft = CONV(c->qid);
674                 sprint(buf, "%d", ft);
675                 return readstr(offset, a, n, buf);
676         case Qdata:
677                 b = sslbread(c, n, offset);
678                 break;
679         case Qencalgs:
680                 return readstr(offset, a, n, encalgs);
681                 break;
682         case Qhashalgs:
683                 return readstr(offset, a, n, hashalgs);
684                 break;
685         }
686
687         if(waserror()){
688                 freeblist(b);
689                 nexterror();
690         }
691
692         n = 0;
693         va = a;
694         for(nb = b; nb; nb = nb->next){
695                 i = BLEN(nb);
696                 memmove(va+n, nb->rp, i);
697                 n += i;
698         }
699
700         freeblist(b);
701         poperror();
702
703         return n;
704 }
705
706 /*
707  *  this algorithm doesn't have to be great since we're just
708  *  trying to obscure the block fill
709  */
710 static void
711 randfill(uchar *buf, int len)
712 {
713         while(len-- > 0)
714                 *buf++ = nrand(256);
715 }
716
717 static long
718 sslbwrite(Chan *c, Block *b, ulong)
719 {
720         Dstate * volatile s;
721         long rv;
722
723         s = dstate[CONV(c->qid)];
724         if(s == nil)
725                 panic("sslbwrite");
726
727         if(s->state == Sincomplete){
728                 freeb(b);
729                 error(Ebadusefd);
730         }
731
732         /* lock so split writes won't interleave */
733         if(waserror()){
734                 qunlock(&s->out.q);
735                 nexterror();
736         }
737         qlock(&s->out.q);
738
739         rv = sslput(s, b);
740
741         poperror();
742         qunlock(&s->out.q);
743
744         return rv;
745 }
746
747 /*
748  *  use SSL record format, add in count, digest and/or encrypt.
749  *  the write is interruptable.  if it is interrupted, we'll
750  *  get out of sync with the far side.  not much we can do about
751  *  it since we don't know if any bytes have been written.
752  */
753 static long
754 sslput(Dstate *s, Block * volatile b)
755 {
756         Block *nb;
757         int h, n, m, pad, rv;
758         uchar *p;
759         int offset;
760
761         if(waserror()){
762                 if(b != nil)
763                         freeb(b);
764                 nexterror();
765         }
766
767         rv = 0;
768         while(b != nil){
769                 m = n = BLEN(b);
770                 h = s->diglen + 2;
771
772                 /* trim to maximum block size */
773                 pad = 0;
774                 if(m > s->max){
775                         m = s->max;
776                 } else if(s->blocklen != 1){
777                         pad = (m + s->diglen)%s->blocklen;
778                         if(pad){
779                                 if(m > s->maxpad){
780                                         pad = 0;
781                                         m = s->maxpad;
782                                 } else {
783                                         pad = s->blocklen - pad;
784                                         h++;
785                                 }
786                         }
787                 }
788
789                 rv += m;
790                 if(m != n){
791                         nb = allocb(m + h + pad);
792                         memmove(nb->wp + h, b->rp, m);
793                         nb->wp += m + h;
794                         b->rp += m;
795                 } else {
796                         /* add header space */
797                         nb = padblock(b, h);
798                         b = 0;
799                 }
800                 m += s->diglen;
801
802                 /* SSL style count */
803                 if(pad){
804                         nb = padblock(nb, -pad);
805                         randfill(nb->wp, pad);
806                         nb->wp += pad;
807                         m += pad;
808
809                         p = nb->rp;
810                         p[0] = (m>>8);
811                         p[1] = m;
812                         p[2] = pad;
813                         offset = 3;
814                 } else {
815                         p = nb->rp;
816                         p[0] = (m>>8) | 0x80;
817                         p[1] = m;
818                         offset = 2;
819                 }
820
821                 switch(s->state){
822                 case Sencrypting:
823                         nb = encryptb(s, nb, offset);
824                         break;
825                 case Sdigesting:
826                         nb = digestb(s, nb, offset);
827                         break;
828                 case Sdigenc:
829                         nb = digestb(s, nb, offset);
830                         nb = encryptb(s, nb, offset);
831                         break;
832                 }
833
834                 s->out.mid++;
835
836                 m = BLEN(nb);
837                 devtab[s->c->type]->bwrite(s->c, nb, s->c->offset);
838                 s->c->offset += m;
839         }
840
841         poperror();
842         return rv;
843 }
844
845 static void
846 setsecret(OneWay *w, uchar *secret, int n)
847 {
848         if(w->secret)
849                 free(w->secret);
850
851         w->secret = smalloc(n);
852         memmove(w->secret, secret, n);
853         w->slen = n;
854 }
855
856 static void
857 initDESkey(OneWay *w)
858 {
859         if(w->state){
860                 free(w->state);
861                 w->state = 0;
862         }
863
864         w->state = smalloc(sizeof(DESstate));
865         if(w->slen >= 16)
866                 setupDESstate(w->state, w->secret, w->secret+8);
867         else if(w->slen >= 8)
868                 setupDESstate(w->state, w->secret, 0);
869         else
870                 error("secret too short");
871 }
872
873 /*
874  *  40 bit DES is the same as 56 bit DES.  However,
875  *  16 bits of the key are masked to zero.
876  */
877 static void
878 initDESkey_40(OneWay *w)
879 {
880         uchar key[8];
881
882         if(w->state){
883                 free(w->state);
884                 w->state = 0;
885         }
886
887         if(w->slen >= 8){
888                 memmove(key, w->secret, 8);
889                 key[0] &= 0x0f;
890                 key[2] &= 0x0f;
891                 key[4] &= 0x0f;
892                 key[6] &= 0x0f;
893         }
894
895         w->state = smalloc(sizeof(DESstate));
896         if(w->slen >= 16)
897                 setupDESstate(w->state, key, w->secret+8);
898         else if(w->slen >= 8)
899                 setupDESstate(w->state, key, 0);
900         else
901                 error("secret too short");
902 }
903
904 static void
905 initRC4key(OneWay *w)
906 {
907         if(w->state){
908                 free(w->state);
909                 w->state = 0;
910         }
911
912         w->state = smalloc(sizeof(RC4state));
913         setupRC4state(w->state, w->secret, w->slen);
914 }
915
916 /*
917  *  40 bit RC4 is the same as n-bit RC4.  However,
918  *  we ignore all but the first 40 bits of the key.
919  */
920 static void
921 initRC4key_40(OneWay *w)
922 {
923         if(w->state){
924                 free(w->state);
925                 w->state = 0;
926         }
927
928         if(w->slen > 5)
929                 w->slen = 5;
930
931         w->state = smalloc(sizeof(RC4state));
932         setupRC4state(w->state, w->secret, w->slen);
933 }
934
935 /*
936  *  128 bit RC4 is the same as n-bit RC4.  However,
937  *  we ignore all but the first 128 bits of the key.
938  */
939 static void
940 initRC4key_128(OneWay *w)
941 {
942         if(w->state){
943                 free(w->state);
944                 w->state = 0;
945         }
946
947         if(w->slen > 16)
948                 w->slen = 16;
949
950         w->state = smalloc(sizeof(RC4state));
951         setupRC4state(w->state, w->secret, w->slen);
952 }
953
954
955 typedef struct Hashalg Hashalg;
956 struct Hashalg
957 {
958         char    *name;
959         int     diglen;
960         DigestState *(*hf)(uchar*, ulong, uchar*, DigestState*);
961 };
962
963 Hashalg hashtab[] =
964 {
965         { "md4", MD4dlen, md4, },
966         { "md5", MD5dlen, md5, },
967         { "sha1", SHA1dlen, sha1, },
968         { "sha", SHA1dlen, sha1, },
969         { 0 }
970 };
971
972 static int
973 parsehashalg(char *p, Dstate *s)
974 {
975         Hashalg *ha;
976
977         for(ha = hashtab; ha->name; ha++){
978                 if(strcmp(p, ha->name) == 0){
979                         s->hf = ha->hf;
980                         s->diglen = ha->diglen;
981                         s->state &= ~Sclear;
982                         s->state |= Sdigesting;
983                         return 0;
984                 }
985         }
986         return -1;
987 }
988
989 typedef struct Encalg Encalg;
990 struct Encalg
991 {
992         char    *name;
993         int     blocklen;
994         int     alg;
995         void    (*keyinit)(OneWay*);
996 };
997
998 #ifdef NOSPOOKS
999 Encalg encrypttab[] =
1000 {
1001         { "descbc", 8, DESCBC, initDESkey, },           /* DEPRECATED -- use des_56_cbc */
1002         { "desecb", 8, DESECB, initDESkey, },           /* DEPRECATED -- use des_56_ecb */
1003         { "des_56_cbc", 8, DESCBC, initDESkey, },
1004         { "des_56_ecb", 8, DESECB, initDESkey, },
1005         { "des_40_cbc", 8, DESCBC, initDESkey_40, },
1006         { "des_40_ecb", 8, DESECB, initDESkey_40, },
1007         { "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1008         { "rc4_256", 1, RC4, initRC4key, },
1009         { "rc4_128", 1, RC4, initRC4key_128, },
1010         { "rc4_40", 1, RC4, initRC4key_40, },
1011         { 0 }
1012 };
1013 #else
1014 Encalg encrypttab[] =
1015 {
1016         { "des_40_cbc", 8, DESCBC, initDESkey_40, },
1017         { "des_40_ecb", 8, DESECB, initDESkey_40, },
1018         { "rc4", 1, RC4, initRC4key_40, },              /* DEPRECATED -- use rc4_X      */
1019         { "rc4_40", 1, RC4, initRC4key_40, },
1020         { 0 }
1021 };
1022 #endif NOSPOOKS
1023
1024 static int
1025 parseencryptalg(char *p, Dstate *s)
1026 {
1027         Encalg *ea;
1028
1029         for(ea = encrypttab; ea->name; ea++){
1030                 if(strcmp(p, ea->name) == 0){
1031                         s->encryptalg = ea->alg;
1032                         s->blocklen = ea->blocklen;
1033                         (*ea->keyinit)(&s->in);
1034                         (*ea->keyinit)(&s->out);
1035                         s->state &= ~Sclear;
1036                         s->state |= Sencrypting;
1037                         return 0;
1038                 }
1039         }
1040         return -1;
1041 }
1042
1043 static long
1044 sslwrite(Chan *c, void *a, long n, vlong)
1045 {
1046         Dstate * volatile s;
1047         Block * volatile b;
1048         int m, t;
1049         char *p, *np, *e, buf[128];
1050         uchar *x;
1051
1052         s = dstate[CONV(c->qid)];
1053         if(s == 0)
1054                 panic("sslwrite");
1055
1056         t = TYPE(c->qid);
1057         if(t == Qdata){
1058                 if(s->state == Sincomplete)
1059                         error(Ebadusefd);
1060
1061                 /* lock should a write gets split over multiple records */
1062                 if(waserror()){
1063                         qunlock(&s->out.q);
1064                         nexterror();
1065                 }
1066                 qlock(&s->out.q);
1067
1068                 p = a;
1069                 e = p + n;
1070                 do {
1071                         m = e - p;
1072                         if(m > s->max)
1073                                 m = s->max;
1074
1075                         b = allocb(m);
1076                         if(waserror()){
1077                                 freeb(b);
1078                                 nexterror();
1079                         }
1080                         memmove(b->wp, p, m);
1081                         poperror();
1082                         b->wp += m;
1083
1084                         sslput(s, b);
1085
1086                         p += m;
1087                 } while(p < e);
1088
1089                 poperror();
1090                 qunlock(&s->out.q);
1091                 return n;
1092         }
1093
1094         /* mutex with operations using what we're about to change */
1095         if(waserror()){
1096                 qunlock(&s->in.ctlq);
1097                 qunlock(&s->out.q);
1098                 nexterror();
1099         }
1100         qlock(&s->in.ctlq);
1101         qlock(&s->out.q);
1102
1103         switch(t){
1104         default:
1105                 panic("sslwrite");
1106         case Qsecretin:
1107                 setsecret(&s->in, a, n);
1108                 goto out;
1109         case Qsecretout:
1110                 setsecret(&s->out, a, n);
1111                 goto out;
1112         case Qctl:
1113                 break;
1114         }
1115
1116         if(n >= sizeof(buf))
1117                 error("arg too long");
1118         strncpy(buf, a, n);
1119         buf[n] = 0;
1120         p = strchr(buf, '\n');
1121         if(p)
1122                 *p = 0;
1123         p = strchr(buf, ' ');
1124         if(p)
1125                 *p++ = 0;
1126
1127         if(strcmp(buf, "fd") == 0){
1128                 s->c = buftochan(p);
1129
1130                 /* default is clear (msg delimiters only) */
1131                 s->state = Sclear;
1132                 s->blocklen = 1;
1133                 s->diglen = 0;
1134                 s->maxpad = s->max = (1<<15) - s->diglen - 1;
1135                 s->in.mid = 0;
1136                 s->out.mid = 0;
1137         } else if(strcmp(buf, "alg") == 0 && p != 0){
1138                 s->blocklen = 1;
1139                 s->diglen = 0;
1140
1141                 if(s->c == 0)
1142                         error("must set fd before algorithm");
1143
1144                 s->state = Sclear;
1145                 s->maxpad = s->max = (1<<15) - s->diglen - 1;
1146                 if(strcmp(p, "clear") == 0){
1147                         goto out;
1148                 }
1149
1150                 if(s->in.secret && s->out.secret == 0)
1151                         setsecret(&s->out, s->in.secret, s->in.slen);
1152                 if(s->out.secret && s->in.secret == 0)
1153                         setsecret(&s->in, s->out.secret, s->out.slen);
1154                 if(s->in.secret == 0 || s->out.secret == 0)
1155                         error("algorithm but no secret");
1156
1157                 s->hf = 0;
1158                 s->encryptalg = Noencryption;
1159                 s->blocklen = 1;
1160
1161                 for(;;){
1162                         np = strchr(p, ' ');
1163                         if(np)
1164                                 *np++ = 0;
1165
1166                         if(parsehashalg(p, s) < 0)
1167                         if(parseencryptalg(p, s) < 0)
1168                                 error("bad algorithm");
1169
1170                         if(np == 0)
1171                                 break;
1172                         p = np;
1173                 }
1174
1175                 if(s->hf == 0 && s->encryptalg == Noencryption)
1176                         error("bad algorithm");
1177
1178                 if(s->blocklen != 1){
1179                         s->max = (1<<15) - s->diglen - 1;
1180                         s->max -= s->max % s->blocklen;
1181                         s->maxpad = (1<<14) - s->diglen - 1;
1182                         s->maxpad -= s->maxpad % s->blocklen;
1183                 } else
1184                         s->maxpad = s->max = (1<<15) - s->diglen - 1;
1185         } else if(strcmp(buf, "secretin") == 0 && p != 0) {
1186                 m = (strlen(p)*3)/2;
1187                 x = smalloc(m);
1188                 t = dec64(x, m, p, strlen(p));
1189                 if(t <= 0){
1190                         free(x);
1191                         error(Ebadarg);
1192                 }
1193                 setsecret(&s->in, x, t);
1194                 free(x);
1195         } else if(strcmp(buf, "secretout") == 0 && p != 0) {
1196                 m = (strlen(p)*3)/2 + 1;
1197                 x = smalloc(m);
1198                 t = dec64(x, m, p, strlen(p));
1199                 if(t <= 0){
1200                         free(x);
1201                         error(Ebadarg);
1202                 }
1203                 setsecret(&s->out, x, t);
1204                 free(x);
1205         } else
1206                 error(Ebadarg);
1207
1208 out:
1209         qunlock(&s->in.ctlq);
1210         qunlock(&s->out.q);
1211         poperror();
1212         return n;
1213 }
1214
1215 static void
1216 sslinit(void)
1217 {
1218         struct Encalg *e;
1219         struct Hashalg *h;
1220         int n;
1221         char *cp;
1222
1223         n = 1;
1224         for(e = encrypttab; e->name != nil; e++)
1225                 n += strlen(e->name) + 1;
1226         cp = encalgs = smalloc(n);
1227         for(e = encrypttab;;){
1228                 strcpy(cp, e->name);
1229                 cp += strlen(e->name);
1230                 e++;
1231                 if(e->name == nil)
1232                         break;
1233                 *cp++ = ' ';
1234         }
1235         *cp = 0;
1236
1237         n = 1;
1238         for(h = hashtab; h->name != nil; h++)
1239                 n += strlen(h->name) + 1;
1240         cp = hashalgs = smalloc(n);
1241         for(h = hashtab;;){
1242                 strcpy(cp, h->name);
1243                 cp += strlen(h->name);
1244                 h++;
1245                 if(h->name == nil)
1246                         break;
1247                 *cp++ = ' ';
1248         }
1249         *cp = 0;
1250 }
1251
1252 Dev ssldevtab = {
1253         'D',
1254         "ssl",
1255
1256         devreset,
1257         sslinit,
1258         devshutdown,
1259         sslattach,
1260         sslwalk,
1261         sslstat,
1262         sslopen,
1263         devcreate,
1264         sslclose,
1265         sslread,
1266         sslbread,
1267         sslwrite,
1268         sslbwrite,
1269         devremove,
1270         sslwstat,
1271 };
1272
1273 static Block*
1274 encryptb(Dstate *s, Block *b, int offset)
1275 {
1276         uchar *p, *ep, *p2, *ip, *eip;
1277         DESstate *ds;
1278
1279         switch(s->encryptalg){
1280         case DESECB:
1281                 ds = s->out.state;
1282                 ep = b->rp + BLEN(b);
1283                 for(p = b->rp + offset; p < ep; p += 8)
1284                         block_cipher(ds->expanded, p, 0);
1285                 break;
1286         case DESCBC:
1287                 ds = s->out.state;
1288                 ep = b->rp + BLEN(b);
1289                 for(p = b->rp + offset; p < ep; p += 8){
1290                         p2 = p;
1291                         ip = ds->ivec;
1292                         for(eip = ip+8; ip < eip; )
1293                                 *p2++ ^= *ip++;
1294                         block_cipher(ds->expanded, p, 0);
1295                         memmove(ds->ivec, p, 8);
1296                 }
1297                 break;
1298         case RC4:
1299                 rc4(s->out.state, b->rp + offset, BLEN(b) - offset);
1300                 break;
1301         }
1302         return b;
1303 }
1304
1305 static Block*
1306 decryptb(Dstate *s, Block *bin)
1307 {
1308         Block *b, **l;
1309         uchar *p, *ep, *tp, *ip, *eip;
1310         DESstate *ds;
1311         uchar tmp[8];
1312         int i;
1313
1314         l = &bin;
1315         for(b = bin; b; b = b->next){
1316                 /* make sure we have a multiple of s->blocklen */
1317                 if(s->blocklen > 1){
1318                         i = BLEN(b);
1319                         if(i % s->blocklen){
1320                                 *l = b = pullupblock(b, i + s->blocklen - (i%s->blocklen));
1321                                 if(b == 0)
1322                                         error("ssl encrypted message too short");
1323                         }
1324                 }
1325                 l = &b->next;
1326
1327                 /* decrypt */
1328                 switch(s->encryptalg){
1329                 case DESECB:
1330                         ds = s->in.state;
1331                         ep = b->rp + BLEN(b);
1332                         for(p = b->rp; p < ep; p += 8)
1333                                 block_cipher(ds->expanded, p, 1);
1334                         break;
1335                 case DESCBC:
1336                         ds = s->in.state;
1337                         ep = b->rp + BLEN(b);
1338                         for(p = b->rp; p < ep;){
1339                                 memmove(tmp, p, 8);
1340                                 block_cipher(ds->expanded, p, 1);
1341                                 tp = tmp;
1342                                 ip = ds->ivec;
1343                                 for(eip = ip+8; ip < eip; ){
1344                                         *p++ ^= *ip;
1345                                         *ip++ = *tp++;
1346                                 }
1347                         }
1348                         break;
1349                 case RC4:
1350                         rc4(s->in.state, b->rp, BLEN(b));
1351                         break;
1352                 }
1353         }
1354         return bin;
1355 }
1356
1357 static Block*
1358 digestb(Dstate *s, Block *b, int offset)
1359 {
1360         uchar *p;
1361         DigestState ss;
1362         uchar msgid[4];
1363         ulong n, h;
1364         OneWay *w;
1365
1366         w = &s->out;
1367
1368         memset(&ss, 0, sizeof(ss));
1369         h = s->diglen + offset;
1370         n = BLEN(b) - h;
1371
1372         /* hash secret + message */
1373         (*s->hf)(w->secret, w->slen, 0, &ss);
1374         (*s->hf)(b->rp + h, n, 0, &ss);
1375
1376         /* hash message id */
1377         p = msgid;
1378         n = w->mid;
1379         *p++ = n>>24;
1380         *p++ = n>>16;
1381         *p++ = n>>8;
1382         *p = n;
1383         (*s->hf)(msgid, 4, b->rp + offset, &ss);
1384
1385         return b;
1386 }
1387
1388 static void
1389 checkdigestb(Dstate *s, Block *bin)
1390 {
1391         uchar *p;
1392         DigestState ss;
1393         uchar msgid[4];
1394         int n, h;
1395         OneWay *w;
1396         uchar digest[128];
1397         Block *b;
1398
1399         w = &s->in;
1400
1401         memset(&ss, 0, sizeof(ss));
1402
1403         /* hash secret */
1404         (*s->hf)(w->secret, w->slen, 0, &ss);
1405
1406         /* hash message */
1407         h = s->diglen;
1408         for(b = bin; b; b = b->next){
1409                 n = BLEN(b) - h;
1410                 if(n < 0)
1411                         panic("checkdigestb");
1412                 (*s->hf)(b->rp + h, n, 0, &ss);
1413                 h = 0;
1414         }
1415
1416         /* hash message id */
1417         p = msgid;
1418         n = w->mid;
1419         *p++ = n>>24;
1420         *p++ = n>>16;
1421         *p++ = n>>8;
1422         *p = n;
1423         (*s->hf)(msgid, 4, digest, &ss);
1424
1425         if(memcmp(digest, bin->rp, s->diglen) != 0)
1426                 error("bad digest");
1427 }
1428
1429 /* get channel associated with an fd */
1430 static Chan*
1431 buftochan(char *p)
1432 {
1433         Chan *c;
1434         int fd;
1435
1436         if(p == 0)
1437                 error(Ebadarg);
1438         fd = strtoul(p, 0, 0);
1439         if(fd < 0)
1440                 error(Ebadarg);
1441         c = fdtochan(fd, -1, 0, 1);     /* error check and inc ref */
1442         if(devtab[c->type] == &ssldevtab){
1443                 cclose(c);
1444                 error("cannot ssl encrypt devssl files");
1445         }
1446         return c;
1447 }
1448
1449 /* hand up a digest connection */
1450 static void
1451 sslhangup(Dstate *s)
1452 {
1453         Block *b;
1454
1455         qlock(&s->in.q);
1456         for(b = s->processed; b; b = s->processed){
1457                 s->processed = b->next;
1458                 freeb(b);
1459         }
1460         if(s->unprocessed){
1461                 freeb(s->unprocessed);
1462                 s->unprocessed = 0;
1463         }
1464         s->state = Sincomplete;
1465         qunlock(&s->in.q);
1466 }
1467
1468 static Dstate*
1469 dsclone(Chan *ch)
1470 {
1471         int i;
1472         Dstate *ret;
1473
1474         if(waserror()) {
1475                 unlock(&dslock);
1476                 nexterror();
1477         }
1478         lock(&dslock);
1479         ret = nil;
1480         for(i=0; i<Maxdstate; i++){
1481                 if(dstate[i] == nil){
1482                         dsnew(ch, &dstate[i]);
1483                         ret = dstate[i];
1484                         break;
1485                 }
1486         }
1487         unlock(&dslock);
1488         poperror();
1489         return ret;
1490 }
1491
1492 static void
1493 dsnew(Chan *ch, Dstate **pp)
1494 {
1495         Dstate *s;
1496         int t;
1497
1498         *pp = s = malloc(sizeof(*s));
1499         if(!s)
1500                 error(Enomem);
1501         if(pp - dstate >= dshiwat)
1502                 dshiwat++;
1503         memset(s, 0, sizeof(*s));
1504         s->state = Sincomplete;
1505         s->ref = 1;
1506         kstrdup(&s->user, up->user);
1507         s->perm = 0660;
1508         t = TYPE(ch->qid);
1509         if(t == Qclonus)
1510                 t = Qctl;
1511         ch->qid.path = QID(pp - dstate, t);
1512         ch->qid.vers = 0;
1513 }