]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libsec/port/tlshand.c
libsec: declare aes_setupEnc static
[plan9front.git] / sys / src / libsec / port / tlshand.c
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7
8 // The main groups of functions are:
9 //              client/server - main handshake protocol definition
10 //              message functions - formating handshake messages
11 //              cipher choices - catalog of digest and encrypt algorithms
12 //              security functions - PKCS#1, sslHMAC, session keygen
13 //              general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16
17 enum {
18         TLSFinishedLen = 12,
19         SSL3FinishedLen = MD5dlen+SHA1dlen,
20         MaxKeyData = 160,       // amount of secret we may need
21         MaxChunk = 1<<15,
22         RandomSize = 32,
23         SidSize = 32,
24         MasterSecretSize = 48,
25         AQueue = 0,
26         AFlush = 1,
27 };
28
29 typedef struct TlsSec TlsSec;
30
31 typedef struct Bytes{
32         int len;
33         uchar data[1];  // [len]
34 } Bytes;
35
36 typedef struct Ints{
37         int len;
38         int data[1];  // [len]
39 } Ints;
40
41 typedef struct Algs{
42         char *enc;
43         char *digest;
44         int nsecret;
45         int tlsid;
46         int ok;
47 } Algs;
48
49 typedef struct Namedcurve{
50         int tlsid;
51         char *name;
52
53         char *p;
54         char *a;
55         char *b;
56         char *G;
57         char *n;
58         char *h;
59 } Namedcurve;
60
61 typedef struct Finished{
62         uchar verify[SSL3FinishedLen];
63         int n;
64 } Finished;
65
66 typedef struct HandshakeHash {
67         MD5state        md5;
68         SHAstate        sha1;
69         SHA2_256state   sha2_256;
70 } HandshakeHash;
71
72 typedef struct TlsConnection{
73         TlsSec *sec;    // security management goo
74         int hand, ctl;  // record layer file descriptors
75         int erred;              // set when tlsError called
76         int (*trace)(char*fmt, ...); // for debugging
77         int version;    // protocol we are speaking
78         int verset;             // version has been set
79         int ver2hi;             // server got a version 2 hello
80         int isClient;   // is this the client or server?
81         Bytes *sid;             // SessionID
82         Bytes *cert;    // only last - no chain
83
84         Lock statelk;
85         int state;              // must be set using setstate
86
87         // input buffer for handshake messages
88         uchar recvbuf[MaxChunk];
89         uchar *rp, *ep;
90
91         // output buffer
92         uchar sendbuf[MaxChunk];
93         uchar *sendp;
94
95         uchar crandom[RandomSize];      // client random
96         uchar srandom[RandomSize];      // server random
97         int clientVersion;      // version in ClientHello
98         int cipher;
99         char *digest;   // name of digest algorithm to use
100         char *enc;              // name of encryption algorithm to use
101         int nsecret;    // amount of secret data to init keys
102
103         // for finished messages
104         HandshakeHash   handhash;
105         Finished        finished;
106 } TlsConnection;
107
108 typedef struct Msg{
109         int tag;
110         union {
111                 struct {
112                         int version;
113                         uchar   random[RandomSize];
114                         Bytes*  sid;
115                         Ints*   ciphers;
116                         Bytes*  compressors;
117                         Bytes*  extensions;
118                 } clientHello;
119                 struct {
120                         int version;
121                         uchar   random[RandomSize];
122                         Bytes*  sid;
123                         int     cipher;
124                         int     compressor;
125                         Bytes*  extensions;
126                 } serverHello;
127                 struct {
128                         int ncert;
129                         Bytes **certs;
130                 } certificate;
131                 struct {
132                         Bytes *types;
133                         int nca;
134                         Bytes **cas;
135                 } certificateRequest;
136                 struct {
137                         Bytes *key;
138                 } clientKeyExchange;
139                 struct {
140                         Bytes *dh_p;
141                         Bytes *dh_g;
142                         Bytes *dh_Ys;
143                         Bytes *dh_signature;
144                         int curve;
145                 } serverKeyExchange;
146                 struct {
147                         Bytes *signature;
148                 } certificateVerify;            
149                 Finished finished;
150         } u;
151 } Msg;
152
153 typedef struct TlsSec{
154         char *server;   // name of remote; nil for server
155         int ok; // <0 killed; == 0 in progress; >0 reusable
156         RSApub *rsapub;
157         AuthRpc *rpc;   // factotum for rsa private key
158         uchar sec[MasterSecretSize];    // master secret
159         uchar crandom[RandomSize];      // client random
160         uchar srandom[RandomSize];      // server random
161         int clientVers;         // version in ClientHello
162         int vers;                       // final version
163         // byte generation and handshake checksum
164         void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
165         void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int);
166         int nfin;
167 } TlsSec;
168
169
170 enum {
171         SSL3Version     = 0x0300,
172         TLS10Version    = 0x0301,
173         TLS11Version    = 0x0302,
174         TLS12Version    = 0x0303,
175         ProtocolVersion = TLS12Version, // maximum version we speak
176         MinProtoVersion = 0x0300,       // limits on version we accept
177         MaxProtoVersion = 0x03ff,
178 };
179
180 // handshake type
181 enum {
182         HHelloRequest,
183         HClientHello,
184         HServerHello,
185         HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
186         HCertificate = 11,
187         HServerKeyExchange,
188         HCertificateRequest,
189         HServerHelloDone,
190         HCertificateVerify,
191         HClientKeyExchange,
192         HFinished = 20,
193         HMax
194 };
195
196 // alerts
197 enum {
198         ECloseNotify = 0,
199         EUnexpectedMessage = 10,
200         EBadRecordMac = 20,
201         EDecryptionFailed = 21,
202         ERecordOverflow = 22,
203         EDecompressionFailure = 30,
204         EHandshakeFailure = 40,
205         ENoCertificate = 41,
206         EBadCertificate = 42,
207         EUnsupportedCertificate = 43,
208         ECertificateRevoked = 44,
209         ECertificateExpired = 45,
210         ECertificateUnknown = 46,
211         EIllegalParameter = 47,
212         EUnknownCa = 48,
213         EAccessDenied = 49,
214         EDecodeError = 50,
215         EDecryptError = 51,
216         EExportRestriction = 60,
217         EProtocolVersion = 70,
218         EInsufficientSecurity = 71,
219         EInternalError = 80,
220         EUserCanceled = 90,
221         ENoRenegotiation = 100,
222         EMax = 256
223 };
224
225 // cipher suites
226 enum {
227         TLS_NULL_WITH_NULL_NULL                 = 0x0000,
228         TLS_RSA_WITH_NULL_MD5                   = 0x0001,
229         TLS_RSA_WITH_NULL_SHA                   = 0x0002,
230         TLS_RSA_EXPORT_WITH_RC4_40_MD5          = 0x0003,
231         TLS_RSA_WITH_RC4_128_MD5                = 0x0004,
232         TLS_RSA_WITH_RC4_128_SHA                = 0x0005,
233         TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5      = 0X0006,
234         TLS_RSA_WITH_IDEA_CBC_SHA               = 0X0007,
235         TLS_RSA_EXPORT_WITH_DES40_CBC_SHA       = 0X0008,
236         TLS_RSA_WITH_DES_CBC_SHA                = 0X0009,
237         TLS_RSA_WITH_3DES_EDE_CBC_SHA           = 0X000A,
238         TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA    = 0X000B,
239         TLS_DH_DSS_WITH_DES_CBC_SHA             = 0X000C,
240         TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA        = 0X000D,
241         TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA    = 0X000E,
242         TLS_DH_RSA_WITH_DES_CBC_SHA             = 0X000F,
243         TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA        = 0X0010,
244         TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA   = 0X0011,
245         TLS_DHE_DSS_WITH_DES_CBC_SHA            = 0X0012,
246         TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA       = 0X0013,       // ZZZ must be implemented for tls1.0 compliance
247         TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA   = 0X0014,
248         TLS_DHE_RSA_WITH_DES_CBC_SHA            = 0X0015,
249         TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA       = 0X0016,
250         TLS_DH_anon_EXPORT_WITH_RC4_40_MD5      = 0x0017,
251         TLS_DH_anon_WITH_RC4_128_MD5            = 0x0018,
252         TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA   = 0X0019,
253         TLS_DH_anon_WITH_DES_CBC_SHA            = 0X001A,
254         TLS_DH_anon_WITH_3DES_EDE_CBC_SHA       = 0X001B,
255
256         TLS_RSA_WITH_AES_128_CBC_SHA            = 0X002f,       // aes, aka rijndael with 128 bit blocks
257         TLS_DH_DSS_WITH_AES_128_CBC_SHA         = 0X0030,
258         TLS_DH_RSA_WITH_AES_128_CBC_SHA         = 0X0031,
259         TLS_DHE_DSS_WITH_AES_128_CBC_SHA        = 0X0032,
260         TLS_DHE_RSA_WITH_AES_128_CBC_SHA        = 0X0033,
261         TLS_DH_anon_WITH_AES_128_CBC_SHA        = 0X0034,
262         TLS_RSA_WITH_AES_256_CBC_SHA            = 0X0035,
263         TLS_DH_DSS_WITH_AES_256_CBC_SHA         = 0X0036,
264         TLS_DH_RSA_WITH_AES_256_CBC_SHA         = 0X0037,
265         TLS_DHE_DSS_WITH_AES_256_CBC_SHA        = 0X0038,
266         TLS_DHE_RSA_WITH_AES_256_CBC_SHA        = 0X0039,
267         TLS_DH_anon_WITH_AES_256_CBC_SHA        = 0X003A,
268
269         TLS_RSA_WITH_AES_128_CBC_SHA256         = 0X003C,
270         TLS_RSA_WITH_AES_256_CBC_SHA256         = 0X003D,
271
272         TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA      = 0xC013,
273         TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA      = 0xC014,
274         TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA    = 0xC009,
275         TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA    = 0xC00A,
276         CipherMax
277 };
278
279 // compression methods
280 enum {
281         CompressionNull = 0,
282         CompressionMax
283 };
284
285 static Algs cipherAlgs[] = {
286         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
287         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA},
288         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
289         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
290         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
291         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
292         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
293         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
294         {"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_RSA_WITH_AES_128_CBC_SHA256},
295         {"aes_256_cbc", "sha256", 2*(32+16+SHA2_256dlen), TLS_RSA_WITH_AES_256_CBC_SHA256},
296         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
297         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
298         {"rc4_128", "sha1",     2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
299         {"rc4_128", "md5",      2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
300 };
301
302 static uchar compressors[] = {
303         CompressionNull,
304 };
305
306 static Namedcurve namedcurves[] = {
307 {0x0017, "secp256r1",
308         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
309         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC",
310         "5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B",
311         "046B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C2964FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
312         "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
313         "1"}
314 };
315
316 static uchar pointformats[] = {
317         CompressionNull /* support of uncompressed point format is mandatory */
318 };
319
320 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
321 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
322 static void     msgClear(Msg *m);
323 static char* msgPrint(char *buf, int n, Msg *m);
324 static int      msgRecv(TlsConnection *c, Msg *m);
325 static int      msgSend(TlsConnection *c, Msg *m, int act);
326 static void     tlsError(TlsConnection *c, int err, char *msg, ...);
327 #pragma varargck argpos tlsError 3
328 static int setVersion(TlsConnection *c, int version);
329 static int finishedMatch(TlsConnection *c, Finished *f);
330 static void tlsConnectionFree(TlsConnection *c);
331
332 static int setAlgs(TlsConnection *c, int a);
333 static int okCipher(Ints *cv);
334 static int okCompression(Bytes *cv);
335 static int initCiphers(void);
336 static Ints* makeciphers(void);
337
338 static TlsSec*  tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
339 static int      tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
340 static TlsSec*  tlsSecInitc(int cvers, uchar *crandom);
341 static Bytes*   tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
342 static Bytes*   tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
343 static Bytes*   tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
344 static int      tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
345 static void     tlsSecOk(TlsSec *sec);
346 static void     tlsSecKill(TlsSec *sec);
347 static void     tlsSecClose(TlsSec *sec);
348 static void     setMasterSecret(TlsSec *sec, Bytes *pm);
349 static void     serverMasterSecret(TlsSec *sec, Bytes *epm);
350 static void     setSecrets(TlsSec *sec, uchar *kd, int nkd);
351 static Bytes*   clientMasterSecret(TlsSec *sec, RSApub *pub);
352 static Bytes*   pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
353 static Bytes*   pkcs1_decrypt(TlsSec *sec, Bytes *cipher);
354 static void     tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
355 static void     tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
356 static void     sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
357 static void     sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
358                         uchar *seed0, int nseed0, uchar *seed1, int nseed1);
359 static int setVers(TlsSec *sec, int version);
360
361 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
362 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
363 static void factotum_rsa_close(AuthRpc*rpc);
364
365 static void* emalloc(int);
366 static void* erealloc(void*, int);
367 static void put32(uchar *p, u32int);
368 static void put24(uchar *p, int);
369 static void put16(uchar *p, int);
370 static u32int get32(uchar *p);
371 static int get24(uchar *p);
372 static int get16(uchar *p);
373 static Bytes* newbytes(int len);
374 static Bytes* makebytes(uchar* buf, int len);
375 static Bytes* mptobytes(mpint* big);
376 static mpint* bytestomp(Bytes* bytes);
377 static void freebytes(Bytes* b);
378 static Ints* newints(int len);
379 static void freeints(Ints* b);
380
381 /* x509.c */
382 extern mpint* pkcs1padbuf(uchar *buf, int len, mpint *modulus);
383
384 //================= client/server ========================
385
386 //      push TLS onto fd, returning new (application) file descriptor
387 //              or -1 if error.
388 int
389 tlsServer(int fd, TLSconn *conn)
390 {
391         char buf[8];
392         char dname[64];
393         int n, data, ctl, hand;
394         TlsConnection *tls;
395
396         if(conn == nil)
397                 return -1;
398         ctl = open("#a/tls/clone", ORDWR);
399         if(ctl < 0)
400                 return -1;
401         n = read(ctl, buf, sizeof(buf)-1);
402         if(n < 0){
403                 close(ctl);
404                 return -1;
405         }
406         buf[n] = 0;
407         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
408         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
409         hand = open(dname, ORDWR);
410         if(hand < 0){
411                 close(ctl);
412                 return -1;
413         }
414         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
415         tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
416         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
417         data = open(dname, ORDWR);
418         close(hand);
419         close(ctl);
420         if(data < 0 || tls == nil){
421                 if(tls != nil)
422                         tlsConnectionFree(tls);
423                 return -1;
424         }
425         free(conn->cert);
426         conn->cert = 0;  // client certificates are not yet implemented
427         conn->certlen = 0;
428         conn->sessionIDlen = tls->sid->len;
429         conn->sessionID = emalloc(conn->sessionIDlen);
430         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
431         if(conn->sessionKey != nil
432         && conn->sessionType != nil
433         && strcmp(conn->sessionType, "ttls") == 0)
434                 tls->sec->prf(
435                         conn->sessionKey, conn->sessionKeylen,
436                         tls->sec->sec, MasterSecretSize,
437                         conn->sessionConst, 
438                         tls->sec->crandom, RandomSize,
439                         tls->sec->srandom, RandomSize);
440         tlsConnectionFree(tls);
441         close(fd);
442         return data;
443 }
444
445 static uchar*
446 tlsClientExtensions(TLSconn *conn, int *plen)
447 {
448         uchar *b, *p;
449         int i, n, m;
450
451         p = b = nil;
452
453         // RFC6066 - Server Name Identification
454         if(conn->serverName != nil){
455                 n = strlen(conn->serverName);
456
457                 m = p - b;
458                 b = erealloc(b, m + 2+2+2+1+2+n);
459                 p = b + m;
460
461                 put16(p, 0), p += 2;            /* Type: server_name */
462                 put16(p, 2+1+2+n), p += 2;      /* Length */
463                 put16(p, 1+2+n), p += 2;        /* Server Name list length */
464                 *p++ = 0;                       /* Server Name Type: host_name */
465                 put16(p, n), p += 2;            /* Server Name length */
466                 memmove(p, conn->serverName, n);
467                 p += n;
468         }
469
470         // ECDHE
471         if(1){
472                 m = p - b;
473                 b = erealloc(b, m + 2+2+2+nelem(namedcurves)*2 + 2+2+1+nelem(pointformats));
474                 p = b + m;
475
476                 n = nelem(namedcurves);
477                 put16(p, 0x000a), p += 2;       /* Type: elliptic_curves */
478                 put16(p, (n+1)*2), p += 2;      /* Length */
479                 put16(p, n*2), p += 2;          /* Elliptic Curves Length */
480                 for(i=0; i < n; i++){           /* Elliptic curves */
481                         put16(p, namedcurves[i].tlsid);
482                         p += 2;
483                 }
484
485                 n = nelem(pointformats);
486                 put16(p, 0x000b), p += 2;       /* Type: ec_point_formats */
487                 put16(p, n+1), p += 2;          /* Length */
488                 *p++ = n;                       /* EC point formats Length */
489                 for(i=0; i < n; i++)            /* Elliptic curves point formats */
490                         *p++ = pointformats[i];
491         }
492         
493         *plen = p - b;
494         return b;
495 }
496
497 //      push TLS onto fd, returning new (application) file descriptor
498 //              or -1 if error.
499 int
500 tlsClient(int fd, TLSconn *conn)
501 {
502         char buf[8];
503         char dname[64];
504         int n, data, ctl, hand;
505         TlsConnection *tls;
506         uchar *ext;
507
508         if(conn == nil)
509                 return -1;
510         ctl = open("#a/tls/clone", ORDWR);
511         if(ctl < 0)
512                 return -1;
513         n = read(ctl, buf, sizeof(buf)-1);
514         if(n < 0){
515                 close(ctl);
516                 return -1;
517         }
518         buf[n] = 0;
519         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
520         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
521         hand = open(dname, ORDWR);
522         if(hand < 0){
523                 close(ctl);
524                 return -1;
525         }
526         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
527         data = open(dname, ORDWR);
528         if(data < 0){
529                 close(hand);
530                 close(ctl);
531                 return -1;
532         }
533         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
534         ext = tlsClientExtensions(conn, &n);
535         tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, 
536                 ext, n, conn->trace);
537         free(ext);
538         close(hand);
539         close(ctl);
540         if(tls == nil){
541                 close(data);
542                 return -1;
543         }
544         conn->certlen = tls->cert->len;
545         conn->cert = emalloc(conn->certlen);
546         memcpy(conn->cert, tls->cert->data, conn->certlen);
547         conn->sessionIDlen = tls->sid->len;
548         conn->sessionID = emalloc(conn->sessionIDlen);
549         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
550         if(conn->sessionKey != nil
551         && conn->sessionType != nil
552         && strcmp(conn->sessionType, "ttls") == 0)
553                 tls->sec->prf(
554                         conn->sessionKey, conn->sessionKeylen,
555                         tls->sec->sec, MasterSecretSize,
556                         conn->sessionConst, 
557                         tls->sec->crandom, RandomSize,
558                         tls->sec->srandom, RandomSize);
559         tlsConnectionFree(tls);
560         close(fd);
561         return data;
562 }
563
564 static int
565 countchain(PEMChain *p)
566 {
567         int i = 0;
568
569         while (p) {
570                 i++;
571                 p = p->next;
572         }
573         return i;
574 }
575
576 static TlsConnection *
577 tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
578 {
579         TlsConnection *c;
580         Msg m;
581         Bytes *csid;
582         uchar sid[SidSize], kd[MaxKeyData];
583         char *secrets;
584         int cipher, compressor, nsid, rv, numcerts, i;
585
586         if(trace)
587                 trace("tlsServer2\n");
588         if(!initCiphers())
589                 return nil;
590         c = emalloc(sizeof(TlsConnection));
591         c->ctl = ctl;
592         c->hand = hand;
593         c->trace = trace;
594         c->version = ProtocolVersion;
595
596         memset(&m, 0, sizeof(m));
597         if(!msgRecv(c, &m)){
598                 if(trace)
599                         trace("initial msgRecv failed\n");
600                 goto Err;
601         }
602         if(m.tag != HClientHello) {
603                 tlsError(c, EUnexpectedMessage, "expected a client hello");
604                 goto Err;
605         }
606         c->clientVersion = m.u.clientHello.version;
607         if(trace)
608                 trace("ClientHello version %x\n", c->clientVersion);
609         if(setVersion(c, c->clientVersion) < 0) {
610                 tlsError(c, EIllegalParameter, "incompatible version");
611                 goto Err;
612         }
613
614         memmove(c->crandom, m.u.clientHello.random, RandomSize);
615         cipher = okCipher(m.u.clientHello.ciphers);
616         if(cipher < 0) {
617                 // reply with EInsufficientSecurity if we know that's the case
618                 if(cipher == -2)
619                         tlsError(c, EInsufficientSecurity, "cipher suites too weak");
620                 else
621                         tlsError(c, EHandshakeFailure, "no matching cipher suite");
622                 goto Err;
623         }
624         if(!setAlgs(c, cipher)){
625                 tlsError(c, EHandshakeFailure, "no matching cipher suite");
626                 goto Err;
627         }
628         compressor = okCompression(m.u.clientHello.compressors);
629         if(compressor < 0) {
630                 tlsError(c, EHandshakeFailure, "no matching compressor");
631                 goto Err;
632         }
633
634         csid = m.u.clientHello.sid;
635         if(trace)
636                 trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
637         c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
638         if(c->sec == nil){
639                 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
640                 goto Err;
641         }
642         c->sec->rpc = factotum_rsa_open(cert, certlen);
643         if(c->sec->rpc == nil){
644                 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
645                 goto Err;
646         }
647         c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
648         if(c->sec->rsapub == nil){
649                 tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
650                 goto Err;
651         }
652         msgClear(&m);
653
654         m.tag = HServerHello;
655         m.u.serverHello.version = c->version;
656         memmove(m.u.serverHello.random, c->srandom, RandomSize);
657         m.u.serverHello.cipher = cipher;
658         m.u.serverHello.compressor = compressor;
659         c->sid = makebytes(sid, nsid);
660         m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
661         if(!msgSend(c, &m, AQueue))
662                 goto Err;
663         msgClear(&m);
664
665         m.tag = HCertificate;
666         numcerts = countchain(chp);
667         m.u.certificate.ncert = 1 + numcerts;
668         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
669         m.u.certificate.certs[0] = makebytes(cert, certlen);
670         for (i = 0; i < numcerts && chp; i++, chp = chp->next)
671                 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
672         if(!msgSend(c, &m, AQueue))
673                 goto Err;
674         msgClear(&m);
675
676         m.tag = HServerHelloDone;
677         if(!msgSend(c, &m, AFlush))
678                 goto Err;
679         msgClear(&m);
680
681         if(!msgRecv(c, &m))
682                 goto Err;
683         if(m.tag != HClientKeyExchange) {
684                 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
685                 goto Err;
686         }
687         if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
688                 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
689                 goto Err;
690         }
691         setSecrets(c->sec, kd, c->nsecret);
692         if(trace)
693                 trace("tls secrets\n");
694         secrets = (char*)emalloc(2*c->nsecret);
695         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
696         rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
697         memset(secrets, 0, 2*c->nsecret);
698         free(secrets);
699         memset(kd, 0, c->nsecret);
700         if(rv < 0){
701                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
702                 goto Err;
703         }
704         msgClear(&m);
705
706         /* no CertificateVerify; skip to Finished */
707         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
708                 tlsError(c, EInternalError, "can't set finished: %r");
709                 goto Err;
710         }
711         if(!msgRecv(c, &m))
712                 goto Err;
713         if(m.tag != HFinished) {
714                 tlsError(c, EUnexpectedMessage, "expected a finished");
715                 goto Err;
716         }
717         if(!finishedMatch(c, &m.u.finished)) {
718                 tlsError(c, EHandshakeFailure, "finished verification failed");
719                 goto Err;
720         }
721         msgClear(&m);
722
723         /* change cipher spec */
724         if(fprint(c->ctl, "changecipher") < 0){
725                 tlsError(c, EInternalError, "can't enable cipher: %r");
726                 goto Err;
727         }
728
729         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
730                 tlsError(c, EInternalError, "can't set finished: %r");
731                 goto Err;
732         }
733         m.tag = HFinished;
734         m.u.finished = c->finished;
735         if(!msgSend(c, &m, AFlush))
736                 goto Err;
737         msgClear(&m);
738         if(trace)
739                 trace("tls finished\n");
740
741         if(fprint(c->ctl, "opened") < 0)
742                 goto Err;
743         tlsSecOk(c->sec);
744         return c;
745
746 Err:
747         msgClear(&m);
748         tlsConnectionFree(c);
749         return 0;
750 }
751
752 static int
753 isDHE(int tlsid)
754 {
755         switch(tlsid){
756         case TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA:
757         case TLS_DHE_DSS_WITH_DES_CBC_SHA:
758         case TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA:
759         case TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA:
760         case TLS_DHE_RSA_WITH_DES_CBC_SHA:
761         case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
762         case TLS_DHE_DSS_WITH_AES_128_CBC_SHA:
763         case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
764         case TLS_DHE_DSS_WITH_AES_256_CBC_SHA:
765         case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
766                 return 1;
767         }
768         return 0;
769 }
770
771 static int
772 isECDHE(int tlsid)
773 {
774         switch(tlsid){
775         case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
776         case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
777         case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
778         case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
779                 return 1;
780         }
781         return 0;
782 }
783
784 static Bytes*
785 tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, 
786         Bytes *p, Bytes *g, Bytes *Ys)
787 {
788         mpint *G, *P, *Y, *K;
789         Bytes *epm;
790         DHstate dh;
791
792         if(p == nil || g == nil || Ys == nil)
793                 return nil;
794
795         memmove(sec->srandom, srandom, RandomSize);
796         if(setVers(sec, vers) < 0)
797                 return nil;
798
799         epm = nil;
800         P = bytestomp(p);
801         G = bytestomp(g);
802         Y = bytestomp(Ys);
803         K = nil;
804
805         if(P == nil || G == nil || Y == nil || dh_new(&dh, P, nil, G) == nil)
806                 goto Out;
807         epm = mptobytes(dh.y);
808         K = dh_finish(&dh, Y);
809         if(K == nil){
810                 freebytes(epm);
811                 epm = nil;
812                 goto Out;
813         }
814         setMasterSecret(sec, mptobytes(K));
815
816 Out:
817         mpfree(K);
818         mpfree(Y);
819         mpfree(G);
820         mpfree(P);
821
822         return epm;
823 }
824
825 static ECpoint*
826 bytestoec(ECdomain *dom, Bytes *bp, ECpoint *ret)
827 {
828         char *hex = "0123456789ABCDEF";
829         char *s;
830         int i;
831
832         s = emalloc(2*bp->len + 1);
833         for(i=0; i < bp->len; i++){
834                 s[2*i] = hex[bp->data[i]>>4 & 15];
835                 s[2*i+1] = hex[bp->data[i] & 15];
836         }
837         s[2*bp->len] = '\0';
838         ret = strtoec(dom, s, nil, ret);
839         free(s);
840         return ret;
841 }
842
843 static Bytes*
844 ectobytes(int type, ECpoint *p)
845 {
846         Bytes *bx, *by, *bp;
847
848         bx = mptobytes(p->x);
849         by = mptobytes(p->y);
850         bp = newbytes(bx->len + by->len + 1);
851         bp->data[0] =  type;
852         memmove(bp->data+1, bx->data, bx->len);
853         memmove(bp->data+1+bx->len, by->data, by->len);
854         freebytes(bx);
855         freebytes(by);
856         return bp;
857 }
858
859 static Bytes*
860 tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys)
861 {
862         Namedcurve *nc, *enc;
863         Bytes *epm;
864         ECdomain dom;
865         ECpoint G, K, Y;
866         ECpriv Q;
867
868         if(Ys == nil)
869                 return nil;
870
871         enc = &namedcurves[nelem(namedcurves)];
872         for(nc = namedcurves; nc != enc; nc++)
873                 if(nc->tlsid == curve)
874                         break;
875
876         if(nc == enc)
877                 return nil;
878                 
879         memmove(sec->srandom, srandom, RandomSize);
880         if(setVers(sec, vers) < 0)
881                 return nil;
882         
883         epm = nil;
884
885         memset(&dom, 0, sizeof(dom));
886         dom.p = strtomp(nc->p, nil, 16, nil);
887         dom.a = strtomp(nc->a, nil, 16, nil);
888         dom.b = strtomp(nc->b, nil, 16, nil);
889         dom.n = strtomp(nc->n, nil, 16, nil);
890         dom.h = strtomp(nc->h, nil, 16, nil);
891
892         memset(&G, 0, sizeof(G));
893         G.x = mpnew(0);
894         G.y = mpnew(0);
895
896         memset(&Q, 0, sizeof(Q));
897         Q.x = mpnew(0);
898         Q.y = mpnew(0);
899         Q.d = mpnew(0);
900
901         memset(&K, 0, sizeof(K));
902         K.x = mpnew(0);
903         K.y = mpnew(0);
904
905         memset(&Y, 0, sizeof(Y));
906         Y.x = mpnew(0);
907         Y.y = mpnew(0);
908
909         if(dom.p == nil || dom.a == nil || dom.b == nil || dom.n == nil || dom.h == nil)
910                 goto Out;
911         if(Q.x == nil || Q.y == nil || Q.d == nil)
912                 goto Out;
913         if(G.x == nil || G.y == nil)
914                 goto Out;
915         if(K.x == nil || K.y == nil)
916                 goto Out;
917         if(Y.x == nil || Y.y == nil)
918                 goto Out;
919
920         dom.G = strtoec(&dom, nc->G, nil, &G);
921         if(dom.G == nil)
922                 goto Out;
923
924         if(bytestoec(&dom, Ys, &Y) == nil)
925                 goto Out;
926
927         if(ecgen(&dom, &Q) == nil)
928                 goto Out;
929
930         ecmul(&dom, &Y, Q.d, &K);
931         setMasterSecret(sec, mptobytes(K.x));
932
933         /* 0x04 = uncompressed public key */
934         epm = ectobytes(0x04, &Q);
935         
936 Out:
937         mpfree(Y.x);
938         mpfree(Y.y);
939
940         mpfree(K.x);
941         mpfree(K.y);
942
943         mpfree(Q.x);
944         mpfree(Q.y);
945         mpfree(Q.d);
946
947         mpfree(G.x);
948         mpfree(G.y);
949
950         mpfree(dom.p);
951         mpfree(dom.a);
952         mpfree(dom.b);
953         mpfree(dom.n);
954         mpfree(dom.h);
955
956         return epm;
957 }
958
959 static TlsConnection *
960 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
961         int (*trace)(char*fmt, ...))
962 {
963         TlsConnection *c;
964         Msg m;
965         uchar kd[MaxKeyData];
966         char *secrets;
967         int creq, dhx, rv, cipher;
968         mpint *signedMP, *paddedHashes;
969         Bytes *epm;
970
971         if(!initCiphers())
972                 return nil;
973         epm = nil;
974         c = emalloc(sizeof(TlsConnection));
975         c->version = ProtocolVersion;
976
977         // client certificate signature not implemented for TLS1.2
978         if(cert != nil && certlen > 0 && c->version >= TLS12Version)
979                 c->version = TLS11Version;
980
981         c->ctl = ctl;
982         c->hand = hand;
983         c->trace = trace;
984         c->isClient = 1;
985         c->clientVersion = c->version;
986
987         c->sec = tlsSecInitc(c->clientVersion, c->crandom);
988         if(c->sec == nil)
989                 goto Err;
990         /* client hello */
991         memset(&m, 0, sizeof(m));
992         m.tag = HClientHello;
993         m.u.clientHello.version = c->clientVersion;
994         memmove(m.u.clientHello.random, c->crandom, RandomSize);
995         m.u.clientHello.sid = makebytes(csid, ncsid);
996         m.u.clientHello.ciphers = makeciphers();
997         m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
998         m.u.clientHello.extensions = makebytes(ext, extlen);
999         if(!msgSend(c, &m, AFlush))
1000                 goto Err;
1001         msgClear(&m);
1002
1003         /* server hello */
1004         if(!msgRecv(c, &m))
1005                 goto Err;
1006         if(m.tag != HServerHello) {
1007                 tlsError(c, EUnexpectedMessage, "expected a server hello");
1008                 goto Err;
1009         }
1010         if(setVersion(c, m.u.serverHello.version) < 0) {
1011                 tlsError(c, EIllegalParameter, "incompatible version %r");
1012                 goto Err;
1013         }
1014         memmove(c->srandom, m.u.serverHello.random, RandomSize);
1015         c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
1016         if(c->sid->len != 0 && c->sid->len != SidSize) {
1017                 tlsError(c, EIllegalParameter, "invalid server session identifier");
1018                 goto Err;
1019         }
1020         cipher = m.u.serverHello.cipher;
1021         if(!setAlgs(c, cipher)) {
1022                 tlsError(c, EIllegalParameter, "invalid cipher suite");
1023                 goto Err;
1024         }
1025         if(m.u.serverHello.compressor != CompressionNull) {
1026                 tlsError(c, EIllegalParameter, "invalid compression");
1027                 goto Err;
1028         }
1029         msgClear(&m);
1030
1031         /* certificate */
1032         if(!msgRecv(c, &m) || m.tag != HCertificate) {
1033                 tlsError(c, EUnexpectedMessage, "expected a certificate");
1034                 goto Err;
1035         }
1036         if(m.u.certificate.ncert < 1) {
1037                 tlsError(c, EIllegalParameter, "runt certificate");
1038                 goto Err;
1039         }
1040         c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
1041         msgClear(&m);
1042
1043         /* server key exchange */
1044         dhx = isDHE(cipher) || isECDHE(cipher);
1045         if(!msgRecv(c, &m))
1046                 goto Err;
1047         if(m.tag == HServerKeyExchange) {
1048                 if(!dhx){
1049                         tlsError(c, EUnexpectedMessage, "got an server key exchange");
1050                         goto Err;
1051                 }
1052                 if(isECDHE(cipher))
1053                         epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
1054                                 m.u.serverKeyExchange.curve,
1055                                 m.u.serverKeyExchange.dh_Ys);
1056                 else
1057                         epm = tlsSecDHEc(c->sec, c->srandom, c->version,
1058                                 m.u.serverKeyExchange.dh_p, 
1059                                 m.u.serverKeyExchange.dh_g,
1060                                 m.u.serverKeyExchange.dh_Ys);
1061                 if(epm == nil)
1062                         goto Badcert;
1063                 msgClear(&m);
1064                 if(!msgRecv(c, &m))
1065                         goto Err;
1066         } else if(dhx){
1067                 tlsError(c, EUnexpectedMessage, "expected server key exchange");
1068                 goto Err;
1069         }
1070
1071         /* certificate request (optional) */
1072         creq = 0;
1073         if(m.tag == HCertificateRequest) {
1074                 creq = 1;
1075                 msgClear(&m);
1076                 if(!msgRecv(c, &m))
1077                         goto Err;
1078         }
1079
1080         if(m.tag != HServerHelloDone) {
1081                 tlsError(c, EUnexpectedMessage, "expected a server hello done");
1082                 goto Err;
1083         }
1084         msgClear(&m);
1085
1086         if(!dhx)
1087                 epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
1088                         c->cert->data, c->cert->len, c->version);
1089
1090         if(epm == nil){
1091         Badcert:
1092                 tlsError(c, EBadCertificate, "bad certificate: %r");
1093                 goto Err;
1094         }
1095
1096         setSecrets(c->sec, kd, c->nsecret);
1097         secrets = (char*)emalloc(2*c->nsecret);
1098         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
1099         rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
1100         memset(secrets, 0, 2*c->nsecret);
1101         free(secrets);
1102         memset(kd, 0, c->nsecret);
1103         if(rv < 0){
1104                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
1105                 goto Err;
1106         }
1107
1108         if(creq) {
1109                 if(cert != nil && certlen > 0){
1110                         m.u.certificate.ncert = 1;
1111                         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
1112                         m.u.certificate.certs[0] = makebytes(cert, certlen);
1113                 }               
1114                 m.tag = HCertificate;
1115                 if(!msgSend(c, &m, AFlush))
1116                         goto Err;
1117                 msgClear(&m);
1118         }
1119
1120         /* client key exchange */
1121         m.tag = HClientKeyExchange;
1122         m.u.clientKeyExchange.key = epm;
1123         epm = nil;
1124         if(m.u.clientKeyExchange.key == nil) {
1125                 tlsError(c, EHandshakeFailure, "can't set secret: %r");
1126                 goto Err;
1127         }
1128          
1129         if(!msgSend(c, &m, AFlush))
1130                 goto Err;
1131         msgClear(&m);
1132
1133         /* certificate verify */
1134         if(creq && cert != nil && certlen > 0) {
1135                 uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */
1136                 HandshakeHash hsave;
1137
1138                 /* save the state for the Finish message */
1139                 hsave = c->handhash;
1140                 md5(nil, 0, hshashes, &c->handhash.md5);
1141                 sha1(nil, 0, hshashes+MD5dlen, &c->handhash.sha1);
1142                 c->handhash = hsave;
1143
1144                 c->sec->rpc = factotum_rsa_open(cert, certlen);
1145                 if(c->sec->rpc == nil){
1146                         tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
1147                         goto Err;
1148                 }
1149                 c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
1150                 if(c->sec->rsapub == nil){
1151                         tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
1152                         goto Err;
1153                 }
1154
1155                 paddedHashes = pkcs1padbuf(hshashes, MD5dlen+SHA1dlen, c->sec->rsapub->n);
1156                 signedMP = factotum_rsa_decrypt(c->sec->rpc, paddedHashes);
1157                 if(signedMP == nil){
1158                         tlsError(c, EHandshakeFailure, "factotum_rsa_decrypt: %r");
1159                         goto Err;
1160                 }
1161                 m.u.certificateVerify.signature = mptobytes(signedMP);
1162                 mpfree(signedMP);
1163
1164                 m.tag = HCertificateVerify;
1165                 if(!msgSend(c, &m, AFlush))
1166                         goto Err;
1167                 msgClear(&m);
1168         } 
1169
1170         /* change cipher spec */
1171         if(fprint(c->ctl, "changecipher") < 0){
1172                 tlsError(c, EInternalError, "can't enable cipher: %r");
1173                 goto Err;
1174         }
1175
1176         // Cipherchange must occur immediately before Finished to avoid
1177         // potential hole;  see section 4.3 of Wagner Schneier 1996.
1178         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
1179                 tlsError(c, EInternalError, "can't set finished 1: %r");
1180                 goto Err;
1181         }
1182         m.tag = HFinished;
1183         m.u.finished = c->finished;
1184         if(!msgSend(c, &m, AFlush)) {
1185                 tlsError(c, EInternalError, "can't flush after client Finished: %r");
1186                 goto Err;
1187         }
1188         msgClear(&m);
1189
1190         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
1191                 tlsError(c, EInternalError, "can't set finished 0: %r");
1192                 goto Err;
1193         }
1194         if(!msgRecv(c, &m)) {
1195                 tlsError(c, EInternalError, "can't read server Finished: %r");
1196                 goto Err;
1197         }
1198         if(m.tag != HFinished) {
1199                 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
1200                 goto Err;
1201         }
1202
1203         if(!finishedMatch(c, &m.u.finished)) {
1204                 tlsError(c, EHandshakeFailure, "finished verification failed");
1205                 goto Err;
1206         }
1207         msgClear(&m);
1208
1209         if(fprint(c->ctl, "opened") < 0){
1210                 if(trace)
1211                         trace("unable to do final open: %r\n");
1212                 goto Err;
1213         }
1214         tlsSecOk(c->sec);
1215         return c;
1216
1217 Err:
1218         free(epm);
1219         msgClear(&m);
1220         tlsConnectionFree(c);
1221         return 0;
1222 }
1223
1224
1225 //================= message functions ========================
1226
1227 static void
1228 msgHash(TlsConnection *c, uchar *p, int n)
1229 {
1230         md5(p, n, 0, &c->handhash.md5);
1231         sha1(p, n, 0, &c->handhash.sha1);
1232         if(c->version >= TLS12Version)
1233                 sha2_256(p, n, 0, &c->handhash.sha2_256);
1234 }
1235
1236 static int
1237 msgSend(TlsConnection *c, Msg *m, int act)
1238 {
1239         uchar *p; // sendp = start of new message;  p = write pointer
1240         int nn, n, i;
1241
1242         if(c->sendp == nil)
1243                 c->sendp = c->sendbuf;
1244         p = c->sendp;
1245         if(c->trace)
1246                 c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m));
1247
1248         p[0] = m->tag;  // header - fill in size later
1249         p += 4;
1250
1251         switch(m->tag) {
1252         default:
1253                 tlsError(c, EInternalError, "can't encode a %d", m->tag);
1254                 goto Err;
1255         case HClientHello:
1256                 // version
1257                 put16(p, m->u.clientHello.version);
1258                 p += 2;
1259
1260                 // random
1261                 memmove(p, m->u.clientHello.random, RandomSize);
1262                 p += RandomSize;
1263
1264                 // sid
1265                 n = m->u.clientHello.sid->len;
1266                 assert(n < 256);
1267                 p[0] = n;
1268                 memmove(p+1, m->u.clientHello.sid->data, n);
1269                 p += n+1;
1270
1271                 n = m->u.clientHello.ciphers->len;
1272                 assert(n > 0 && n < 200);
1273                 put16(p, n*2);
1274                 p += 2;
1275                 for(i=0; i<n; i++) {
1276                         put16(p, m->u.clientHello.ciphers->data[i]);
1277                         p += 2;
1278                 }
1279
1280                 n = m->u.clientHello.compressors->len;
1281                 assert(n > 0);
1282                 p[0] = n;
1283                 memmove(p+1, m->u.clientHello.compressors->data, n);
1284                 p += n+1;
1285
1286                 if(m->u.clientHello.extensions == nil)
1287                         break;
1288                 n = m->u.clientHello.extensions->len;
1289                 if(n == 0)
1290                         break;
1291                 put16(p, n);
1292                 memmove(p+2, m->u.clientHello.extensions->data, n);
1293                 p += n+2;
1294                 break;
1295         case HServerHello:
1296                 put16(p, m->u.serverHello.version);
1297                 p += 2;
1298
1299                 // random
1300                 memmove(p, m->u.serverHello.random, RandomSize);
1301                 p += RandomSize;
1302
1303                 // sid
1304                 n = m->u.serverHello.sid->len;
1305                 assert(n < 256);
1306                 p[0] = n;
1307                 memmove(p+1, m->u.serverHello.sid->data, n);
1308                 p += n+1;
1309
1310                 put16(p, m->u.serverHello.cipher);
1311                 p += 2;
1312                 p[0] = m->u.serverHello.compressor;
1313                 p += 1;
1314
1315                 if(m->u.serverHello.extensions == nil)
1316                         break;
1317                 n = m->u.serverHello.extensions->len;
1318                 if(n == 0)
1319                         break;
1320                 put16(p, n);
1321                 memmove(p+2, m->u.serverHello.extensions->data, n);
1322                 p += n+2;
1323                 break;
1324         case HServerHelloDone:
1325                 break;
1326         case HCertificate:
1327                 nn = 0;
1328                 for(i = 0; i < m->u.certificate.ncert; i++)
1329                         nn += 3 + m->u.certificate.certs[i]->len;
1330                 if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) {
1331                         tlsError(c, EInternalError, "output buffer too small for certificate");
1332                         goto Err;
1333                 }
1334                 put24(p, nn);
1335                 p += 3;
1336                 for(i = 0; i < m->u.certificate.ncert; i++){
1337                         put24(p, m->u.certificate.certs[i]->len);
1338                         p += 3;
1339                         memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
1340                         p += m->u.certificate.certs[i]->len;
1341                 }
1342                 break;
1343         case HCertificateVerify:
1344                 put16(p, m->u.certificateVerify.signature->len);
1345                 p += 2;
1346                 memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
1347                 p += m->u.certificateVerify.signature->len;
1348                 break;  
1349         case HClientKeyExchange:
1350                 n = m->u.clientKeyExchange.key->len;
1351                 if(c->version != SSL3Version){
1352                         if(isECDHE(c->cipher))
1353                                 *p++ = n;
1354                         else
1355                                 put16(p, n), p += 2;
1356                 }
1357                 memmove(p, m->u.clientKeyExchange.key->data, n);
1358                 p += n;
1359                 break;
1360         case HFinished:
1361                 memmove(p, m->u.finished.verify, m->u.finished.n);
1362                 p += m->u.finished.n;
1363                 break;
1364         }
1365
1366         // go back and fill in size
1367         n = p - c->sendp;
1368         assert(p <= c->sendbuf + sizeof(c->sendbuf));
1369         put24(c->sendp+1, n-4);
1370
1371         // remember hash of Handshake messages
1372         if(m->tag != HHelloRequest)
1373                 msgHash(c, c->sendp, n);
1374
1375         c->sendp = p;
1376         if(act == AFlush){
1377                 c->sendp = c->sendbuf;
1378                 if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){
1379                         fprint(2, "write error: %r\n");
1380                         goto Err;
1381                 }
1382         }
1383         msgClear(m);
1384         return 1;
1385 Err:
1386         msgClear(m);
1387         return 0;
1388 }
1389
1390 static uchar*
1391 tlsReadN(TlsConnection *c, int n)
1392 {
1393         uchar *p;
1394         int nn, nr;
1395
1396         nn = c->ep - c->rp;
1397         if(nn < n){
1398                 if(c->rp != c->recvbuf){
1399                         memmove(c->recvbuf, c->rp, nn);
1400                         c->rp = c->recvbuf;
1401                         c->ep = &c->recvbuf[nn];
1402                 }
1403                 for(; nn < n; nn += nr) {
1404                         nr = read(c->hand, &c->rp[nn], n - nn);
1405                         if(nr <= 0)
1406                                 return nil;
1407                         c->ep += nr;
1408                 }
1409         }
1410         p = c->rp;
1411         c->rp += n;
1412         return p;
1413 }
1414
1415 static int
1416 msgRecv(TlsConnection *c, Msg *m)
1417 {
1418         uchar *p;
1419         int type, n, nn, i, nsid, nrandom, nciph;
1420
1421         for(;;) {
1422                 p = tlsReadN(c, 4);
1423                 if(p == nil)
1424                         return 0;
1425                 type = p[0];
1426                 n = get24(p+1);
1427
1428                 if(type != HHelloRequest)
1429                         break;
1430                 if(n != 0) {
1431                         tlsError(c, EDecodeError, "invalid hello request during handshake");
1432                         return 0;
1433                 }
1434         }
1435
1436         if(n > sizeof(c->recvbuf)) {
1437                 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf));
1438                 return 0;
1439         }
1440
1441         if(type == HSSL2ClientHello){
1442                 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
1443                         This is sent by some clients that we must interoperate
1444                         with, such as Java's JSSE and Microsoft's Internet Explorer. */
1445                 p = tlsReadN(c, n);
1446                 if(p == nil)
1447                         return 0;
1448                 msgHash(c, p, n);
1449                 m->tag = HClientHello;
1450                 if(n < 22)
1451                         goto Short;
1452                 m->u.clientHello.version = get16(p+1);
1453                 p += 3;
1454                 n -= 3;
1455                 nn = get16(p); /* cipher_spec_len */
1456                 nsid = get16(p + 2);
1457                 nrandom = get16(p + 4);
1458                 p += 6;
1459                 n -= 6;
1460                 if(nsid != 0    /* no sid's, since shouldn't restart using ssl2 header */
1461                                 || nrandom < 16 || nn % 3)
1462                         goto Err;
1463                 if(c->trace && (n - nrandom != nn))
1464                         c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1465                 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1466                 nciph = 0;
1467                 for(i = 0; i < nn; i += 3)
1468                         if(p[i] == 0)
1469                                 nciph++;
1470                 m->u.clientHello.ciphers = newints(nciph);
1471                 nciph = 0;
1472                 for(i = 0; i < nn; i += 3)
1473                         if(p[i] == 0)
1474                                 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1475                 p += nn;
1476                 m->u.clientHello.sid = makebytes(nil, 0);
1477                 if(nrandom > RandomSize)
1478                         nrandom = RandomSize;
1479                 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1480                 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1481                 m->u.clientHello.compressors = newbytes(1);
1482                 m->u.clientHello.compressors->data[0] = CompressionNull;
1483                 goto Ok;
1484         }
1485         msgHash(c, p, 4);
1486
1487         p = tlsReadN(c, n);
1488         if(p == nil)
1489                 return 0;
1490
1491         msgHash(c, p, n);
1492
1493         m->tag = type;
1494
1495         switch(type) {
1496         default:
1497                 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1498                 goto Err;
1499         case HClientHello:
1500                 if(n < 2)
1501                         goto Short;
1502                 m->u.clientHello.version = get16(p);
1503                 p += 2;
1504                 n -= 2;
1505
1506                 if(n < RandomSize)
1507                         goto Short;
1508                 memmove(m->u.clientHello.random, p, RandomSize);
1509                 p += RandomSize;
1510                 n -= RandomSize;
1511                 if(n < 1 || n < p[0]+1)
1512                         goto Short;
1513                 m->u.clientHello.sid = makebytes(p+1, p[0]);
1514                 p += m->u.clientHello.sid->len+1;
1515                 n -= m->u.clientHello.sid->len+1;
1516
1517                 if(n < 2)
1518                         goto Short;
1519                 nn = get16(p);
1520                 p += 2;
1521                 n -= 2;
1522
1523                 if((nn & 1) || n < nn || nn < 2)
1524                         goto Short;
1525                 m->u.clientHello.ciphers = newints(nn >> 1);
1526                 for(i = 0; i < nn; i += 2)
1527                         m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1528                 p += nn;
1529                 n -= nn;
1530
1531                 if(n < 1 || n < p[0]+1 || p[0] == 0)
1532                         goto Short;
1533                 nn = p[0];
1534                 m->u.clientHello.compressors = makebytes(p+1, nn);
1535                 p += nn + 1;
1536                 n -= nn + 1;
1537
1538                 if(n < 2)
1539                         break;
1540                 nn = get16(p);
1541                 if(nn > n-2)
1542                         goto Short;
1543                 m->u.clientHello.extensions = makebytes(p+2, nn);
1544                 n -= nn + 2;
1545                 break;
1546         case HServerHello:
1547                 if(n < 2)
1548                         goto Short;
1549                 m->u.serverHello.version = get16(p);
1550                 p += 2;
1551                 n -= 2;
1552
1553                 if(n < RandomSize)
1554                         goto Short;
1555                 memmove(m->u.serverHello.random, p, RandomSize);
1556                 p += RandomSize;
1557                 n -= RandomSize;
1558
1559                 if(n < 1 || n < p[0]+1)
1560                         goto Short;
1561                 m->u.serverHello.sid = makebytes(p+1, p[0]);
1562                 p += m->u.serverHello.sid->len+1;
1563                 n -= m->u.serverHello.sid->len+1;
1564
1565                 if(n < 3)
1566                         goto Short;
1567                 m->u.serverHello.cipher = get16(p);
1568                 m->u.serverHello.compressor = p[2];
1569                 p += 3;
1570                 n -= 3;
1571
1572                 if(n < 2)
1573                         break;
1574                 nn = get16(p);
1575                 if(nn > n-2)
1576                         goto Short;
1577                 m->u.serverHello.extensions = makebytes(p+2, nn);
1578                 n -= nn + 2;
1579                 break;
1580         case HCertificate:
1581                 if(n < 3)
1582                         goto Short;
1583                 nn = get24(p);
1584                 p += 3;
1585                 n -= 3;
1586                 if(nn == 0 && n > 0)
1587                         goto Short;
1588                 /* certs */
1589                 i = 0;
1590                 while(n > 0) {
1591                         if(n < 3)
1592                                 goto Short;
1593                         nn = get24(p);
1594                         p += 3;
1595                         n -= 3;
1596                         if(nn > n)
1597                                 goto Short;
1598                         m->u.certificate.ncert = i+1;
1599                         m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes*));
1600                         m->u.certificate.certs[i] = makebytes(p, nn);
1601                         p += nn;
1602                         n -= nn;
1603                         i++;
1604                 }
1605                 break;
1606         case HCertificateRequest:
1607                 if(n < 1)
1608                         goto Short;
1609                 nn = p[0];
1610                 p += 1;
1611                 n -= 1;
1612                 if(nn > n)
1613                         goto Short;
1614                 m->u.certificateRequest.types = makebytes(p, nn);
1615                 p += nn;
1616                 n -= nn;
1617                 if(n < 2)
1618                         goto Short;
1619                 nn = get16(p);
1620                 p += 2;
1621                 n -= 2;
1622                 /* nn == 0 can happen; yahoo's servers do it */
1623                 if(nn != n)
1624                         goto Short;
1625                 /* cas */
1626                 i = 0;
1627                 while(n > 0) {
1628                         if(n < 2)
1629                                 goto Short;
1630                         nn = get16(p);
1631                         p += 2;
1632                         n -= 2;
1633                         if(nn < 1 || nn > n)
1634                                 goto Short;
1635                         m->u.certificateRequest.nca = i+1;
1636                         m->u.certificateRequest.cas = erealloc(
1637                                 m->u.certificateRequest.cas, (i+1)*sizeof(Bytes*));
1638                         m->u.certificateRequest.cas[i] = makebytes(p, nn);
1639                         p += nn;
1640                         n -= nn;
1641                         i++;
1642                 }
1643                 break;
1644         case HServerHelloDone:
1645                 break;
1646         case HServerKeyExchange:
1647                 if(n < 2)
1648                         goto Short;
1649                 if(isECDHE(c->cipher)){
1650                         nn = *p;
1651                         p++, n--;
1652                         if(nn != 3 || nn > n) /* not a named curve */
1653                                 goto Short;
1654                         nn = get16(p);
1655                         p += 2, n -= 2;
1656                         m->u.serverKeyExchange.curve = nn;
1657
1658                         nn = *p++, n--;
1659                         if(nn < 1 || nn > n)
1660                                 goto Short;
1661                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1662                         p += nn, n -= nn;
1663                 }else if(isDHE(c->cipher)){
1664                         nn = get16(p);
1665                         p += 2, n -= 2;
1666                         if(nn < 1 || nn > n)
1667                                 goto Short;
1668                         m->u.serverKeyExchange.dh_p = makebytes(p, nn);
1669                         p += nn, n -= nn;
1670         
1671                         if(n < 2)
1672                                 goto Short;
1673                         nn = get16(p);
1674                         p += 2, n -= 2;
1675                         if(nn < 1 || nn > n)
1676                                 goto Short;
1677                         m->u.serverKeyExchange.dh_g = makebytes(p, nn);
1678                         p += nn, n -= nn;
1679         
1680                         if(n < 2)
1681                                 goto Short;
1682                         nn = get16(p);
1683                         p += 2, n -= 2;
1684                         if(nn < 1 || nn > n)
1685                                 goto Short;
1686                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1687                         p += nn, n -= nn;
1688                 } else {
1689                         /* should not happen */
1690                         break;
1691                 }
1692                 if(n >= 2){
1693                         if(c->version >= TLS12Version){
1694                                 /* signature hash algorithm */
1695                                 p += 2, n -= 2;
1696                                 if(n < 2)
1697                                         goto Short;
1698                         }
1699                         nn = get16(p);
1700                         p += 2, n -= 2;
1701                         if(nn > 0 && nn <= n){
1702                                 m->u.serverKeyExchange.dh_signature = makebytes(p, nn);
1703                                 n -= nn;
1704                         }
1705                 }
1706                 break;          
1707         case HClientKeyExchange:
1708                 /*
1709                  * this message depends upon the encryption selected
1710                  * assume rsa.
1711                  */
1712                 if(c->version == SSL3Version)
1713                         nn = n;
1714                 else{
1715                         if(n < 2)
1716                                 goto Short;
1717                         nn = get16(p);
1718                         p += 2;
1719                         n -= 2;
1720                 }
1721                 if(n < nn)
1722                         goto Short;
1723                 m->u.clientKeyExchange.key = makebytes(p, nn);
1724                 n -= nn;
1725                 break;
1726         case HFinished:
1727                 m->u.finished.n = c->finished.n;
1728                 if(n < m->u.finished.n)
1729                         goto Short;
1730                 memmove(m->u.finished.verify, p, m->u.finished.n);
1731                 n -= m->u.finished.n;
1732                 break;
1733         }
1734
1735         if(type != HClientHello && type != HServerHello && n != 0)
1736                 goto Short;
1737 Ok:
1738         if(c->trace){
1739                 char *buf;
1740                 buf = emalloc(8000);
1741                 c->trace("recv %s", msgPrint(buf, 8000, m));
1742                 free(buf);
1743         }
1744         return 1;
1745 Short:
1746         tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
1747 Err:
1748         msgClear(m);
1749         return 0;
1750 }
1751
1752 static void
1753 msgClear(Msg *m)
1754 {
1755         int i;
1756
1757         switch(m->tag) {
1758         default:
1759                 sysfatal("msgClear: unknown message type: %d", m->tag);
1760         case HHelloRequest:
1761                 break;
1762         case HClientHello:
1763                 freebytes(m->u.clientHello.sid);
1764                 freeints(m->u.clientHello.ciphers);
1765                 freebytes(m->u.clientHello.compressors);
1766                 freebytes(m->u.clientHello.extensions);
1767                 break;
1768         case HServerHello:
1769                 freebytes(m->u.serverHello.sid);
1770                 freebytes(m->u.serverHello.extensions);
1771                 break;
1772         case HCertificate:
1773                 for(i=0; i<m->u.certificate.ncert; i++)
1774                         freebytes(m->u.certificate.certs[i]);
1775                 free(m->u.certificate.certs);
1776                 break;
1777         case HCertificateRequest:
1778                 freebytes(m->u.certificateRequest.types);
1779                 for(i=0; i<m->u.certificateRequest.nca; i++)
1780                         freebytes(m->u.certificateRequest.cas[i]);
1781                 free(m->u.certificateRequest.cas);
1782                 break;
1783         case HCertificateVerify:
1784                 freebytes(m->u.certificateVerify.signature);
1785                 break;
1786         case HServerHelloDone:
1787                 break;
1788         case HServerKeyExchange:
1789                 freebytes(m->u.serverKeyExchange.dh_p);
1790                 freebytes(m->u.serverKeyExchange.dh_g);
1791                 freebytes(m->u.serverKeyExchange.dh_Ys);
1792                 freebytes(m->u.serverKeyExchange.dh_signature);
1793                 break;
1794         case HClientKeyExchange:
1795                 freebytes(m->u.clientKeyExchange.key);
1796                 break;
1797         case HFinished:
1798                 break;
1799         }
1800         memset(m, 0, sizeof(Msg));
1801 }
1802
1803 static char *
1804 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1805 {
1806         int i;
1807
1808         if(s0)
1809                 bs = seprint(bs, be, "%s", s0);
1810         if(b == nil)
1811                 bs = seprint(bs, be, "nil");
1812         else {
1813                 bs = seprint(bs, be, "<%d> [", b->len);
1814                 for(i=0; i<b->len; i++)
1815                         bs = seprint(bs, be, "%.2x ", b->data[i]);
1816         }
1817         bs = seprint(bs, be, "]");
1818         if(s1)
1819                 bs = seprint(bs, be, "%s", s1);
1820         return bs;
1821 }
1822
1823 static char *
1824 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1825 {
1826         int i;
1827
1828         if(s0)
1829                 bs = seprint(bs, be, "%s", s0);
1830         bs = seprint(bs, be, "[");
1831         if(b == nil)
1832                 bs = seprint(bs, be, "nil");
1833         else
1834                 for(i=0; i<b->len; i++)
1835                         bs = seprint(bs, be, "%x ", b->data[i]);
1836         bs = seprint(bs, be, "]");
1837         if(s1)
1838                 bs = seprint(bs, be, "%s", s1);
1839         return bs;
1840 }
1841
1842 static char*
1843 msgPrint(char *buf, int n, Msg *m)
1844 {
1845         int i;
1846         char *bs = buf, *be = buf+n;
1847
1848         switch(m->tag) {
1849         default:
1850                 bs = seprint(bs, be, "unknown %d\n", m->tag);
1851                 break;
1852         case HClientHello:
1853                 bs = seprint(bs, be, "ClientHello\n");
1854                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1855                 bs = seprint(bs, be, "\trandom: ");
1856                 for(i=0; i<RandomSize; i++)
1857                         bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1858                 bs = seprint(bs, be, "\n");
1859                 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1860                 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1861                 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1862                 if(m->u.clientHello.extensions != nil)
1863                         bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
1864                 break;
1865         case HServerHello:
1866                 bs = seprint(bs, be, "ServerHello\n");
1867                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1868                 bs = seprint(bs, be, "\trandom: ");
1869                 for(i=0; i<RandomSize; i++)
1870                         bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1871                 bs = seprint(bs, be, "\n");
1872                 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1873                 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1874                 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1875                 if(m->u.serverHello.extensions != nil)
1876                         bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
1877                 break;
1878         case HCertificate:
1879                 bs = seprint(bs, be, "Certificate\n");
1880                 for(i=0; i<m->u.certificate.ncert; i++)
1881                         bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1882                 break;
1883         case HCertificateRequest:
1884                 bs = seprint(bs, be, "CertificateRequest\n");
1885                 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1886                 bs = seprint(bs, be, "\tcertificateauthorities\n");
1887                 for(i=0; i<m->u.certificateRequest.nca; i++)
1888                         bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1889                 break;
1890         case HCertificateVerify:
1891                 bs = seprint(bs, be, "HCertificateVerify\n");
1892                 bs = bytesPrint(bs, be, "\tsignature: ", m->u.certificateVerify.signature,"\n");
1893                 break;  
1894         case HServerHelloDone:
1895                 bs = seprint(bs, be, "ServerHelloDone\n");
1896                 break;
1897         case HServerKeyExchange:
1898                 bs = seprint(bs, be, "HServerKeyExchange\n");
1899                 if(m->u.serverKeyExchange.curve != 0){
1900                         bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
1901                 } else {
1902                         bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
1903                         bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
1904                 }
1905                 bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
1906                 bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
1907                 break;
1908         case HClientKeyExchange:
1909                 bs = seprint(bs, be, "HClientKeyExchange\n");
1910                 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1911                 break;
1912         case HFinished:
1913                 bs = seprint(bs, be, "HFinished\n");
1914                 for(i=0; i<m->u.finished.n; i++)
1915                         bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1916                 bs = seprint(bs, be, "\n");
1917                 break;
1918         }
1919         USED(bs);
1920         return buf;
1921 }
1922
1923 static void
1924 tlsError(TlsConnection *c, int err, char *fmt, ...)
1925 {
1926         char msg[512];
1927         va_list arg;
1928
1929         va_start(arg, fmt);
1930         vseprint(msg, msg+sizeof(msg), fmt, arg);
1931         va_end(arg);
1932         if(c->trace)
1933                 c->trace("tlsError: %s\n", msg);
1934         else if(c->erred)
1935                 fprint(2, "double error: %r, %s", msg);
1936         else
1937                 werrstr("tls: local %s", msg);
1938         c->erred = 1;
1939         fprint(c->ctl, "alert %d", err);
1940 }
1941
1942 // commit to specific version number
1943 static int
1944 setVersion(TlsConnection *c, int version)
1945 {
1946         if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1947                 return -1;
1948         if(version > c->version)
1949                 version = c->version;
1950         if(version == SSL3Version) {
1951                 c->version = version;
1952                 c->finished.n = SSL3FinishedLen;
1953         }else {
1954                 c->version = version;
1955                 c->finished.n = TLSFinishedLen;
1956         }
1957         c->verset = 1;
1958         return fprint(c->ctl, "version 0x%x", version);
1959 }
1960
1961 // confirm that received Finished message matches the expected value
1962 static int
1963 finishedMatch(TlsConnection *c, Finished *f)
1964 {
1965         return memcmp(f->verify, c->finished.verify, f->n) == 0;
1966 }
1967
1968 // free memory associated with TlsConnection struct
1969 //              (but don't close the TLS channel itself)
1970 static void
1971 tlsConnectionFree(TlsConnection *c)
1972 {
1973         tlsSecClose(c->sec);
1974         freebytes(c->sid);
1975         freebytes(c->cert);
1976         memset(c, 0, sizeof(c));
1977         free(c);
1978 }
1979
1980
1981 //================= cipher choices ========================
1982
1983 static int weakCipher[CipherMax] =
1984 {
1985         1,      /* TLS_NULL_WITH_NULL_NULL */
1986         1,      /* TLS_RSA_WITH_NULL_MD5 */
1987         1,      /* TLS_RSA_WITH_NULL_SHA */
1988         1,      /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1989         0,      /* TLS_RSA_WITH_RC4_128_MD5 */
1990         0,      /* TLS_RSA_WITH_RC4_128_SHA */
1991         1,      /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1992         0,      /* TLS_RSA_WITH_IDEA_CBC_SHA */
1993         1,      /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1994         0,      /* TLS_RSA_WITH_DES_CBC_SHA */
1995         0,      /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1996         1,      /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1997         0,      /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1998         0,      /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1999         1,      /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
2000         0,      /* TLS_DH_RSA_WITH_DES_CBC_SHA */
2001         0,      /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
2002         1,      /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
2003         0,      /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
2004         0,      /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
2005         1,      /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
2006         0,      /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
2007         0,      /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
2008         1,      /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
2009         1,      /* TLS_DH_anon_WITH_RC4_128_MD5 */
2010         1,      /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
2011         1,      /* TLS_DH_anon_WITH_DES_CBC_SHA */
2012         1,      /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
2013 };
2014
2015 static int
2016 setAlgs(TlsConnection *c, int a)
2017 {
2018         int i;
2019
2020         for(i = 0; i < nelem(cipherAlgs); i++){
2021                 if(cipherAlgs[i].tlsid == a){
2022                         c->cipher = a;
2023                         c->enc = cipherAlgs[i].enc;
2024                         c->digest = cipherAlgs[i].digest;
2025                         c->nsecret = cipherAlgs[i].nsecret;
2026                         if(c->nsecret > MaxKeyData)
2027                                 return 0;
2028                         return 1;
2029                 }
2030         }
2031         return 0;
2032 }
2033
2034 static int
2035 okCipher(Ints *cv)
2036 {
2037         int weak, i, j, c;
2038
2039         weak = 1;
2040         for(i = 0; i < cv->len; i++) {
2041                 c = cv->data[i];
2042                 if(c >= CipherMax)
2043                         weak = 0;
2044                 else
2045                         weak &= weakCipher[c];
2046                 if(isDHE(c) || isECDHE(c))
2047                         continue;       /* TODO: not implemented for server */
2048                 for(j = 0; j < nelem(cipherAlgs); j++)
2049                         if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
2050                                 return c;
2051         }
2052         if(weak)
2053                 return -2;
2054         return -1;
2055 }
2056
2057 static int
2058 okCompression(Bytes *cv)
2059 {
2060         int i, j, c;
2061
2062         for(i = 0; i < cv->len; i++) {
2063                 c = cv->data[i];
2064                 for(j = 0; j < nelem(compressors); j++) {
2065                         if(compressors[j] == c)
2066                                 return c;
2067                 }
2068         }
2069         return -1;
2070 }
2071
2072 static Lock     ciphLock;
2073 static int      nciphers;
2074
2075 static int
2076 initCiphers(void)
2077 {
2078         enum {MaxAlgF = 1024, MaxAlgs = 10};
2079         char s[MaxAlgF], *flds[MaxAlgs];
2080         int i, j, n, ok;
2081
2082         lock(&ciphLock);
2083         if(nciphers){
2084                 unlock(&ciphLock);
2085                 return nciphers;
2086         }
2087         j = open("#a/tls/encalgs", OREAD);
2088         if(j < 0){
2089                 werrstr("can't open #a/tls/encalgs: %r");
2090                 return 0;
2091         }
2092         n = read(j, s, MaxAlgF-1);
2093         close(j);
2094         if(n <= 0){
2095                 werrstr("nothing in #a/tls/encalgs: %r");
2096                 return 0;
2097         }
2098         s[n] = 0;
2099         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2100         for(i = 0; i < nelem(cipherAlgs); i++){
2101                 ok = 0;
2102                 for(j = 0; j < n; j++){
2103                         if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
2104                                 ok = 1;
2105                                 break;
2106                         }
2107                 }
2108                 cipherAlgs[i].ok = ok;
2109         }
2110
2111         j = open("#a/tls/hashalgs", OREAD);
2112         if(j < 0){
2113                 werrstr("can't open #a/tls/hashalgs: %r");
2114                 return 0;
2115         }
2116         n = read(j, s, MaxAlgF-1);
2117         close(j);
2118         if(n <= 0){
2119                 werrstr("nothing in #a/tls/hashalgs: %r");
2120                 return 0;
2121         }
2122         s[n] = 0;
2123         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2124         for(i = 0; i < nelem(cipherAlgs); i++){
2125                 ok = 0;
2126                 for(j = 0; j < n; j++){
2127                         if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
2128                                 ok = 1;
2129                                 break;
2130                         }
2131                 }
2132                 cipherAlgs[i].ok &= ok;
2133                 if(cipherAlgs[i].ok)
2134                         nciphers++;
2135         }
2136         unlock(&ciphLock);
2137         return nciphers;
2138 }
2139
2140 static Ints*
2141 makeciphers(void)
2142 {
2143         Ints *is;
2144         int i, j;
2145
2146         is = newints(nciphers);
2147         j = 0;
2148         for(i = 0; i < nelem(cipherAlgs); i++){
2149                 if(cipherAlgs[i].ok)
2150                         is->data[j++] = cipherAlgs[i].tlsid;
2151         }
2152         return is;
2153 }
2154
2155
2156
2157 //================= security functions ========================
2158
2159 // given X.509 certificate, set up connection to factotum
2160 //      for using corresponding private key
2161 static AuthRpc*
2162 factotum_rsa_open(uchar *cert, int certlen)
2163 {
2164         int afd;
2165         char *s;
2166         mpint *pub = nil;
2167         RSApub *rsapub;
2168         AuthRpc *rpc;
2169
2170         // start talking to factotum
2171         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
2172                 return nil;
2173         if((rpc = auth_allocrpc(afd)) == nil){
2174                 close(afd);
2175                 return nil;
2176         }
2177         s = "proto=rsa service=tls role=client";
2178         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
2179                 factotum_rsa_close(rpc);
2180                 return nil;
2181         }
2182
2183         // roll factotum keyring around to match certificate
2184         rsapub = X509toRSApub(cert, certlen, nil, 0);
2185         while(1){
2186                 if(auth_rpc(rpc, "read", nil, 0) != ARok){
2187                         factotum_rsa_close(rpc);
2188                         rpc = nil;
2189                         goto done;
2190                 }
2191                 pub = strtomp(rpc->arg, nil, 16, nil);
2192                 assert(pub != nil);
2193                 if(mpcmp(pub,rsapub->n) == 0)
2194                         break;
2195         }
2196 done:
2197         mpfree(pub);
2198         rsapubfree(rsapub);
2199         return rpc;
2200 }
2201
2202 static mpint*
2203 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
2204 {
2205         char *p;
2206         int rv;
2207
2208         p = mptoa(cipher, 16, nil, 0);
2209         mpfree(cipher);
2210         if(p == nil)
2211                 return nil;
2212         rv = auth_rpc(rpc, "write", p, strlen(p));
2213         free(p);
2214         if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
2215                 return nil;
2216         return strtomp(rpc->arg, nil, 16, nil);
2217 }
2218
2219 static void
2220 factotum_rsa_close(AuthRpc*rpc)
2221 {
2222         if(!rpc)
2223                 return;
2224         close(rpc->afd);
2225         auth_freerpc(rpc);
2226 }
2227
2228 static void
2229 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2230 {
2231         uchar ai[MD5dlen], tmp[MD5dlen];
2232         int i, n;
2233         MD5state *s;
2234
2235         // generate a1
2236         s = hmac_md5(label, nlabel, key, nkey, nil, nil);
2237         s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2238         hmac_md5(seed1, nseed1, key, nkey, ai, s);
2239
2240         while(nbuf > 0) {
2241                 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
2242                 s = hmac_md5(label, nlabel, key, nkey, nil, s);
2243                 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2244                 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
2245                 n = MD5dlen;
2246                 if(n > nbuf)
2247                         n = nbuf;
2248                 for(i = 0; i < n; i++)
2249                         buf[i] ^= tmp[i];
2250                 buf += n;
2251                 nbuf -= n;
2252                 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
2253                 memmove(ai, tmp, MD5dlen);
2254         }
2255 }
2256
2257 static void
2258 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2259 {
2260         uchar ai[SHA1dlen], tmp[SHA1dlen];
2261         int i, n;
2262         SHAstate *s;
2263
2264         // generate a1
2265         s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
2266         s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2267         hmac_sha1(seed1, nseed1, key, nkey, ai, s);
2268
2269         while(nbuf > 0) {
2270                 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
2271                 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
2272                 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2273                 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
2274                 n = SHA1dlen;
2275                 if(n > nbuf)
2276                         n = nbuf;
2277                 for(i = 0; i < n; i++)
2278                         buf[i] ^= tmp[i];
2279                 buf += n;
2280                 nbuf -= n;
2281                 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
2282                 memmove(ai, tmp, SHA1dlen);
2283         }
2284 }
2285
2286 static void
2287 p_sha256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
2288 {
2289         uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
2290         SHAstate *s;
2291         int n;
2292
2293         // generate a1
2294         s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
2295         hmac_sha2_256(seed, nseed, key, nkey, ai, s);
2296
2297         while(nbuf > 0) {
2298                 s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
2299                 s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
2300                 hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
2301                 n = SHA2_256dlen;
2302                 if(n > nbuf)
2303                         n = nbuf;
2304                 memmove(buf, tmp, n);
2305                 buf += n;
2306                 nbuf -= n;
2307                 hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
2308                 memmove(ai, tmp, SHA2_256dlen);
2309         }
2310 }
2311
2312 // fill buf with md5(args)^sha1(args)
2313 static void
2314 tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2315 {
2316         int nlabel = strlen(label);
2317         int n = (nkey + 1) >> 1;
2318
2319         memset(buf, 0, nbuf);
2320         tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2321         tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2322 }
2323
2324 static void
2325 tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2326 {
2327         uchar seed[2*RandomSize];
2328
2329         assert(nseed0+nseed1 <= sizeof(seed));
2330         memmove(seed, seed0, nseed0);
2331         memmove(seed+nseed0, seed1, nseed1);
2332         p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1);
2333 }
2334
2335 /*
2336  * for setting server session id's
2337  */
2338 static Lock     sidLock;
2339 static long     maxSid = 1;
2340
2341 /* the keys are verified to have the same public components
2342  * and to function correctly with pkcs 1 encryption and decryption. */
2343 static TlsSec*
2344 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
2345 {
2346         TlsSec *sec = emalloc(sizeof(*sec));
2347
2348         USED(csid); USED(ncsid);  // ignore csid for now
2349
2350         memmove(sec->crandom, crandom, RandomSize);
2351         sec->clientVers = cvers;
2352
2353         put32(sec->srandom, time(0));
2354         genrandom(sec->srandom+4, RandomSize-4);
2355         memmove(srandom, sec->srandom, RandomSize);
2356
2357         /*
2358          * make up a unique sid: use our pid, and and incrementing id
2359          * can signal no sid by setting nssid to 0.
2360          */
2361         memset(ssid, 0, SidSize);
2362         put32(ssid, getpid());
2363         lock(&sidLock);
2364         put32(ssid+4, maxSid++);
2365         unlock(&sidLock);
2366         *nssid = SidSize;
2367         return sec;
2368 }
2369
2370 static int
2371 tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm)
2372 {
2373         if(epm != nil){
2374                 if(setVers(sec, vers) < 0)
2375                         goto Err;
2376                 serverMasterSecret(sec, epm);
2377         }else if(sec->vers != vers){
2378                 werrstr("mismatched session versions");
2379                 goto Err;
2380         }
2381         return 0;
2382 Err:
2383         sec->ok = -1;
2384         return -1;
2385 }
2386
2387 static TlsSec*
2388 tlsSecInitc(int cvers, uchar *crandom)
2389 {
2390         TlsSec *sec = emalloc(sizeof(*sec));
2391         sec->clientVers = cvers;
2392         put32(sec->crandom, time(0));
2393         genrandom(sec->crandom+4, RandomSize-4);
2394         memmove(crandom, sec->crandom, RandomSize);
2395         return sec;
2396 }
2397
2398 static Bytes*
2399 tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
2400 {
2401         RSApub *pub;
2402         Bytes *epm;
2403
2404         USED(sid);
2405         USED(nsid);
2406         
2407         memmove(sec->srandom, srandom, RandomSize);
2408         if(setVers(sec, vers) < 0)
2409                 goto Err;
2410         pub = X509toRSApub(cert, ncert, nil, 0);
2411         if(pub == nil){
2412                 werrstr("invalid x509/rsa certificate");
2413                 goto Err;
2414         }
2415         epm = clientMasterSecret(sec, pub);
2416         rsapubfree(pub);
2417         if(epm != nil)
2418                 return epm;
2419 Err:
2420         sec->ok = -1;
2421         return nil;
2422 }
2423
2424 static int
2425 tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient)
2426 {
2427         if(sec->nfin != nfin){
2428                 sec->ok = -1;
2429                 werrstr("invalid finished exchange");
2430                 return -1;
2431         }
2432         hsh.md5.malloced = 0;
2433         hsh.sha1.malloced = 0;
2434         hsh.sha2_256.malloced = 0;
2435         (*sec->setFinished)(sec, hsh, fin, isclient);
2436         return 1;
2437 }
2438
2439 static void
2440 tlsSecOk(TlsSec *sec)
2441 {
2442         if(sec->ok == 0)
2443                 sec->ok = 1;
2444 }
2445
2446 static void
2447 tlsSecKill(TlsSec *sec)
2448 {
2449         if(!sec)
2450                 return;
2451         factotum_rsa_close(sec->rpc);
2452         sec->ok = -1;
2453 }
2454
2455 static void
2456 tlsSecClose(TlsSec *sec)
2457 {
2458         if(!sec)
2459                 return;
2460         factotum_rsa_close(sec->rpc);
2461         free(sec->server);
2462         free(sec);
2463 }
2464
2465 static int
2466 setVers(TlsSec *sec, int v)
2467 {
2468         if(v == SSL3Version){
2469                 sec->setFinished = sslSetFinished;
2470                 sec->nfin = SSL3FinishedLen;
2471                 sec->prf = sslPRF;
2472         }else if(v < TLS12Version) {
2473                 sec->setFinished = tls10SetFinished;
2474                 sec->nfin = TLSFinishedLen;
2475                 sec->prf = tls10PRF;
2476         }else {
2477                 sec->setFinished = tls12SetFinished;
2478                 sec->nfin = TLSFinishedLen;
2479                 sec->prf = tls12PRF;
2480         }
2481         sec->vers = v;
2482         return 0;
2483 }
2484
2485 /*
2486  * generate secret keys from the master secret.
2487  *
2488  * different crypto selections will require different amounts
2489  * of key expansion and use of key expansion data,
2490  * but it's all generated using the same function.
2491  */
2492 static void
2493 setSecrets(TlsSec *sec, uchar *kd, int nkd)
2494 {
2495         (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
2496                         sec->srandom, RandomSize, sec->crandom, RandomSize);
2497 }
2498
2499 /*
2500  * set the master secret from the pre-master secret,
2501  * destroys premaster.
2502  */
2503 static void
2504 setMasterSecret(TlsSec *sec, Bytes *pm)
2505 {
2506         (*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
2507                         sec->crandom, RandomSize, sec->srandom, RandomSize);
2508
2509         memset(pm->data, 0, pm->len);   
2510         freebytes(pm);
2511 }
2512
2513 static void
2514 serverMasterSecret(TlsSec *sec, Bytes *epm)
2515 {
2516         Bytes *pm;
2517
2518         pm = pkcs1_decrypt(sec, epm);
2519
2520         // if the client messed up, just continue as if everything is ok,
2521         // to prevent attacks to check for correctly formatted messages.
2522         // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
2523         if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
2524                 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
2525                         sec->ok, pm, pm != nil ? get16(pm->data) : -1, sec->clientVers, epm->len);
2526                 sec->ok = -1;
2527                 freebytes(pm);
2528                 pm = newbytes(MasterSecretSize);
2529                 genrandom(pm->data, MasterSecretSize);
2530         }
2531         assert(pm->len == MasterSecretSize);
2532         setMasterSecret(sec, pm);
2533 }
2534
2535 static Bytes*
2536 clientMasterSecret(TlsSec *sec, RSApub *pub)
2537 {
2538         Bytes *pm, *epm;
2539
2540         pm = newbytes(MasterSecretSize);
2541         put16(pm->data, sec->clientVers);
2542         genrandom(pm->data+2, MasterSecretSize - 2);
2543         epm = pkcs1_encrypt(pm, pub, 2);
2544         setMasterSecret(sec, pm);
2545         return epm;
2546 }
2547
2548 static void
2549 sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2550 {
2551         DigestState *s;
2552         uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
2553         char *label;
2554
2555         if(isClient)
2556                 label = "CLNT";
2557         else
2558                 label = "SRVR";
2559
2560         md5((uchar*)label, 4, nil, &hsh.md5);
2561         md5(sec->sec, MasterSecretSize, nil, &hsh.md5);
2562         memset(pad, 0x36, 48);
2563         md5(pad, 48, nil, &hsh.md5);
2564         md5(nil, 0, h0, &hsh.md5);
2565         memset(pad, 0x5C, 48);
2566         s = md5(sec->sec, MasterSecretSize, nil, nil);
2567         s = md5(pad, 48, nil, s);
2568         md5(h0, MD5dlen, finished, s);
2569
2570         sha1((uchar*)label, 4, nil, &hsh.sha1);
2571         sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1);
2572         memset(pad, 0x36, 40);
2573         sha1(pad, 40, nil, &hsh.sha1);
2574         sha1(nil, 0, h1, &hsh.sha1);
2575         memset(pad, 0x5C, 40);
2576         s = sha1(sec->sec, MasterSecretSize, nil, nil);
2577         s = sha1(pad, 40, nil, s);
2578         sha1(h1, SHA1dlen, finished + MD5dlen, s);
2579 }
2580
2581 // fill "finished" arg with md5(args)^sha1(args)
2582 static void
2583 tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2584 {
2585         uchar h0[MD5dlen], h1[SHA1dlen];
2586         char *label;
2587
2588         // get current hash value, but allow further messages to be hashed in
2589         md5(nil, 0, h0, &hsh.md5);
2590         sha1(nil, 0, h1, &hsh.sha1);
2591
2592         if(isClient)
2593                 label = "client finished";
2594         else
2595                 label = "server finished";
2596         tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2597 }
2598
2599 static void
2600 tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2601 {
2602         uchar seed[SHA2_256dlen];
2603         char *label;
2604
2605         // get current hash value, but allow further messages to be hashed in
2606         sha2_256(nil, 0, seed, &hsh.sha2_256);
2607
2608         if(isClient)
2609                 label = "client finished";
2610         else
2611                 label = "server finished";
2612         p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen);
2613 }
2614
2615 static void
2616 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2617 {
2618         uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2619         DigestState *s;
2620         int i, n, len;
2621
2622         USED(label);
2623         len = 1;
2624         while(nbuf > 0){
2625                 if(len > 26)
2626                         return;
2627                 for(i = 0; i < len; i++)
2628                         tmp[i] = 'A' - 1 + len;
2629                 s = sha1(tmp, len, nil, nil);
2630                 s = sha1(key, nkey, nil, s);
2631                 s = sha1(seed0, nseed0, nil, s);
2632                 sha1(seed1, nseed1, sha1dig, s);
2633                 s = md5(key, nkey, nil, nil);
2634                 md5(sha1dig, SHA1dlen, md5dig, s);
2635                 n = MD5dlen;
2636                 if(n > nbuf)
2637                         n = nbuf;
2638                 memmove(buf, md5dig, n);
2639                 buf += n;
2640                 nbuf -= n;
2641                 len++;
2642         }
2643 }
2644
2645 static mpint*
2646 bytestomp(Bytes* bytes)
2647 {
2648         return betomp(bytes->data, bytes->len, nil);
2649 }
2650
2651 /*
2652  * Convert mpint* to Bytes, putting high order byte first.
2653  */
2654 static Bytes*
2655 mptobytes(mpint* big)
2656 {
2657         Bytes* ans;
2658         int n;
2659
2660         n = (mpsignif(big)+7)/8;
2661         if(n == 0) n = 1;
2662         ans = newbytes(n);
2663         ans->len = mptobe(big, ans->data, n, nil);
2664         return ans;
2665 }
2666
2667 // Do RSA computation on block according to key, and pad
2668 // result on left with zeros to make it modlen long.
2669 static Bytes*
2670 rsacomp(Bytes* block, RSApub* key, int modlen)
2671 {
2672         mpint *x, *y;
2673         Bytes *a, *ybytes;
2674         int ylen;
2675
2676         x = bytestomp(block);
2677         y = rsaencrypt(key, x, nil);
2678         mpfree(x);
2679         ybytes = mptobytes(y);
2680         ylen = ybytes->len;
2681         mpfree(y);
2682
2683         if(ylen < modlen) {
2684                 a = newbytes(modlen);
2685                 memset(a->data, 0, modlen-ylen);
2686                 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2687                 freebytes(ybytes);
2688                 ybytes = a;
2689         }
2690         else if(ylen > modlen) {
2691                 // assume it has leading zeros (mod should make it so)
2692                 a = newbytes(modlen);
2693                 memmove(a->data, ybytes->data, modlen);
2694                 freebytes(ybytes);
2695                 ybytes = a;
2696         }
2697         return ybytes;
2698 }
2699
2700 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2701 static Bytes*
2702 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2703 {
2704         Bytes *pad, *eb, *ans;
2705         int i, dlen, padlen, modlen;
2706
2707         modlen = (mpsignif(key->n)+7)/8;
2708         dlen = data->len;
2709         if(modlen < 12 || dlen > modlen - 11)
2710                 return nil;
2711         padlen = modlen - 3 - dlen;
2712         pad = newbytes(padlen);
2713         genrandom(pad->data, padlen);
2714         for(i = 0; i < padlen; i++) {
2715                 if(blocktype == 0)
2716                         pad->data[i] = 0;
2717                 else if(blocktype == 1)
2718                         pad->data[i] = 255;
2719                 else if(pad->data[i] == 0)
2720                         pad->data[i] = 1;
2721         }
2722         eb = newbytes(modlen);
2723         eb->data[0] = 0;
2724         eb->data[1] = blocktype;
2725         memmove(eb->data+2, pad->data, padlen);
2726         eb->data[padlen+2] = 0;
2727         memmove(eb->data+padlen+3, data->data, dlen);
2728         ans = rsacomp(eb, key, modlen);
2729         freebytes(eb);
2730         freebytes(pad);
2731         return ans;
2732 }
2733
2734 // decrypt data according to PKCS#1, with given key.
2735 // expect a block type of 2.
2736 static Bytes*
2737 pkcs1_decrypt(TlsSec *sec, Bytes *cipher)
2738 {
2739         Bytes *eb, *ans = nil;
2740         int i, modlen;
2741         mpint *x, *y;
2742
2743         modlen = (mpsignif(sec->rsapub->n)+7)/8;
2744         if(cipher->len != modlen)
2745                 return nil;
2746         x = bytestomp(cipher);
2747         y = factotum_rsa_decrypt(sec->rpc, x);
2748         if(y == nil)
2749                 return nil;
2750         eb = mptobytes(y);
2751         mpfree(y);
2752         if(eb->len < modlen){ // pad on left with zeros
2753                 ans = newbytes(modlen);
2754                 memset(ans->data, 0, modlen-eb->len);
2755                 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2756                 freebytes(eb);
2757                 eb = ans;
2758         }
2759         if(eb->data[0] == 0 && eb->data[1] == 2) {
2760                 for(i = 2; i < modlen; i++)
2761                         if(eb->data[i] == 0)
2762                                 break;
2763                 if(i < modlen - 1)
2764                         ans = makebytes(eb->data+i+1, modlen-(i+1));
2765         }
2766         freebytes(eb);
2767         return ans;
2768 }
2769
2770
2771 //================= general utility functions ========================
2772
2773 static void *
2774 emalloc(int n)
2775 {
2776         void *p;
2777         if(n==0)
2778                 n=1;
2779         p = malloc(n);
2780         if(p == nil)
2781                 sysfatal("out of memory");
2782         memset(p, 0, n);
2783         setmalloctag(p, getcallerpc(&n));
2784         return p;
2785 }
2786
2787 static void *
2788 erealloc(void *ReallocP, int ReallocN)
2789 {
2790         if(ReallocN == 0)
2791                 ReallocN = 1;
2792         if(ReallocP == nil)
2793                 ReallocP = emalloc(ReallocN);
2794         else if((ReallocP = realloc(ReallocP, ReallocN)) == nil)
2795                 sysfatal("out of memory");
2796         setrealloctag(ReallocP, getcallerpc(&ReallocP));
2797         return(ReallocP);
2798 }
2799
2800 static void
2801 put32(uchar *p, u32int x)
2802 {
2803         p[0] = x>>24;
2804         p[1] = x>>16;
2805         p[2] = x>>8;
2806         p[3] = x;
2807 }
2808
2809 static void
2810 put24(uchar *p, int x)
2811 {
2812         p[0] = x>>16;
2813         p[1] = x>>8;
2814         p[2] = x;
2815 }
2816
2817 static void
2818 put16(uchar *p, int x)
2819 {
2820         p[0] = x>>8;
2821         p[1] = x;
2822 }
2823
2824 static u32int
2825 get32(uchar *p)
2826 {
2827         return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2828 }
2829
2830 static int
2831 get24(uchar *p)
2832 {
2833         return (p[0]<<16)|(p[1]<<8)|p[2];
2834 }
2835
2836 static int
2837 get16(uchar *p)
2838 {
2839         return (p[0]<<8)|p[1];
2840 }
2841
2842 #define OFFSET(x, s) offsetof(s, x)
2843
2844 static Bytes*
2845 newbytes(int len)
2846 {
2847         Bytes* ans;
2848
2849         if(len < 0)
2850                 abort();
2851         ans = (Bytes*)emalloc(OFFSET(data[0], Bytes) + len);
2852         ans->len = len;
2853         return ans;
2854 }
2855
2856 /*
2857  * newbytes(len), with data initialized from buf
2858  */
2859 static Bytes*
2860 makebytes(uchar* buf, int len)
2861 {
2862         Bytes* ans;
2863
2864         ans = newbytes(len);
2865         memmove(ans->data, buf, len);
2866         return ans;
2867 }
2868
2869 static void
2870 freebytes(Bytes* b)
2871 {
2872         free(b);
2873 }
2874
2875 /* len is number of ints */
2876 static Ints*
2877 newints(int len)
2878 {
2879         Ints* ans;
2880
2881         if(len < 0 || len > ((uint)-1>>1)/sizeof(int))
2882                 abort();
2883         ans = (Ints*)emalloc(OFFSET(data[0], Ints) + len*sizeof(int));
2884         ans->len = len;
2885         return ans;
2886 }
2887
2888 static void
2889 freeints(Ints* b)
2890 {
2891         free(b);
2892 }