]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libsec/port/tlshand.c
48c281068e9507f21b33db3cd4831b0cbf680477
[plan9front.git] / sys / src / libsec / port / tlshand.c
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7
8 // The main groups of functions are:
9 //              client/server - main handshake protocol definition
10 //              message functions - formating handshake messages
11 //              cipher choices - catalog of digest and encrypt algorithms
12 //              security functions - PKCS#1, sslHMAC, session keygen
13 //              general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16
17 enum {
18         TLSFinishedLen = 12,
19         SSL3FinishedLen = MD5dlen+SHA1dlen,
20         MaxKeyData = 136,       // amount of secret we may need
21         MaxChunk = 1<<15,
22         RandomSize = 32,
23         SidSize = 32,
24         MasterSecretSize = 48,
25         AQueue = 0,
26         AFlush = 1,
27 };
28
29 typedef struct TlsSec TlsSec;
30
31 typedef struct Bytes{
32         int len;
33         uchar data[1];  // [len]
34 } Bytes;
35
36 typedef struct Ints{
37         int len;
38         int data[1];  // [len]
39 } Ints;
40
41 typedef struct Algs{
42         char *enc;
43         char *digest;
44         int nsecret;
45         int tlsid;
46         int ok;
47 } Algs;
48
49 typedef struct Namedcurve{
50         int tlsid;
51         char *name;
52
53         char *p;
54         char *a;
55         char *b;
56         char *G;
57         char *n;
58         char *h;
59 } Namedcurve;
60
61 typedef struct Finished{
62         uchar verify[SSL3FinishedLen];
63         int n;
64 } Finished;
65
66 typedef struct HandshakeHash {
67         MD5state        md5;
68         SHAstate        sha1;
69         SHA2_256state   sha2_256;
70 } HandshakeHash;
71
72 typedef struct TlsConnection{
73         TlsSec *sec;    // security management goo
74         int hand, ctl;  // record layer file descriptors
75         int erred;              // set when tlsError called
76         int (*trace)(char*fmt, ...); // for debugging
77         int version;    // protocol we are speaking
78         int verset;             // version has been set
79         int ver2hi;             // server got a version 2 hello
80         int isClient;   // is this the client or server?
81         Bytes *sid;             // SessionID
82         Bytes *cert;    // only last - no chain
83
84         Lock statelk;
85         int state;              // must be set using setstate
86
87         // input buffer for handshake messages
88         uchar recvbuf[MaxChunk];
89         uchar *rp, *ep;
90
91         // output buffer
92         uchar sendbuf[MaxChunk];
93         uchar *sendp;
94
95         uchar crandom[RandomSize];      // client random
96         uchar srandom[RandomSize];      // server random
97         int clientVersion;      // version in ClientHello
98         int cipher;
99         char *digest;   // name of digest algorithm to use
100         char *enc;              // name of encryption algorithm to use
101         int nsecret;    // amount of secret data to init keys
102
103         // for finished messages
104         HandshakeHash   handhash;
105         Finished        finished;
106 } TlsConnection;
107
108 typedef struct Msg{
109         int tag;
110         union {
111                 struct {
112                         int version;
113                         uchar   random[RandomSize];
114                         Bytes*  sid;
115                         Ints*   ciphers;
116                         Bytes*  compressors;
117                         Bytes*  extensions;
118                 } clientHello;
119                 struct {
120                         int version;
121                         uchar   random[RandomSize];
122                         Bytes*  sid;
123                         int     cipher;
124                         int     compressor;
125                         Bytes*  extensions;
126                 } serverHello;
127                 struct {
128                         int ncert;
129                         Bytes **certs;
130                 } certificate;
131                 struct {
132                         Bytes *types;
133                         int nca;
134                         Bytes **cas;
135                 } certificateRequest;
136                 struct {
137                         Bytes *key;
138                 } clientKeyExchange;
139                 struct {
140                         Bytes *dh_p;
141                         Bytes *dh_g;
142                         Bytes *dh_Ys;
143                         Bytes *dh_signature;
144                         int curve;
145                 } serverKeyExchange;
146                 struct {
147                         Bytes *signature;
148                 } certificateVerify;            
149                 Finished finished;
150         } u;
151 } Msg;
152
153 typedef struct TlsSec{
154         char *server;   // name of remote; nil for server
155         int ok; // <0 killed; == 0 in progress; >0 reusable
156         RSApub *rsapub;
157         AuthRpc *rpc;   // factotum for rsa private key
158         uchar sec[MasterSecretSize];    // master secret
159         uchar crandom[RandomSize];      // client random
160         uchar srandom[RandomSize];      // server random
161         int clientVers;         // version in ClientHello
162         int vers;                       // final version
163         // byte generation and handshake checksum
164         void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
165         void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int);
166         int nfin;
167 } TlsSec;
168
169
170 enum {
171         SSL3Version     = 0x0300,
172         TLS10Version    = 0x0301,
173         TLS11Version    = 0x0302,
174         TLS12Version    = 0x0303,
175         ProtocolVersion = TLS12Version, // maximum version we speak
176         MinProtoVersion = 0x0300,       // limits on version we accept
177         MaxProtoVersion = 0x03ff,
178 };
179
180 // handshake type
181 enum {
182         HHelloRequest,
183         HClientHello,
184         HServerHello,
185         HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
186         HCertificate = 11,
187         HServerKeyExchange,
188         HCertificateRequest,
189         HServerHelloDone,
190         HCertificateVerify,
191         HClientKeyExchange,
192         HFinished = 20,
193         HMax
194 };
195
196 // alerts
197 enum {
198         ECloseNotify = 0,
199         EUnexpectedMessage = 10,
200         EBadRecordMac = 20,
201         EDecryptionFailed = 21,
202         ERecordOverflow = 22,
203         EDecompressionFailure = 30,
204         EHandshakeFailure = 40,
205         ENoCertificate = 41,
206         EBadCertificate = 42,
207         EUnsupportedCertificate = 43,
208         ECertificateRevoked = 44,
209         ECertificateExpired = 45,
210         ECertificateUnknown = 46,
211         EIllegalParameter = 47,
212         EUnknownCa = 48,
213         EAccessDenied = 49,
214         EDecodeError = 50,
215         EDecryptError = 51,
216         EExportRestriction = 60,
217         EProtocolVersion = 70,
218         EInsufficientSecurity = 71,
219         EInternalError = 80,
220         EUserCanceled = 90,
221         ENoRenegotiation = 100,
222         EMax = 256
223 };
224
225 // cipher suites
226 enum {
227         TLS_NULL_WITH_NULL_NULL                 = 0x0000,
228         TLS_RSA_WITH_NULL_MD5                   = 0x0001,
229         TLS_RSA_WITH_NULL_SHA                   = 0x0002,
230         TLS_RSA_EXPORT_WITH_RC4_40_MD5          = 0x0003,
231         TLS_RSA_WITH_RC4_128_MD5                = 0x0004,
232         TLS_RSA_WITH_RC4_128_SHA                = 0x0005,
233         TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5      = 0X0006,
234         TLS_RSA_WITH_IDEA_CBC_SHA               = 0X0007,
235         TLS_RSA_EXPORT_WITH_DES40_CBC_SHA       = 0X0008,
236         TLS_RSA_WITH_DES_CBC_SHA                = 0X0009,
237         TLS_RSA_WITH_3DES_EDE_CBC_SHA           = 0X000A,
238         TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA    = 0X000B,
239         TLS_DH_DSS_WITH_DES_CBC_SHA             = 0X000C,
240         TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA        = 0X000D,
241         TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA    = 0X000E,
242         TLS_DH_RSA_WITH_DES_CBC_SHA             = 0X000F,
243         TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA        = 0X0010,
244         TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA   = 0X0011,
245         TLS_DHE_DSS_WITH_DES_CBC_SHA            = 0X0012,
246         TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA       = 0X0013,       // ZZZ must be implemented for tls1.0 compliance
247         TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA   = 0X0014,
248         TLS_DHE_RSA_WITH_DES_CBC_SHA            = 0X0015,
249         TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA       = 0X0016,
250         TLS_DH_anon_EXPORT_WITH_RC4_40_MD5      = 0x0017,
251         TLS_DH_anon_WITH_RC4_128_MD5            = 0x0018,
252         TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA   = 0X0019,
253         TLS_DH_anon_WITH_DES_CBC_SHA            = 0X001A,
254         TLS_DH_anon_WITH_3DES_EDE_CBC_SHA       = 0X001B,
255
256         TLS_RSA_WITH_AES_128_CBC_SHA            = 0X002f,       // aes, aka rijndael with 128 bit blocks
257         TLS_DH_DSS_WITH_AES_128_CBC_SHA         = 0X0030,
258         TLS_DH_RSA_WITH_AES_128_CBC_SHA         = 0X0031,
259         TLS_DHE_DSS_WITH_AES_128_CBC_SHA        = 0X0032,
260         TLS_DHE_RSA_WITH_AES_128_CBC_SHA        = 0X0033,
261         TLS_DH_anon_WITH_AES_128_CBC_SHA        = 0X0034,
262         TLS_RSA_WITH_AES_256_CBC_SHA            = 0X0035,
263         TLS_DH_DSS_WITH_AES_256_CBC_SHA         = 0X0036,
264         TLS_DH_RSA_WITH_AES_256_CBC_SHA         = 0X0037,
265         TLS_DHE_DSS_WITH_AES_256_CBC_SHA        = 0X0038,
266         TLS_DHE_RSA_WITH_AES_256_CBC_SHA        = 0X0039,
267         TLS_DH_anon_WITH_AES_256_CBC_SHA        = 0X003A,
268         
269         TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA      = 0xC013,
270         TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA      = 0xC014,
271         TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA  = 0xC009,
272         TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A,
273         CipherMax
274 };
275
276 // compression methods
277 enum {
278         CompressionNull = 0,
279         CompressionMax
280 };
281
282 static Algs cipherAlgs[] = {
283         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
284         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA},
285         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
286         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
287         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
288         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
289         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
290         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
291         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
292         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
293         {"rc4_128", "sha1",     2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
294         {"rc4_128", "md5",      2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
295 };
296
297 static uchar compressors[] = {
298         CompressionNull,
299 };
300
301 static Namedcurve namedcurves[] = {
302 {0x0017, "secp256r1",
303         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
304         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC",
305         "5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B",
306         "046B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C2964FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
307         "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
308         "1"}
309 };
310
311 static uchar pointformats[] = {
312         CompressionNull /* support of uncompressed point format is mandatory */
313 };
314
315 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
316 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
317 static void     msgClear(Msg *m);
318 static char* msgPrint(char *buf, int n, Msg *m);
319 static int      msgRecv(TlsConnection *c, Msg *m);
320 static int      msgSend(TlsConnection *c, Msg *m, int act);
321 static void     tlsError(TlsConnection *c, int err, char *msg, ...);
322 #pragma varargck argpos tlsError 3
323 static int setVersion(TlsConnection *c, int version);
324 static int finishedMatch(TlsConnection *c, Finished *f);
325 static void tlsConnectionFree(TlsConnection *c);
326
327 static int setAlgs(TlsConnection *c, int a);
328 static int okCipher(Ints *cv);
329 static int okCompression(Bytes *cv);
330 static int initCiphers(void);
331 static Ints* makeciphers(void);
332
333 static TlsSec*  tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
334 static int      tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
335 static TlsSec*  tlsSecInitc(int cvers, uchar *crandom);
336 static Bytes*   tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
337 static Bytes*   tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
338 static Bytes*   tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
339 static int      tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
340 static void     tlsSecOk(TlsSec *sec);
341 static void     tlsSecKill(TlsSec *sec);
342 static void     tlsSecClose(TlsSec *sec);
343 static void     setMasterSecret(TlsSec *sec, Bytes *pm);
344 static void     serverMasterSecret(TlsSec *sec, Bytes *epm);
345 static void     setSecrets(TlsSec *sec, uchar *kd, int nkd);
346 static Bytes*   clientMasterSecret(TlsSec *sec, RSApub *pub);
347 static Bytes*   pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
348 static Bytes*   pkcs1_decrypt(TlsSec *sec, Bytes *cipher);
349 static void     tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
350 static void     tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
351 static void     sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
352 static void     sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
353                         uchar *seed0, int nseed0, uchar *seed1, int nseed1);
354 static int setVers(TlsSec *sec, int version);
355
356 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
357 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
358 static void factotum_rsa_close(AuthRpc*rpc);
359
360 static void* emalloc(int);
361 static void* erealloc(void*, int);
362 static void put32(uchar *p, u32int);
363 static void put24(uchar *p, int);
364 static void put16(uchar *p, int);
365 static u32int get32(uchar *p);
366 static int get24(uchar *p);
367 static int get16(uchar *p);
368 static Bytes* newbytes(int len);
369 static Bytes* makebytes(uchar* buf, int len);
370 static Bytes* mptobytes(mpint* big);
371 static mpint* bytestomp(Bytes* bytes);
372 static void freebytes(Bytes* b);
373 static Ints* newints(int len);
374 static void freeints(Ints* b);
375
376 /* x509.c */
377 extern mpint* pkcs1padbuf(uchar *buf, int len, mpint *modulus);
378
379 //================= client/server ========================
380
381 //      push TLS onto fd, returning new (application) file descriptor
382 //              or -1 if error.
383 int
384 tlsServer(int fd, TLSconn *conn)
385 {
386         char buf[8];
387         char dname[64];
388         int n, data, ctl, hand;
389         TlsConnection *tls;
390
391         if(conn == nil)
392                 return -1;
393         ctl = open("#a/tls/clone", ORDWR);
394         if(ctl < 0)
395                 return -1;
396         n = read(ctl, buf, sizeof(buf)-1);
397         if(n < 0){
398                 close(ctl);
399                 return -1;
400         }
401         buf[n] = 0;
402         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
403         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
404         hand = open(dname, ORDWR);
405         if(hand < 0){
406                 close(ctl);
407                 return -1;
408         }
409         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
410         tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
411         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
412         data = open(dname, ORDWR);
413         close(hand);
414         close(ctl);
415         if(data < 0 || tls == nil){
416                 if(tls != nil)
417                         tlsConnectionFree(tls);
418                 return -1;
419         }
420         free(conn->cert);
421         conn->cert = 0;  // client certificates are not yet implemented
422         conn->certlen = 0;
423         conn->sessionIDlen = tls->sid->len;
424         conn->sessionID = emalloc(conn->sessionIDlen);
425         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
426         if(conn->sessionKey != nil
427         && conn->sessionType != nil
428         && strcmp(conn->sessionType, "ttls") == 0)
429                 tls->sec->prf(
430                         conn->sessionKey, conn->sessionKeylen,
431                         tls->sec->sec, MasterSecretSize,
432                         conn->sessionConst, 
433                         tls->sec->crandom, RandomSize,
434                         tls->sec->srandom, RandomSize);
435         tlsConnectionFree(tls);
436         close(fd);
437         return data;
438 }
439
440 static uchar*
441 tlsClientExtensions(TLSconn *conn, int *plen)
442 {
443         uchar *b, *p;
444         int i, n, m;
445
446         p = b = nil;
447
448         // RFC6066 - Server Name Identification
449         if(conn->serverName != nil){
450                 n = strlen(conn->serverName);
451
452                 m = p - b;
453                 b = erealloc(b, m + 2+2+2+1+2+n);
454                 p = b + m;
455
456                 put16(p, 0), p += 2;            /* Type: server_name */
457                 put16(p, 2+1+2+n), p += 2;      /* Length */
458                 put16(p, 1+2+n), p += 2;        /* Server Name list length */
459                 *p++ = 0;                       /* Server Name Type: host_name */
460                 put16(p, n), p += 2;            /* Server Name length */
461                 memmove(p, conn->serverName, n);
462                 p += n;
463         }
464
465         // ECDHE
466         if(1){
467                 m = p - b;
468                 b = erealloc(b, m + 2+2+2+nelem(namedcurves)*2 + 2+2+1+nelem(pointformats));
469                 p = b + m;
470
471                 n = nelem(namedcurves);
472                 put16(p, 0x000a), p += 2;       /* Type: elliptic_curves */
473                 put16(p, (n+1)*2), p += 2;      /* Length */
474                 put16(p, n*2), p += 2;          /* Elliptic Curves Length */
475                 for(i=0; i < n; i++){           /* Elliptic curves */
476                         put16(p, namedcurves[i].tlsid);
477                         p += 2;
478                 }
479
480                 n = nelem(pointformats);
481                 put16(p, 0x000b), p += 2;       /* Type: ec_point_formats */
482                 put16(p, n+1), p += 2;          /* Length */
483                 *p++ = n;                       /* EC point formats Length */
484                 for(i=0; i < n; i++)            /* Elliptic curves point formats */
485                         *p++ = pointformats[i];
486         }
487         
488         *plen = p - b;
489         return b;
490 }
491
492 //      push TLS onto fd, returning new (application) file descriptor
493 //              or -1 if error.
494 int
495 tlsClient(int fd, TLSconn *conn)
496 {
497         char buf[8];
498         char dname[64];
499         int n, data, ctl, hand;
500         TlsConnection *tls;
501         uchar *ext;
502
503         if(conn == nil)
504                 return -1;
505         ctl = open("#a/tls/clone", ORDWR);
506         if(ctl < 0)
507                 return -1;
508         n = read(ctl, buf, sizeof(buf)-1);
509         if(n < 0){
510                 close(ctl);
511                 return -1;
512         }
513         buf[n] = 0;
514         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
515         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
516         hand = open(dname, ORDWR);
517         if(hand < 0){
518                 close(ctl);
519                 return -1;
520         }
521         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
522         data = open(dname, ORDWR);
523         if(data < 0){
524                 close(hand);
525                 close(ctl);
526                 return -1;
527         }
528         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
529         ext = tlsClientExtensions(conn, &n);
530         tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, 
531                 ext, n, conn->trace);
532         free(ext);
533         close(hand);
534         close(ctl);
535         if(tls == nil){
536                 close(data);
537                 return -1;
538         }
539         conn->certlen = tls->cert->len;
540         conn->cert = emalloc(conn->certlen);
541         memcpy(conn->cert, tls->cert->data, conn->certlen);
542         conn->sessionIDlen = tls->sid->len;
543         conn->sessionID = emalloc(conn->sessionIDlen);
544         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
545         if(conn->sessionKey != nil
546         && conn->sessionType != nil
547         && strcmp(conn->sessionType, "ttls") == 0)
548                 tls->sec->prf(
549                         conn->sessionKey, conn->sessionKeylen,
550                         tls->sec->sec, MasterSecretSize,
551                         conn->sessionConst, 
552                         tls->sec->crandom, RandomSize,
553                         tls->sec->srandom, RandomSize);
554         tlsConnectionFree(tls);
555         close(fd);
556         return data;
557 }
558
559 static int
560 countchain(PEMChain *p)
561 {
562         int i = 0;
563
564         while (p) {
565                 i++;
566                 p = p->next;
567         }
568         return i;
569 }
570
571 static TlsConnection *
572 tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
573 {
574         TlsConnection *c;
575         Msg m;
576         Bytes *csid;
577         uchar sid[SidSize], kd[MaxKeyData];
578         char *secrets;
579         int cipher, compressor, nsid, rv, numcerts, i;
580
581         if(trace)
582                 trace("tlsServer2\n");
583         if(!initCiphers())
584                 return nil;
585         c = emalloc(sizeof(TlsConnection));
586         c->ctl = ctl;
587         c->hand = hand;
588         c->trace = trace;
589         c->version = ProtocolVersion;
590
591         memset(&m, 0, sizeof(m));
592         if(!msgRecv(c, &m)){
593                 if(trace)
594                         trace("initial msgRecv failed\n");
595                 goto Err;
596         }
597         if(m.tag != HClientHello) {
598                 tlsError(c, EUnexpectedMessage, "expected a client hello");
599                 goto Err;
600         }
601         c->clientVersion = m.u.clientHello.version;
602         if(trace)
603                 trace("ClientHello version %x\n", c->clientVersion);
604         if(setVersion(c, c->clientVersion) < 0) {
605                 tlsError(c, EIllegalParameter, "incompatible version");
606                 goto Err;
607         }
608
609         memmove(c->crandom, m.u.clientHello.random, RandomSize);
610         cipher = okCipher(m.u.clientHello.ciphers);
611         if(cipher < 0) {
612                 // reply with EInsufficientSecurity if we know that's the case
613                 if(cipher == -2)
614                         tlsError(c, EInsufficientSecurity, "cipher suites too weak");
615                 else
616                         tlsError(c, EHandshakeFailure, "no matching cipher suite");
617                 goto Err;
618         }
619         if(!setAlgs(c, cipher)){
620                 tlsError(c, EHandshakeFailure, "no matching cipher suite");
621                 goto Err;
622         }
623         compressor = okCompression(m.u.clientHello.compressors);
624         if(compressor < 0) {
625                 tlsError(c, EHandshakeFailure, "no matching compressor");
626                 goto Err;
627         }
628
629         csid = m.u.clientHello.sid;
630         if(trace)
631                 trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
632         c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
633         if(c->sec == nil){
634                 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
635                 goto Err;
636         }
637         c->sec->rpc = factotum_rsa_open(cert, certlen);
638         if(c->sec->rpc == nil){
639                 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
640                 goto Err;
641         }
642         c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
643         if(c->sec->rsapub == nil){
644                 tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
645                 goto Err;
646         }
647         msgClear(&m);
648
649         m.tag = HServerHello;
650         m.u.serverHello.version = c->version;
651         memmove(m.u.serverHello.random, c->srandom, RandomSize);
652         m.u.serverHello.cipher = cipher;
653         m.u.serverHello.compressor = compressor;
654         c->sid = makebytes(sid, nsid);
655         m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
656         if(!msgSend(c, &m, AQueue))
657                 goto Err;
658         msgClear(&m);
659
660         m.tag = HCertificate;
661         numcerts = countchain(chp);
662         m.u.certificate.ncert = 1 + numcerts;
663         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
664         m.u.certificate.certs[0] = makebytes(cert, certlen);
665         for (i = 0; i < numcerts && chp; i++, chp = chp->next)
666                 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
667         if(!msgSend(c, &m, AQueue))
668                 goto Err;
669         msgClear(&m);
670
671         m.tag = HServerHelloDone;
672         if(!msgSend(c, &m, AFlush))
673                 goto Err;
674         msgClear(&m);
675
676         if(!msgRecv(c, &m))
677                 goto Err;
678         if(m.tag != HClientKeyExchange) {
679                 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
680                 goto Err;
681         }
682         if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
683                 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
684                 goto Err;
685         }
686         setSecrets(c->sec, kd, c->nsecret);
687         if(trace)
688                 trace("tls secrets\n");
689         secrets = (char*)emalloc(2*c->nsecret);
690         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
691         rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
692         memset(secrets, 0, 2*c->nsecret);
693         free(secrets);
694         memset(kd, 0, c->nsecret);
695         if(rv < 0){
696                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
697                 goto Err;
698         }
699         msgClear(&m);
700
701         /* no CertificateVerify; skip to Finished */
702         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
703                 tlsError(c, EInternalError, "can't set finished: %r");
704                 goto Err;
705         }
706         if(!msgRecv(c, &m))
707                 goto Err;
708         if(m.tag != HFinished) {
709                 tlsError(c, EUnexpectedMessage, "expected a finished");
710                 goto Err;
711         }
712         if(!finishedMatch(c, &m.u.finished)) {
713                 tlsError(c, EHandshakeFailure, "finished verification failed");
714                 goto Err;
715         }
716         msgClear(&m);
717
718         /* change cipher spec */
719         if(fprint(c->ctl, "changecipher") < 0){
720                 tlsError(c, EInternalError, "can't enable cipher: %r");
721                 goto Err;
722         }
723
724         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
725                 tlsError(c, EInternalError, "can't set finished: %r");
726                 goto Err;
727         }
728         m.tag = HFinished;
729         m.u.finished = c->finished;
730         if(!msgSend(c, &m, AFlush))
731                 goto Err;
732         msgClear(&m);
733         if(trace)
734                 trace("tls finished\n");
735
736         if(fprint(c->ctl, "opened") < 0)
737                 goto Err;
738         tlsSecOk(c->sec);
739         return c;
740
741 Err:
742         msgClear(&m);
743         tlsConnectionFree(c);
744         return 0;
745 }
746
747 static int
748 isDHE(int tlsid)
749 {
750         switch(tlsid){
751         case TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA:
752         case TLS_DHE_DSS_WITH_DES_CBC_SHA:
753         case TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA:
754         case TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA:
755         case TLS_DHE_RSA_WITH_DES_CBC_SHA:
756         case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
757         case TLS_DHE_DSS_WITH_AES_128_CBC_SHA:
758         case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
759         case TLS_DHE_DSS_WITH_AES_256_CBC_SHA:
760         case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
761                 return 1;
762         }
763         return 0;
764 }
765
766 static int
767 isECDHE(int tlsid)
768 {
769         switch(tlsid){
770         case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
771         case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
772         case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
773         case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
774                 return 1;
775         }
776         return 0;
777 }
778
779 static Bytes*
780 tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, 
781         Bytes *p, Bytes *g, Bytes *Ys)
782 {
783         mpint *G, *P, *Y, *K;
784         Bytes *epm;
785         DHstate dh;
786
787         if(p == nil || g == nil || Ys == nil)
788                 return nil;
789
790         memmove(sec->srandom, srandom, RandomSize);
791         if(setVers(sec, vers) < 0)
792                 return nil;
793
794         epm = nil;
795         P = bytestomp(p);
796         G = bytestomp(g);
797         Y = bytestomp(Ys);
798         K = nil;
799
800         if(P == nil || G == nil || Y == nil || dh_new(&dh, P, G) == nil)
801                 goto Out;
802         epm = mptobytes(dh.y);
803         K = dh_finish(&dh, Y);
804         if(K == nil){
805                 freebytes(epm);
806                 epm = nil;
807                 goto Out;
808         }
809         setMasterSecret(sec, mptobytes(K));
810
811 Out:
812         mpfree(K);
813         mpfree(Y);
814         mpfree(G);
815         mpfree(P);
816
817         return epm;
818 }
819
820 static ECpoint*
821 bytestoec(ECdomain *dom, Bytes *bp, ECpoint *ret)
822 {
823         char *hex = "0123456789ABCDEF";
824         char *s;
825         int i;
826
827         s = emalloc(2*bp->len + 1);
828         for(i=0; i < bp->len; i++){
829                 s[2*i] = hex[bp->data[i]>>4 & 15];
830                 s[2*i+1] = hex[bp->data[i] & 15];
831         }
832         s[2*bp->len] = '\0';
833         ret = strtoec(dom, s, nil, ret);
834         free(s);
835         return ret;
836 }
837
838 static Bytes*
839 ectobytes(int type, ECpoint *p)
840 {
841         Bytes *bx, *by, *bp;
842
843         bx = mptobytes(p->x);
844         by = mptobytes(p->y);
845         bp = newbytes(bx->len + by->len + 1);
846         bp->data[0] =  type;
847         memmove(bp->data+1, bx->data, bx->len);
848         memmove(bp->data+1+bx->len, by->data, by->len);
849         freebytes(bx);
850         freebytes(by);
851         return bp;
852 }
853
854 static Bytes*
855 tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys)
856 {
857         Namedcurve *nc, *enc;
858         Bytes *epm;
859         ECdomain dom;
860         ECpoint G, K, Y;
861         ECpriv Q;
862
863         if(Ys == nil)
864                 return nil;
865
866         enc = &namedcurves[nelem(namedcurves)];
867         for(nc = namedcurves; nc != enc; nc++)
868                 if(nc->tlsid == curve)
869                         break;
870
871         if(nc == enc)
872                 return nil;
873                 
874         memmove(sec->srandom, srandom, RandomSize);
875         if(setVers(sec, vers) < 0)
876                 return nil;
877         
878         epm = nil;
879
880         memset(&dom, 0, sizeof(dom));
881         dom.p = strtomp(nc->p, nil, 16, nil);
882         dom.a = strtomp(nc->a, nil, 16, nil);
883         dom.b = strtomp(nc->b, nil, 16, nil);
884         dom.n = strtomp(nc->n, nil, 16, nil);
885         dom.h = strtomp(nc->h, nil, 16, nil);
886
887         memset(&G, 0, sizeof(G));
888         G.x = mpnew(0);
889         G.y = mpnew(0);
890
891         memset(&Q, 0, sizeof(Q));
892         Q.x = mpnew(0);
893         Q.y = mpnew(0);
894         Q.d = mpnew(0);
895
896         memset(&K, 0, sizeof(K));
897         K.x = mpnew(0);
898         K.y = mpnew(0);
899
900         memset(&Y, 0, sizeof(Y));
901         Y.x = mpnew(0);
902         Y.y = mpnew(0);
903
904         if(dom.p == nil || dom.a == nil || dom.b == nil || dom.n == nil || dom.h == nil)
905                 goto Out;
906         if(Q.x == nil || Q.y == nil || Q.d == nil)
907                 goto Out;
908         if(G.x == nil || G.y == nil)
909                 goto Out;
910         if(K.x == nil || K.y == nil)
911                 goto Out;
912         if(Y.x == nil || Y.y == nil)
913                 goto Out;
914
915         dom.G = strtoec(&dom, nc->G, nil, &G);
916         if(dom.G == nil)
917                 goto Out;
918
919         if(bytestoec(&dom, Ys, &Y) == nil)
920                 goto Out;
921
922         if(ecgen(&dom, &Q) == nil)
923                 goto Out;
924
925         ecmul(&dom, &Y, Q.d, &K);
926         setMasterSecret(sec, mptobytes(K.x));
927
928         /* 0x04 = uncompressed public key */
929         epm = ectobytes(0x04, &Q);
930         
931 Out:
932         mpfree(Y.x);
933         mpfree(Y.y);
934
935         mpfree(K.x);
936         mpfree(K.y);
937
938         mpfree(Q.x);
939         mpfree(Q.y);
940         mpfree(Q.d);
941
942         mpfree(G.x);
943         mpfree(G.y);
944
945         mpfree(dom.p);
946         mpfree(dom.a);
947         mpfree(dom.b);
948         mpfree(dom.n);
949         mpfree(dom.h);
950
951         return epm;
952 }
953
954 static TlsConnection *
955 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
956         int (*trace)(char*fmt, ...))
957 {
958         TlsConnection *c;
959         Msg m;
960         uchar kd[MaxKeyData];
961         char *secrets;
962         int creq, dhx, rv, cipher;
963         mpint *signedMP, *paddedHashes;
964         Bytes *epm;
965
966         if(!initCiphers())
967                 return nil;
968         epm = nil;
969         c = emalloc(sizeof(TlsConnection));
970         c->version = ProtocolVersion;
971
972         // client certificate signature not implemented for TLS1.2
973         if(cert != nil && certlen > 0 && c->version >= TLS12Version)
974                 c->version = TLS11Version;
975
976         c->ctl = ctl;
977         c->hand = hand;
978         c->trace = trace;
979         c->isClient = 1;
980         c->clientVersion = c->version;
981
982         c->sec = tlsSecInitc(c->clientVersion, c->crandom);
983         if(c->sec == nil)
984                 goto Err;
985         /* client hello */
986         memset(&m, 0, sizeof(m));
987         m.tag = HClientHello;
988         m.u.clientHello.version = c->clientVersion;
989         memmove(m.u.clientHello.random, c->crandom, RandomSize);
990         m.u.clientHello.sid = makebytes(csid, ncsid);
991         m.u.clientHello.ciphers = makeciphers();
992         m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
993         m.u.clientHello.extensions = makebytes(ext, extlen);
994         if(!msgSend(c, &m, AFlush))
995                 goto Err;
996         msgClear(&m);
997
998         /* server hello */
999         if(!msgRecv(c, &m))
1000                 goto Err;
1001         if(m.tag != HServerHello) {
1002                 tlsError(c, EUnexpectedMessage, "expected a server hello");
1003                 goto Err;
1004         }
1005         if(setVersion(c, m.u.serverHello.version) < 0) {
1006                 tlsError(c, EIllegalParameter, "incompatible version %r");
1007                 goto Err;
1008         }
1009         memmove(c->srandom, m.u.serverHello.random, RandomSize);
1010         c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
1011         if(c->sid->len != 0 && c->sid->len != SidSize) {
1012                 tlsError(c, EIllegalParameter, "invalid server session identifier");
1013                 goto Err;
1014         }
1015         cipher = m.u.serverHello.cipher;
1016         if(!setAlgs(c, cipher)) {
1017                 tlsError(c, EIllegalParameter, "invalid cipher suite");
1018                 goto Err;
1019         }
1020         if(m.u.serverHello.compressor != CompressionNull) {
1021                 tlsError(c, EIllegalParameter, "invalid compression");
1022                 goto Err;
1023         }
1024         msgClear(&m);
1025
1026         /* certificate */
1027         if(!msgRecv(c, &m) || m.tag != HCertificate) {
1028                 tlsError(c, EUnexpectedMessage, "expected a certificate");
1029                 goto Err;
1030         }
1031         if(m.u.certificate.ncert < 1) {
1032                 tlsError(c, EIllegalParameter, "runt certificate");
1033                 goto Err;
1034         }
1035         c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
1036         msgClear(&m);
1037
1038         /* server key exchange */
1039         dhx = isDHE(cipher) || isECDHE(cipher);
1040         if(!msgRecv(c, &m))
1041                 goto Err;
1042         if(m.tag == HServerKeyExchange) {
1043                 if(!dhx){
1044                         tlsError(c, EUnexpectedMessage, "got an server key exchange");
1045                         goto Err;
1046                 }
1047                 if(isECDHE(cipher))
1048                         epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
1049                                 m.u.serverKeyExchange.curve,
1050                                 m.u.serverKeyExchange.dh_Ys);
1051                 else
1052                         epm = tlsSecDHEc(c->sec, c->srandom, c->version,
1053                                 m.u.serverKeyExchange.dh_p, 
1054                                 m.u.serverKeyExchange.dh_g,
1055                                 m.u.serverKeyExchange.dh_Ys);
1056                 if(epm == nil)
1057                         goto Badcert;
1058                 msgClear(&m);
1059                 if(!msgRecv(c, &m))
1060                         goto Err;
1061         } else if(dhx){
1062                 tlsError(c, EUnexpectedMessage, "expected server key exchange");
1063                 goto Err;
1064         }
1065
1066         /* certificate request (optional) */
1067         creq = 0;
1068         if(m.tag == HCertificateRequest) {
1069                 creq = 1;
1070                 msgClear(&m);
1071                 if(!msgRecv(c, &m))
1072                         goto Err;
1073         }
1074
1075         if(m.tag != HServerHelloDone) {
1076                 tlsError(c, EUnexpectedMessage, "expected a server hello done");
1077                 goto Err;
1078         }
1079         msgClear(&m);
1080
1081         if(!dhx)
1082                 epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
1083                         c->cert->data, c->cert->len, c->version);
1084
1085         if(epm == nil){
1086         Badcert:
1087                 tlsError(c, EBadCertificate, "bad certificate: %r");
1088                 goto Err;
1089         }
1090
1091         setSecrets(c->sec, kd, c->nsecret);
1092         secrets = (char*)emalloc(2*c->nsecret);
1093         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
1094         rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
1095         memset(secrets, 0, 2*c->nsecret);
1096         free(secrets);
1097         memset(kd, 0, c->nsecret);
1098         if(rv < 0){
1099                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
1100                 goto Err;
1101         }
1102
1103         if(creq) {
1104                 if(cert != nil && certlen > 0){
1105                         m.u.certificate.ncert = 1;
1106                         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
1107                         m.u.certificate.certs[0] = makebytes(cert, certlen);
1108                 }               
1109                 m.tag = HCertificate;
1110                 if(!msgSend(c, &m, AFlush))
1111                         goto Err;
1112                 msgClear(&m);
1113         }
1114
1115         /* client key exchange */
1116         m.tag = HClientKeyExchange;
1117         m.u.clientKeyExchange.key = epm;
1118         epm = nil;
1119         if(m.u.clientKeyExchange.key == nil) {
1120                 tlsError(c, EHandshakeFailure, "can't set secret: %r");
1121                 goto Err;
1122         }
1123          
1124         if(!msgSend(c, &m, AFlush))
1125                 goto Err;
1126         msgClear(&m);
1127
1128         /* certificate verify */
1129         if(creq && cert != nil && certlen > 0) {
1130                 uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */
1131                 HandshakeHash hsave;
1132
1133                 /* save the state for the Finish message */
1134                 hsave = c->handhash;
1135                 md5(nil, 0, hshashes, &c->handhash.md5);
1136                 sha1(nil, 0, hshashes+MD5dlen, &c->handhash.sha1);
1137                 c->handhash = hsave;
1138
1139                 c->sec->rpc = factotum_rsa_open(cert, certlen);
1140                 if(c->sec->rpc == nil){
1141                         tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
1142                         goto Err;
1143                 }
1144                 c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
1145                 if(c->sec->rsapub == nil){
1146                         tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
1147                         goto Err;
1148                 }
1149
1150                 paddedHashes = pkcs1padbuf(hshashes, MD5dlen+SHA1dlen, c->sec->rsapub->n);
1151                 signedMP = factotum_rsa_decrypt(c->sec->rpc, paddedHashes);
1152                 if(signedMP == nil){
1153                         tlsError(c, EHandshakeFailure, "factotum_rsa_decrypt: %r");
1154                         goto Err;
1155                 }
1156                 m.u.certificateVerify.signature = mptobytes(signedMP);
1157                 mpfree(signedMP);
1158
1159                 m.tag = HCertificateVerify;
1160                 if(!msgSend(c, &m, AFlush))
1161                         goto Err;
1162                 msgClear(&m);
1163         } 
1164
1165         /* change cipher spec */
1166         if(fprint(c->ctl, "changecipher") < 0){
1167                 tlsError(c, EInternalError, "can't enable cipher: %r");
1168                 goto Err;
1169         }
1170
1171         // Cipherchange must occur immediately before Finished to avoid
1172         // potential hole;  see section 4.3 of Wagner Schneier 1996.
1173         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
1174                 tlsError(c, EInternalError, "can't set finished 1: %r");
1175                 goto Err;
1176         }
1177         m.tag = HFinished;
1178         m.u.finished = c->finished;
1179         if(!msgSend(c, &m, AFlush)) {
1180                 tlsError(c, EInternalError, "can't flush after client Finished: %r");
1181                 goto Err;
1182         }
1183         msgClear(&m);
1184
1185         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
1186                 tlsError(c, EInternalError, "can't set finished 0: %r");
1187                 goto Err;
1188         }
1189         if(!msgRecv(c, &m)) {
1190                 tlsError(c, EInternalError, "can't read server Finished: %r");
1191                 goto Err;
1192         }
1193         if(m.tag != HFinished) {
1194                 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
1195                 goto Err;
1196         }
1197
1198         if(!finishedMatch(c, &m.u.finished)) {
1199                 tlsError(c, EHandshakeFailure, "finished verification failed");
1200                 goto Err;
1201         }
1202         msgClear(&m);
1203
1204         if(fprint(c->ctl, "opened") < 0){
1205                 if(trace)
1206                         trace("unable to do final open: %r\n");
1207                 goto Err;
1208         }
1209         tlsSecOk(c->sec);
1210         return c;
1211
1212 Err:
1213         free(epm);
1214         msgClear(&m);
1215         tlsConnectionFree(c);
1216         return 0;
1217 }
1218
1219
1220 //================= message functions ========================
1221
1222 static void
1223 msgHash(TlsConnection *c, uchar *p, int n)
1224 {
1225         md5(p, n, 0, &c->handhash.md5);
1226         sha1(p, n, 0, &c->handhash.sha1);
1227         if(c->version >= TLS12Version)
1228                 sha2_256(p, n, 0, &c->handhash.sha2_256);
1229 }
1230
1231 static int
1232 msgSend(TlsConnection *c, Msg *m, int act)
1233 {
1234         uchar *p; // sendp = start of new message;  p = write pointer
1235         int nn, n, i;
1236
1237         if(c->sendp == nil)
1238                 c->sendp = c->sendbuf;
1239         p = c->sendp;
1240         if(c->trace)
1241                 c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m));
1242
1243         p[0] = m->tag;  // header - fill in size later
1244         p += 4;
1245
1246         switch(m->tag) {
1247         default:
1248                 tlsError(c, EInternalError, "can't encode a %d", m->tag);
1249                 goto Err;
1250         case HClientHello:
1251                 // version
1252                 put16(p, m->u.clientHello.version);
1253                 p += 2;
1254
1255                 // random
1256                 memmove(p, m->u.clientHello.random, RandomSize);
1257                 p += RandomSize;
1258
1259                 // sid
1260                 n = m->u.clientHello.sid->len;
1261                 assert(n < 256);
1262                 p[0] = n;
1263                 memmove(p+1, m->u.clientHello.sid->data, n);
1264                 p += n+1;
1265
1266                 n = m->u.clientHello.ciphers->len;
1267                 assert(n > 0 && n < 200);
1268                 put16(p, n*2);
1269                 p += 2;
1270                 for(i=0; i<n; i++) {
1271                         put16(p, m->u.clientHello.ciphers->data[i]);
1272                         p += 2;
1273                 }
1274
1275                 n = m->u.clientHello.compressors->len;
1276                 assert(n > 0);
1277                 p[0] = n;
1278                 memmove(p+1, m->u.clientHello.compressors->data, n);
1279                 p += n+1;
1280
1281                 if(m->u.clientHello.extensions == nil)
1282                         break;
1283                 n = m->u.clientHello.extensions->len;
1284                 if(n == 0)
1285                         break;
1286                 put16(p, n);
1287                 memmove(p+2, m->u.clientHello.extensions->data, n);
1288                 p += n+2;
1289                 break;
1290         case HServerHello:
1291                 put16(p, m->u.serverHello.version);
1292                 p += 2;
1293
1294                 // random
1295                 memmove(p, m->u.serverHello.random, RandomSize);
1296                 p += RandomSize;
1297
1298                 // sid
1299                 n = m->u.serverHello.sid->len;
1300                 assert(n < 256);
1301                 p[0] = n;
1302                 memmove(p+1, m->u.serverHello.sid->data, n);
1303                 p += n+1;
1304
1305                 put16(p, m->u.serverHello.cipher);
1306                 p += 2;
1307                 p[0] = m->u.serverHello.compressor;
1308                 p += 1;
1309
1310                 if(m->u.serverHello.extensions == nil)
1311                         break;
1312                 n = m->u.serverHello.extensions->len;
1313                 if(n == 0)
1314                         break;
1315                 put16(p, n);
1316                 memmove(p+2, m->u.serverHello.extensions->data, n);
1317                 p += n+2;
1318                 break;
1319         case HServerHelloDone:
1320                 break;
1321         case HCertificate:
1322                 nn = 0;
1323                 for(i = 0; i < m->u.certificate.ncert; i++)
1324                         nn += 3 + m->u.certificate.certs[i]->len;
1325                 if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) {
1326                         tlsError(c, EInternalError, "output buffer too small for certificate");
1327                         goto Err;
1328                 }
1329                 put24(p, nn);
1330                 p += 3;
1331                 for(i = 0; i < m->u.certificate.ncert; i++){
1332                         put24(p, m->u.certificate.certs[i]->len);
1333                         p += 3;
1334                         memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
1335                         p += m->u.certificate.certs[i]->len;
1336                 }
1337                 break;
1338         case HCertificateVerify:
1339                 put16(p, m->u.certificateVerify.signature->len);
1340                 p += 2;
1341                 memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
1342                 p += m->u.certificateVerify.signature->len;
1343                 break;  
1344         case HClientKeyExchange:
1345                 n = m->u.clientKeyExchange.key->len;
1346                 if(c->version != SSL3Version){
1347                         if(isECDHE(c->cipher))
1348                                 *p++ = n;
1349                         else
1350                                 put16(p, n), p += 2;
1351                 }
1352                 memmove(p, m->u.clientKeyExchange.key->data, n);
1353                 p += n;
1354                 break;
1355         case HFinished:
1356                 memmove(p, m->u.finished.verify, m->u.finished.n);
1357                 p += m->u.finished.n;
1358                 break;
1359         }
1360
1361         // go back and fill in size
1362         n = p - c->sendp;
1363         assert(p <= c->sendbuf + sizeof(c->sendbuf));
1364         put24(c->sendp+1, n-4);
1365
1366         // remember hash of Handshake messages
1367         if(m->tag != HHelloRequest)
1368                 msgHash(c, c->sendp, n);
1369
1370         c->sendp = p;
1371         if(act == AFlush){
1372                 c->sendp = c->sendbuf;
1373                 if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){
1374                         fprint(2, "write error: %r\n");
1375                         goto Err;
1376                 }
1377         }
1378         msgClear(m);
1379         return 1;
1380 Err:
1381         msgClear(m);
1382         return 0;
1383 }
1384
1385 static uchar*
1386 tlsReadN(TlsConnection *c, int n)
1387 {
1388         uchar *p;
1389         int nn, nr;
1390
1391         nn = c->ep - c->rp;
1392         if(nn < n){
1393                 if(c->rp != c->recvbuf){
1394                         memmove(c->recvbuf, c->rp, nn);
1395                         c->rp = c->recvbuf;
1396                         c->ep = &c->recvbuf[nn];
1397                 }
1398                 for(; nn < n; nn += nr) {
1399                         nr = read(c->hand, &c->rp[nn], n - nn);
1400                         if(nr <= 0)
1401                                 return nil;
1402                         c->ep += nr;
1403                 }
1404         }
1405         p = c->rp;
1406         c->rp += n;
1407         return p;
1408 }
1409
1410 static int
1411 msgRecv(TlsConnection *c, Msg *m)
1412 {
1413         uchar *p;
1414         int type, n, nn, i, nsid, nrandom, nciph;
1415
1416         for(;;) {
1417                 p = tlsReadN(c, 4);
1418                 if(p == nil)
1419                         return 0;
1420                 type = p[0];
1421                 n = get24(p+1);
1422
1423                 if(type != HHelloRequest)
1424                         break;
1425                 if(n != 0) {
1426                         tlsError(c, EDecodeError, "invalid hello request during handshake");
1427                         return 0;
1428                 }
1429         }
1430
1431         if(n > sizeof(c->recvbuf)) {
1432                 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf));
1433                 return 0;
1434         }
1435
1436         if(type == HSSL2ClientHello){
1437                 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
1438                         This is sent by some clients that we must interoperate
1439                         with, such as Java's JSSE and Microsoft's Internet Explorer. */
1440                 p = tlsReadN(c, n);
1441                 if(p == nil)
1442                         return 0;
1443                 msgHash(c, p, n);
1444                 m->tag = HClientHello;
1445                 if(n < 22)
1446                         goto Short;
1447                 m->u.clientHello.version = get16(p+1);
1448                 p += 3;
1449                 n -= 3;
1450                 nn = get16(p); /* cipher_spec_len */
1451                 nsid = get16(p + 2);
1452                 nrandom = get16(p + 4);
1453                 p += 6;
1454                 n -= 6;
1455                 if(nsid != 0    /* no sid's, since shouldn't restart using ssl2 header */
1456                                 || nrandom < 16 || nn % 3)
1457                         goto Err;
1458                 if(c->trace && (n - nrandom != nn))
1459                         c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1460                 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1461                 nciph = 0;
1462                 for(i = 0; i < nn; i += 3)
1463                         if(p[i] == 0)
1464                                 nciph++;
1465                 m->u.clientHello.ciphers = newints(nciph);
1466                 nciph = 0;
1467                 for(i = 0; i < nn; i += 3)
1468                         if(p[i] == 0)
1469                                 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1470                 p += nn;
1471                 m->u.clientHello.sid = makebytes(nil, 0);
1472                 if(nrandom > RandomSize)
1473                         nrandom = RandomSize;
1474                 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1475                 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1476                 m->u.clientHello.compressors = newbytes(1);
1477                 m->u.clientHello.compressors->data[0] = CompressionNull;
1478                 goto Ok;
1479         }
1480         msgHash(c, p, 4);
1481
1482         p = tlsReadN(c, n);
1483         if(p == nil)
1484                 return 0;
1485
1486         msgHash(c, p, n);
1487
1488         m->tag = type;
1489
1490         switch(type) {
1491         default:
1492                 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1493                 goto Err;
1494         case HClientHello:
1495                 if(n < 2)
1496                         goto Short;
1497                 m->u.clientHello.version = get16(p);
1498                 p += 2;
1499                 n -= 2;
1500
1501                 if(n < RandomSize)
1502                         goto Short;
1503                 memmove(m->u.clientHello.random, p, RandomSize);
1504                 p += RandomSize;
1505                 n -= RandomSize;
1506                 if(n < 1 || n < p[0]+1)
1507                         goto Short;
1508                 m->u.clientHello.sid = makebytes(p+1, p[0]);
1509                 p += m->u.clientHello.sid->len+1;
1510                 n -= m->u.clientHello.sid->len+1;
1511
1512                 if(n < 2)
1513                         goto Short;
1514                 nn = get16(p);
1515                 p += 2;
1516                 n -= 2;
1517
1518                 if((nn & 1) || n < nn || nn < 2)
1519                         goto Short;
1520                 m->u.clientHello.ciphers = newints(nn >> 1);
1521                 for(i = 0; i < nn; i += 2)
1522                         m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1523                 p += nn;
1524                 n -= nn;
1525
1526                 if(n < 1 || n < p[0]+1 || p[0] == 0)
1527                         goto Short;
1528                 nn = p[0];
1529                 m->u.clientHello.compressors = makebytes(p+1, nn);
1530                 p += nn + 1;
1531                 n -= nn + 1;
1532
1533                 if(n < 2)
1534                         break;
1535                 nn = get16(p);
1536                 if(nn > n-2)
1537                         goto Short;
1538                 m->u.clientHello.extensions = makebytes(p+2, nn);
1539                 n -= nn + 2;
1540                 break;
1541         case HServerHello:
1542                 if(n < 2)
1543                         goto Short;
1544                 m->u.serverHello.version = get16(p);
1545                 p += 2;
1546                 n -= 2;
1547
1548                 if(n < RandomSize)
1549                         goto Short;
1550                 memmove(m->u.serverHello.random, p, RandomSize);
1551                 p += RandomSize;
1552                 n -= RandomSize;
1553
1554                 if(n < 1 || n < p[0]+1)
1555                         goto Short;
1556                 m->u.serverHello.sid = makebytes(p+1, p[0]);
1557                 p += m->u.serverHello.sid->len+1;
1558                 n -= m->u.serverHello.sid->len+1;
1559
1560                 if(n < 3)
1561                         goto Short;
1562                 m->u.serverHello.cipher = get16(p);
1563                 m->u.serverHello.compressor = p[2];
1564                 p += 3;
1565                 n -= 3;
1566
1567                 if(n < 2)
1568                         break;
1569                 nn = get16(p);
1570                 if(nn > n-2)
1571                         goto Short;
1572                 m->u.serverHello.extensions = makebytes(p+2, nn);
1573                 n -= nn + 2;
1574                 break;
1575         case HCertificate:
1576                 if(n < 3)
1577                         goto Short;
1578                 nn = get24(p);
1579                 p += 3;
1580                 n -= 3;
1581                 if(nn == 0 && n > 0)
1582                         goto Short;
1583                 /* certs */
1584                 i = 0;
1585                 while(n > 0) {
1586                         if(n < 3)
1587                                 goto Short;
1588                         nn = get24(p);
1589                         p += 3;
1590                         n -= 3;
1591                         if(nn > n)
1592                                 goto Short;
1593                         m->u.certificate.ncert = i+1;
1594                         m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes*));
1595                         m->u.certificate.certs[i] = makebytes(p, nn);
1596                         p += nn;
1597                         n -= nn;
1598                         i++;
1599                 }
1600                 break;
1601         case HCertificateRequest:
1602                 if(n < 1)
1603                         goto Short;
1604                 nn = p[0];
1605                 p += 1;
1606                 n -= 1;
1607                 if(nn > n)
1608                         goto Short;
1609                 m->u.certificateRequest.types = makebytes(p, nn);
1610                 p += nn;
1611                 n -= nn;
1612                 if(n < 2)
1613                         goto Short;
1614                 nn = get16(p);
1615                 p += 2;
1616                 n -= 2;
1617                 /* nn == 0 can happen; yahoo's servers do it */
1618                 if(nn != n)
1619                         goto Short;
1620                 /* cas */
1621                 i = 0;
1622                 while(n > 0) {
1623                         if(n < 2)
1624                                 goto Short;
1625                         nn = get16(p);
1626                         p += 2;
1627                         n -= 2;
1628                         if(nn < 1 || nn > n)
1629                                 goto Short;
1630                         m->u.certificateRequest.nca = i+1;
1631                         m->u.certificateRequest.cas = erealloc(
1632                                 m->u.certificateRequest.cas, (i+1)*sizeof(Bytes*));
1633                         m->u.certificateRequest.cas[i] = makebytes(p, nn);
1634                         p += nn;
1635                         n -= nn;
1636                         i++;
1637                 }
1638                 break;
1639         case HServerHelloDone:
1640                 break;
1641         case HServerKeyExchange:
1642                 if(n < 2)
1643                         goto Short;
1644                 if(isECDHE(c->cipher)){
1645                         nn = *p;
1646                         p++, n--;
1647                         if(nn != 3 || nn > n) /* not a named curve */
1648                                 goto Short;
1649                         nn = get16(p);
1650                         p += 2, n -= 2;
1651                         m->u.serverKeyExchange.curve = nn;
1652
1653                         nn = *p++, n--;
1654                         if(nn < 1 || nn > n)
1655                                 goto Short;
1656                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1657                         p += nn, n -= nn;
1658                 }else if(isDHE(c->cipher)){
1659                         nn = get16(p);
1660                         p += 2, n -= 2;
1661                         if(nn < 1 || nn > n)
1662                                 goto Short;
1663                         m->u.serverKeyExchange.dh_p = makebytes(p, nn);
1664                         p += nn, n -= nn;
1665         
1666                         if(n < 2)
1667                                 goto Short;
1668                         nn = get16(p);
1669                         p += 2, n -= 2;
1670                         if(nn < 1 || nn > n)
1671                                 goto Short;
1672                         m->u.serverKeyExchange.dh_g = makebytes(p, nn);
1673                         p += nn, n -= nn;
1674         
1675                         if(n < 2)
1676                                 goto Short;
1677                         nn = get16(p);
1678                         p += 2, n -= 2;
1679                         if(nn < 1 || nn > n)
1680                                 goto Short;
1681                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1682                         p += nn, n -= nn;
1683                 } else {
1684                         /* should not happen */
1685                         break;
1686                 }
1687                 if(n >= 2){
1688                         if(c->version >= TLS12Version){
1689                                 /* signature hash algorithm */
1690                                 p += 2, n -= 2;
1691                                 if(n < 2)
1692                                         goto Short;
1693                         }
1694                         nn = get16(p);
1695                         p += 2, n -= 2;
1696                         if(nn > 0 && nn <= n){
1697                                 m->u.serverKeyExchange.dh_signature = makebytes(p, nn);
1698                                 n -= nn;
1699                         }
1700                 }
1701                 break;          
1702         case HClientKeyExchange:
1703                 /*
1704                  * this message depends upon the encryption selected
1705                  * assume rsa.
1706                  */
1707                 if(c->version == SSL3Version)
1708                         nn = n;
1709                 else{
1710                         if(n < 2)
1711                                 goto Short;
1712                         nn = get16(p);
1713                         p += 2;
1714                         n -= 2;
1715                 }
1716                 if(n < nn)
1717                         goto Short;
1718                 m->u.clientKeyExchange.key = makebytes(p, nn);
1719                 n -= nn;
1720                 break;
1721         case HFinished:
1722                 m->u.finished.n = c->finished.n;
1723                 if(n < m->u.finished.n)
1724                         goto Short;
1725                 memmove(m->u.finished.verify, p, m->u.finished.n);
1726                 n -= m->u.finished.n;
1727                 break;
1728         }
1729
1730         if(type != HClientHello && type != HServerHello && n != 0)
1731                 goto Short;
1732 Ok:
1733         if(c->trace){
1734                 char *buf;
1735                 buf = emalloc(8000);
1736                 c->trace("recv %s", msgPrint(buf, 8000, m));
1737                 free(buf);
1738         }
1739         return 1;
1740 Short:
1741         tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
1742 Err:
1743         msgClear(m);
1744         return 0;
1745 }
1746
1747 static void
1748 msgClear(Msg *m)
1749 {
1750         int i;
1751
1752         switch(m->tag) {
1753         default:
1754                 sysfatal("msgClear: unknown message type: %d", m->tag);
1755         case HHelloRequest:
1756                 break;
1757         case HClientHello:
1758                 freebytes(m->u.clientHello.sid);
1759                 freeints(m->u.clientHello.ciphers);
1760                 freebytes(m->u.clientHello.compressors);
1761                 freebytes(m->u.clientHello.extensions);
1762                 break;
1763         case HServerHello:
1764                 freebytes(m->u.serverHello.sid);
1765                 freebytes(m->u.serverHello.extensions);
1766                 break;
1767         case HCertificate:
1768                 for(i=0; i<m->u.certificate.ncert; i++)
1769                         freebytes(m->u.certificate.certs[i]);
1770                 free(m->u.certificate.certs);
1771                 break;
1772         case HCertificateRequest:
1773                 freebytes(m->u.certificateRequest.types);
1774                 for(i=0; i<m->u.certificateRequest.nca; i++)
1775                         freebytes(m->u.certificateRequest.cas[i]);
1776                 free(m->u.certificateRequest.cas);
1777                 break;
1778         case HCertificateVerify:
1779                 freebytes(m->u.certificateVerify.signature);
1780                 break;
1781         case HServerHelloDone:
1782                 break;
1783         case HServerKeyExchange:
1784                 freebytes(m->u.serverKeyExchange.dh_p);
1785                 freebytes(m->u.serverKeyExchange.dh_g);
1786                 freebytes(m->u.serverKeyExchange.dh_Ys);
1787                 freebytes(m->u.serverKeyExchange.dh_signature);
1788                 break;
1789         case HClientKeyExchange:
1790                 freebytes(m->u.clientKeyExchange.key);
1791                 break;
1792         case HFinished:
1793                 break;
1794         }
1795         memset(m, 0, sizeof(Msg));
1796 }
1797
1798 static char *
1799 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1800 {
1801         int i;
1802
1803         if(s0)
1804                 bs = seprint(bs, be, "%s", s0);
1805         if(b == nil)
1806                 bs = seprint(bs, be, "nil");
1807         else {
1808                 bs = seprint(bs, be, "<%d> [", b->len);
1809                 for(i=0; i<b->len; i++)
1810                         bs = seprint(bs, be, "%.2x ", b->data[i]);
1811         }
1812         bs = seprint(bs, be, "]");
1813         if(s1)
1814                 bs = seprint(bs, be, "%s", s1);
1815         return bs;
1816 }
1817
1818 static char *
1819 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1820 {
1821         int i;
1822
1823         if(s0)
1824                 bs = seprint(bs, be, "%s", s0);
1825         bs = seprint(bs, be, "[");
1826         if(b == nil)
1827                 bs = seprint(bs, be, "nil");
1828         else
1829                 for(i=0; i<b->len; i++)
1830                         bs = seprint(bs, be, "%x ", b->data[i]);
1831         bs = seprint(bs, be, "]");
1832         if(s1)
1833                 bs = seprint(bs, be, "%s", s1);
1834         return bs;
1835 }
1836
1837 static char*
1838 msgPrint(char *buf, int n, Msg *m)
1839 {
1840         int i;
1841         char *bs = buf, *be = buf+n;
1842
1843         switch(m->tag) {
1844         default:
1845                 bs = seprint(bs, be, "unknown %d\n", m->tag);
1846                 break;
1847         case HClientHello:
1848                 bs = seprint(bs, be, "ClientHello\n");
1849                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1850                 bs = seprint(bs, be, "\trandom: ");
1851                 for(i=0; i<RandomSize; i++)
1852                         bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1853                 bs = seprint(bs, be, "\n");
1854                 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1855                 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1856                 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1857                 if(m->u.clientHello.extensions != nil)
1858                         bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
1859                 break;
1860         case HServerHello:
1861                 bs = seprint(bs, be, "ServerHello\n");
1862                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1863                 bs = seprint(bs, be, "\trandom: ");
1864                 for(i=0; i<RandomSize; i++)
1865                         bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1866                 bs = seprint(bs, be, "\n");
1867                 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1868                 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1869                 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1870                 if(m->u.serverHello.extensions != nil)
1871                         bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
1872                 break;
1873         case HCertificate:
1874                 bs = seprint(bs, be, "Certificate\n");
1875                 for(i=0; i<m->u.certificate.ncert; i++)
1876                         bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1877                 break;
1878         case HCertificateRequest:
1879                 bs = seprint(bs, be, "CertificateRequest\n");
1880                 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1881                 bs = seprint(bs, be, "\tcertificateauthorities\n");
1882                 for(i=0; i<m->u.certificateRequest.nca; i++)
1883                         bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1884                 break;
1885         case HCertificateVerify:
1886                 bs = seprint(bs, be, "HCertificateVerify\n");
1887                 bs = bytesPrint(bs, be, "\tsignature: ", m->u.certificateVerify.signature,"\n");
1888                 break;  
1889         case HServerHelloDone:
1890                 bs = seprint(bs, be, "ServerHelloDone\n");
1891                 break;
1892         case HServerKeyExchange:
1893                 bs = seprint(bs, be, "HServerKeyExchange\n");
1894                 if(m->u.serverKeyExchange.curve != 0){
1895                         bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
1896                 } else {
1897                         bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
1898                         bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
1899                 }
1900                 bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
1901                 bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
1902                 break;
1903         case HClientKeyExchange:
1904                 bs = seprint(bs, be, "HClientKeyExchange\n");
1905                 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1906                 break;
1907         case HFinished:
1908                 bs = seprint(bs, be, "HFinished\n");
1909                 for(i=0; i<m->u.finished.n; i++)
1910                         bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1911                 bs = seprint(bs, be, "\n");
1912                 break;
1913         }
1914         USED(bs);
1915         return buf;
1916 }
1917
1918 static void
1919 tlsError(TlsConnection *c, int err, char *fmt, ...)
1920 {
1921         char msg[512];
1922         va_list arg;
1923
1924         va_start(arg, fmt);
1925         vseprint(msg, msg+sizeof(msg), fmt, arg);
1926         va_end(arg);
1927         if(c->trace)
1928                 c->trace("tlsError: %s\n", msg);
1929         else if(c->erred)
1930                 fprint(2, "double error: %r, %s", msg);
1931         else
1932                 werrstr("tls: local %s", msg);
1933         c->erred = 1;
1934         fprint(c->ctl, "alert %d", err);
1935 }
1936
1937 // commit to specific version number
1938 static int
1939 setVersion(TlsConnection *c, int version)
1940 {
1941         if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1942                 return -1;
1943         if(version > c->version)
1944                 version = c->version;
1945         if(version == SSL3Version) {
1946                 c->version = version;
1947                 c->finished.n = SSL3FinishedLen;
1948         }else {
1949                 c->version = version;
1950                 c->finished.n = TLSFinishedLen;
1951         }
1952         c->verset = 1;
1953         return fprint(c->ctl, "version 0x%x", version);
1954 }
1955
1956 // confirm that received Finished message matches the expected value
1957 static int
1958 finishedMatch(TlsConnection *c, Finished *f)
1959 {
1960         return memcmp(f->verify, c->finished.verify, f->n) == 0;
1961 }
1962
1963 // free memory associated with TlsConnection struct
1964 //              (but don't close the TLS channel itself)
1965 static void
1966 tlsConnectionFree(TlsConnection *c)
1967 {
1968         tlsSecClose(c->sec);
1969         freebytes(c->sid);
1970         freebytes(c->cert);
1971         memset(c, 0, sizeof(c));
1972         free(c);
1973 }
1974
1975
1976 //================= cipher choices ========================
1977
1978 static int weakCipher[CipherMax] =
1979 {
1980         1,      /* TLS_NULL_WITH_NULL_NULL */
1981         1,      /* TLS_RSA_WITH_NULL_MD5 */
1982         1,      /* TLS_RSA_WITH_NULL_SHA */
1983         1,      /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1984         0,      /* TLS_RSA_WITH_RC4_128_MD5 */
1985         0,      /* TLS_RSA_WITH_RC4_128_SHA */
1986         1,      /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1987         0,      /* TLS_RSA_WITH_IDEA_CBC_SHA */
1988         1,      /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1989         0,      /* TLS_RSA_WITH_DES_CBC_SHA */
1990         0,      /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1991         1,      /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1992         0,      /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1993         0,      /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1994         1,      /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1995         0,      /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1996         0,      /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1997         1,      /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1998         0,      /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1999         0,      /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
2000         1,      /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
2001         0,      /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
2002         0,      /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
2003         1,      /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
2004         1,      /* TLS_DH_anon_WITH_RC4_128_MD5 */
2005         1,      /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
2006         1,      /* TLS_DH_anon_WITH_DES_CBC_SHA */
2007         1,      /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
2008 };
2009
2010 static int
2011 setAlgs(TlsConnection *c, int a)
2012 {
2013         int i;
2014
2015         for(i = 0; i < nelem(cipherAlgs); i++){
2016                 if(cipherAlgs[i].tlsid == a){
2017                         c->cipher = a;
2018                         c->enc = cipherAlgs[i].enc;
2019                         c->digest = cipherAlgs[i].digest;
2020                         c->nsecret = cipherAlgs[i].nsecret;
2021                         if(c->nsecret > MaxKeyData)
2022                                 return 0;
2023                         return 1;
2024                 }
2025         }
2026         return 0;
2027 }
2028
2029 static int
2030 okCipher(Ints *cv)
2031 {
2032         int weak, i, j, c;
2033
2034         weak = 1;
2035         for(i = 0; i < cv->len; i++) {
2036                 c = cv->data[i];
2037                 if(c >= CipherMax)
2038                         weak = 0;
2039                 else
2040                         weak &= weakCipher[c];
2041                 if(isDHE(c) || isECDHE(c))
2042                         continue;       /* TODO: not implemented for server */
2043                 for(j = 0; j < nelem(cipherAlgs); j++)
2044                         if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
2045                                 return c;
2046         }
2047         if(weak)
2048                 return -2;
2049         return -1;
2050 }
2051
2052 static int
2053 okCompression(Bytes *cv)
2054 {
2055         int i, j, c;
2056
2057         for(i = 0; i < cv->len; i++) {
2058                 c = cv->data[i];
2059                 for(j = 0; j < nelem(compressors); j++) {
2060                         if(compressors[j] == c)
2061                                 return c;
2062                 }
2063         }
2064         return -1;
2065 }
2066
2067 static Lock     ciphLock;
2068 static int      nciphers;
2069
2070 static int
2071 initCiphers(void)
2072 {
2073         enum {MaxAlgF = 1024, MaxAlgs = 10};
2074         char s[MaxAlgF], *flds[MaxAlgs];
2075         int i, j, n, ok;
2076
2077         lock(&ciphLock);
2078         if(nciphers){
2079                 unlock(&ciphLock);
2080                 return nciphers;
2081         }
2082         j = open("#a/tls/encalgs", OREAD);
2083         if(j < 0){
2084                 werrstr("can't open #a/tls/encalgs: %r");
2085                 return 0;
2086         }
2087         n = read(j, s, MaxAlgF-1);
2088         close(j);
2089         if(n <= 0){
2090                 werrstr("nothing in #a/tls/encalgs: %r");
2091                 return 0;
2092         }
2093         s[n] = 0;
2094         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2095         for(i = 0; i < nelem(cipherAlgs); i++){
2096                 ok = 0;
2097                 for(j = 0; j < n; j++){
2098                         if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
2099                                 ok = 1;
2100                                 break;
2101                         }
2102                 }
2103                 cipherAlgs[i].ok = ok;
2104         }
2105
2106         j = open("#a/tls/hashalgs", OREAD);
2107         if(j < 0){
2108                 werrstr("can't open #a/tls/hashalgs: %r");
2109                 return 0;
2110         }
2111         n = read(j, s, MaxAlgF-1);
2112         close(j);
2113         if(n <= 0){
2114                 werrstr("nothing in #a/tls/hashalgs: %r");
2115                 return 0;
2116         }
2117         s[n] = 0;
2118         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2119         for(i = 0; i < nelem(cipherAlgs); i++){
2120                 ok = 0;
2121                 for(j = 0; j < n; j++){
2122                         if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
2123                                 ok = 1;
2124                                 break;
2125                         }
2126                 }
2127                 cipherAlgs[i].ok &= ok;
2128                 if(cipherAlgs[i].ok)
2129                         nciphers++;
2130         }
2131         unlock(&ciphLock);
2132         return nciphers;
2133 }
2134
2135 static Ints*
2136 makeciphers(void)
2137 {
2138         Ints *is;
2139         int i, j;
2140
2141         is = newints(nciphers);
2142         j = 0;
2143         for(i = 0; i < nelem(cipherAlgs); i++){
2144                 if(cipherAlgs[i].ok)
2145                         is->data[j++] = cipherAlgs[i].tlsid;
2146         }
2147         return is;
2148 }
2149
2150
2151
2152 //================= security functions ========================
2153
2154 // given X.509 certificate, set up connection to factotum
2155 //      for using corresponding private key
2156 static AuthRpc*
2157 factotum_rsa_open(uchar *cert, int certlen)
2158 {
2159         int afd;
2160         char *s;
2161         mpint *pub = nil;
2162         RSApub *rsapub;
2163         AuthRpc *rpc;
2164
2165         // start talking to factotum
2166         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
2167                 return nil;
2168         if((rpc = auth_allocrpc(afd)) == nil){
2169                 close(afd);
2170                 return nil;
2171         }
2172         s = "proto=rsa service=tls role=client";
2173         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
2174                 factotum_rsa_close(rpc);
2175                 return nil;
2176         }
2177
2178         // roll factotum keyring around to match certificate
2179         rsapub = X509toRSApub(cert, certlen, nil, 0);
2180         while(1){
2181                 if(auth_rpc(rpc, "read", nil, 0) != ARok){
2182                         factotum_rsa_close(rpc);
2183                         rpc = nil;
2184                         goto done;
2185                 }
2186                 pub = strtomp(rpc->arg, nil, 16, nil);
2187                 assert(pub != nil);
2188                 if(mpcmp(pub,rsapub->n) == 0)
2189                         break;
2190         }
2191 done:
2192         mpfree(pub);
2193         rsapubfree(rsapub);
2194         return rpc;
2195 }
2196
2197 static mpint*
2198 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
2199 {
2200         char *p;
2201         int rv;
2202
2203         p = mptoa(cipher, 16, nil, 0);
2204         mpfree(cipher);
2205         if(p == nil)
2206                 return nil;
2207         rv = auth_rpc(rpc, "write", p, strlen(p));
2208         free(p);
2209         if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
2210                 return nil;
2211         return strtomp(rpc->arg, nil, 16, nil);
2212 }
2213
2214 static void
2215 factotum_rsa_close(AuthRpc*rpc)
2216 {
2217         if(!rpc)
2218                 return;
2219         close(rpc->afd);
2220         auth_freerpc(rpc);
2221 }
2222
2223 static void
2224 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2225 {
2226         uchar ai[MD5dlen], tmp[MD5dlen];
2227         int i, n;
2228         MD5state *s;
2229
2230         // generate a1
2231         s = hmac_md5(label, nlabel, key, nkey, nil, nil);
2232         s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2233         hmac_md5(seed1, nseed1, key, nkey, ai, s);
2234
2235         while(nbuf > 0) {
2236                 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
2237                 s = hmac_md5(label, nlabel, key, nkey, nil, s);
2238                 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2239                 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
2240                 n = MD5dlen;
2241                 if(n > nbuf)
2242                         n = nbuf;
2243                 for(i = 0; i < n; i++)
2244                         buf[i] ^= tmp[i];
2245                 buf += n;
2246                 nbuf -= n;
2247                 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
2248                 memmove(ai, tmp, MD5dlen);
2249         }
2250 }
2251
2252 static void
2253 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2254 {
2255         uchar ai[SHA1dlen], tmp[SHA1dlen];
2256         int i, n;
2257         SHAstate *s;
2258
2259         // generate a1
2260         s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
2261         s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2262         hmac_sha1(seed1, nseed1, key, nkey, ai, s);
2263
2264         while(nbuf > 0) {
2265                 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
2266                 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
2267                 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2268                 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
2269                 n = SHA1dlen;
2270                 if(n > nbuf)
2271                         n = nbuf;
2272                 for(i = 0; i < n; i++)
2273                         buf[i] ^= tmp[i];
2274                 buf += n;
2275                 nbuf -= n;
2276                 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
2277                 memmove(ai, tmp, SHA1dlen);
2278         }
2279 }
2280
2281 static void
2282 p_sha256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
2283 {
2284         uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
2285         SHAstate *s;
2286         int n;
2287
2288         // generate a1
2289         s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
2290         hmac_sha2_256(seed, nseed, key, nkey, ai, s);
2291
2292         while(nbuf > 0) {
2293                 s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
2294                 s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
2295                 hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
2296                 n = SHA2_256dlen;
2297                 if(n > nbuf)
2298                         n = nbuf;
2299                 memmove(buf, tmp, n);
2300                 buf += n;
2301                 nbuf -= n;
2302                 hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
2303                 memmove(ai, tmp, SHA2_256dlen);
2304         }
2305 }
2306
2307 // fill buf with md5(args)^sha1(args)
2308 static void
2309 tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2310 {
2311         int nlabel = strlen(label);
2312         int n = (nkey + 1) >> 1;
2313
2314         memset(buf, 0, nbuf);
2315         tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2316         tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2317 }
2318
2319 static void
2320 tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2321 {
2322         uchar seed[2*RandomSize];
2323
2324         assert(nseed0+nseed1 <= sizeof(seed));
2325         memmove(seed, seed0, nseed0);
2326         memmove(seed+nseed0, seed1, nseed1);
2327         p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1);
2328 }
2329
2330 /*
2331  * for setting server session id's
2332  */
2333 static Lock     sidLock;
2334 static long     maxSid = 1;
2335
2336 /* the keys are verified to have the same public components
2337  * and to function correctly with pkcs 1 encryption and decryption. */
2338 static TlsSec*
2339 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
2340 {
2341         TlsSec *sec = emalloc(sizeof(*sec));
2342
2343         USED(csid); USED(ncsid);  // ignore csid for now
2344
2345         memmove(sec->crandom, crandom, RandomSize);
2346         sec->clientVers = cvers;
2347
2348         put32(sec->srandom, time(0));
2349         genrandom(sec->srandom+4, RandomSize-4);
2350         memmove(srandom, sec->srandom, RandomSize);
2351
2352         /*
2353          * make up a unique sid: use our pid, and and incrementing id
2354          * can signal no sid by setting nssid to 0.
2355          */
2356         memset(ssid, 0, SidSize);
2357         put32(ssid, getpid());
2358         lock(&sidLock);
2359         put32(ssid+4, maxSid++);
2360         unlock(&sidLock);
2361         *nssid = SidSize;
2362         return sec;
2363 }
2364
2365 static int
2366 tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm)
2367 {
2368         if(epm != nil){
2369                 if(setVers(sec, vers) < 0)
2370                         goto Err;
2371                 serverMasterSecret(sec, epm);
2372         }else if(sec->vers != vers){
2373                 werrstr("mismatched session versions");
2374                 goto Err;
2375         }
2376         return 0;
2377 Err:
2378         sec->ok = -1;
2379         return -1;
2380 }
2381
2382 static TlsSec*
2383 tlsSecInitc(int cvers, uchar *crandom)
2384 {
2385         TlsSec *sec = emalloc(sizeof(*sec));
2386         sec->clientVers = cvers;
2387         put32(sec->crandom, time(0));
2388         genrandom(sec->crandom+4, RandomSize-4);
2389         memmove(crandom, sec->crandom, RandomSize);
2390         return sec;
2391 }
2392
2393 static Bytes*
2394 tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
2395 {
2396         RSApub *pub;
2397         Bytes *epm;
2398
2399         USED(sid);
2400         USED(nsid);
2401         
2402         memmove(sec->srandom, srandom, RandomSize);
2403         if(setVers(sec, vers) < 0)
2404                 goto Err;
2405         pub = X509toRSApub(cert, ncert, nil, 0);
2406         if(pub == nil){
2407                 werrstr("invalid x509/rsa certificate");
2408                 goto Err;
2409         }
2410         epm = clientMasterSecret(sec, pub);
2411         rsapubfree(pub);
2412         if(epm != nil)
2413                 return epm;
2414 Err:
2415         sec->ok = -1;
2416         return nil;
2417 }
2418
2419 static int
2420 tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient)
2421 {
2422         if(sec->nfin != nfin){
2423                 sec->ok = -1;
2424                 werrstr("invalid finished exchange");
2425                 return -1;
2426         }
2427         hsh.md5.malloced = 0;
2428         hsh.sha1.malloced = 0;
2429         hsh.sha2_256.malloced = 0;
2430         (*sec->setFinished)(sec, hsh, fin, isclient);
2431         return 1;
2432 }
2433
2434 static void
2435 tlsSecOk(TlsSec *sec)
2436 {
2437         if(sec->ok == 0)
2438                 sec->ok = 1;
2439 }
2440
2441 static void
2442 tlsSecKill(TlsSec *sec)
2443 {
2444         if(!sec)
2445                 return;
2446         factotum_rsa_close(sec->rpc);
2447         sec->ok = -1;
2448 }
2449
2450 static void
2451 tlsSecClose(TlsSec *sec)
2452 {
2453         if(!sec)
2454                 return;
2455         factotum_rsa_close(sec->rpc);
2456         free(sec->server);
2457         free(sec);
2458 }
2459
2460 static int
2461 setVers(TlsSec *sec, int v)
2462 {
2463         if(v == SSL3Version){
2464                 sec->setFinished = sslSetFinished;
2465                 sec->nfin = SSL3FinishedLen;
2466                 sec->prf = sslPRF;
2467         }else if(v < TLS12Version) {
2468                 sec->setFinished = tls10SetFinished;
2469                 sec->nfin = TLSFinishedLen;
2470                 sec->prf = tls10PRF;
2471         }else {
2472                 sec->setFinished = tls12SetFinished;
2473                 sec->nfin = TLSFinishedLen;
2474                 sec->prf = tls12PRF;
2475         }
2476         sec->vers = v;
2477         return 0;
2478 }
2479
2480 /*
2481  * generate secret keys from the master secret.
2482  *
2483  * different crypto selections will require different amounts
2484  * of key expansion and use of key expansion data,
2485  * but it's all generated using the same function.
2486  */
2487 static void
2488 setSecrets(TlsSec *sec, uchar *kd, int nkd)
2489 {
2490         (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
2491                         sec->srandom, RandomSize, sec->crandom, RandomSize);
2492 }
2493
2494 /*
2495  * set the master secret from the pre-master secret,
2496  * destroys premaster.
2497  */
2498 static void
2499 setMasterSecret(TlsSec *sec, Bytes *pm)
2500 {
2501         (*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
2502                         sec->crandom, RandomSize, sec->srandom, RandomSize);
2503
2504         memset(pm->data, 0, pm->len);   
2505         freebytes(pm);
2506 }
2507
2508 static void
2509 serverMasterSecret(TlsSec *sec, Bytes *epm)
2510 {
2511         Bytes *pm;
2512
2513         pm = pkcs1_decrypt(sec, epm);
2514
2515         // if the client messed up, just continue as if everything is ok,
2516         // to prevent attacks to check for correctly formatted messages.
2517         // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
2518         if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
2519                 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
2520                         sec->ok, pm, pm != nil ? get16(pm->data) : -1, sec->clientVers, epm->len);
2521                 sec->ok = -1;
2522                 freebytes(pm);
2523                 pm = newbytes(MasterSecretSize);
2524                 genrandom(pm->data, MasterSecretSize);
2525         }
2526         assert(pm->len == MasterSecretSize);
2527         setMasterSecret(sec, pm);
2528 }
2529
2530 static Bytes*
2531 clientMasterSecret(TlsSec *sec, RSApub *pub)
2532 {
2533         Bytes *pm, *epm;
2534
2535         pm = newbytes(MasterSecretSize);
2536         put16(pm->data, sec->clientVers);
2537         genrandom(pm->data+2, MasterSecretSize - 2);
2538         epm = pkcs1_encrypt(pm, pub, 2);
2539         setMasterSecret(sec, pm);
2540         return epm;
2541 }
2542
2543 static void
2544 sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2545 {
2546         DigestState *s;
2547         uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
2548         char *label;
2549
2550         if(isClient)
2551                 label = "CLNT";
2552         else
2553                 label = "SRVR";
2554
2555         md5((uchar*)label, 4, nil, &hsh.md5);
2556         md5(sec->sec, MasterSecretSize, nil, &hsh.md5);
2557         memset(pad, 0x36, 48);
2558         md5(pad, 48, nil, &hsh.md5);
2559         md5(nil, 0, h0, &hsh.md5);
2560         memset(pad, 0x5C, 48);
2561         s = md5(sec->sec, MasterSecretSize, nil, nil);
2562         s = md5(pad, 48, nil, s);
2563         md5(h0, MD5dlen, finished, s);
2564
2565         sha1((uchar*)label, 4, nil, &hsh.sha1);
2566         sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1);
2567         memset(pad, 0x36, 40);
2568         sha1(pad, 40, nil, &hsh.sha1);
2569         sha1(nil, 0, h1, &hsh.sha1);
2570         memset(pad, 0x5C, 40);
2571         s = sha1(sec->sec, MasterSecretSize, nil, nil);
2572         s = sha1(pad, 40, nil, s);
2573         sha1(h1, SHA1dlen, finished + MD5dlen, s);
2574 }
2575
2576 // fill "finished" arg with md5(args)^sha1(args)
2577 static void
2578 tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2579 {
2580         uchar h0[MD5dlen], h1[SHA1dlen];
2581         char *label;
2582
2583         // get current hash value, but allow further messages to be hashed in
2584         md5(nil, 0, h0, &hsh.md5);
2585         sha1(nil, 0, h1, &hsh.sha1);
2586
2587         if(isClient)
2588                 label = "client finished";
2589         else
2590                 label = "server finished";
2591         tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2592 }
2593
2594 static void
2595 tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2596 {
2597         uchar seed[SHA2_256dlen];
2598         char *label;
2599
2600         // get current hash value, but allow further messages to be hashed in
2601         sha2_256(nil, 0, seed, &hsh.sha2_256);
2602
2603         if(isClient)
2604                 label = "client finished";
2605         else
2606                 label = "server finished";
2607         p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen);
2608 }
2609
2610 static void
2611 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2612 {
2613         uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2614         DigestState *s;
2615         int i, n, len;
2616
2617         USED(label);
2618         len = 1;
2619         while(nbuf > 0){
2620                 if(len > 26)
2621                         return;
2622                 for(i = 0; i < len; i++)
2623                         tmp[i] = 'A' - 1 + len;
2624                 s = sha1(tmp, len, nil, nil);
2625                 s = sha1(key, nkey, nil, s);
2626                 s = sha1(seed0, nseed0, nil, s);
2627                 sha1(seed1, nseed1, sha1dig, s);
2628                 s = md5(key, nkey, nil, nil);
2629                 md5(sha1dig, SHA1dlen, md5dig, s);
2630                 n = MD5dlen;
2631                 if(n > nbuf)
2632                         n = nbuf;
2633                 memmove(buf, md5dig, n);
2634                 buf += n;
2635                 nbuf -= n;
2636                 len++;
2637         }
2638 }
2639
2640 static mpint*
2641 bytestomp(Bytes* bytes)
2642 {
2643         return betomp(bytes->data, bytes->len, nil);
2644 }
2645
2646 /*
2647  * Convert mpint* to Bytes, putting high order byte first.
2648  */
2649 static Bytes*
2650 mptobytes(mpint* big)
2651 {
2652         Bytes* ans;
2653         int n;
2654
2655         n = (mpsignif(big)+7)/8;
2656         if(n == 0) n = 1;
2657         ans = newbytes(n);
2658         ans->len = mptobe(big, ans->data, n, nil);
2659         return ans;
2660 }
2661
2662 // Do RSA computation on block according to key, and pad
2663 // result on left with zeros to make it modlen long.
2664 static Bytes*
2665 rsacomp(Bytes* block, RSApub* key, int modlen)
2666 {
2667         mpint *x, *y;
2668         Bytes *a, *ybytes;
2669         int ylen;
2670
2671         x = bytestomp(block);
2672         y = rsaencrypt(key, x, nil);
2673         mpfree(x);
2674         ybytes = mptobytes(y);
2675         ylen = ybytes->len;
2676         mpfree(y);
2677
2678         if(ylen < modlen) {
2679                 a = newbytes(modlen);
2680                 memset(a->data, 0, modlen-ylen);
2681                 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2682                 freebytes(ybytes);
2683                 ybytes = a;
2684         }
2685         else if(ylen > modlen) {
2686                 // assume it has leading zeros (mod should make it so)
2687                 a = newbytes(modlen);
2688                 memmove(a->data, ybytes->data, modlen);
2689                 freebytes(ybytes);
2690                 ybytes = a;
2691         }
2692         return ybytes;
2693 }
2694
2695 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2696 static Bytes*
2697 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2698 {
2699         Bytes *pad, *eb, *ans;
2700         int i, dlen, padlen, modlen;
2701
2702         modlen = (mpsignif(key->n)+7)/8;
2703         dlen = data->len;
2704         if(modlen < 12 || dlen > modlen - 11)
2705                 return nil;
2706         padlen = modlen - 3 - dlen;
2707         pad = newbytes(padlen);
2708         genrandom(pad->data, padlen);
2709         for(i = 0; i < padlen; i++) {
2710                 if(blocktype == 0)
2711                         pad->data[i] = 0;
2712                 else if(blocktype == 1)
2713                         pad->data[i] = 255;
2714                 else if(pad->data[i] == 0)
2715                         pad->data[i] = 1;
2716         }
2717         eb = newbytes(modlen);
2718         eb->data[0] = 0;
2719         eb->data[1] = blocktype;
2720         memmove(eb->data+2, pad->data, padlen);
2721         eb->data[padlen+2] = 0;
2722         memmove(eb->data+padlen+3, data->data, dlen);
2723         ans = rsacomp(eb, key, modlen);
2724         freebytes(eb);
2725         freebytes(pad);
2726         return ans;
2727 }
2728
2729 // decrypt data according to PKCS#1, with given key.
2730 // expect a block type of 2.
2731 static Bytes*
2732 pkcs1_decrypt(TlsSec *sec, Bytes *cipher)
2733 {
2734         Bytes *eb, *ans = nil;
2735         int i, modlen;
2736         mpint *x, *y;
2737
2738         modlen = (mpsignif(sec->rsapub->n)+7)/8;
2739         if(cipher->len != modlen)
2740                 return nil;
2741         x = bytestomp(cipher);
2742         y = factotum_rsa_decrypt(sec->rpc, x);
2743         if(y == nil)
2744                 return nil;
2745         eb = mptobytes(y);
2746         mpfree(y);
2747         if(eb->len < modlen){ // pad on left with zeros
2748                 ans = newbytes(modlen);
2749                 memset(ans->data, 0, modlen-eb->len);
2750                 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2751                 freebytes(eb);
2752                 eb = ans;
2753         }
2754         if(eb->data[0] == 0 && eb->data[1] == 2) {
2755                 for(i = 2; i < modlen; i++)
2756                         if(eb->data[i] == 0)
2757                                 break;
2758                 if(i < modlen - 1)
2759                         ans = makebytes(eb->data+i+1, modlen-(i+1));
2760         }
2761         freebytes(eb);
2762         return ans;
2763 }
2764
2765
2766 //================= general utility functions ========================
2767
2768 static void *
2769 emalloc(int n)
2770 {
2771         void *p;
2772         if(n==0)
2773                 n=1;
2774         p = malloc(n);
2775         if(p == nil)
2776                 sysfatal("out of memory");
2777         memset(p, 0, n);
2778         setmalloctag(p, getcallerpc(&n));
2779         return p;
2780 }
2781
2782 static void *
2783 erealloc(void *ReallocP, int ReallocN)
2784 {
2785         if(ReallocN == 0)
2786                 ReallocN = 1;
2787         if(ReallocP == nil)
2788                 ReallocP = emalloc(ReallocN);
2789         else if((ReallocP = realloc(ReallocP, ReallocN)) == nil)
2790                 sysfatal("out of memory");
2791         setrealloctag(ReallocP, getcallerpc(&ReallocP));
2792         return(ReallocP);
2793 }
2794
2795 static void
2796 put32(uchar *p, u32int x)
2797 {
2798         p[0] = x>>24;
2799         p[1] = x>>16;
2800         p[2] = x>>8;
2801         p[3] = x;
2802 }
2803
2804 static void
2805 put24(uchar *p, int x)
2806 {
2807         p[0] = x>>16;
2808         p[1] = x>>8;
2809         p[2] = x;
2810 }
2811
2812 static void
2813 put16(uchar *p, int x)
2814 {
2815         p[0] = x>>8;
2816         p[1] = x;
2817 }
2818
2819 static u32int
2820 get32(uchar *p)
2821 {
2822         return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2823 }
2824
2825 static int
2826 get24(uchar *p)
2827 {
2828         return (p[0]<<16)|(p[1]<<8)|p[2];
2829 }
2830
2831 static int
2832 get16(uchar *p)
2833 {
2834         return (p[0]<<8)|p[1];
2835 }
2836
2837 #define OFFSET(x, s) offsetof(s, x)
2838
2839 static Bytes*
2840 newbytes(int len)
2841 {
2842         Bytes* ans;
2843
2844         if(len < 0)
2845                 abort();
2846         ans = (Bytes*)emalloc(OFFSET(data[0], Bytes) + len);
2847         ans->len = len;
2848         return ans;
2849 }
2850
2851 /*
2852  * newbytes(len), with data initialized from buf
2853  */
2854 static Bytes*
2855 makebytes(uchar* buf, int len)
2856 {
2857         Bytes* ans;
2858
2859         ans = newbytes(len);
2860         memmove(ans->data, buf, len);
2861         return ans;
2862 }
2863
2864 static void
2865 freebytes(Bytes* b)
2866 {
2867         free(b);
2868 }
2869
2870 /* len is number of ints */
2871 static Ints*
2872 newints(int len)
2873 {
2874         Ints* ans;
2875
2876         if(len < 0 || len > ((uint)-1>>1)/sizeof(int))
2877                 abort();
2878         ans = (Ints*)emalloc(OFFSET(data[0], Ints) + len*sizeof(int));
2879         ans->len = len;
2880         return ans;
2881 }
2882
2883 static void
2884 freeints(Ints* b)
2885 {
2886         free(b);
2887 }