]> git.lizzy.rs Git - plan9front.git/blob - sys/src/9/port/devtls.c
devtls: allocate cipher states in secret memory
[plan9front.git] / sys / src / 9 / port / devtls.c
1 /*
2  *  devtls - record layer for transport layer security 1.2 and secure sockets layer 3.0
3  */
4 #include        "u.h"
5 #include        "../port/lib.h"
6 #include        "mem.h"
7 #include        "dat.h"
8 #include        "fns.h"
9 #include        "../port/error.h"
10
11 #include        <libsec.h>
12
13 typedef struct OneWay   OneWay;
14 typedef struct Secret   Secret;
15 typedef struct TlsRec   TlsRec;
16 typedef struct TlsErrs  TlsErrs;
17
18 enum {
19         Statlen=        1024,           /* max. length of status or stats message */
20         /* buffer limits */
21         MaxRecLen       = 1<<14,        /* max payload length of a record layer message */
22         MaxCipherRecLen = MaxRecLen + 2048,
23         RecHdrLen       = 5,
24         MaxMacLen       = SHA2_256dlen,
25
26         /* protocol versions we can accept */
27         SSL3Version     = 0x0300,
28         TLS10Version    = 0x0301,
29         TLS11Version    = 0x0302,
30         TLS12Version    = 0x0303,
31         MinProtoVersion = 0x0300,       /* limits on version we accept */
32         MaxProtoVersion = 0x03ff,
33
34         /* connection states */
35         SHandshake      = 1 << 0,       /* doing handshake */
36         SOpen           = 1 << 1,       /* application data can be sent */
37         SRClose         = 1 << 2,       /* remote side has closed down */
38         SLClose         = 1 << 3,       /* sent a close notify alert */
39         SAlert          = 1 << 5,       /* sending or sent a fatal alert */
40         SError          = 1 << 6,       /* some sort of error has occured */
41         SClosed         = 1 << 7,       /* it is all over */
42
43         /* record types */
44         RChangeCipherSpec = 20,
45         RAlert,
46         RHandshake,
47         RApplication,
48
49         SSL2ClientHello = 1,
50         HSSL2ClientHello = 9,  /* local convention;  see tlshand.c */
51
52         /* alerts */
53         ECloseNotify                    = 0,
54         EUnexpectedMessage      = 10,
55         EBadRecordMac           = 20,
56         EDecryptionFailed               = 21,
57         ERecordOverflow                 = 22,
58         EDecompressionFailure   = 30,
59         EHandshakeFailure               = 40,
60         ENoCertificate                  = 41,
61         EBadCertificate                 = 42,
62         EUnsupportedCertificate         = 43,
63         ECertificateRevoked             = 44,
64         ECertificateExpired             = 45,
65         ECertificateUnknown     = 46,
66         EIllegalParameter               = 47,
67         EUnknownCa                      = 48,
68         EAccessDenied           = 49,
69         EDecodeError                    = 50,
70         EDecryptError                   = 51,
71         EExportRestriction              = 60,
72         EProtocolVersion                = 70,
73         EInsufficientSecurity   = 71,
74         EInternalError                  = 80,
75         EUserCanceled                   = 90,
76         ENoRenegotiation                = 100,
77         EUnrecognizedName               = 112,
78
79         EMAX = 256
80 };
81
82 struct Secret
83 {
84         char            *encalg;        /* name of encryption alg */
85         char            *hashalg;       /* name of hash alg */
86
87         int             (*aead_enc)(Secret*, uchar*, int, uchar*, uchar*, int);
88         int             (*aead_dec)(Secret*, uchar*, int, uchar*, uchar*, int);
89
90         int             (*enc)(Secret*, uchar*, int);
91         int             (*dec)(Secret*, uchar*, int);
92         int             (*unpad)(uchar*, int, int);
93         DigestState*    (*mac)(uchar*, ulong, uchar*, ulong, uchar*, DigestState*);
94
95         int             block;          /* encryption block len, 0 if none */
96         int             maclen;         /* # bytes of record mac / authentication tag */
97         int             recivlen;       /* # bytes of record iv for AEAD ciphers */
98         void            *enckey;
99         uchar           mackey[MaxMacLen];
100 };
101
102 struct OneWay
103 {
104         QLock           io;             /* locks io access */
105         QLock           seclock;        /* locks secret paramaters */
106         u64int          seq;
107         Secret          *sec;           /* cipher in use */
108         Secret          *new;           /* cipher waiting for enable */
109 };
110
111 struct TlsRec
112 {
113         Chan    *c;                             /* io channel */
114         int             ref;                            /* serialized by tdlock for atomic destroy */
115         int             version;                        /* version of the protocol we are speaking */
116         char            verset;                 /* version has been set */
117         char            opened;                 /* opened command every issued? */
118         char            err[ERRMAX];            /* error message to return to handshake requests */
119         vlong   handin;                 /* bytes communicated by the record layer */
120         vlong   handout;
121         vlong   datain;
122         vlong   dataout;
123
124         Lock            statelk;
125         int             state;
126         int             debug;
127
128         /*
129          * function to genrate authenticated data blob for different
130          * protocol versions
131          */
132         int             (*packAAD)(u64int, uchar*, uchar*);
133
134         /* input side -- protected by in.io */
135         OneWay          in;
136         Block           *processed;     /* next bunch of application data */
137         Block           *unprocessed;   /* data read from c but not parsed into records */
138
139         /* handshake queue */
140         Lock            hqlock;                 /* protects hqref, alloc & free of handq, hprocessed */
141         int             hqref;
142         Queue           *handq;         /* queue of handshake messages */
143         Block           *hprocessed;    /* remainder of last block read from handq */
144         QLock           hqread;         /* protects reads for hprocessed, handq */
145
146         /* output side */
147         OneWay          out;
148
149         /* protections */
150         char            *user;
151         int             perm;
152 };
153
154 struct TlsErrs{
155         int     err;
156         int     sslerr;
157         int     tlserr;
158         int     fatal;
159         char    *msg;
160 };
161
162 static TlsErrs tlserrs[] = {
163         {ECloseNotify,                  ECloseNotify,                   ECloseNotify,                   0,      "close notify"},
164         {EUnexpectedMessage,    EUnexpectedMessage,     EUnexpectedMessage,     1, "unexpected message"},
165         {EBadRecordMac,         EBadRecordMac,          EBadRecordMac,          1, "bad record mac"},
166         {EDecryptionFailed,             EIllegalParameter,              EDecryptionFailed,              1, "decryption failed"},
167         {ERecordOverflow,               EIllegalParameter,              ERecordOverflow,                1, "record too long"},
168         {EDecompressionFailure, EDecompressionFailure,  EDecompressionFailure,  1, "decompression failed"},
169         {EHandshakeFailure,             EHandshakeFailure,              EHandshakeFailure,              1, "could not negotiate acceptable security parameters"},
170         {ENoCertificate,                ENoCertificate,                 ECertificateUnknown,    1, "no appropriate certificate available"},
171         {EBadCertificate,               EBadCertificate,                EBadCertificate,                1, "corrupted or invalid certificate"},
172         {EUnsupportedCertificate,       EUnsupportedCertificate,        EUnsupportedCertificate,        1, "unsupported certificate type"},
173         {ECertificateRevoked,   ECertificateRevoked,            ECertificateRevoked,            1, "revoked certificate"},
174         {ECertificateExpired,           ECertificateExpired,            ECertificateExpired,            1, "expired certificate"},
175         {ECertificateUnknown,   ECertificateUnknown,    ECertificateUnknown,    1, "unacceptable certificate"},
176         {EIllegalParameter,             EIllegalParameter,              EIllegalParameter,              1, "illegal parameter"},
177         {EUnknownCa,                    EHandshakeFailure,              EUnknownCa,                     1, "unknown certificate authority"},
178         {EAccessDenied,         EHandshakeFailure,              EAccessDenied,          1, "access denied"},
179         {EDecodeError,                  EIllegalParameter,              EDecodeError,                   1, "error decoding message"},
180         {EDecryptError,                 EIllegalParameter,              EDecryptError,                  1, "error decrypting message"},
181         {EExportRestriction,            EHandshakeFailure,              EExportRestriction,             1, "export restriction violated"},
182         {EProtocolVersion,              EIllegalParameter,              EProtocolVersion,               1, "protocol version not supported"},
183         {EInsufficientSecurity, EHandshakeFailure,              EInsufficientSecurity,  1, "stronger security routines required"},
184         {EInternalError,                        EHandshakeFailure,              EInternalError,                 1, "internal error"},
185         {EUserCanceled,         ECloseNotify,                   EUserCanceled,                  0, "handshake canceled by user"},
186         {ENoRenegotiation,              EUnexpectedMessage,     ENoRenegotiation,               0, "no renegotiation"},
187 };
188
189 enum
190 {
191         /* max. open tls connections */
192         MaxTlsDevs      = 1024
193 };
194
195 static  Lock    tdlock;
196 static  int     tdhiwat;
197 static  int     maxtlsdevs = 128;
198 static  TlsRec  **tlsdevs;
199 static  char    **trnames;
200 static  char    *encalgs;
201 static  char    *hashalgs;
202
203 enum{
204         Qtopdir         = 1,    /* top level directory */
205         Qprotodir,
206         Qclonus,
207         Qencalgs,
208         Qhashalgs,
209         Qconvdir,               /* directory for a conversation */
210         Qdata,
211         Qctl,
212         Qhand,
213         Qstatus,
214         Qstats,
215 };
216
217 #define TYPE(x)         ((x).path & 0xf)
218 #define CONV(x)         (((x).path >> 5)&(MaxTlsDevs-1))
219 #define QID(c, y)       (((c)<<5) | (y))
220
221 static void     checkstate(TlsRec *, int, int);
222 static void     ensure(TlsRec*, Block**, int);
223 static void     consume(Block**, uchar*, int);
224 static Chan*    buftochan(char*);
225 static void     tlshangup(TlsRec*);
226 static void     tlsError(TlsRec*, char *);
227 static void     alertHand(TlsRec*, char *);
228 static TlsRec   *newtls(Chan *c);
229 static TlsRec   *mktlsrec(void);
230 static DigestState*sslmac_md5(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s);
231 static DigestState*sslmac_sha1(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s);
232 static DigestState*nomac(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s);
233 static int      sslPackAAD(u64int, uchar*, uchar*);
234 static int      tlsPackAAD(u64int, uchar*, uchar*);
235 static void     packMac(Secret*, uchar*, int, uchar*, int, uchar*);
236 static void     put64(uchar *p, u64int);
237 static void     put32(uchar *p, u32int);
238 static void     put24(uchar *p, int);
239 static void     put16(uchar *p, int);
240 static u32int   get32(uchar *p);
241 static int      get16(uchar *p);
242 static void     tlsSetState(TlsRec *tr, int new, int old);
243 static void     rcvAlert(TlsRec *tr, int err);
244 static void     sendAlert(TlsRec *tr, int err);
245 static void     rcvError(TlsRec *tr, int err, char *msg, ...);
246 static int      rc4enc(Secret *sec, uchar *buf, int n);
247 static int      des3enc(Secret *sec, uchar *buf, int n);
248 static int      des3dec(Secret *sec, uchar *buf, int n);
249 static int      aesenc(Secret *sec, uchar *buf, int n);
250 static int      aesdec(Secret *sec, uchar *buf, int n);
251 static int      ccpoly_aead_enc(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len);
252 static int      ccpoly_aead_dec(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len);
253 static int      aesgcm_aead_enc(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len);
254 static int      aesgcm_aead_dec(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len);
255 static int      noenc(Secret *sec, uchar *buf, int n);
256 static int      sslunpad(uchar *buf, int n, int block);
257 static int      tlsunpad(uchar *buf, int n, int block);
258 static void     freeSec(Secret *sec);
259 static char     *tlsstate(int s);
260 static void     pdump(int, void*, char*);
261
262 #pragma varargck        argpos  rcvError        3
263
264 static char *tlsnames[] = {
265 [Qclonus]               "clone",
266 [Qencalgs]      "encalgs",
267 [Qhashalgs]     "hashalgs",
268 [Qdata]         "data",
269 [Qctl]          "ctl",
270 [Qhand]         "hand",
271 [Qstatus]               "status",
272 [Qstats]                "stats",
273 };
274
275 static int convdir[] = { Qctl, Qdata, Qhand, Qstatus, Qstats };
276
277 static int
278 tlsgen(Chan *c, char*, Dirtab *, int, int s, Dir *dp)
279 {
280         Qid q;
281         TlsRec *tr;
282         char *name, *nm;
283         int perm, t;
284
285         q.vers = 0;
286         q.type = QTFILE;
287
288         t = TYPE(c->qid);
289         switch(t) {
290         case Qtopdir:
291                 if(s == DEVDOTDOT){
292                         q.path = QID(0, Qtopdir);
293                         q.type = QTDIR;
294                         devdir(c, q, "#a", 0, eve, 0555, dp);
295                         return 1;
296                 }
297                 if(s > 0)
298                         return -1;
299                 q.path = QID(0, Qprotodir);
300                 q.type = QTDIR;
301                 devdir(c, q, "tls", 0, eve, 0555, dp);
302                 return 1;
303         case Qprotodir:
304                 if(s == DEVDOTDOT){
305                         q.path = QID(0, Qtopdir);
306                         q.type = QTDIR;
307                         devdir(c, q, ".", 0, eve, 0555, dp);
308                         return 1;
309                 }
310                 if(s < 3){
311                         switch(s) {
312                         default:
313                                 return -1;
314                         case 0:
315                                 q.path = QID(0, Qclonus);
316                                 break;
317                         case 1:
318                                 q.path = QID(0, Qencalgs);
319                                 break;
320                         case 2:
321                                 q.path = QID(0, Qhashalgs);
322                                 break;
323                         }
324                         perm = 0444;
325                         if(TYPE(q) == Qclonus)
326                                 perm = 0555;
327                         devdir(c, q, tlsnames[TYPE(q)], 0, eve, perm, dp);
328                         return 1;
329                 }
330                 s -= 3;
331                 if(s >= tdhiwat)
332                         return -1;
333                 q.path = QID(s, Qconvdir);
334                 q.type = QTDIR;
335                 lock(&tdlock);
336                 tr = tlsdevs[s];
337                 if(tr != nil)
338                         nm = tr->user;
339                 else
340                         nm = eve;
341                 if((name = trnames[s]) == nil){
342                         name = trnames[s] = smalloc(16);
343                         sprint(name, "%d", s);
344                 }
345                 devdir(c, q, name, 0, nm, 0555, dp);
346                 unlock(&tdlock);
347                 return 1;
348         case Qconvdir:
349                 if(s == DEVDOTDOT){
350                         q.path = QID(0, Qprotodir);
351                         q.type = QTDIR;
352                         devdir(c, q, "tls", 0, eve, 0555, dp);
353                         return 1;
354                 }
355                 if(s < 0 || s >= nelem(convdir))
356                         return -1;
357                 lock(&tdlock);
358                 tr = tlsdevs[CONV(c->qid)];
359                 if(tr != nil){
360                         nm = tr->user;
361                         perm = tr->perm;
362                 }else{
363                         perm = 0;
364                         nm = eve;
365                 }
366                 t = convdir[s];
367                 if(t == Qstatus || t == Qstats)
368                         perm &= 0444;
369                 q.path = QID(CONV(c->qid), t);
370                 devdir(c, q, tlsnames[t], 0, nm, perm, dp);
371                 unlock(&tdlock);
372                 return 1;
373         case Qclonus:
374         case Qencalgs:
375         case Qhashalgs:
376                 perm = 0444;
377                 if(t == Qclonus)
378                         perm = 0555;
379                 devdir(c, c->qid, tlsnames[t], 0, eve, perm, dp);
380                 return 1;
381         default:
382                 lock(&tdlock);
383                 tr = tlsdevs[CONV(c->qid)];
384                 if(tr != nil){
385                         nm = tr->user;
386                         perm = tr->perm;
387                 }else{
388                         perm = 0;
389                         nm = eve;
390                 }
391                 if(t == Qstatus || t == Qstats)
392                         perm &= 0444;
393                 devdir(c, c->qid, tlsnames[t], 0, nm, perm, dp);
394                 unlock(&tdlock);
395                 return 1;
396         }
397 }
398
399 static Chan*
400 tlsattach(char *spec)
401 {
402         Chan *c;
403
404         c = devattach('a', spec);
405         c->qid.path = QID(0, Qtopdir);
406         c->qid.type = QTDIR;
407         c->qid.vers = 0;
408         return c;
409 }
410
411 static Walkqid*
412 tlswalk(Chan *c, Chan *nc, char **name, int nname)
413 {
414         return devwalk(c, nc, name, nname, nil, 0, tlsgen);
415 }
416
417 static int
418 tlsstat(Chan *c, uchar *db, int n)
419 {
420         return devstat(c, db, n, nil, 0, tlsgen);
421 }
422
423 static Chan*
424 tlsopen(Chan *c, int omode)
425 {
426         TlsRec *tr, **pp;
427         int t;
428
429         t = TYPE(c->qid);
430         switch(t) {
431         default:
432                 panic("tlsopen");
433         case Qtopdir:
434         case Qprotodir:
435         case Qconvdir:
436                 if(omode != OREAD)
437                         error(Eperm);
438                 break;
439         case Qclonus:
440                 tr = newtls(c);
441                 if(tr == nil)
442                         error(Enodev);
443                 break;
444         case Qctl:
445         case Qdata:
446         case Qhand:
447         case Qstatus:
448         case Qstats:
449                 if((t == Qstatus || t == Qstats) && omode != OREAD)
450                         error(Eperm);
451                 if(waserror()) {
452                         unlock(&tdlock);
453                         nexterror();
454                 }
455                 lock(&tdlock);
456                 pp = &tlsdevs[CONV(c->qid)];
457                 tr = *pp;
458                 if(tr == nil)
459                         error("must open connection using clone");
460                 devpermcheck(tr->user, tr->perm, omode);
461                 if(t == Qhand){
462                         if(waserror()){
463                                 unlock(&tr->hqlock);
464                                 nexterror();
465                         }
466                         lock(&tr->hqlock);
467                         if(tr->handq != nil)
468                                 error(Einuse);
469                         tr->handq = qopen(2 * MaxCipherRecLen, 0, nil, nil);
470                         if(tr->handq == nil)
471                                 error("cannot allocate handshake queue");
472                         tr->hqref = 1;
473                         unlock(&tr->hqlock);
474                         poperror();
475                 }
476                 tr->ref++;
477                 unlock(&tdlock);
478                 poperror();
479                 break;
480         case Qencalgs:
481         case Qhashalgs:
482                 if(omode != OREAD)
483                         error(Eperm);
484                 break;
485         }
486         c->mode = openmode(omode);
487         c->flag |= COPEN;
488         c->offset = 0;
489         c->iounit = MaxRecLen;
490         return c;
491 }
492
493 static int
494 tlswstat(Chan *c, uchar *dp, int n)
495 {
496         Dir *d;
497         TlsRec *tr;
498         int rv;
499
500         d = nil;
501         if(waserror()){
502                 free(d);
503                 unlock(&tdlock);
504                 nexterror();
505         }
506
507         lock(&tdlock);
508         tr = tlsdevs[CONV(c->qid)];
509         if(tr == nil)
510                 error(Ebadusefd);
511         if(strcmp(tr->user, up->user) != 0)
512                 error(Eperm);
513
514         d = smalloc(n + sizeof *d);
515         rv = convM2D(dp, n, &d[0], (char*) &d[1]);
516         if(rv == 0)
517                 error(Eshortstat);
518         if(!emptystr(d->uid))
519                 kstrdup(&tr->user, d->uid);
520         if(d->mode != ~0UL)
521                 tr->perm = d->mode;
522
523         free(d);
524         poperror();
525         unlock(&tdlock);
526
527         return rv;
528 }
529
530 static void
531 dechandq(TlsRec *tr)
532 {
533         lock(&tr->hqlock);
534         if(--tr->hqref == 0){
535                 if(tr->handq != nil){
536                         qfree(tr->handq);
537                         tr->handq = nil;
538                 }
539                 if(tr->hprocessed != nil){
540                         freeb(tr->hprocessed);
541                         tr->hprocessed = nil;
542                 }
543         }
544         unlock(&tr->hqlock);
545 }
546
547 static void
548 tlsclose(Chan *c)
549 {
550         TlsRec *tr;
551         int t;
552
553         t = TYPE(c->qid);
554         switch(t) {
555         case Qctl:
556         case Qdata:
557         case Qhand:
558         case Qstatus:
559         case Qstats:
560                 if((c->flag & COPEN) == 0)
561                         break;
562
563                 tr = tlsdevs[CONV(c->qid)];
564                 if(tr == nil)
565                         break;
566
567                 if(t == Qhand)
568                         dechandq(tr);
569
570                 lock(&tdlock);
571                 if(--tr->ref > 0) {
572                         unlock(&tdlock);
573                         return;
574                 }
575                 tlsdevs[CONV(c->qid)] = nil;
576                 unlock(&tdlock);
577
578                 if(tr->c != nil && !waserror()){
579                         checkstate(tr, 0, SOpen|SHandshake|SRClose);
580                         sendAlert(tr, ECloseNotify);
581                         poperror();
582                 }
583                 tlshangup(tr);
584                 if(tr->c != nil)
585                         cclose(tr->c);
586                 freeSec(tr->in.sec);
587                 freeSec(tr->in.new);
588                 freeSec(tr->out.sec);
589                 freeSec(tr->out.new);
590                 free(tr->user);
591                 free(tr);
592                 break;
593         }
594 }
595
596 /*
597  *  make sure we have at least 'n' bytes in list 'l'
598  */
599 static void
600 ensure(TlsRec *s, Block **l, int n)
601 {
602         int sofar, i;
603         Block *b, *bl;
604
605         sofar = 0;
606         for(b = *l; b; b = b->next){
607                 sofar += BLEN(b);
608                 if(sofar >= n)
609                         return;
610                 l = &b->next;
611         }
612
613         while(sofar < n){
614                 bl = devtab[s->c->type]->bread(s->c, MaxCipherRecLen + RecHdrLen, 0);
615                 if(bl == 0)
616                         error(Ehungup);
617                 *l = bl;
618                 i = 0;
619                 for(b = bl; b; b = b->next){
620                         i += BLEN(b);
621                         l = &b->next;
622                 }
623                 if(i == 0)
624                         error(Ehungup);
625                 sofar += i;
626         }
627 if(s->debug) pprint("ensure read %d\n", sofar);
628 }
629
630 /*
631  *  copy 'n' bytes from 'l' into 'p' and free
632  *  the bytes in 'l'
633  */
634 static void
635 consume(Block **l, uchar *p, int n)
636 {
637         Block *b;
638         int i;
639
640         for(; *l && n > 0; n -= i){
641                 b = *l;
642                 i = BLEN(b);
643                 if(i > n)
644                         i = n;
645                 memmove(p, b->rp, i);
646                 b->rp += i;
647                 p += i;
648                 if(BLEN(b) < 0)
649                         panic("consume");
650                 if(BLEN(b))
651                         break;
652                 *l = b->next;
653                 freeb(b);
654         }
655 }
656
657 /*
658  *  give back n bytes
659  */
660 static void
661 regurgitate(TlsRec *s, uchar *p, int n)
662 {
663         Block *b;
664
665         if(n <= 0)
666                 return;
667         b = s->unprocessed;
668         if(s->unprocessed == nil || b->rp - b->base < n) {
669                 b = allocb(n);
670                 memmove(b->wp, p, n);
671                 b->wp += n;
672                 b->next = s->unprocessed;
673                 s->unprocessed = b;
674         } else {
675                 b->rp -= n;
676                 memmove(b->rp, p, n);
677         }
678 }
679
680 /*
681  *  remove at most n bytes from the queue
682  */
683 static Block*
684 qgrab(Block **l, int n)
685 {
686         Block *bb, *b;
687         int i;
688
689         b = *l;
690         if(BLEN(b) == n){
691                 *l = b->next;
692                 b->next = nil;
693                 return b;
694         }
695
696         i = 0;
697         for(bb = b; bb != nil && i < n; bb = bb->next)
698                 i += BLEN(bb);
699         if(i > n)
700                 i = n;
701
702         bb = allocb(i);
703         consume(l, bb->wp, i);
704         bb->wp += i;
705         return bb;
706 }
707
708 static void
709 tlsclosed(TlsRec *tr, int new)
710 {
711         lock(&tr->statelk);
712         if(tr->state == SOpen || tr->state == SHandshake)
713                 tr->state = new;
714         else if((new | tr->state) == (SRClose|SLClose))
715                 tr->state = SClosed;
716         unlock(&tr->statelk);
717         alertHand(tr, "close notify");
718 }
719
720 /*
721  *  read and process one tls record layer message
722  *  must be called with tr->in.io held
723  *  We can't let Eintrs lose data, since doing so will get
724  *  us out of sync with the sender and break the reliablity
725  *  of the channel.  Eintr only happens during the reads in
726  *  consume.  Therefore we put back any bytes consumed before
727  *  the last call to ensure.
728  */
729 static void
730 tlsrecread(TlsRec *tr)
731 {
732         OneWay *volatile in;
733         Block *volatile b;
734         uchar *p, aad[8+RecHdrLen], header[RecHdrLen], hmac[MaxMacLen];
735         int volatile nconsumed;
736         int len, type, ver, unpad_len, aadlen, ivlen;
737         Secret *sec;
738
739         nconsumed = 0;
740         if(waserror()){
741                 if(strcmp(up->errstr, Eintr) == 0 && !waserror()){
742                         regurgitate(tr, header, nconsumed);
743                         poperror();
744                 }else
745                         tlsError(tr, "channel error");
746                 nexterror();
747         }
748         ensure(tr, &tr->unprocessed, RecHdrLen);
749         consume(&tr->unprocessed, header, RecHdrLen);
750 if(tr->debug)pprint("consumed %d header\n", RecHdrLen);
751         nconsumed = RecHdrLen;
752
753         if((tr->handin == 0) && (header[0] & 0x80)){
754                 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
755                         This is sent by some clients that we must interoperate
756                         with, such as Java's JSSE and Microsoft's Internet Explorer. */
757                 len = (get16(header) & ~0x8000) - 3;
758                 type = header[2];
759                 ver = get16(header + 3);
760                 if(type != SSL2ClientHello || len < 22)
761                         rcvError(tr, EProtocolVersion, "invalid initial SSL2-like message");
762         }else{  /* normal SSL3 record format */
763                 type = header[0];
764                 ver = get16(header+1);
765                 len = get16(header+3);
766         }
767         if(ver != tr->version && (tr->verset || ver < MinProtoVersion || ver > MaxProtoVersion))
768                 rcvError(tr, EProtocolVersion, "devtls expected ver=%x%s, saw (len=%d) type=%x ver=%x '%.12s'",
769                         tr->version, tr->verset?"/set":"", len, type, ver, (char*)header);
770         if(len > MaxCipherRecLen || len < 0)
771                 rcvError(tr, ERecordOverflow, "record message too long %d", len);
772         ensure(tr, &tr->unprocessed, len);
773         nconsumed = 0;
774         poperror();
775
776         /*
777          * If an Eintr happens after this, we'll get out of sync.
778          * Make sure nothing we call can sleep.
779          * Errors are ok, as they kill the connection.
780          * Luckily, allocb won't sleep, it'll just error out.
781          */
782         b = nil;
783         if(waserror()){
784                 if(b != nil)
785                         freeb(b);
786                 tlsError(tr, "channel error");
787                 nexterror();
788         }
789         b = qgrab(&tr->unprocessed, len);
790 if(tr->debug) pprint("consumed unprocessed %d\n", len);
791
792         in = &tr->in;
793         if(waserror()){
794                 qunlock(&in->seclock);
795                 nexterror();
796         }
797         qlock(&in->seclock);
798         p = b->rp;
799         sec = in->sec;
800         if(sec != nil) {
801                 /* to avoid Canvel-Hiltgen-Vaudenay-Vuagnoux attack, all errors here
802                         should look alike, including timing of the response. */
803                 if(sec->aead_dec != nil)
804                         unpad_len = len;
805                 else {
806                         unpad_len = (*sec->dec)(sec, p, len);
807 if(tr->debug) pprint("decrypted %d\n", unpad_len);
808 if(tr->debug) pdump(unpad_len, p, "decrypted:");
809                 }
810
811                 ivlen = sec->recivlen;
812                 if(tr->version >= TLS11Version){
813                         if(ivlen == 0)
814                                 ivlen = sec->block;
815                 }
816                 len -= ivlen;
817                 if(len < 0)
818                         rcvError(tr, EDecodeError, "runt record message");
819                 unpad_len -= ivlen;
820                 p += ivlen;
821
822                 if(unpad_len >= sec->maclen)
823                         len = unpad_len - sec->maclen;
824
825                 /* update length */
826                 put16(header+3, len);
827                 aadlen = (*tr->packAAD)(in->seq++, header, aad);
828                 if(sec->aead_dec != nil) {
829                         len = (*sec->aead_dec)(sec, aad, aadlen, p - ivlen, p, unpad_len);
830                         if(len < 0)
831                                 rcvError(tr, EBadRecordMac, "record mac mismatch");
832                 } else {
833                         packMac(sec, aad, aadlen, p, len, hmac);
834                         if(unpad_len < sec->maclen)
835                                 rcvError(tr, EBadRecordMac, "short record mac");
836                         if(tsmemcmp(hmac, p + len, sec->maclen) != 0)
837                                 rcvError(tr, EBadRecordMac, "record mac mismatch");
838                 }
839                 b->rp = p;
840                 b->wp = p+len;
841         }
842         qunlock(&in->seclock);
843         poperror();
844         if(len < 0)
845                 rcvError(tr, EDecodeError, "runt record message");
846
847         switch(type) {
848         default:
849                 rcvError(tr, EIllegalParameter, "invalid record message %#x", type);
850                 break;
851         case RChangeCipherSpec:
852                 if(len != 1 || p[0] != 1)
853                         rcvError(tr, EDecodeError, "invalid change cipher spec");
854                 qlock(&in->seclock);
855                 if(in->new == nil){
856                         qunlock(&in->seclock);
857                         rcvError(tr, EUnexpectedMessage, "unexpected change cipher spec");
858                 }
859                 freeSec(in->sec);
860                 in->sec = in->new;
861                 in->new = nil;
862                 in->seq = 0;
863                 qunlock(&in->seclock);
864                 break;
865         case RAlert:
866                 if(len != 2)
867                         rcvError(tr, EDecodeError, "invalid alert");
868                 if(p[0] == 2)
869                         rcvAlert(tr, p[1]);
870                 if(p[0] != 1)
871                         rcvError(tr, EIllegalParameter, "invalid alert fatal code");
872
873                 /*
874                  * propagate non-fatal alerts to handshaker
875                  */
876                 switch(p[1]){
877                 case ECloseNotify:
878                         tlsclosed(tr, SRClose);
879                         if(tr->opened)
880                                 error("tls hungup");
881                         error("close notify");
882                         break;
883                 case ENoRenegotiation:
884                         alertHand(tr, "no renegotiation");
885                         break;
886                 case EUserCanceled:
887                         alertHand(tr, "handshake canceled by user");
888                         break;
889                 case EUnrecognizedName:
890                         /* happens in response to SNI, can be ignored. */
891                         break;
892                 default:
893                         rcvError(tr, EIllegalParameter, "invalid alert code");
894                 }
895                 break;
896         case RHandshake:
897                 /*
898                  * don't worry about dropping the block
899                  * qbwrite always queues even if flow controlled and interrupted.
900                  *
901                  * if there isn't any handshaker, ignore the request,
902                  * but notify the other side we are doing so.
903                  */
904                 lock(&tr->hqlock);
905                 if(tr->handq != nil){
906                         tr->hqref++;
907                         unlock(&tr->hqlock);
908                         if(waserror()){
909                                 dechandq(tr);
910                                 nexterror();
911                         }
912                         b = padblock(b, 1);
913                         *b->rp = RHandshake;
914                         qbwrite(tr->handq, b);
915                         b = nil;
916                         poperror();
917                         dechandq(tr);
918                 }else{
919                         unlock(&tr->hqlock);
920                         if(tr->verset && tr->version != SSL3Version && !waserror()){
921                                 sendAlert(tr, ENoRenegotiation);
922                                 poperror();
923                         }
924                 }
925                 break;
926         case SSL2ClientHello:
927                 lock(&tr->hqlock);
928                 if(tr->handq != nil){
929                         tr->hqref++;
930                         unlock(&tr->hqlock);
931                         if(waserror()){
932                                 dechandq(tr);
933                                 nexterror();
934                         }
935                         /* Pass the SSL2 format data, so that the handshake code can compute
936                                 the correct checksums.  HSSL2ClientHello = HandshakeType 9 is
937                                 unused in RFC2246. */
938                         b = padblock(b, 8);
939                         b->rp[0] = RHandshake;
940                         b->rp[1] = HSSL2ClientHello;
941                         put24(&b->rp[2], len+3);
942                         b->rp[5] = SSL2ClientHello;
943                         put16(&b->rp[6], ver);
944                         qbwrite(tr->handq, b);
945                         b = nil;
946                         poperror();
947                         dechandq(tr);
948                 }else{
949                         unlock(&tr->hqlock);
950                         if(tr->verset && tr->version != SSL3Version && !waserror()){
951                                 sendAlert(tr, ENoRenegotiation);
952                                 poperror();
953                         }
954                 }
955                 break;
956         case RApplication:
957                 if(!tr->opened)
958                         rcvError(tr, EUnexpectedMessage, "application message received before handshake completed");
959                 if(BLEN(b) > 0){
960                         tr->processed = b;
961                         b = nil;
962                 }
963                 break;
964         }
965         if(b != nil)
966                 freeb(b);
967         poperror();
968 }
969
970 /*
971  * got a fatal alert message
972  */
973 static void
974 rcvAlert(TlsRec *tr, int err)
975 {
976         char *s;
977         int i;
978
979         s = "unknown error";
980         for(i=0; i < nelem(tlserrs); i++){
981                 if(tlserrs[i].err == err){
982                         s = tlserrs[i].msg;
983                         break;
984                 }
985         }
986 if(tr->debug) pprint("rcvAlert: %s\n", s);
987
988         tlsError(tr, s);
989         if(!tr->opened)
990                 error(s);
991         error("tls error");
992 }
993
994 /*
995  * found an error while decoding the input stream
996  */
997 static void
998 rcvError(TlsRec *tr, int err, char *fmt, ...)
999 {
1000         char msg[ERRMAX];
1001         va_list arg;
1002
1003         va_start(arg, fmt);
1004         vseprint(msg, msg+sizeof(msg), fmt, arg);
1005         va_end(arg);
1006 if(tr->debug) pprint("rcvError: %s\n", msg);
1007
1008         sendAlert(tr, err);
1009
1010         if(!tr->opened)
1011                 error(msg);
1012         error("tls error");
1013 }
1014
1015 /*
1016  * make sure the next hand operation returns with a 'msg' error
1017  */
1018 static void
1019 alertHand(TlsRec *tr, char *msg)
1020 {
1021         Block *b;
1022         int n;
1023
1024         lock(&tr->hqlock);
1025         if(tr->handq == nil){
1026                 unlock(&tr->hqlock);
1027                 return;
1028         }
1029         tr->hqref++;
1030         unlock(&tr->hqlock);
1031
1032         n = strlen(msg);
1033         if(waserror()){
1034                 dechandq(tr);
1035                 nexterror();
1036         }
1037         b = allocb(n + 2);
1038         *b->wp++ = RAlert;
1039         memmove(b->wp, msg, n + 1);
1040         b->wp += n + 1;
1041
1042         qbwrite(tr->handq, b);
1043
1044         poperror();
1045         dechandq(tr);
1046 }
1047
1048 static void
1049 checkstate(TlsRec *tr, int ishand, int ok)
1050 {
1051         int state;
1052
1053         lock(&tr->statelk);
1054         state = tr->state;
1055         unlock(&tr->statelk);
1056         if(state & ok)
1057                 return;
1058         switch(state){
1059         case SHandshake:
1060         case SOpen:
1061                 break;
1062         case SError:
1063         case SAlert:
1064                 if(ishand)
1065                         error(tr->err);
1066                 error("tls error");
1067         case SRClose:
1068         case SLClose:
1069         case SClosed:
1070                 error("tls hungup");
1071         }
1072         error("tls improperly configured");
1073 }
1074
1075 static Block*
1076 tlsbread(Chan *c, long n, ulong offset)
1077 {
1078         int ty;
1079         Block *b;
1080         TlsRec *volatile tr;
1081
1082         ty = TYPE(c->qid);
1083         switch(ty) {
1084         default:
1085                 return devbread(c, n, offset);
1086         case Qhand:
1087         case Qdata:
1088                 break;
1089         }
1090
1091         tr = tlsdevs[CONV(c->qid)];
1092         if(tr == nil)
1093                 panic("tlsbread");
1094
1095         if(waserror()){
1096                 qunlock(&tr->in.io);
1097                 nexterror();
1098         }
1099         qlock(&tr->in.io);
1100         if(ty == Qdata){
1101                 checkstate(tr, 0, SOpen);
1102                 while(tr->processed == nil)
1103                         tlsrecread(tr);
1104
1105                 /* return at most what was asked for */
1106                 b = qgrab(&tr->processed, n);
1107 if(tr->debug) pprint("consumed processed %zd\n", BLEN(b));
1108 if(tr->debug) pdump(BLEN(b), b->rp, "consumed:");
1109                 qunlock(&tr->in.io);
1110                 poperror();
1111                 tr->datain += BLEN(b);
1112         }else{
1113                 checkstate(tr, 1, SOpen|SHandshake|SLClose);
1114
1115                 /*
1116                  * it's ok to look at state without the lock
1117                  * since it only protects reading records,
1118                  * and we have that tr->in.io held.
1119                  */
1120                 while(!tr->opened && tr->hprocessed == nil && !qcanread(tr->handq))
1121                         tlsrecread(tr);
1122
1123                 qunlock(&tr->in.io);
1124                 poperror();
1125
1126                 if(waserror()){
1127                         qunlock(&tr->hqread);
1128                         nexterror();
1129                 }
1130                 qlock(&tr->hqread);
1131                 if(tr->hprocessed == nil){
1132                         b = qbread(tr->handq, MaxRecLen + 1);
1133                         if(*b->rp++ == RAlert){
1134                                 kstrcpy(up->errstr, (char*)b->rp, ERRMAX);
1135                                 freeb(b);
1136                                 nexterror();
1137                         }
1138                         tr->hprocessed = b;
1139                 }
1140                 b = qgrab(&tr->hprocessed, n);
1141                 poperror();
1142                 qunlock(&tr->hqread);
1143                 tr->handin += BLEN(b);
1144         }
1145
1146         return b;
1147 }
1148
1149 static long
1150 tlsread(Chan *c, void *a, long n, vlong off)
1151 {
1152         Block *volatile b;
1153         Block *nb;
1154         uchar *va;
1155         int i, ty;
1156         char *buf, *s, *e;
1157         ulong offset = off;
1158         TlsRec * tr;
1159
1160         if(c->qid.type & QTDIR)
1161                 return devdirread(c, a, n, 0, 0, tlsgen);
1162
1163         tr = tlsdevs[CONV(c->qid)];
1164         ty = TYPE(c->qid);
1165         switch(ty) {
1166         default:
1167                 error(Ebadusefd);
1168         case Qstatus:
1169                 buf = smalloc(Statlen);
1170                 qlock(&tr->in.seclock);
1171                 qlock(&tr->out.seclock);
1172                 s = buf;
1173                 e = buf + Statlen;
1174                 s = seprint(s, e, "State: %s\n", tlsstate(tr->state));
1175                 s = seprint(s, e, "Version: %#x\n", tr->version);
1176                 if(tr->in.sec != nil)
1177                         s = seprint(s, e, "EncIn: %s\nHashIn: %s\n", tr->in.sec->encalg, tr->in.sec->hashalg);
1178                 if(tr->in.new != nil)
1179                         s = seprint(s, e, "NewEncIn: %s\nNewHashIn: %s\n", tr->in.new->encalg, tr->in.new->hashalg);
1180                 if(tr->out.sec != nil)
1181                         s = seprint(s, e, "EncOut: %s\nHashOut: %s\n", tr->out.sec->encalg, tr->out.sec->hashalg);
1182                 if(tr->out.new != nil)
1183                         s = seprint(s, e, "NewEncOut: %s\nNewHashOut: %s\n", tr->out.new->encalg, tr->out.new->hashalg);
1184                 if(tr->c != nil)
1185                         seprint(s, e, "Chan: %s\n", chanpath(tr->c));
1186                 qunlock(&tr->in.seclock);
1187                 qunlock(&tr->out.seclock);
1188                 n = readstr(offset, a, n, buf);
1189                 free(buf);
1190                 return n;
1191         case Qstats:
1192                 buf = smalloc(Statlen);
1193                 s = buf;
1194                 e = buf + Statlen;
1195                 s = seprint(s, e, "DataIn: %lld\n", tr->datain);
1196                 s = seprint(s, e, "DataOut: %lld\n", tr->dataout);
1197                 s = seprint(s, e, "HandIn: %lld\n", tr->handin);
1198                 seprint(s, e, "HandOut: %lld\n", tr->handout);
1199                 n = readstr(offset, a, n, buf);
1200                 free(buf);
1201                 return n;
1202         case Qctl:
1203                 buf = smalloc(Statlen);
1204                 snprint(buf, Statlen, "%llud", CONV(c->qid));
1205                 n = readstr(offset, a, n, buf);
1206                 free(buf);
1207                 return n;
1208         case Qdata:
1209         case Qhand:
1210                 b = tlsbread(c, n, offset);
1211                 break;
1212         case Qencalgs:
1213                 return readstr(offset, a, n, encalgs);
1214         case Qhashalgs:
1215                 return readstr(offset, a, n, hashalgs);
1216         }
1217
1218         if(waserror()){
1219                 freeblist(b);
1220                 nexterror();
1221         }
1222
1223         n = 0;
1224         va = a;
1225         for(nb = b; nb; nb = nb->next){
1226                 i = BLEN(nb);
1227                 memmove(va+n, nb->rp, i);
1228                 n += i;
1229         }
1230
1231         freeblist(b);
1232         poperror();
1233
1234         return n;
1235 }
1236
1237 static void
1238 randfill(uchar *buf, int len)
1239 {
1240         while(len-- > 0)
1241                 *buf++ = nrand(256);
1242 }
1243
1244 /*
1245  *  write a block in tls records
1246  */
1247 static void
1248 tlsrecwrite(TlsRec *tr, int type, Block *b)
1249 {
1250         Block *volatile bb;
1251         Block *nb;
1252         uchar *p, aad[8+RecHdrLen];
1253         OneWay *volatile out;
1254         int n, ivlen, maclen, aadlen, pad, ok;
1255         Secret *sec;
1256
1257         out = &tr->out;
1258         bb = b;
1259         if(waserror()){
1260                 qunlock(&out->io);
1261                 if(bb != nil)
1262                         freeb(bb);
1263                 nexterror();
1264         }
1265         qlock(&out->io);
1266 if(tr->debug)pprint("send %zd\n", BLEN(b));
1267 if(tr->debug)pdump(BLEN(b), b->rp, "sent:");
1268
1269
1270         ok = SHandshake|SOpen|SRClose;
1271         if(type == RAlert)
1272                 ok |= SAlert;
1273         while(bb != nil){
1274                 checkstate(tr, type != RApplication, ok);
1275
1276                 /*
1277                  * get at most one maximal record's input,
1278                  * with padding on the front for header and
1279                  * back for mac and maximal block padding.
1280                  */
1281                 if(waserror()){
1282                         qunlock(&out->seclock);
1283                         nexterror();
1284                 }
1285                 qlock(&out->seclock);
1286                 maclen = 0;
1287                 pad = 0;
1288                 ivlen = 0;
1289                 sec = out->sec;
1290                 if(sec != nil){
1291                         maclen = sec->maclen;
1292                         pad = maclen + sec->block;
1293                         ivlen = sec->recivlen;
1294                         if(tr->version >= TLS11Version){
1295                                 if(ivlen == 0)
1296                                         ivlen = sec->block;
1297                         }
1298                 }
1299                 n = BLEN(bb);
1300                 if(n > MaxRecLen){
1301                         n = MaxRecLen;
1302                         nb = allocb(RecHdrLen + ivlen + n + pad);
1303                         memmove(nb->wp + RecHdrLen + ivlen, bb->rp, n);
1304                         bb->rp += n;
1305                 }else{
1306                         /*
1307                          * carefully reuse bb so it will get freed if we're out of memory
1308                          */
1309                         bb = padblock(bb, RecHdrLen + ivlen);
1310                         if(pad)
1311                                 nb = padblock(bb, -pad);
1312                         else
1313                                 nb = bb;
1314                         bb = nil;
1315                 }
1316
1317                 p = nb->rp;
1318                 p[0] = type;
1319                 put16(p+1, tr->version);
1320                 put16(p+3, n);
1321
1322                 if(sec != nil){
1323                         aadlen = (*tr->packAAD)(out->seq++, p, aad);
1324                         if(sec->aead_enc != nil)
1325                                 n = (*sec->aead_enc)(sec, aad, aadlen, p + RecHdrLen, p + RecHdrLen + ivlen, n) + ivlen;
1326                         else {
1327                                 if(ivlen > 0)
1328                                         randfill(p + RecHdrLen, ivlen);
1329                                 packMac(sec, aad, aadlen, p + RecHdrLen + ivlen, n, p + RecHdrLen + ivlen + n);
1330                                 n = (*sec->enc)(sec, p + RecHdrLen, ivlen + n + maclen);
1331                         }
1332                         nb->wp = p + RecHdrLen + n;
1333
1334                         /* update length */
1335                         put16(p+3, n);
1336                 }
1337                 if(type == RChangeCipherSpec){
1338                         if(out->new == nil)
1339                                 error("change cipher without a new cipher");
1340                         freeSec(out->sec);
1341                         out->sec = out->new;
1342                         out->new = nil;
1343                         out->seq = 0;
1344                 }
1345                 qunlock(&out->seclock);
1346                 poperror();
1347
1348                 /*
1349                  * if bwrite error's, we assume the block is queued.
1350                  * if not, we're out of sync with the receiver and will not recover.
1351                  */
1352                 if(waserror()){
1353                         if(strcmp(up->errstr, "interrupted") != 0)
1354                                 tlsError(tr, "channel error");
1355                         nexterror();
1356                 }
1357                 devtab[tr->c->type]->bwrite(tr->c, nb, 0);
1358                 poperror();
1359         }
1360         qunlock(&out->io);
1361         poperror();
1362 }
1363
1364 static long
1365 tlsbwrite(Chan *c, Block *b, ulong offset)
1366 {
1367         int ty;
1368         ulong n;
1369         TlsRec *tr;
1370
1371         n = BLEN(b);
1372
1373         tr = tlsdevs[CONV(c->qid)];
1374         if(tr == nil)
1375                 panic("tlsbwrite");
1376
1377         ty = TYPE(c->qid);
1378         switch(ty) {
1379         default:
1380                 return devbwrite(c, b, offset);
1381         case Qhand:
1382                 tlsrecwrite(tr, RHandshake, b);
1383                 tr->handout += n;
1384                 break;
1385         case Qdata:
1386                 checkstate(tr, 0, SOpen);
1387                 tlsrecwrite(tr, RApplication, b);
1388                 tr->dataout += n;
1389                 break;
1390         }
1391
1392         return n;
1393 }
1394
1395 typedef struct Hashalg Hashalg;
1396 struct Hashalg
1397 {
1398         char    *name;
1399         int     maclen;
1400         void    (*initkey)(Hashalg *, int, Secret *, uchar*);
1401 };
1402
1403 static void
1404 initmd5key(Hashalg *ha, int version, Secret *s, uchar *p)
1405 {
1406         s->maclen = ha->maclen;
1407         if(version == SSL3Version)
1408                 s->mac = sslmac_md5;
1409         else
1410                 s->mac = hmac_md5;
1411         memmove(s->mackey, p, ha->maclen);
1412 }
1413
1414 static void
1415 initclearmac(Hashalg *, int, Secret *s, uchar *)
1416 {
1417         s->mac = nomac;
1418 }
1419
1420 static void
1421 initsha1key(Hashalg *ha, int version, Secret *s, uchar *p)
1422 {
1423         s->maclen = ha->maclen;
1424         if(version == SSL3Version)
1425                 s->mac = sslmac_sha1;
1426         else
1427                 s->mac = hmac_sha1;
1428         memmove(s->mackey, p, ha->maclen);
1429 }
1430
1431 static void
1432 initsha2_256key(Hashalg *ha, int version, Secret *s, uchar *p)
1433 {
1434         if(version == SSL3Version)
1435                 error("sha256 cannot be used with SSL");
1436         s->maclen = ha->maclen;
1437         s->mac = hmac_sha2_256;
1438         memmove(s->mackey, p, ha->maclen);
1439 }
1440
1441 static Hashalg hashtab[] =
1442 {
1443         { "clear",      0,              initclearmac, },
1444         { "md5",        MD5dlen,        initmd5key, },
1445         { "sha1",       SHA1dlen,       initsha1key, },
1446         { "sha256",     SHA2_256dlen,   initsha2_256key, },
1447         { 0 }
1448 };
1449
1450 static Hashalg*
1451 parsehashalg(char *p)
1452 {
1453         Hashalg *ha;
1454
1455         for(ha = hashtab; ha->name; ha++)
1456                 if(strcmp(p, ha->name) == 0)
1457                         return ha;
1458         error("unsupported hash algorithm");
1459         return nil;
1460 }
1461
1462 typedef struct Encalg Encalg;
1463 struct Encalg
1464 {
1465         char    *name;
1466         int     keylen;
1467         int     ivlen;
1468         void    (*initkey)(Encalg *ea, Secret *, uchar*, uchar*);
1469 };
1470
1471 static void
1472 initRC4key(Encalg *ea, Secret *s, uchar *p, uchar *)
1473 {
1474         s->enckey = secalloc(sizeof(RC4state));
1475         s->enc = rc4enc;
1476         s->dec = rc4enc;
1477         setupRC4state(s->enckey, p, ea->keylen);
1478 }
1479
1480 static void
1481 initDES3key(Encalg *, Secret *s, uchar *p, uchar *iv)
1482 {
1483         s->enckey = secalloc(sizeof(DES3state));
1484         s->enc = des3enc;
1485         s->dec = des3dec;
1486         s->block = 8;
1487         setupDES3state(s->enckey, (uchar(*)[8])p, iv);
1488 }
1489
1490 static void
1491 initAESkey(Encalg *ea, Secret *s, uchar *p, uchar *iv)
1492 {
1493         s->enckey = secalloc(sizeof(AESstate));
1494         s->enc = aesenc;
1495         s->dec = aesdec;
1496         s->block = 16;
1497         setupAESstate(s->enckey, p, ea->keylen, iv);
1498 }
1499
1500 static void
1501 initccpolykey(Encalg *ea, Secret *s, uchar *p, uchar *iv)
1502 {
1503         s->enckey = secalloc(sizeof(Chachastate));
1504         s->aead_enc = ccpoly_aead_enc;
1505         s->aead_dec = ccpoly_aead_dec;
1506         s->maclen = Poly1305dlen;
1507         if(ea->ivlen == 0) {
1508                 /* older draft version, iv is 64-bit sequence number */
1509                 setupChachastate(s->enckey, p, ea->keylen, nil, 64/8, 20);
1510         } else {
1511                 /* IETF standard, 96-bit iv xored with sequence number */
1512                 memmove(s->mackey, iv, ea->ivlen);
1513                 setupChachastate(s->enckey, p, ea->keylen, iv, ea->ivlen, 20);
1514         }
1515 }
1516
1517 static void
1518 initaesgcmkey(Encalg *ea, Secret *s, uchar *p, uchar *iv)
1519 {
1520         s->enckey = secalloc(sizeof(AESGCMstate));
1521         s->aead_enc = aesgcm_aead_enc;
1522         s->aead_dec = aesgcm_aead_dec;
1523         s->maclen = 16;
1524         s->recivlen = 8;
1525         memmove(s->mackey, iv, ea->ivlen);
1526         randfill(s->mackey + ea->ivlen, s->recivlen);
1527         setupAESGCMstate(s->enckey, p, ea->keylen, nil, 0);
1528 }
1529
1530 static void
1531 initclearenc(Encalg *, Secret *s, uchar *, uchar *)
1532 {
1533         s->enc = noenc;
1534         s->dec = noenc;
1535 }
1536
1537 static Encalg encrypttab[] =
1538 {
1539         { "clear", 0, 0, initclearenc },
1540         { "rc4_128", 128/8, 0, initRC4key },
1541         { "3des_ede_cbc", 3 * 8, 8, initDES3key },
1542         { "aes_128_cbc", 128/8, 16, initAESkey },
1543         { "aes_256_cbc", 256/8, 16, initAESkey },
1544         { "ccpoly64_aead", 256/8, 0, initccpolykey },
1545         { "ccpoly96_aead", 256/8, 96/8, initccpolykey },
1546         { "aes_128_gcm_aead", 128/8, 4, initaesgcmkey },
1547         { "aes_256_gcm_aead", 256/8, 4, initaesgcmkey },
1548         { 0 }
1549 };
1550
1551 static Encalg*
1552 parseencalg(char *p)
1553 {
1554         Encalg *ea;
1555
1556         for(ea = encrypttab; ea->name; ea++)
1557                 if(strcmp(p, ea->name) == 0)
1558                         return ea;
1559         error("unsupported encryption algorithm");
1560         return nil;
1561 }
1562
1563 static long
1564 tlswrite(Chan *c, void *a, long n, vlong off)
1565 {
1566         Encalg *ea;
1567         Hashalg *ha;
1568         TlsRec *volatile tr;
1569         Secret *volatile tos, *volatile toc;
1570         Block *volatile b;
1571         Cmdbuf *volatile cb;
1572         int m, ty;
1573         char *p, *e;
1574         uchar *volatile x;
1575         ulong offset = off;
1576
1577         tr = tlsdevs[CONV(c->qid)];
1578         if(tr == nil)
1579                 panic("tlswrite");
1580
1581         ty = TYPE(c->qid);
1582         switch(ty){
1583         case Qdata:
1584         case Qhand:
1585                 p = a;
1586                 e = p + n;
1587                 do{
1588                         m = e - p;
1589                         if(m > c->iounit)
1590                                 m = c->iounit;
1591
1592                         b = allocb(m);
1593                         if(waserror()){
1594                                 freeb(b);
1595                                 nexterror();
1596                         }
1597                         memmove(b->wp, p, m);
1598                         poperror();
1599                         b->wp += m;
1600
1601                         tlsbwrite(c, b, offset);
1602
1603                         p += m;
1604                 }while(p < e);
1605                 return n;
1606         case Qctl:
1607                 break;
1608         default:
1609                 error(Ebadusefd);
1610                 return -1;
1611         }
1612
1613         cb = parsecmd(a, n);
1614         if(waserror()){
1615                 free(cb);
1616                 nexterror();
1617         }
1618         if(cb->nf < 1)
1619                 error("short control request");
1620
1621         /* mutex with operations using what we're about to change */
1622         if(waserror()){
1623                 qunlock(&tr->in.seclock);
1624                 qunlock(&tr->out.seclock);
1625                 nexterror();
1626         }
1627         qlock(&tr->in.seclock);
1628         qlock(&tr->out.seclock);
1629
1630         if(strcmp(cb->f[0], "fd") == 0){
1631                 if(cb->nf != 3)
1632                         error("usage: fd open-fd version");
1633                 if(tr->c != nil)
1634                         error(Einuse);
1635                 m = strtol(cb->f[2], nil, 0);
1636                 if(m < MinProtoVersion || m > MaxProtoVersion)
1637                         error("unsupported version");
1638                 tr->c = buftochan(cb->f[1]);
1639                 tr->version = m;
1640                 tlsSetState(tr, SHandshake, SClosed);
1641         }else if(strcmp(cb->f[0], "version") == 0){
1642                 if(cb->nf != 2)
1643                         error("usage: version vers");
1644                 if(tr->c == nil)
1645                         error("must set fd before version");
1646                 if(tr->verset)
1647                         error("version already set");
1648                 m = strtol(cb->f[1], nil, 0);
1649                 if(m < MinProtoVersion || m > MaxProtoVersion)
1650                         error("unsupported version");
1651                 if(m == SSL3Version)
1652                         tr->packAAD = sslPackAAD;
1653                 else
1654                         tr->packAAD = tlsPackAAD;
1655                 tr->verset = 1;
1656                 tr->version = m;
1657         }else if(strcmp(cb->f[0], "secret") == 0){
1658                 if(cb->nf != 5)
1659                         error("usage: secret hashalg encalg isclient secretdata");
1660                 if(tr->c == nil || !tr->verset)
1661                         error("must set fd and version before secrets");
1662
1663                 if(tr->in.new != nil){
1664                         freeSec(tr->in.new);
1665                         tr->in.new = nil;
1666                 }
1667                 if(tr->out.new != nil){
1668                         freeSec(tr->out.new);
1669                         tr->out.new = nil;
1670                 }
1671
1672                 ha = parsehashalg(cb->f[1]);
1673                 ea = parseencalg(cb->f[2]);
1674
1675                 p = cb->f[4];
1676                 m = (strlen(p)*3)/2 + 1;
1677                 x = secalloc(m);
1678                 tos = secalloc(sizeof(Secret));
1679                 toc = secalloc(sizeof(Secret));
1680                 if(waserror()){
1681                         secfree(x);
1682                         freeSec(tos);
1683                         freeSec(toc);
1684                         nexterror();
1685                 }
1686
1687                 m = dec64(x, m, p, strlen(p));
1688                 memset(p, 0, strlen(p));
1689                 if(m < 2 * ha->maclen + 2 * ea->keylen + 2 * ea->ivlen)
1690                         error("not enough secret data provided");
1691
1692                 if(!ha->initkey || !ea->initkey)
1693                         error("misimplemented secret algorithm");
1694
1695                 (*ha->initkey)(ha, tr->version, tos, &x[0]);
1696                 (*ha->initkey)(ha, tr->version, toc, &x[ha->maclen]);
1697                 (*ea->initkey)(ea, tos, &x[2 * ha->maclen], &x[2 * ha->maclen + 2 * ea->keylen]);
1698                 (*ea->initkey)(ea, toc, &x[2 * ha->maclen + ea->keylen], &x[2 * ha->maclen + 2 * ea->keylen + ea->ivlen]);
1699
1700                 if(!tos->aead_enc || !tos->aead_dec || !toc->aead_enc || !toc->aead_dec)
1701                         if(!tos->mac || !tos->enc || !tos->dec || !toc->mac || !toc->enc || !toc->dec)
1702                                 error("missing algorithm implementations");
1703
1704                 if(strtol(cb->f[3], nil, 0) == 0){
1705                         tr->in.new = tos;
1706                         tr->out.new = toc;
1707                 }else{
1708                         tr->in.new = toc;
1709                         tr->out.new = tos;
1710                 }
1711                 if(tr->version == SSL3Version){
1712                         toc->unpad = sslunpad;
1713                         tos->unpad = sslunpad;
1714                 }else{
1715                         toc->unpad = tlsunpad;
1716                         tos->unpad = tlsunpad;
1717                 }
1718                 toc->encalg = ea->name;
1719                 toc->hashalg = ha->name;
1720                 tos->encalg = ea->name;
1721                 tos->hashalg = ha->name;
1722
1723                 secfree(x);
1724                 poperror();
1725         }else if(strcmp(cb->f[0], "changecipher") == 0){
1726                 if(cb->nf != 1)
1727                         error("usage: changecipher");
1728                 if(tr->out.new == nil)
1729                         error("cannot change cipher spec without setting secret");
1730
1731                 qunlock(&tr->in.seclock);
1732                 qunlock(&tr->out.seclock);
1733                 poperror();
1734                 free(cb);
1735                 poperror();
1736
1737                 /*
1738                  * the real work is done as the message is written
1739                  * so the stream is encrypted in sync.
1740                  */
1741                 b = allocb(1);
1742                 *b->wp++ = 1;
1743                 tlsrecwrite(tr, RChangeCipherSpec, b);
1744                 return n;
1745         }else if(strcmp(cb->f[0], "opened") == 0){
1746                 if(cb->nf != 1)
1747                         error("usage: opened");
1748                 if(tr->in.sec == nil || tr->out.sec == nil)
1749                         error("cipher must be configured before enabling data messages");
1750                 lock(&tr->statelk);
1751                 if(tr->state != SHandshake && tr->state != SOpen){
1752                         unlock(&tr->statelk);
1753                         error("cannot enable data messages");
1754                 }
1755                 tr->state = SOpen;
1756                 unlock(&tr->statelk);
1757                 tr->opened = 1;
1758         }else if(strcmp(cb->f[0], "alert") == 0){
1759                 if(cb->nf != 2)
1760                         error("usage: alert n");
1761                 if(tr->c == nil)
1762                         error("must set fd before sending alerts");
1763                 m = strtol(cb->f[1], nil, 0);
1764
1765                 qunlock(&tr->in.seclock);
1766                 qunlock(&tr->out.seclock);
1767                 poperror();
1768                 free(cb);
1769                 poperror();
1770
1771                 sendAlert(tr, m);
1772
1773                 if(m == ECloseNotify)
1774                         tlsclosed(tr, SLClose);
1775
1776                 return n;
1777         } else if(strcmp(cb->f[0], "debug") == 0){
1778                 if(cb->nf == 2){
1779                         if(strcmp(cb->f[1], "on") == 0)
1780                                 tr->debug = 1;
1781                         else
1782                                 tr->debug = 0;
1783                 } else
1784                         tr->debug = 1;
1785         } else
1786                 error(Ebadarg);
1787
1788         qunlock(&tr->in.seclock);
1789         qunlock(&tr->out.seclock);
1790         poperror();
1791         free(cb);
1792         poperror();
1793
1794         return n;
1795 }
1796
1797 static void
1798 tlsinit(void)
1799 {
1800         struct Encalg *e;
1801         struct Hashalg *h;
1802         int n;
1803         char *cp;
1804         static int already;
1805
1806         if(!already){
1807                 fmtinstall('H', encodefmt);
1808                 already = 1;
1809         }
1810
1811         tlsdevs = smalloc(sizeof(TlsRec*) * maxtlsdevs);
1812         trnames = smalloc((sizeof *trnames) * maxtlsdevs);
1813
1814         n = 1;
1815         for(e = encrypttab; e->name != nil; e++)
1816                 n += strlen(e->name) + 1;
1817         cp = encalgs = smalloc(n);
1818         for(e = encrypttab;;){
1819                 strcpy(cp, e->name);
1820                 cp += strlen(e->name);
1821                 e++;
1822                 if(e->name == nil)
1823                         break;
1824                 *cp++ = ' ';
1825         }
1826         *cp = 0;
1827
1828         n = 1;
1829         for(h = hashtab; h->name != nil; h++)
1830                 n += strlen(h->name) + 1;
1831         cp = hashalgs = smalloc(n);
1832         for(h = hashtab;;){
1833                 strcpy(cp, h->name);
1834                 cp += strlen(h->name);
1835                 h++;
1836                 if(h->name == nil)
1837                         break;
1838                 *cp++ = ' ';
1839         }
1840         *cp = 0;
1841 }
1842
1843 Dev tlsdevtab = {
1844         'a',
1845         "tls",
1846
1847         devreset,
1848         tlsinit,
1849         devshutdown,
1850         tlsattach,
1851         tlswalk,
1852         tlsstat,
1853         tlsopen,
1854         devcreate,
1855         tlsclose,
1856         tlsread,
1857         tlsbread,
1858         tlswrite,
1859         tlsbwrite,
1860         devremove,
1861         tlswstat,
1862 };
1863
1864 /* get channel associated with an fd */
1865 static Chan*
1866 buftochan(char *p)
1867 {
1868         Chan *c;
1869         int fd;
1870
1871         if(p == 0)
1872                 error(Ebadarg);
1873         fd = strtoul(p, 0, 0);
1874         if(fd < 0)
1875                 error(Ebadarg);
1876         c = fdtochan(fd, ORDWR, 1, 1);  /* error check and inc ref */
1877         return c;
1878 }
1879
1880 static void
1881 sendAlert(TlsRec *tr, int err)
1882 {
1883         Block *b;
1884         int i, fatal;
1885         char *msg;
1886
1887 if(tr->debug)pprint("sendAlert %d\n", err);
1888         fatal = 1;
1889         msg = "tls unknown alert";
1890         for(i=0; i < nelem(tlserrs); i++) {
1891                 if(tlserrs[i].err == err) {
1892                         msg = tlserrs[i].msg;
1893                         if(tr->version == SSL3Version)
1894                                 err = tlserrs[i].sslerr;
1895                         else
1896                                 err = tlserrs[i].tlserr;
1897                         fatal = tlserrs[i].fatal;
1898                         break;
1899                 }
1900         }
1901
1902         if(!waserror()){
1903                 b = allocb(2);
1904                 *b->wp++ = fatal + 1;
1905                 *b->wp++ = err;
1906                 if(fatal)
1907                         tlsSetState(tr, SAlert, SOpen|SHandshake|SRClose);
1908                 tlsrecwrite(tr, RAlert, b);
1909                 poperror();
1910         }
1911         if(fatal)
1912                 tlsError(tr, msg);
1913 }
1914
1915 static void
1916 tlsError(TlsRec *tr, char *msg)
1917 {
1918         int s;
1919
1920 if(tr->debug)pprint("tlsError %s\n", msg);
1921         lock(&tr->statelk);
1922         s = tr->state;
1923         tr->state = SError;
1924         if(s != SError){
1925                 strncpy(tr->err, msg, ERRMAX - 1);
1926                 tr->err[ERRMAX - 1] = '\0';
1927         }
1928         unlock(&tr->statelk);
1929         if(s != SError)
1930                 alertHand(tr, msg);
1931 }
1932
1933 static void
1934 tlsSetState(TlsRec *tr, int new, int old)
1935 {
1936         lock(&tr->statelk);
1937         if(tr->state & old)
1938                 tr->state = new;
1939         unlock(&tr->statelk);
1940 }
1941
1942 /* hand up a digest connection */
1943 static void
1944 tlshangup(TlsRec *tr)
1945 {
1946         Block *b;
1947
1948         qlock(&tr->in.io);
1949         for(b = tr->processed; b; b = tr->processed){
1950                 tr->processed = b->next;
1951                 freeb(b);
1952         }
1953         if(tr->unprocessed != nil){
1954                 freeb(tr->unprocessed);
1955                 tr->unprocessed = nil;
1956         }
1957         qunlock(&tr->in.io);
1958
1959         tlsSetState(tr, SClosed, ~0);
1960 }
1961
1962 static TlsRec*
1963 newtls(Chan *ch)
1964 {
1965         TlsRec **pp, **ep, **np;
1966         char **nmp;
1967         int t, newmax;
1968
1969         if(waserror()) {
1970                 unlock(&tdlock);
1971                 nexterror();
1972         }
1973         lock(&tdlock);
1974         ep = &tlsdevs[maxtlsdevs];
1975         for(pp = tlsdevs; pp < ep; pp++)
1976                 if(*pp == nil)
1977                         break;
1978         if(pp >= ep) {
1979                 if(maxtlsdevs >= MaxTlsDevs) {
1980                         unlock(&tdlock);
1981                         poperror();
1982                         return nil;
1983                 }
1984                 newmax = 2 * maxtlsdevs;
1985                 if(newmax > MaxTlsDevs)
1986                         newmax = MaxTlsDevs;
1987                 np = smalloc(sizeof(TlsRec*) * newmax);
1988                 memmove(np, tlsdevs, sizeof(TlsRec*) * maxtlsdevs);
1989                 tlsdevs = np;
1990                 pp = &tlsdevs[maxtlsdevs];
1991                 memset(pp, 0, sizeof(TlsRec*)*(newmax - maxtlsdevs));
1992
1993                 nmp = smalloc(sizeof *nmp * newmax);
1994                 memmove(nmp, trnames, sizeof *nmp * maxtlsdevs);
1995                 trnames = nmp;
1996
1997                 maxtlsdevs = newmax;
1998         }
1999         *pp = mktlsrec();
2000         if(pp - tlsdevs >= tdhiwat)
2001                 tdhiwat++;
2002         t = TYPE(ch->qid);
2003         if(t == Qclonus)
2004                 t = Qctl;
2005         ch->qid.path = QID(pp - tlsdevs, t);
2006         ch->qid.vers = 0;
2007         unlock(&tdlock);
2008         poperror();
2009         return *pp;
2010 }
2011
2012 static TlsRec *
2013 mktlsrec(void)
2014 {
2015         TlsRec *tr;
2016
2017         tr = mallocz(sizeof(*tr), 1);
2018         if(tr == nil)
2019                 error(Enomem);
2020         tr->state = SClosed;
2021         tr->ref = 1;
2022         kstrdup(&tr->user, up->user);
2023         tr->perm = 0660;
2024         return tr;
2025 }
2026
2027 static char*
2028 tlsstate(int s)
2029 {
2030         switch(s){
2031         case SHandshake:
2032                 return "Handshaking";
2033         case SOpen:
2034                 return "Established";
2035         case SRClose:
2036                 return "RemoteClosed";
2037         case SLClose:
2038                 return "LocalClosed";
2039         case SAlert:
2040                 return "Alerting";
2041         case SError:
2042                 return "Errored";
2043         case SClosed:
2044                 return "Closed";
2045         }
2046         return "Unknown";
2047 }
2048
2049 static void
2050 freeSec(Secret *s)
2051 {
2052         if(s == nil)
2053                 return;
2054         secfree(s->enckey);
2055         secfree(s);
2056 }
2057
2058 static int
2059 noenc(Secret *, uchar *, int n)
2060 {
2061         return n;
2062 }
2063
2064 static int
2065 rc4enc(Secret *sec, uchar *buf, int n)
2066 {
2067         rc4(sec->enckey, buf, n);
2068         return n;
2069 }
2070
2071 static int
2072 tlsunpad(uchar *buf, int n, int block)
2073 {
2074         int pad, nn;
2075
2076         pad = buf[n - 1];
2077         nn = n - 1 - pad;
2078         if(nn <= 0 || n % block)
2079                 return -1;
2080         while(--n > nn)
2081                 if(pad != buf[n - 1])
2082                         return -1;
2083         return nn;
2084 }
2085
2086 static int
2087 sslunpad(uchar *buf, int n, int block)
2088 {
2089         int pad, nn;
2090
2091         pad = buf[n - 1];
2092         nn = n - 1 - pad;
2093         if(nn <= 0 || n % block)
2094                 return -1;
2095         return nn;
2096 }
2097
2098 static int
2099 blockpad(uchar *buf, int n, int block)
2100 {
2101         int pad, nn;
2102
2103         nn = n + block;
2104         nn -= nn % block;
2105         pad = nn - (n + 1);
2106         while(n < nn)
2107                 buf[n++] = pad;
2108         return nn;
2109 }
2110                 
2111 static int
2112 des3enc(Secret *sec, uchar *buf, int n)
2113 {
2114         n = blockpad(buf, n, 8);
2115         des3CBCencrypt(buf, n, sec->enckey);
2116         return n;
2117 }
2118
2119 static int
2120 des3dec(Secret *sec, uchar *buf, int n)
2121 {
2122         des3CBCdecrypt(buf, n, sec->enckey);
2123         return (*sec->unpad)(buf, n, 8);
2124 }
2125
2126 static int
2127 aesenc(Secret *sec, uchar *buf, int n)
2128 {
2129         n = blockpad(buf, n, 16);
2130         aesCBCencrypt(buf, n, sec->enckey);
2131         return n;
2132 }
2133
2134 static int
2135 aesdec(Secret *sec, uchar *buf, int n)
2136 {
2137         aesCBCdecrypt(buf, n, sec->enckey);
2138         return (*sec->unpad)(buf, n, 16);
2139 }
2140
2141 static void
2142 ccpoly_aead_setiv(Secret *sec, uchar seq[8])
2143 {
2144         uchar iv[ChachaIVlen];
2145         Chachastate *cs;
2146         int i;
2147
2148         cs = (Chachastate*)sec->enckey;
2149         if(cs->ivwords == 2){
2150                 chacha_setiv(cs, seq);
2151                 return;
2152         }
2153
2154         memmove(iv, sec->mackey, ChachaIVlen);
2155         for(i=0; i<8; i++)
2156                 iv[i+(ChachaIVlen-8)] ^= seq[i];
2157
2158         chacha_setiv(cs, iv);
2159
2160         memset(iv, 0, sizeof(iv));
2161 }
2162
2163 static int
2164 ccpoly_aead_enc(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len)
2165 {
2166         USED(reciv);
2167         ccpoly_aead_setiv(sec, aad);
2168         ccpoly_encrypt(data, len, aad, aadlen, data+len, sec->enckey);
2169         return len + sec->maclen;
2170 }
2171
2172 static int
2173 ccpoly_aead_dec(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len)
2174 {
2175         USED(reciv);
2176         len -= sec->maclen;
2177         if(len < 0)
2178                 return -1;
2179         ccpoly_aead_setiv(sec, aad);
2180         if(ccpoly_decrypt(data, len, aad, aadlen, data+len, sec->enckey) != 0)
2181                 return -1;
2182         return len;
2183 }
2184
2185 static int
2186 aesgcm_aead_enc(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len)
2187 {
2188         uchar iv[12];
2189         int i;
2190
2191         memmove(iv, sec->mackey, 4+8);
2192         for(i=0; i<8; i++) iv[4+i] ^= aad[i];
2193         memmove(reciv, iv+4, 8);
2194         aesgcm_setiv(sec->enckey, iv, 12);
2195         memset(iv, 0, sizeof(iv));
2196         aesgcm_encrypt(data, len, aad, aadlen, data+len, sec->enckey);
2197         return len + sec->maclen;
2198 }
2199
2200 static int
2201 aesgcm_aead_dec(Secret *sec, uchar *aad, int aadlen, uchar *reciv, uchar *data, int len)
2202 {
2203         uchar iv[12];
2204
2205         len -= sec->maclen;
2206         if(len < 0)
2207                 return -1;
2208         memmove(iv, sec->mackey, 4);
2209         memmove(iv+4, reciv, 8);
2210         aesgcm_setiv(sec->enckey, iv, 12);
2211         memset(iv, 0, sizeof(iv));
2212         if(aesgcm_decrypt(data, len, aad, aadlen, data+len, sec->enckey) != 0)
2213                 return -1;
2214         return len;
2215 }
2216
2217
2218 static DigestState*
2219 nomac(uchar *, ulong, uchar *, ulong, uchar *, DigestState *)
2220 {
2221         return nil;
2222 }
2223
2224 /*
2225  * sslmac: mac calculations for ssl 3.0 only; tls 1.0 uses the standard hmac.
2226  */
2227 static DigestState*
2228 sslmac_x(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s,
2229         DigestState*(*x)(uchar*, ulong, uchar*, DigestState*), int xlen, int padlen)
2230 {
2231         int i;
2232         uchar pad[48], innerdigest[20];
2233
2234         if(xlen > sizeof(innerdigest)
2235         || padlen > sizeof(pad))
2236                 return nil;
2237
2238         if(klen>64)
2239                 return nil;
2240
2241         /* first time through */
2242         if(s == nil){
2243                 for(i=0; i<padlen; i++)
2244                         pad[i] = 0x36;
2245                 s = (*x)(key, klen, nil, nil);
2246                 s = (*x)(pad, padlen, nil, s);
2247                 if(s == nil)
2248                         return nil;
2249         }
2250
2251         s = (*x)(p, len, nil, s);
2252         if(digest == nil)
2253                 return s;
2254
2255         /* last time through */
2256         for(i=0; i<padlen; i++)
2257                 pad[i] = 0x5c;
2258         (*x)(nil, 0, innerdigest, s);
2259         s = (*x)(key, klen, nil, nil);
2260         s = (*x)(pad, padlen, nil, s);
2261         (*x)(innerdigest, xlen, digest, s);
2262         return nil;
2263 }
2264
2265 static DigestState*
2266 sslmac_sha1(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s)
2267 {
2268         return sslmac_x(p, len, key, klen, digest, s, sha1, SHA1dlen, 40);
2269 }
2270
2271 static DigestState*
2272 sslmac_md5(uchar *p, ulong len, uchar *key, ulong klen, uchar *digest, DigestState *s)
2273 {
2274         return sslmac_x(p, len, key, klen, digest, s, md5, MD5dlen, 48);
2275 }
2276
2277 static int
2278 sslPackAAD(u64int seq, uchar *hdr, uchar *aad)
2279 {
2280         put64(aad, seq);
2281         aad[8] = hdr[0];
2282         aad[9] = hdr[3];
2283         aad[10] = hdr[4];
2284         return 11;
2285 }
2286
2287 static int
2288 tlsPackAAD(u64int seq, uchar *hdr, uchar *aad)
2289 {
2290         put64(aad, seq);
2291         aad[8] = hdr[0];
2292         aad[9] = hdr[1];
2293         aad[10] = hdr[2];
2294         aad[11] = hdr[3];
2295         aad[12] = hdr[4];
2296         return 13;
2297 }
2298
2299 static void
2300 packMac(Secret *sec, uchar *aad, int aadlen, uchar *body, int bodylen, uchar *mac)
2301 {
2302         DigestState *s;
2303
2304         s = (*sec->mac)(aad, aadlen, sec->mackey, sec->maclen, nil, nil);
2305         (*sec->mac)(body, bodylen, sec->mackey, sec->maclen, mac, s);
2306 }
2307
2308 static void
2309 put32(uchar *p, u32int x)
2310 {
2311         p[0] = x>>24;
2312         p[1] = x>>16;
2313         p[2] = x>>8;
2314         p[3] = x;
2315 }
2316
2317 static void
2318 put64(uchar *p, u64int x)
2319 {
2320         put32(p, x >> 32);
2321         put32(p+4, x);
2322 }
2323
2324 static void
2325 put24(uchar *p, int x)
2326 {
2327         p[0] = x>>16;
2328         p[1] = x>>8;
2329         p[2] = x;
2330 }
2331
2332 static void
2333 put16(uchar *p, int x)
2334 {
2335         p[0] = x>>8;
2336         p[1] = x;
2337 }
2338
2339 static u32int
2340 get32(uchar *p)
2341 {
2342         return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2343 }
2344
2345 static int
2346 get16(uchar *p)
2347 {
2348         return (p[0]<<8)|p[1];
2349 }
2350
2351 static char *charmap = "0123456789abcdef";
2352
2353 static void
2354 pdump(int len, void *a, char *tag)
2355 {
2356         uchar *p;
2357         int i;
2358         char buf[65+32];
2359         char *q;
2360
2361         p = a;
2362         strcpy(buf, tag);
2363         while(len > 0){
2364                 q = buf + strlen(tag);
2365                 for(i = 0; len > 0 && i < 32; i++){
2366                         if(*p >= ' ' && *p < 0x7f){
2367                                 *q++ = ' ';
2368                                 *q++ = *p;
2369                         } else {
2370                                 *q++ = charmap[*p>>4];
2371                                 *q++ = charmap[*p & 0xf];
2372                         }
2373                         len--;
2374                         p++;
2375                 }
2376                 *q = 0;
2377
2378                 if(len > 0)
2379                         pprint("%s...\n", buf);
2380                 else
2381                         pprint("%s\n", buf);
2382         }
2383 }