]> git.lizzy.rs Git - plan9front.git/blob - sys/src/9/port/devtls.c
ad6750a36ca922d177dfbe207c1f9a623233a806
[plan9front.git] / sys / src / 9 / port / devtls.c
1 /*
2  *  devtls - record layer for transport layer security 1.0 and secure sockets layer 3.0
3  */
4 #include        "u.h"
5 #include        "../port/lib.h"
6 #include        "mem.h"
7 #include        "dat.h"
8 #include        "fns.h"
9 #include        "../port/error.h"
10
11 #include        <libsec.h>
12
13 typedef struct OneWay   OneWay;
14 typedef struct Secret   Secret;
15 typedef struct TlsRec   TlsRec;
16 typedef struct TlsErrs  TlsErrs;
17
18 enum {
19         Statlen=        1024,           /* max. length of status or stats message */
20         /* buffer limits */
21         MaxRecLen       = 1<<14,        /* max payload length of a record layer message */
22         MaxCipherRecLen = MaxRecLen + 2048,
23         RecHdrLen       = 5,
24         MaxMacLen       = SHA1dlen,
25
26         /* protocol versions we can accept */
27         SSL3Version     = 0x0300,
28         TLS10Version    = 0x0301,
29         TLS11Version    = 0x0302,
30         MinProtoVersion = 0x0300,       /* limits on version we accept */
31         MaxProtoVersion = 0x03ff,
32
33         /* connection states */
34         SHandshake      = 1 << 0,       /* doing handshake */
35         SOpen           = 1 << 1,       /* application data can be sent */
36         SRClose         = 1 << 2,       /* remote side has closed down */
37         SLClose         = 1 << 3,       /* sent a close notify alert */
38         SAlert          = 1 << 5,       /* sending or sent a fatal alert */
39         SError          = 1 << 6,       /* some sort of error has occured */
40         SClosed         = 1 << 7,       /* it is all over */
41
42         /* record types */
43         RChangeCipherSpec = 20,
44         RAlert,
45         RHandshake,
46         RApplication,
47
48         SSL2ClientHello = 1,
49         HSSL2ClientHello = 9,  /* local convention;  see tlshand.c */
50
51         /* alerts */
52         ECloseNotify                    = 0,
53         EUnexpectedMessage      = 10,
54         EBadRecordMac           = 20,
55         EDecryptionFailed               = 21,
56         ERecordOverflow                 = 22,
57         EDecompressionFailure   = 30,
58         EHandshakeFailure               = 40,
59         ENoCertificate                  = 41,
60         EBadCertificate                 = 42,
61         EUnsupportedCertificate         = 43,
62         ECertificateRevoked             = 44,
63         ECertificateExpired             = 45,
64         ECertificateUnknown     = 46,
65         EIllegalParameter               = 47,
66         EUnknownCa                      = 48,
67         EAccessDenied           = 49,
68         EDecodeError                    = 50,
69         EDecryptError                   = 51,
70         EExportRestriction              = 60,
71         EProtocolVersion                = 70,
72         EInsufficientSecurity   = 71,
73         EInternalError                  = 80,
74         EUserCanceled                   = 90,
75         ENoRenegotiation                = 100,
76         EUnrecognizedName               = 112,
77
78         EMAX = 256
79 };
80
81 struct Secret
82 {
83         char            *encalg;        /* name of encryption alg */
84         char            *hashalg;       /* name of hash alg */
85         int             (*enc)(Secret*, uchar*, int);
86         int             (*dec)(Secret*, uchar*, int);
87         int             (*unpad)(uchar*, int, int);
88         DigestState     *(*mac)(uchar*, ulong, uchar*, ulong, uchar*, DigestState*);
89         int             block;          /* encryption block len, 0 if none */
90         int             maclen;
91         void            *enckey;
92         uchar   mackey[MaxMacLen];
93 };
94
95 struct OneWay
96 {
97         QLock           io;             /* locks io access */
98         QLock           seclock;        /* locks secret paramaters */
99         ulong           seq;
100         Secret          *sec;           /* cipher in use */
101         Secret          *new;           /* cipher waiting for enable */
102 };
103
104 struct TlsRec
105 {
106         Chan    *c;                             /* io channel */
107         int             ref;                            /* serialized by tdlock for atomic destroy */
108         int             version;                        /* version of the protocol we are speaking */
109         char            verset;                 /* version has been set */
110         char            opened;                 /* opened command every issued? */
111         char            err[ERRMAX];            /* error message to return to handshake requests */
112         vlong   handin;                 /* bytes communicated by the record layer */
113         vlong   handout;
114         vlong   datain;
115         vlong   dataout;
116
117         Lock            statelk;
118         int             state;
119         int             debug;
120
121         /* record layer mac functions for different protocol versions */
122         void            (*packMac)(Secret*, uchar*, uchar*, uchar*, uchar*, int, uchar*);
123
124         /* input side -- protected by in.io */
125         OneWay          in;
126         Block           *processed;     /* next bunch of application data */
127         Block           *unprocessed;   /* data read from c but not parsed into records */
128
129         /* handshake queue */
130         Lock            hqlock;                 /* protects hqref, alloc & free of handq, hprocessed */
131         int             hqref;
132         Queue           *handq;         /* queue of handshake messages */
133         Block           *hprocessed;    /* remainder of last block read from handq */
134         QLock           hqread;         /* protects reads for hprocessed, handq */
135
136         /* output side */
137         OneWay          out;
138
139         /* protections */
140         char            *user;
141         int             perm;
142 };
143
144 struct TlsErrs{
145         int     err;
146         int     sslerr;
147         int     tlserr;
148         int     fatal;
149         char    *msg;
150 };
151
152 static TlsErrs tlserrs[] = {
153         {ECloseNotify,                  ECloseNotify,                   ECloseNotify,                   0,      "close notify"},
154         {EUnexpectedMessage,    EUnexpectedMessage,     EUnexpectedMessage,     1, "unexpected message"},
155         {EBadRecordMac,         EBadRecordMac,          EBadRecordMac,          1, "bad record mac"},
156         {EDecryptionFailed,             EIllegalParameter,              EDecryptionFailed,              1, "decryption failed"},
157         {ERecordOverflow,               EIllegalParameter,              ERecordOverflow,                1, "record too long"},
158         {EDecompressionFailure, EDecompressionFailure,  EDecompressionFailure,  1, "decompression failed"},
159         {EHandshakeFailure,             EHandshakeFailure,              EHandshakeFailure,              1, "could not negotiate acceptable security parameters"},
160         {ENoCertificate,                ENoCertificate,                 ECertificateUnknown,    1, "no appropriate certificate available"},
161         {EBadCertificate,               EBadCertificate,                EBadCertificate,                1, "corrupted or invalid certificate"},
162         {EUnsupportedCertificate,       EUnsupportedCertificate,        EUnsupportedCertificate,        1, "unsupported certificate type"},
163         {ECertificateRevoked,   ECertificateRevoked,            ECertificateRevoked,            1, "revoked certificate"},
164         {ECertificateExpired,           ECertificateExpired,            ECertificateExpired,            1, "expired certificate"},
165         {ECertificateUnknown,   ECertificateUnknown,    ECertificateUnknown,    1, "unacceptable certificate"},
166         {EIllegalParameter,             EIllegalParameter,              EIllegalParameter,              1, "illegal parameter"},
167         {EUnknownCa,                    EHandshakeFailure,              EUnknownCa,                     1, "unknown certificate authority"},
168         {EAccessDenied,         EHandshakeFailure,              EAccessDenied,          1, "access denied"},
169         {EDecodeError,                  EIllegalParameter,              EDecodeError,                   1, "error decoding message"},
170         {EDecryptError,                 EIllegalParameter,              EDecryptError,                  1, "error decrypting message"},
171         {EExportRestriction,            EHandshakeFailure,              EExportRestriction,             1, "export restriction violated"},
172         {EProtocolVersion,              EIllegalParameter,              EProtocolVersion,               1, "protocol version not supported"},
173         {EInsufficientSecurity, EHandshakeFailure,              EInsufficientSecurity,  1, "stronger security routines required"},
174         {EInternalError,                        EHandshakeFailure,              EInternalError,                 1, "internal error"},
175         {EUserCanceled,         ECloseNotify,                   EUserCanceled,                  0, "handshake canceled by user"},
176         {ENoRenegotiation,              EUnexpectedMessage,     ENoRenegotiation,               0, "no renegotiation"},
177 };
178
179 enum
180 {
181         /* max. open tls connections */
182         MaxTlsDevs      = 1024
183 };
184
185 static  Lock    tdlock;
186 static  int     tdhiwat;
187 static  int     maxtlsdevs = 128;
188 static  TlsRec  **tlsdevs;
189 static  char    **trnames;
190 static  char    *encalgs;
191 static  char    *hashalgs;
192
193 enum{
194         Qtopdir         = 1,    /* top level directory */
195         Qprotodir,
196         Qclonus,
197         Qencalgs,
198         Qhashalgs,
199         Qconvdir,               /* directory for a conversation */
200         Qdata,
201         Qctl,
202         Qhand,
203         Qstatus,
204         Qstats,
205 };
206
207 #define TYPE(x)         ((x).path & 0xf)
208 #define CONV(x)         (((x).path >> 5)&(MaxTlsDevs-1))
209 #define QID(c, y)       (((c)<<5) | (y))
210
211 static void     checkstate(TlsRec *, int, int);
212 static void     ensure(TlsRec*, Block**, int);
213 static void     consume(Block**, uchar*, int);
214 static Chan*    buftochan(char*);
215 static void     tlshangup(TlsRec*);
216 static void     tlsError(TlsRec*, char *);
217 static void     alertHand(TlsRec*, char *);
218 static TlsRec   *newtls(Chan *c);
219 static TlsRec   *mktlsrec(void);
220 static DigestState*sslmac_md5(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s);
221 static DigestState*sslmac_sha1(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s);
222 static DigestState*nomac(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s);
223 static void     sslPackMac(Secret *sec, uchar *mackey, uchar *seq, uchar *header, uchar *body, int len, uchar *mac);
224 static void     tlsPackMac(Secret *sec, uchar *mackey, uchar *seq, uchar *header, uchar *body, int len, uchar *mac);
225 static void     put64(uchar *p, vlong x);
226 static void     put32(uchar *p, u32int);
227 static void     put24(uchar *p, int);
228 static void     put16(uchar *p, int);
229 static u32int   get32(uchar *p);
230 static int      get16(uchar *p);
231 static void     tlsSetState(TlsRec *tr, int new, int old);
232 static void     rcvAlert(TlsRec *tr, int err);
233 static void     sendAlert(TlsRec *tr, int err);
234 static void     rcvError(TlsRec *tr, int err, char *msg, ...);
235 static int      rc4enc(Secret *sec, uchar *buf, int n);
236 static int      des3enc(Secret *sec, uchar *buf, int n);
237 static int      des3dec(Secret *sec, uchar *buf, int n);
238 static int      aesenc(Secret *sec, uchar *buf, int n);
239 static int      aesdec(Secret *sec, uchar *buf, int n);
240 static int      noenc(Secret *sec, uchar *buf, int n);
241 static int      sslunpad(uchar *buf, int n, int block);
242 static int      tlsunpad(uchar *buf, int n, int block);
243 static void     freeSec(Secret *sec);
244 static char     *tlsstate(int s);
245 static void     pdump(int, void*, char*);
246
247 #pragma varargck        argpos  rcvError        3
248
249 static char *tlsnames[] = {
250 [Qclonus]               "clone",
251 [Qencalgs]      "encalgs",
252 [Qhashalgs]     "hashalgs",
253 [Qdata]         "data",
254 [Qctl]          "ctl",
255 [Qhand]         "hand",
256 [Qstatus]               "status",
257 [Qstats]                "stats",
258 };
259
260 static int convdir[] = { Qctl, Qdata, Qhand, Qstatus, Qstats };
261
262 static int
263 tlsgen(Chan *c, char*, Dirtab *, int, int s, Dir *dp)
264 {
265         Qid q;
266         TlsRec *tr;
267         char *name, *nm;
268         int perm, t;
269
270         q.vers = 0;
271         q.type = QTFILE;
272
273         t = TYPE(c->qid);
274         switch(t) {
275         case Qtopdir:
276                 if(s == DEVDOTDOT){
277                         q.path = QID(0, Qtopdir);
278                         q.type = QTDIR;
279                         devdir(c, q, "#a", 0, eve, 0555, dp);
280                         return 1;
281                 }
282                 if(s > 0)
283                         return -1;
284                 q.path = QID(0, Qprotodir);
285                 q.type = QTDIR;
286                 devdir(c, q, "tls", 0, eve, 0555, dp);
287                 return 1;
288         case Qprotodir:
289                 if(s == DEVDOTDOT){
290                         q.path = QID(0, Qtopdir);
291                         q.type = QTDIR;
292                         devdir(c, q, ".", 0, eve, 0555, dp);
293                         return 1;
294                 }
295                 if(s < 3){
296                         switch(s) {
297                         default:
298                                 return -1;
299                         case 0:
300                                 q.path = QID(0, Qclonus);
301                                 break;
302                         case 1:
303                                 q.path = QID(0, Qencalgs);
304                                 break;
305                         case 2:
306                                 q.path = QID(0, Qhashalgs);
307                                 break;
308                         }
309                         perm = 0444;
310                         if(TYPE(q) == Qclonus)
311                                 perm = 0555;
312                         devdir(c, q, tlsnames[TYPE(q)], 0, eve, perm, dp);
313                         return 1;
314                 }
315                 s -= 3;
316                 if(s >= tdhiwat)
317                         return -1;
318                 q.path = QID(s, Qconvdir);
319                 q.type = QTDIR;
320                 lock(&tdlock);
321                 tr = tlsdevs[s];
322                 if(tr != nil)
323                         nm = tr->user;
324                 else
325                         nm = eve;
326                 if((name = trnames[s]) == nil){
327                         name = trnames[s] = smalloc(16);
328                         sprint(name, "%d", s);
329                 }
330                 devdir(c, q, name, 0, nm, 0555, dp);
331                 unlock(&tdlock);
332                 return 1;
333         case Qconvdir:
334                 if(s == DEVDOTDOT){
335                         q.path = QID(0, Qprotodir);
336                         q.type = QTDIR;
337                         devdir(c, q, "tls", 0, eve, 0555, dp);
338                         return 1;
339                 }
340                 if(s < 0 || s >= nelem(convdir))
341                         return -1;
342                 lock(&tdlock);
343                 tr = tlsdevs[CONV(c->qid)];
344                 if(tr != nil){
345                         nm = tr->user;
346                         perm = tr->perm;
347                 }else{
348                         perm = 0;
349                         nm = eve;
350                 }
351                 t = convdir[s];
352                 if(t == Qstatus || t == Qstats)
353                         perm &= 0444;
354                 q.path = QID(CONV(c->qid), t);
355                 devdir(c, q, tlsnames[t], 0, nm, perm, dp);
356                 unlock(&tdlock);
357                 return 1;
358         case Qclonus:
359         case Qencalgs:
360         case Qhashalgs:
361                 perm = 0444;
362                 if(t == Qclonus)
363                         perm = 0555;
364                 devdir(c, c->qid, tlsnames[t], 0, eve, perm, dp);
365                 return 1;
366         default:
367                 lock(&tdlock);
368                 tr = tlsdevs[CONV(c->qid)];
369                 if(tr != nil){
370                         nm = tr->user;
371                         perm = tr->perm;
372                 }else{
373                         perm = 0;
374                         nm = eve;
375                 }
376                 if(t == Qstatus || t == Qstats)
377                         perm &= 0444;
378                 devdir(c, c->qid, tlsnames[t], 0, nm, perm, dp);
379                 unlock(&tdlock);
380                 return 1;
381         }
382 }
383
384 static Chan*
385 tlsattach(char *spec)
386 {
387         Chan *c;
388
389         c = devattach('a', spec);
390         c->qid.path = QID(0, Qtopdir);
391         c->qid.type = QTDIR;
392         c->qid.vers = 0;
393         return c;
394 }
395
396 static Walkqid*
397 tlswalk(Chan *c, Chan *nc, char **name, int nname)
398 {
399         return devwalk(c, nc, name, nname, nil, 0, tlsgen);
400 }
401
402 static int
403 tlsstat(Chan *c, uchar *db, int n)
404 {
405         return devstat(c, db, n, nil, 0, tlsgen);
406 }
407
408 static Chan*
409 tlsopen(Chan *c, int omode)
410 {
411         TlsRec *tr, **pp;
412         int t, perm;
413
414         perm = 0;
415         omode &= 3;
416         switch(omode) {
417         case OREAD:
418                 perm = 4;
419                 break;
420         case OWRITE:
421                 perm = 2;
422                 break;
423         case ORDWR:
424                 perm = 6;
425                 break;
426         }
427
428         t = TYPE(c->qid);
429         switch(t) {
430         default:
431                 panic("tlsopen");
432         case Qtopdir:
433         case Qprotodir:
434         case Qconvdir:
435                 if(omode != OREAD)
436                         error(Eperm);
437                 break;
438         case Qclonus:
439                 tr = newtls(c);
440                 if(tr == nil)
441                         error(Enodev);
442                 break;
443         case Qctl:
444         case Qdata:
445         case Qhand:
446         case Qstatus:
447         case Qstats:
448                 if((t == Qstatus || t == Qstats) && omode != OREAD)
449                         error(Eperm);
450                 if(waserror()) {
451                         unlock(&tdlock);
452                         nexterror();
453                 }
454                 lock(&tdlock);
455                 pp = &tlsdevs[CONV(c->qid)];
456                 tr = *pp;
457                 if(tr == nil)
458                         error("must open connection using clone");
459                 if((perm & (tr->perm>>6)) != perm
460                 && (strcmp(up->user, tr->user) != 0
461                     || (perm & tr->perm) != perm))
462                         error(Eperm);
463                 if(t == Qhand){
464                         if(waserror()){
465                                 unlock(&tr->hqlock);
466                                 nexterror();
467                         }
468                         lock(&tr->hqlock);
469                         if(tr->handq != nil)
470                                 error(Einuse);
471                         tr->handq = qopen(2 * MaxCipherRecLen, 0, nil, nil);
472                         if(tr->handq == nil)
473                                 error("cannot allocate handshake queue");
474                         tr->hqref = 1;
475                         unlock(&tr->hqlock);
476                         poperror();
477                 }
478                 tr->ref++;
479                 unlock(&tdlock);
480                 poperror();
481                 break;
482         case Qencalgs:
483         case Qhashalgs:
484                 if(omode != OREAD)
485                         error(Eperm);
486                 break;
487         }
488         c->mode = openmode(omode);
489         c->flag |= COPEN;
490         c->offset = 0;
491         c->iounit = qiomaxatomic;
492         return c;
493 }
494
495 static int
496 tlswstat(Chan *c, uchar *dp, int n)
497 {
498         Dir *d;
499         TlsRec *tr;
500         int rv;
501
502         d = nil;
503         if(waserror()){
504                 free(d);
505                 unlock(&tdlock);
506                 nexterror();
507         }
508
509         lock(&tdlock);
510         tr = tlsdevs[CONV(c->qid)];
511         if(tr == nil)
512                 error(Ebadusefd);
513         if(strcmp(tr->user, up->user) != 0)
514                 error(Eperm);
515
516         d = smalloc(n + sizeof *d);
517         rv = convM2D(dp, n, &d[0], (char*) &d[1]);
518         if(rv == 0)
519                 error(Eshortstat);
520         if(!emptystr(d->uid))
521                 kstrdup(&tr->user, d->uid);
522         if(d->mode != ~0UL)
523                 tr->perm = d->mode;
524
525         free(d);
526         poperror();
527         unlock(&tdlock);
528
529         return rv;
530 }
531
532 static void
533 dechandq(TlsRec *tr)
534 {
535         lock(&tr->hqlock);
536         if(--tr->hqref == 0){
537                 if(tr->handq != nil){
538                         qfree(tr->handq);
539                         tr->handq = nil;
540                 }
541                 if(tr->hprocessed != nil){
542                         freeb(tr->hprocessed);
543                         tr->hprocessed = nil;
544                 }
545         }
546         unlock(&tr->hqlock);
547 }
548
549 static void
550 tlsclose(Chan *c)
551 {
552         TlsRec *tr;
553         int t;
554
555         t = TYPE(c->qid);
556         switch(t) {
557         case Qctl:
558         case Qdata:
559         case Qhand:
560         case Qstatus:
561         case Qstats:
562                 if((c->flag & COPEN) == 0)
563                         break;
564
565                 tr = tlsdevs[CONV(c->qid)];
566                 if(tr == nil)
567                         break;
568
569                 if(t == Qhand)
570                         dechandq(tr);
571
572                 lock(&tdlock);
573                 if(--tr->ref > 0) {
574                         unlock(&tdlock);
575                         return;
576                 }
577                 tlsdevs[CONV(c->qid)] = nil;
578                 unlock(&tdlock);
579
580                 if(tr->c != nil && !waserror()){
581                         checkstate(tr, 0, SOpen|SHandshake|SRClose);
582                         sendAlert(tr, ECloseNotify);
583                         poperror();
584                 }
585                 tlshangup(tr);
586                 if(tr->c != nil)
587                         cclose(tr->c);
588                 freeSec(tr->in.sec);
589                 freeSec(tr->in.new);
590                 freeSec(tr->out.sec);
591                 freeSec(tr->out.new);
592                 free(tr->user);
593                 free(tr);
594                 break;
595         }
596 }
597
598 /*
599  *  make sure we have at least 'n' bytes in list 'l'
600  */
601 static void
602 ensure(TlsRec *s, Block **l, int n)
603 {
604         int sofar, i;
605         Block *b, *bl;
606
607         sofar = 0;
608         for(b = *l; b; b = b->next){
609                 sofar += BLEN(b);
610                 if(sofar >= n)
611                         return;
612                 l = &b->next;
613         }
614
615         while(sofar < n){
616                 bl = devtab[s->c->type]->bread(s->c, MaxCipherRecLen + RecHdrLen, 0);
617                 if(bl == 0)
618                         error(Ehungup);
619                 *l = bl;
620                 i = 0;
621                 for(b = bl; b; b = b->next){
622                         i += BLEN(b);
623                         l = &b->next;
624                 }
625                 if(i == 0)
626                         error(Ehungup);
627                 sofar += i;
628         }
629 if(s->debug) pprint("ensure read %d\n", sofar);
630 }
631
632 /*
633  *  copy 'n' bytes from 'l' into 'p' and free
634  *  the bytes in 'l'
635  */
636 static void
637 consume(Block **l, uchar *p, int n)
638 {
639         Block *b;
640         int i;
641
642         for(; *l && n > 0; n -= i){
643                 b = *l;
644                 i = BLEN(b);
645                 if(i > n)
646                         i = n;
647                 memmove(p, b->rp, i);
648                 b->rp += i;
649                 p += i;
650                 if(BLEN(b) < 0)
651                         panic("consume");
652                 if(BLEN(b))
653                         break;
654                 *l = b->next;
655                 freeb(b);
656         }
657 }
658
659 /*
660  *  give back n bytes
661  */
662 static void
663 regurgitate(TlsRec *s, uchar *p, int n)
664 {
665         Block *b;
666
667         if(n <= 0)
668                 return;
669         b = s->unprocessed;
670         if(s->unprocessed == nil || b->rp - b->base < n) {
671                 b = allocb(n);
672                 memmove(b->wp, p, n);
673                 b->wp += n;
674                 b->next = s->unprocessed;
675                 s->unprocessed = b;
676         } else {
677                 b->rp -= n;
678                 memmove(b->rp, p, n);
679         }
680 }
681
682 /*
683  *  remove at most n bytes from the queue
684  */
685 static Block*
686 qgrab(Block **l, int n)
687 {
688         Block *bb, *b;
689         int i;
690
691         b = *l;
692         if(BLEN(b) == n){
693                 *l = b->next;
694                 b->next = nil;
695                 return b;
696         }
697
698         i = 0;
699         for(bb = b; bb != nil && i < n; bb = bb->next)
700                 i += BLEN(bb);
701         if(i > n)
702                 i = n;
703
704         bb = allocb(i);
705         consume(l, bb->wp, i);
706         bb->wp += i;
707         return bb;
708 }
709
710 static void
711 tlsclosed(TlsRec *tr, int new)
712 {
713         lock(&tr->statelk);
714         if(tr->state == SOpen || tr->state == SHandshake)
715                 tr->state = new;
716         else if((new | tr->state) == (SRClose|SLClose))
717                 tr->state = SClosed;
718         unlock(&tr->statelk);
719         alertHand(tr, "close notify");
720 }
721
722 /*
723  *  read and process one tls record layer message
724  *  must be called with tr->in.io held
725  *  We can't let Eintrs lose data, since doing so will get
726  *  us out of sync with the sender and break the reliablity
727  *  of the channel.  Eintr only happens during the reads in
728  *  consume.  Therefore we put back any bytes consumed before
729  *  the last call to ensure.
730  */
731 static void
732 tlsrecread(TlsRec *tr)
733 {
734         OneWay *volatile in;
735         Block *volatile b;
736         uchar *p, seq[8], header[RecHdrLen], hmac[MaxMacLen];
737         int volatile nconsumed;
738         int len, type, ver, unpad_len;
739
740         nconsumed = 0;
741         if(waserror()){
742                 if(strcmp(up->errstr, Eintr) == 0 && !waserror()){
743                         regurgitate(tr, header, nconsumed);
744                         poperror();
745                 }else
746                         tlsError(tr, "channel error");
747                 nexterror();
748         }
749         ensure(tr, &tr->unprocessed, RecHdrLen);
750         consume(&tr->unprocessed, header, RecHdrLen);
751 if(tr->debug)pprint("consumed %d header\n", RecHdrLen);
752         nconsumed = RecHdrLen;
753
754         if((tr->handin == 0) && (header[0] & 0x80)){
755                 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
756                         This is sent by some clients that we must interoperate
757                         with, such as Java's JSSE and Microsoft's Internet Explorer. */
758                 len = (get16(header) & ~0x8000) - 3;
759                 type = header[2];
760                 ver = get16(header + 3);
761                 if(type != SSL2ClientHello || len < 22)
762                         rcvError(tr, EProtocolVersion, "invalid initial SSL2-like message");
763         }else{  /* normal SSL3 record format */
764                 type = header[0];
765                 ver = get16(header+1);
766                 len = get16(header+3);
767         }
768         if(ver != tr->version && (tr->verset || ver < MinProtoVersion || ver > MaxProtoVersion))
769                 rcvError(tr, EProtocolVersion, "devtls expected ver=%x%s, saw (len=%d) type=%x ver=%x '%.12s'",
770                         tr->version, tr->verset?"/set":"", len, type, ver, (char*)header);
771         if(len > MaxCipherRecLen || len < 0)
772                 rcvError(tr, ERecordOverflow, "record message too long %d", len);
773         ensure(tr, &tr->unprocessed, len);
774         nconsumed = 0;
775         poperror();
776
777         /*
778          * If an Eintr happens after this, we'll get out of sync.
779          * Make sure nothing we call can sleep.
780          * Errors are ok, as they kill the connection.
781          * Luckily, allocb won't sleep, it'll just error out.
782          */
783         b = nil;
784         if(waserror()){
785                 if(b != nil)
786                         freeb(b);
787                 tlsError(tr, "channel error");
788                 nexterror();
789         }
790         b = qgrab(&tr->unprocessed, len);
791 if(tr->debug) pprint("consumed unprocessed %d\n", len);
792
793         in = &tr->in;
794         if(waserror()){
795                 qunlock(&in->seclock);
796                 nexterror();
797         }
798         qlock(&in->seclock);
799         p = b->rp;
800         if(in->sec != nil) {
801                 /* to avoid Canvel-Hiltgen-Vaudenay-Vuagnoux attack, all errors here
802                         should look alike, including timing of the response. */
803                 unpad_len = (*in->sec->dec)(in->sec, p, len);
804
805                 /* excplicit iv */
806                 if(tr->version >= TLS11Version){
807                         len -= in->sec->block;
808                         if(len < 0)
809                                 rcvError(tr, EDecodeError, "runt record message");
810
811                         unpad_len -= in->sec->block;
812                         p += in->sec->block;
813                 }
814
815                 if(unpad_len >= in->sec->maclen)
816                         len = unpad_len - in->sec->maclen;
817
818 if(tr->debug) pprint("decrypted %d\n", unpad_len);
819 if(tr->debug) pdump(unpad_len, p, "decrypted:");
820
821                 /* update length */
822                 put16(header+3, len);
823                 put64(seq, in->seq);
824                 in->seq++;
825                 (*tr->packMac)(in->sec, in->sec->mackey, seq, header, p, len, hmac);
826                 if(unpad_len < in->sec->maclen)
827                         rcvError(tr, EBadRecordMac, "short record mac");
828                 if(constcmp(hmac, p+len, in->sec->maclen) != 0)
829                         rcvError(tr, EBadRecordMac, "record mac mismatch");
830                 b->rp = p;
831                 b->wp = p+len;
832         }
833         qunlock(&in->seclock);
834         poperror();
835         if(len < 0)
836                 rcvError(tr, EDecodeError, "runt record message");
837
838         switch(type) {
839         default:
840                 rcvError(tr, EIllegalParameter, "invalid record message %#x", type);
841                 break;
842         case RChangeCipherSpec:
843                 if(len != 1 || p[0] != 1)
844                         rcvError(tr, EDecodeError, "invalid change cipher spec");
845                 qlock(&in->seclock);
846                 if(in->new == nil){
847                         qunlock(&in->seclock);
848                         rcvError(tr, EUnexpectedMessage, "unexpected change cipher spec");
849                 }
850                 freeSec(in->sec);
851                 in->sec = in->new;
852                 in->new = nil;
853                 in->seq = 0;
854                 qunlock(&in->seclock);
855                 break;
856         case RAlert:
857                 if(len != 2)
858                         rcvError(tr, EDecodeError, "invalid alert");
859                 if(p[0] == 2)
860                         rcvAlert(tr, p[1]);
861                 if(p[0] != 1)
862                         rcvError(tr, EIllegalParameter, "invalid alert fatal code");
863
864                 /*
865                  * propagate non-fatal alerts to handshaker
866                  */
867                 switch(p[1]){
868                 case ECloseNotify:
869                         tlsclosed(tr, SRClose);
870                         if(tr->opened)
871                                 error("tls hungup");
872                         error("close notify");
873                         break;
874                 case ENoRenegotiation:
875                         alertHand(tr, "no renegotiation");
876                         break;
877                 case EUserCanceled:
878                         alertHand(tr, "handshake canceled by user");
879                         break;
880                 case EUnrecognizedName:
881                         /* happens in response to SNI, can be ignored. */
882                         break;
883                 default:
884                         rcvError(tr, EIllegalParameter, "invalid alert code");
885                 }
886                 break;
887         case RHandshake:
888                 /*
889                  * don't worry about dropping the block
890                  * qbwrite always queues even if flow controlled and interrupted.
891                  *
892                  * if there isn't any handshaker, ignore the request,
893                  * but notify the other side we are doing so.
894                  */
895                 lock(&tr->hqlock);
896                 if(tr->handq != nil){
897                         tr->hqref++;
898                         unlock(&tr->hqlock);
899                         if(waserror()){
900                                 dechandq(tr);
901                                 nexterror();
902                         }
903                         b = padblock(b, 1);
904                         *b->rp = RHandshake;
905                         qbwrite(tr->handq, b);
906                         b = nil;
907                         poperror();
908                         dechandq(tr);
909                 }else{
910                         unlock(&tr->hqlock);
911                         if(tr->verset && tr->version != SSL3Version && !waserror()){
912                                 sendAlert(tr, ENoRenegotiation);
913                                 poperror();
914                         }
915                 }
916                 break;
917         case SSL2ClientHello:
918                 lock(&tr->hqlock);
919                 if(tr->handq != nil){
920                         tr->hqref++;
921                         unlock(&tr->hqlock);
922                         if(waserror()){
923                                 dechandq(tr);
924                                 nexterror();
925                         }
926                         /* Pass the SSL2 format data, so that the handshake code can compute
927                                 the correct checksums.  HSSL2ClientHello = HandshakeType 9 is
928                                 unused in RFC2246. */
929                         b = padblock(b, 8);
930                         b->rp[0] = RHandshake;
931                         b->rp[1] = HSSL2ClientHello;
932                         put24(&b->rp[2], len+3);
933                         b->rp[5] = SSL2ClientHello;
934                         put16(&b->rp[6], ver);
935                         qbwrite(tr->handq, b);
936                         b = nil;
937                         poperror();
938                         dechandq(tr);
939                 }else{
940                         unlock(&tr->hqlock);
941                         if(tr->verset && tr->version != SSL3Version && !waserror()){
942                                 sendAlert(tr, ENoRenegotiation);
943                                 poperror();
944                         }
945                 }
946                 break;
947         case RApplication:
948                 if(!tr->opened)
949                         rcvError(tr, EUnexpectedMessage, "application message received before handshake completed");
950                 if(BLEN(b) > 0){
951                         tr->processed = b;
952                         b = nil;
953                 }
954                 break;
955         }
956         if(b != nil)
957                 freeb(b);
958         poperror();
959 }
960
961 /*
962  * got a fatal alert message
963  */
964 static void
965 rcvAlert(TlsRec *tr, int err)
966 {
967         char *s;
968         int i;
969
970         s = "unknown error";
971         for(i=0; i < nelem(tlserrs); i++){
972                 if(tlserrs[i].err == err){
973                         s = tlserrs[i].msg;
974                         break;
975                 }
976         }
977 if(tr->debug) pprint("rcvAlert: %s\n", s);
978
979         tlsError(tr, s);
980         if(!tr->opened)
981                 error(s);
982         error("tls error");
983 }
984
985 /*
986  * found an error while decoding the input stream
987  */
988 static void
989 rcvError(TlsRec *tr, int err, char *fmt, ...)
990 {
991         char msg[ERRMAX];
992         va_list arg;
993
994         va_start(arg, fmt);
995         vseprint(msg, msg+sizeof(msg), fmt, arg);
996         va_end(arg);
997 if(tr->debug) pprint("rcvError: %s\n", msg);
998
999         sendAlert(tr, err);
1000
1001         if(!tr->opened)
1002                 error(msg);
1003         error("tls error");
1004 }
1005
1006 /*
1007  * make sure the next hand operation returns with a 'msg' error
1008  */
1009 static void
1010 alertHand(TlsRec *tr, char *msg)
1011 {
1012         Block *b;
1013         int n;
1014
1015         lock(&tr->hqlock);
1016         if(tr->handq == nil){
1017                 unlock(&tr->hqlock);
1018                 return;
1019         }
1020         tr->hqref++;
1021         unlock(&tr->hqlock);
1022
1023         n = strlen(msg);
1024         if(waserror()){
1025                 dechandq(tr);
1026                 nexterror();
1027         }
1028         b = allocb(n + 2);
1029         *b->wp++ = RAlert;
1030         memmove(b->wp, msg, n + 1);
1031         b->wp += n + 1;
1032
1033         qbwrite(tr->handq, b);
1034
1035         poperror();
1036         dechandq(tr);
1037 }
1038
1039 static void
1040 checkstate(TlsRec *tr, int ishand, int ok)
1041 {
1042         int state;
1043
1044         lock(&tr->statelk);
1045         state = tr->state;
1046         unlock(&tr->statelk);
1047         if(state & ok)
1048                 return;
1049         switch(state){
1050         case SHandshake:
1051         case SOpen:
1052                 break;
1053         case SError:
1054         case SAlert:
1055                 if(ishand)
1056                         error(tr->err);
1057                 error("tls error");
1058         case SRClose:
1059         case SLClose:
1060         case SClosed:
1061                 error("tls hungup");
1062         }
1063         error("tls improperly configured");
1064 }
1065
1066 static Block*
1067 tlsbread(Chan *c, long n, ulong offset)
1068 {
1069         int ty;
1070         Block *b;
1071         TlsRec *volatile tr;
1072
1073         ty = TYPE(c->qid);
1074         switch(ty) {
1075         default:
1076                 return devbread(c, n, offset);
1077         case Qhand:
1078         case Qdata:
1079                 break;
1080         }
1081
1082         tr = tlsdevs[CONV(c->qid)];
1083         if(tr == nil)
1084                 panic("tlsbread");
1085
1086         if(waserror()){
1087                 qunlock(&tr->in.io);
1088                 nexterror();
1089         }
1090         qlock(&tr->in.io);
1091         if(ty == Qdata){
1092                 checkstate(tr, 0, SOpen);
1093                 while(tr->processed == nil)
1094                         tlsrecread(tr);
1095
1096                 /* return at most what was asked for */
1097                 b = qgrab(&tr->processed, n);
1098 if(tr->debug) pprint("consumed processed %ld\n", BLEN(b));
1099 if(tr->debug) pdump(BLEN(b), b->rp, "consumed:");
1100                 qunlock(&tr->in.io);
1101                 poperror();
1102                 tr->datain += BLEN(b);
1103         }else{
1104                 checkstate(tr, 1, SOpen|SHandshake|SLClose);
1105
1106                 /*
1107                  * it's ok to look at state without the lock
1108                  * since it only protects reading records,
1109                  * and we have that tr->in.io held.
1110                  */
1111                 while(!tr->opened && tr->hprocessed == nil && !qcanread(tr->handq))
1112                         tlsrecread(tr);
1113
1114                 qunlock(&tr->in.io);
1115                 poperror();
1116
1117                 if(waserror()){
1118                         qunlock(&tr->hqread);
1119                         nexterror();
1120                 }
1121                 qlock(&tr->hqread);
1122                 if(tr->hprocessed == nil){
1123                         b = qbread(tr->handq, MaxRecLen + 1);
1124                         if(*b->rp++ == RAlert){
1125                                 kstrcpy(up->errstr, (char*)b->rp, ERRMAX);
1126                                 freeb(b);
1127                                 nexterror();
1128                         }
1129                         tr->hprocessed = b;
1130                 }
1131                 b = qgrab(&tr->hprocessed, n);
1132                 poperror();
1133                 qunlock(&tr->hqread);
1134                 tr->handin += BLEN(b);
1135         }
1136
1137         return b;
1138 }
1139
1140 static long
1141 tlsread(Chan *c, void *a, long n, vlong off)
1142 {
1143         Block *volatile b;
1144         Block *nb;
1145         uchar *va;
1146         int i, ty;
1147         char *buf, *s, *e;
1148         ulong offset = off;
1149         TlsRec * tr;
1150
1151         if(c->qid.type & QTDIR)
1152                 return devdirread(c, a, n, 0, 0, tlsgen);
1153
1154         tr = tlsdevs[CONV(c->qid)];
1155         ty = TYPE(c->qid);
1156         switch(ty) {
1157         default:
1158                 error(Ebadusefd);
1159         case Qstatus:
1160                 buf = smalloc(Statlen);
1161                 qlock(&tr->in.seclock);
1162                 qlock(&tr->out.seclock);
1163                 s = buf;
1164                 e = buf + Statlen;
1165                 s = seprint(s, e, "State: %s\n", tlsstate(tr->state));
1166                 s = seprint(s, e, "Version: %#x\n", tr->version);
1167                 if(tr->in.sec != nil)
1168                         s = seprint(s, e, "EncIn: %s\nHashIn: %s\n", tr->in.sec->encalg, tr->in.sec->hashalg);
1169                 if(tr->in.new != nil)
1170                         s = seprint(s, e, "NewEncIn: %s\nNewHashIn: %s\n", tr->in.new->encalg, tr->in.new->hashalg);
1171                 if(tr->out.sec != nil)
1172                         s = seprint(s, e, "EncOut: %s\nHashOut: %s\n", tr->out.sec->encalg, tr->out.sec->hashalg);
1173                 if(tr->out.new != nil)
1174                         seprint(s, e, "NewEncOut: %s\nNewHashOut: %s\n", tr->out.new->encalg, tr->out.new->hashalg);
1175                 qunlock(&tr->in.seclock);
1176                 qunlock(&tr->out.seclock);
1177                 n = readstr(offset, a, n, buf);
1178                 free(buf);
1179                 return n;
1180         case Qstats:
1181                 buf = smalloc(Statlen);
1182                 s = buf;
1183                 e = buf + Statlen;
1184                 s = seprint(s, e, "DataIn: %lld\n", tr->datain);
1185                 s = seprint(s, e, "DataOut: %lld\n", tr->dataout);
1186                 s = seprint(s, e, "HandIn: %lld\n", tr->handin);
1187                 seprint(s, e, "HandOut: %lld\n", tr->handout);
1188                 n = readstr(offset, a, n, buf);
1189                 free(buf);
1190                 return n;
1191         case Qctl:
1192                 buf = smalloc(Statlen);
1193                 snprint(buf, Statlen, "%llud", CONV(c->qid));
1194                 n = readstr(offset, a, n, buf);
1195                 free(buf);
1196                 return n;
1197         case Qdata:
1198         case Qhand:
1199                 b = tlsbread(c, n, offset);
1200                 break;
1201         case Qencalgs:
1202                 return readstr(offset, a, n, encalgs);
1203         case Qhashalgs:
1204                 return readstr(offset, a, n, hashalgs);
1205         }
1206
1207         if(waserror()){
1208                 freeblist(b);
1209                 nexterror();
1210         }
1211
1212         n = 0;
1213         va = a;
1214         for(nb = b; nb; nb = nb->next){
1215                 i = BLEN(nb);
1216                 memmove(va+n, nb->rp, i);
1217                 n += i;
1218         }
1219
1220         freeblist(b);
1221         poperror();
1222
1223         return n;
1224 }
1225
1226 static void
1227 randfill(uchar *buf, int len)
1228 {
1229         while(len-- > 0)
1230                 *buf++ = nrand(256);
1231 }
1232
1233 /*
1234  *  write a block in tls records
1235  */
1236 static void
1237 tlsrecwrite(TlsRec *tr, int type, Block *b)
1238 {
1239         Block *volatile bb;
1240         Block *nb;
1241         uchar *p, seq[8];
1242         OneWay *volatile out;
1243         int n, ivlen, maclen, pad, ok;
1244
1245         out = &tr->out;
1246         bb = b;
1247         if(waserror()){
1248                 qunlock(&out->io);
1249                 if(bb != nil)
1250                         freeb(bb);
1251                 nexterror();
1252         }
1253         qlock(&out->io);
1254 if(tr->debug)pprint("send %ld\n", BLEN(b));
1255 if(tr->debug)pdump(BLEN(b), b->rp, "sent:");
1256
1257
1258         ok = SHandshake|SOpen|SRClose;
1259         if(type == RAlert)
1260                 ok |= SAlert;
1261         while(bb != nil){
1262                 checkstate(tr, type != RApplication, ok);
1263
1264                 /*
1265                  * get at most one maximal record's input,
1266                  * with padding on the front for header and
1267                  * back for mac and maximal block padding.
1268                  */
1269                 if(waserror()){
1270                         qunlock(&out->seclock);
1271                         nexterror();
1272                 }
1273                 qlock(&out->seclock);
1274                 maclen = 0;
1275                 pad = 0;
1276                 ivlen = 0;
1277                 if(out->sec != nil){
1278                         maclen = out->sec->maclen;
1279                         pad = maclen + out->sec->block;
1280                         if(tr->version >= TLS11Version)
1281                                 ivlen = out->sec->block;
1282                 }
1283                 n = BLEN(bb);
1284                 if(n > MaxRecLen){
1285                         n = MaxRecLen;
1286                         nb = allocb(RecHdrLen + ivlen + n + pad);
1287                         memmove(nb->wp + RecHdrLen + ivlen, bb->rp, n);
1288                         bb->rp += n;
1289                 }else{
1290                         /*
1291                          * carefully reuse bb so it will get freed if we're out of memory
1292                          */
1293                         bb = padblock(bb, RecHdrLen + ivlen);
1294                         if(pad)
1295                                 nb = padblock(bb, -pad);
1296                         else
1297                                 nb = bb;
1298                         bb = nil;
1299                 }
1300
1301                 p = nb->rp;
1302                 p[0] = type;
1303                 put16(p+1, tr->version);
1304                 put16(p+3, n);
1305
1306                 if(out->sec != nil){
1307                         put64(seq, out->seq);
1308                         out->seq++;
1309                         (*tr->packMac)(out->sec, out->sec->mackey, seq, p, p + RecHdrLen + ivlen, n, p + RecHdrLen + ivlen + n);
1310                         n += maclen;
1311
1312                         /* explicit iv */
1313                         if(ivlen > 0){
1314                                 randfill(p + RecHdrLen, ivlen);
1315                                 n += ivlen;
1316                         }
1317
1318                         /* encrypt */
1319                         n = (*out->sec->enc)(out->sec, p + RecHdrLen, n);
1320                         nb->wp = p + RecHdrLen + n;
1321
1322                         /* update length */
1323                         put16(p+3, n);
1324                 }
1325                 if(type == RChangeCipherSpec){
1326                         if(out->new == nil)
1327                                 error("change cipher without a new cipher");
1328                         freeSec(out->sec);
1329                         out->sec = out->new;
1330                         out->new = nil;
1331                         out->seq = 0;
1332                 }
1333                 qunlock(&out->seclock);
1334                 poperror();
1335
1336                 /*
1337                  * if bwrite error's, we assume the block is queued.
1338                  * if not, we're out of sync with the receiver and will not recover.
1339                  */
1340                 if(waserror()){
1341                         if(strcmp(up->errstr, "interrupted") != 0)
1342                                 tlsError(tr, "channel error");
1343                         nexterror();
1344                 }
1345                 devtab[tr->c->type]->bwrite(tr->c, nb, 0);
1346                 poperror();
1347         }
1348         qunlock(&out->io);
1349         poperror();
1350 }
1351
1352 static long
1353 tlsbwrite(Chan *c, Block *b, ulong offset)
1354 {
1355         int ty;
1356         ulong n;
1357         TlsRec *tr;
1358
1359         n = BLEN(b);
1360
1361         tr = tlsdevs[CONV(c->qid)];
1362         if(tr == nil)
1363                 panic("tlsbwrite");
1364
1365         ty = TYPE(c->qid);
1366         switch(ty) {
1367         default:
1368                 return devbwrite(c, b, offset);
1369         case Qhand:
1370                 tlsrecwrite(tr, RHandshake, b);
1371                 tr->handout += n;
1372                 break;
1373         case Qdata:
1374                 checkstate(tr, 0, SOpen);
1375                 tlsrecwrite(tr, RApplication, b);
1376                 tr->dataout += n;
1377                 break;
1378         }
1379
1380         return n;
1381 }
1382
1383 typedef struct Hashalg Hashalg;
1384 struct Hashalg
1385 {
1386         char    *name;
1387         int     maclen;
1388         void    (*initkey)(Hashalg *, int, Secret *, uchar*);
1389 };
1390
1391 static void
1392 initmd5key(Hashalg *ha, int version, Secret *s, uchar *p)
1393 {
1394         s->maclen = ha->maclen;
1395         if(version == SSL3Version)
1396                 s->mac = sslmac_md5;
1397         else
1398                 s->mac = hmac_md5;
1399         memmove(s->mackey, p, ha->maclen);
1400 }
1401
1402 static void
1403 initclearmac(Hashalg *, int, Secret *s, uchar *)
1404 {
1405         s->maclen = 0;
1406         s->mac = nomac;
1407 }
1408
1409 static void
1410 initsha1key(Hashalg *ha, int version, Secret *s, uchar *p)
1411 {
1412         s->maclen = ha->maclen;
1413         if(version == SSL3Version)
1414                 s->mac = sslmac_sha1;
1415         else
1416                 s->mac = hmac_sha1;
1417         memmove(s->mackey, p, ha->maclen);
1418 }
1419
1420 static Hashalg hashtab[] =
1421 {
1422         { "clear", 0, initclearmac, },
1423         { "md5", MD5dlen, initmd5key, },
1424         { "sha1", SHA1dlen, initsha1key, },
1425         { 0 }
1426 };
1427
1428 static Hashalg*
1429 parsehashalg(char *p)
1430 {
1431         Hashalg *ha;
1432
1433         for(ha = hashtab; ha->name; ha++)
1434                 if(strcmp(p, ha->name) == 0)
1435                         return ha;
1436         error("unsupported hash algorithm");
1437         return nil;
1438 }
1439
1440 typedef struct Encalg Encalg;
1441 struct Encalg
1442 {
1443         char    *name;
1444         int     keylen;
1445         int     ivlen;
1446         void    (*initkey)(Encalg *ea, Secret *, uchar*, uchar*);
1447 };
1448
1449 static void
1450 initRC4key(Encalg *ea, Secret *s, uchar *p, uchar *)
1451 {
1452         s->enckey = smalloc(sizeof(RC4state));
1453         s->enc = rc4enc;
1454         s->dec = rc4enc;
1455         s->block = 0;
1456         setupRC4state(s->enckey, p, ea->keylen);
1457 }
1458
1459 static void
1460 initDES3key(Encalg *, Secret *s, uchar *p, uchar *iv)
1461 {
1462         s->enckey = smalloc(sizeof(DES3state));
1463         s->enc = des3enc;
1464         s->dec = des3dec;
1465         s->block = 8;
1466         setupDES3state(s->enckey, (uchar(*)[8])p, iv);
1467 }
1468
1469 static void
1470 initAESkey(Encalg *ea, Secret *s, uchar *p, uchar *iv)
1471 {
1472         s->enckey = smalloc(sizeof(AESstate));
1473         s->enc = aesenc;
1474         s->dec = aesdec;
1475         s->block = 16;
1476         setupAESstate(s->enckey, p, ea->keylen, iv);
1477 }
1478
1479 static void
1480 initclearenc(Encalg *, Secret *s, uchar *, uchar *)
1481 {
1482         s->enc = noenc;
1483         s->dec = noenc;
1484         s->block = 0;
1485 }
1486
1487 static Encalg encrypttab[] =
1488 {
1489         { "clear", 0, 0, initclearenc },
1490         { "rc4_128", 128/8, 0, initRC4key },
1491         { "3des_ede_cbc", 3 * 8, 8, initDES3key },
1492         { "aes_128_cbc", 128/8, 16, initAESkey },
1493         { "aes_256_cbc", 256/8, 16, initAESkey },
1494         { 0 }
1495 };
1496
1497 static Encalg*
1498 parseencalg(char *p)
1499 {
1500         Encalg *ea;
1501
1502         for(ea = encrypttab; ea->name; ea++)
1503                 if(strcmp(p, ea->name) == 0)
1504                         return ea;
1505         error("unsupported encryption algorithm");
1506         return nil;
1507 }
1508
1509 static long
1510 tlswrite(Chan *c, void *a, long n, vlong off)
1511 {
1512         Encalg *ea;
1513         Hashalg *ha;
1514         TlsRec *volatile tr;
1515         Secret *volatile tos, *volatile toc;
1516         Block *volatile b;
1517         Cmdbuf *volatile cb;
1518         int m, ty;
1519         char *p, *e;
1520         uchar *volatile x;
1521         ulong offset = off;
1522
1523         tr = tlsdevs[CONV(c->qid)];
1524         if(tr == nil)
1525                 panic("tlswrite");
1526
1527         ty = TYPE(c->qid);
1528         switch(ty){
1529         case Qdata:
1530         case Qhand:
1531                 p = a;
1532                 e = p + n;
1533                 do{
1534                         m = e - p;
1535                         if(m > MaxRecLen)
1536                                 m = MaxRecLen;
1537
1538                         b = allocb(m);
1539                         if(waserror()){
1540                                 freeb(b);
1541                                 nexterror();
1542                         }
1543                         memmove(b->wp, p, m);
1544                         poperror();
1545                         b->wp += m;
1546
1547                         tlsbwrite(c, b, offset);
1548
1549                         p += m;
1550                 }while(p < e);
1551                 return n;
1552         case Qctl:
1553                 break;
1554         default:
1555                 error(Ebadusefd);
1556                 return -1;
1557         }
1558
1559         cb = parsecmd(a, n);
1560         if(waserror()){
1561                 free(cb);
1562                 nexterror();
1563         }
1564         if(cb->nf < 1)
1565                 error("short control request");
1566
1567         /* mutex with operations using what we're about to change */
1568         if(waserror()){
1569                 qunlock(&tr->in.seclock);
1570                 qunlock(&tr->out.seclock);
1571                 nexterror();
1572         }
1573         qlock(&tr->in.seclock);
1574         qlock(&tr->out.seclock);
1575
1576         if(strcmp(cb->f[0], "fd") == 0){
1577                 if(cb->nf != 3)
1578                         error("usage: fd open-fd version");
1579                 if(tr->c != nil)
1580                         error(Einuse);
1581                 m = strtol(cb->f[2], nil, 0);
1582                 if(m < MinProtoVersion || m > MaxProtoVersion)
1583                         error("unsupported version");
1584                 tr->c = buftochan(cb->f[1]);
1585                 tr->version = m;
1586                 tlsSetState(tr, SHandshake, SClosed);
1587         }else if(strcmp(cb->f[0], "version") == 0){
1588                 if(cb->nf != 2)
1589                         error("usage: version vers");
1590                 if(tr->c == nil)
1591                         error("must set fd before version");
1592                 if(tr->verset)
1593                         error("version already set");
1594                 m = strtol(cb->f[1], nil, 0);
1595                 if(m < MinProtoVersion || m > MaxProtoVersion)
1596                         error("unsupported version");
1597                 if(m == SSL3Version)
1598                         tr->packMac = sslPackMac;
1599                 else
1600                         tr->packMac = tlsPackMac;
1601                 tr->verset = 1;
1602                 tr->version = m;
1603         }else if(strcmp(cb->f[0], "secret") == 0){
1604                 if(cb->nf != 5)
1605                         error("usage: secret hashalg encalg isclient secretdata");
1606                 if(tr->c == nil || !tr->verset)
1607                         error("must set fd and version before secrets");
1608
1609                 if(tr->in.new != nil){
1610                         freeSec(tr->in.new);
1611                         tr->in.new = nil;
1612                 }
1613                 if(tr->out.new != nil){
1614                         freeSec(tr->out.new);
1615                         tr->out.new = nil;
1616                 }
1617
1618                 ha = parsehashalg(cb->f[1]);
1619                 ea = parseencalg(cb->f[2]);
1620
1621                 p = cb->f[4];
1622                 m = (strlen(p)*3)/2;
1623                 x = smalloc(m);
1624                 tos = nil;
1625                 toc = nil;
1626                 if(waserror()){
1627                         freeSec(tos);
1628                         freeSec(toc);
1629                         free(x);
1630                         nexterror();
1631                 }
1632                 m = dec64(x, m, p, strlen(p));
1633                 if(m < 2 * ha->maclen + 2 * ea->keylen + 2 * ea->ivlen)
1634                         error("not enough secret data provided");
1635
1636                 tos = smalloc(sizeof(Secret));
1637                 toc = smalloc(sizeof(Secret));
1638                 if(!ha->initkey || !ea->initkey)
1639                         error("misimplemented secret algorithm");
1640                 (*ha->initkey)(ha, tr->version, tos, &x[0]);
1641                 (*ha->initkey)(ha, tr->version, toc, &x[ha->maclen]);
1642                 (*ea->initkey)(ea, tos, &x[2 * ha->maclen], &x[2 * ha->maclen + 2 * ea->keylen]);
1643                 (*ea->initkey)(ea, toc, &x[2 * ha->maclen + ea->keylen], &x[2 * ha->maclen + 2 * ea->keylen + ea->ivlen]);
1644
1645                 if(!tos->mac || !tos->enc || !tos->dec
1646                 || !toc->mac || !toc->enc || !toc->dec)
1647                         error("missing algorithm implementations");
1648                 if(strtol(cb->f[3], nil, 0) == 0){
1649                         tr->in.new = tos;
1650                         tr->out.new = toc;
1651                 }else{
1652                         tr->in.new = toc;
1653                         tr->out.new = tos;
1654                 }
1655                 if(tr->version == SSL3Version){
1656                         toc->unpad = sslunpad;
1657                         tos->unpad = sslunpad;
1658                 }else{
1659                         toc->unpad = tlsunpad;
1660                         tos->unpad = tlsunpad;
1661                 }
1662                 toc->encalg = ea->name;
1663                 toc->hashalg = ha->name;
1664                 tos->encalg = ea->name;
1665                 tos->hashalg = ha->name;
1666
1667                 free(x);
1668                 poperror();
1669         }else if(strcmp(cb->f[0], "changecipher") == 0){
1670                 if(cb->nf != 1)
1671                         error("usage: changecipher");
1672                 if(tr->out.new == nil)
1673                         error("cannot change cipher spec without setting secret");
1674
1675                 qunlock(&tr->in.seclock);
1676                 qunlock(&tr->out.seclock);
1677                 poperror();
1678                 free(cb);
1679                 poperror();
1680
1681                 /*
1682                  * the real work is done as the message is written
1683                  * so the stream is encrypted in sync.
1684                  */
1685                 b = allocb(1);
1686                 *b->wp++ = 1;
1687                 tlsrecwrite(tr, RChangeCipherSpec, b);
1688                 return n;
1689         }else if(strcmp(cb->f[0], "opened") == 0){
1690                 if(cb->nf != 1)
1691                         error("usage: opened");
1692                 if(tr->in.sec == nil || tr->out.sec == nil)
1693                         error("cipher must be configured before enabling data messages");
1694                 lock(&tr->statelk);
1695                 if(tr->state != SHandshake && tr->state != SOpen){
1696                         unlock(&tr->statelk);
1697                         error("cannot enable data messages");
1698                 }
1699                 tr->state = SOpen;
1700                 unlock(&tr->statelk);
1701                 tr->opened = 1;
1702         }else if(strcmp(cb->f[0], "alert") == 0){
1703                 if(cb->nf != 2)
1704                         error("usage: alert n");
1705                 if(tr->c == nil)
1706                         error("must set fd before sending alerts");
1707                 m = strtol(cb->f[1], nil, 0);
1708
1709                 qunlock(&tr->in.seclock);
1710                 qunlock(&tr->out.seclock);
1711                 poperror();
1712                 free(cb);
1713                 poperror();
1714
1715                 sendAlert(tr, m);
1716
1717                 if(m == ECloseNotify)
1718                         tlsclosed(tr, SLClose);
1719
1720                 return n;
1721         } else if(strcmp(cb->f[0], "debug") == 0){
1722                 if(cb->nf == 2){
1723                         if(strcmp(cb->f[1], "on") == 0)
1724                                 tr->debug = 1;
1725                         else
1726                                 tr->debug = 0;
1727                 } else
1728                         tr->debug = 1;
1729         } else
1730                 error(Ebadarg);
1731
1732         qunlock(&tr->in.seclock);
1733         qunlock(&tr->out.seclock);
1734         poperror();
1735         free(cb);
1736         poperror();
1737
1738         return n;
1739 }
1740
1741 static void
1742 tlsinit(void)
1743 {
1744         struct Encalg *e;
1745         struct Hashalg *h;
1746         int n;
1747         char *cp;
1748         static int already;
1749
1750         if(!already){
1751                 fmtinstall('H', encodefmt);
1752                 already = 1;
1753         }
1754
1755         tlsdevs = smalloc(sizeof(TlsRec*) * maxtlsdevs);
1756         trnames = smalloc((sizeof *trnames) * maxtlsdevs);
1757
1758         n = 1;
1759         for(e = encrypttab; e->name != nil; e++)
1760                 n += strlen(e->name) + 1;
1761         cp = encalgs = smalloc(n);
1762         for(e = encrypttab;;){
1763                 strcpy(cp, e->name);
1764                 cp += strlen(e->name);
1765                 e++;
1766                 if(e->name == nil)
1767                         break;
1768                 *cp++ = ' ';
1769         }
1770         *cp = 0;
1771
1772         n = 1;
1773         for(h = hashtab; h->name != nil; h++)
1774                 n += strlen(h->name) + 1;
1775         cp = hashalgs = smalloc(n);
1776         for(h = hashtab;;){
1777                 strcpy(cp, h->name);
1778                 cp += strlen(h->name);
1779                 h++;
1780                 if(h->name == nil)
1781                         break;
1782                 *cp++ = ' ';
1783         }
1784         *cp = 0;
1785 }
1786
1787 Dev tlsdevtab = {
1788         'a',
1789         "tls",
1790
1791         devreset,
1792         tlsinit,
1793         devshutdown,
1794         tlsattach,
1795         tlswalk,
1796         tlsstat,
1797         tlsopen,
1798         devcreate,
1799         tlsclose,
1800         tlsread,
1801         tlsbread,
1802         tlswrite,
1803         tlsbwrite,
1804         devremove,
1805         tlswstat,
1806 };
1807
1808 /* get channel associated with an fd */
1809 static Chan*
1810 buftochan(char *p)
1811 {
1812         Chan *c;
1813         int fd;
1814
1815         if(p == 0)
1816                 error(Ebadarg);
1817         fd = strtoul(p, 0, 0);
1818         if(fd < 0)
1819                 error(Ebadarg);
1820         c = fdtochan(fd, -1, 0, 1);     /* error check and inc ref */
1821         return c;
1822 }
1823
1824 static void
1825 sendAlert(TlsRec *tr, int err)
1826 {
1827         Block *b;
1828         int i, fatal;
1829         char *msg;
1830
1831 if(tr->debug)pprint("sendAlert %d\n", err);
1832         fatal = 1;
1833         msg = "tls unknown alert";
1834         for(i=0; i < nelem(tlserrs); i++) {
1835                 if(tlserrs[i].err == err) {
1836                         msg = tlserrs[i].msg;
1837                         if(tr->version == SSL3Version)
1838                                 err = tlserrs[i].sslerr;
1839                         else
1840                                 err = tlserrs[i].tlserr;
1841                         fatal = tlserrs[i].fatal;
1842                         break;
1843                 }
1844         }
1845
1846         if(!waserror()){
1847                 b = allocb(2);
1848                 *b->wp++ = fatal + 1;
1849                 *b->wp++ = err;
1850                 if(fatal)
1851                         tlsSetState(tr, SAlert, SOpen|SHandshake|SRClose);
1852                 tlsrecwrite(tr, RAlert, b);
1853                 poperror();
1854         }
1855         if(fatal)
1856                 tlsError(tr, msg);
1857 }
1858
1859 static void
1860 tlsError(TlsRec *tr, char *msg)
1861 {
1862         int s;
1863
1864 if(tr->debug)pprint("tlsError %s\n", msg);
1865         lock(&tr->statelk);
1866         s = tr->state;
1867         tr->state = SError;
1868         if(s != SError){
1869                 strncpy(tr->err, msg, ERRMAX - 1);
1870                 tr->err[ERRMAX - 1] = '\0';
1871         }
1872         unlock(&tr->statelk);
1873         if(s != SError)
1874                 alertHand(tr, msg);
1875 }
1876
1877 static void
1878 tlsSetState(TlsRec *tr, int new, int old)
1879 {
1880         lock(&tr->statelk);
1881         if(tr->state & old)
1882                 tr->state = new;
1883         unlock(&tr->statelk);
1884 }
1885
1886 /* hand up a digest connection */
1887 static void
1888 tlshangup(TlsRec *tr)
1889 {
1890         Block *b;
1891
1892         qlock(&tr->in.io);
1893         for(b = tr->processed; b; b = tr->processed){
1894                 tr->processed = b->next;
1895                 freeb(b);
1896         }
1897         if(tr->unprocessed != nil){
1898                 freeb(tr->unprocessed);
1899                 tr->unprocessed = nil;
1900         }
1901         qunlock(&tr->in.io);
1902
1903         tlsSetState(tr, SClosed, ~0);
1904 }
1905
1906 static TlsRec*
1907 newtls(Chan *ch)
1908 {
1909         TlsRec **pp, **ep, **np;
1910         char **nmp;
1911         int t, newmax;
1912
1913         if(waserror()) {
1914                 unlock(&tdlock);
1915                 nexterror();
1916         }
1917         lock(&tdlock);
1918         ep = &tlsdevs[maxtlsdevs];
1919         for(pp = tlsdevs; pp < ep; pp++)
1920                 if(*pp == nil)
1921                         break;
1922         if(pp >= ep) {
1923                 if(maxtlsdevs >= MaxTlsDevs) {
1924                         unlock(&tdlock);
1925                         poperror();
1926                         return nil;
1927                 }
1928                 newmax = 2 * maxtlsdevs;
1929                 if(newmax > MaxTlsDevs)
1930                         newmax = MaxTlsDevs;
1931                 np = smalloc(sizeof(TlsRec*) * newmax);
1932                 memmove(np, tlsdevs, sizeof(TlsRec*) * maxtlsdevs);
1933                 tlsdevs = np;
1934                 pp = &tlsdevs[maxtlsdevs];
1935                 memset(pp, 0, sizeof(TlsRec*)*(newmax - maxtlsdevs));
1936
1937                 nmp = smalloc(sizeof *nmp * newmax);
1938                 memmove(nmp, trnames, sizeof *nmp * maxtlsdevs);
1939                 trnames = nmp;
1940
1941                 maxtlsdevs = newmax;
1942         }
1943         *pp = mktlsrec();
1944         if(pp - tlsdevs >= tdhiwat)
1945                 tdhiwat++;
1946         t = TYPE(ch->qid);
1947         if(t == Qclonus)
1948                 t = Qctl;
1949         ch->qid.path = QID(pp - tlsdevs, t);
1950         ch->qid.vers = 0;
1951         unlock(&tdlock);
1952         poperror();
1953         return *pp;
1954 }
1955
1956 static TlsRec *
1957 mktlsrec(void)
1958 {
1959         TlsRec *tr;
1960
1961         tr = mallocz(sizeof(*tr), 1);
1962         if(tr == nil)
1963                 error(Enomem);
1964         tr->state = SClosed;
1965         tr->ref = 1;
1966         kstrdup(&tr->user, up->user);
1967         tr->perm = 0660;
1968         return tr;
1969 }
1970
1971 static char*
1972 tlsstate(int s)
1973 {
1974         switch(s){
1975         case SHandshake:
1976                 return "Handshaking";
1977         case SOpen:
1978                 return "Established";
1979         case SRClose:
1980                 return "RemoteClosed";
1981         case SLClose:
1982                 return "LocalClosed";
1983         case SAlert:
1984                 return "Alerting";
1985         case SError:
1986                 return "Errored";
1987         case SClosed:
1988                 return "Closed";
1989         }
1990         return "Unknown";
1991 }
1992
1993 static void
1994 freeSec(Secret *s)
1995 {
1996         if(s != nil){
1997                 free(s->enckey);
1998                 free(s);
1999         }
2000 }
2001
2002 static int
2003 noenc(Secret *, uchar *, int n)
2004 {
2005         return n;
2006 }
2007
2008 static int
2009 rc4enc(Secret *sec, uchar *buf, int n)
2010 {
2011         rc4(sec->enckey, buf, n);
2012         return n;
2013 }
2014
2015 static int
2016 tlsunpad(uchar *buf, int n, int block)
2017 {
2018         int pad, nn;
2019
2020         pad = buf[n - 1];
2021         nn = n - 1 - pad;
2022         if(nn <= 0 || n % block)
2023                 return -1;
2024         while(--n > nn)
2025                 if(pad != buf[n - 1])
2026                         return -1;
2027         return nn;
2028 }
2029
2030 static int
2031 sslunpad(uchar *buf, int n, int block)
2032 {
2033         int pad, nn;
2034
2035         pad = buf[n - 1];
2036         nn = n - 1 - pad;
2037         if(nn <= 0 || n % block)
2038                 return -1;
2039         return nn;
2040 }
2041
2042 static int
2043 blockpad(uchar *buf, int n, int block)
2044 {
2045         int pad, nn;
2046
2047         nn = n + block;
2048         nn -= nn % block;
2049         pad = nn - (n + 1);
2050         while(n < nn)
2051                 buf[n++] = pad;
2052         return nn;
2053 }
2054                 
2055 static int
2056 des3enc(Secret *sec, uchar *buf, int n)
2057 {
2058         n = blockpad(buf, n, 8);
2059         des3CBCencrypt(buf, n, sec->enckey);
2060         return n;
2061 }
2062
2063 static int
2064 des3dec(Secret *sec, uchar *buf, int n)
2065 {
2066         des3CBCdecrypt(buf, n, sec->enckey);
2067         return (*sec->unpad)(buf, n, 8);
2068 }
2069
2070 static int
2071 aesenc(Secret *sec, uchar *buf, int n)
2072 {
2073         n = blockpad(buf, n, 16);
2074         aesCBCencrypt(buf, n, sec->enckey);
2075         return n;
2076 }
2077
2078 static int
2079 aesdec(Secret *sec, uchar *buf, int n)
2080 {
2081         aesCBCdecrypt(buf, n, sec->enckey);
2082         return (*sec->unpad)(buf, n, 16);
2083 }
2084
2085 static DigestState*
2086 nomac(uchar *, ulong, uchar *, ulong, uchar *, DigestState *)
2087 {
2088         return nil;
2089 }
2090
2091 /*
2092  * sslmac: mac calculations for ssl 3.0 only; tls 1.0 uses the standard hmac.
2093  */
2094 static DigestState*
2095 sslmac_x(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s,
2096         DigestState*(*x)(uchar*, ulong, uchar*, DigestState*), int xlen, int padlen)
2097 {
2098         int i;
2099         uchar pad[48], innerdigest[20];
2100
2101         if(xlen > sizeof(innerdigest)
2102         || padlen > sizeof(pad))
2103                 return nil;
2104
2105         if(klen>64)
2106                 return nil;
2107
2108         /* first time through */
2109         if(s == nil){
2110                 for(i=0; i<padlen; i++)
2111                         pad[i] = 0x36;
2112                 s = (*x)(key, klen, nil, nil);
2113                 s = (*x)(pad, padlen, nil, s);
2114                 if(s == nil)
2115                         return nil;
2116         }
2117
2118         s = (*x)(p, len, nil, s);
2119         if(digest == nil)
2120                 return s;
2121
2122         /* last time through */
2123         for(i=0; i<padlen; i++)
2124                 pad[i] = 0x5c;
2125         (*x)(nil, 0, innerdigest, s);
2126         s = (*x)(key, klen, nil, nil);
2127         s = (*x)(pad, padlen, nil, s);
2128         (*x)(innerdigest, xlen, digest, s);
2129         return nil;
2130 }
2131
2132 static DigestState*
2133 sslmac_sha1(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s)
2134 {
2135         return sslmac_x(p, len, key, klen, digest, s, sha1, SHA1dlen, 40);
2136 }
2137
2138 static DigestState*
2139 sslmac_md5(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s)
2140 {
2141         return sslmac_x(p, len, key, klen, digest, s, md5, MD5dlen, 48);
2142 }
2143
2144 static void
2145 sslPackMac(Secret *sec, uchar *mackey, uchar *seq, uchar *header, uchar *body, int len, uchar *mac)
2146 {
2147         DigestState *s;
2148         uchar buf[11];
2149
2150         memmove(buf, seq, 8);
2151         buf[8] = header[0];
2152         buf[9] = header[3];
2153         buf[10] = header[4];
2154
2155         s = (*sec->mac)(buf, 11, mackey, sec->maclen, 0, 0);
2156         (*sec->mac)(body, len, mackey, sec->maclen, mac, s);
2157 }
2158
2159 static void
2160 tlsPackMac(Secret *sec, uchar *mackey, uchar *seq, uchar *header, uchar *body, int len, uchar *mac)
2161 {
2162         DigestState *s;
2163         uchar buf[13];
2164
2165         memmove(buf, seq, 8);
2166         memmove(&buf[8], header, 5);
2167
2168         s = (*sec->mac)(buf, 13, mackey, sec->maclen, 0, 0);
2169         (*sec->mac)(body, len, mackey, sec->maclen, mac, s);
2170 }
2171
2172 static void
2173 put32(uchar *p, u32int x)
2174 {
2175         p[0] = x>>24;
2176         p[1] = x>>16;
2177         p[2] = x>>8;
2178         p[3] = x;
2179 }
2180
2181 static void
2182 put64(uchar *p, vlong x)
2183 {
2184         put32(p, (u32int)(x >> 32));
2185         put32(p+4, (u32int)x);
2186 }
2187
2188 static void
2189 put24(uchar *p, int x)
2190 {
2191         p[0] = x>>16;
2192         p[1] = x>>8;
2193         p[2] = x;
2194 }
2195
2196 static void
2197 put16(uchar *p, int x)
2198 {
2199         p[0] = x>>8;
2200         p[1] = x;
2201 }
2202
2203 static u32int
2204 get32(uchar *p)
2205 {
2206         return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2207 }
2208
2209 static int
2210 get16(uchar *p)
2211 {
2212         return (p[0]<<8)|p[1];
2213 }
2214
2215 static char *charmap = "0123456789abcdef";
2216
2217 static void
2218 pdump(int len, void *a, char *tag)
2219 {
2220         uchar *p;
2221         int i;
2222         char buf[65+32];
2223         char *q;
2224
2225         p = a;
2226         strcpy(buf, tag);
2227         while(len > 0){
2228                 q = buf + strlen(tag);
2229                 for(i = 0; len > 0 && i < 32; i++){
2230                         if(*p >= ' ' && *p < 0x7f){
2231                                 *q++ = ' ';
2232                                 *q++ = *p;
2233                         } else {
2234                                 *q++ = charmap[*p>>4];
2235                                 *q++ = charmap[*p & 0xf];
2236                         }
2237                         len--;
2238                         p++;
2239                 }
2240                 *q = 0;
2241
2242                 if(len > 0)
2243                         pprint("%s...\n", buf);
2244                 else
2245                         pprint("%s\n", buf);
2246         }
2247 }