]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libsec/port/tlshand.c
import E script from bell labs
[plan9front.git] / sys / src / libsec / port / tlshand.c
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7
8 // The main groups of functions are:
9 //              client/server - main handshake protocol definition
10 //              message functions - formating handshake messages
11 //              cipher choices - catalog of digest and encrypt algorithms
12 //              security functions - PKCS#1, sslHMAC, session keygen
13 //              general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16
17 enum {
18         TLSFinishedLen = 12,
19         SSL3FinishedLen = MD5dlen+SHA1dlen,
20         MaxKeyData = 136,       // amount of secret we may need
21         MaxChunk = 1<<15,
22         RandomSize = 32,
23         SidSize = 32,
24         MasterSecretSize = 48,
25         AQueue = 0,
26         AFlush = 1,
27 };
28
29 typedef struct TlsSec TlsSec;
30
31 typedef struct Bytes{
32         int len;
33         uchar data[1];  // [len]
34 } Bytes;
35
36 typedef struct Ints{
37         int len;
38         int data[1];  // [len]
39 } Ints;
40
41 typedef struct Algs{
42         char *enc;
43         char *digest;
44         int nsecret;
45         int tlsid;
46         int ok;
47 } Algs;
48
49 typedef struct Namedcurve{
50         int tlsid;
51         char *name;
52
53         char *p;
54         char *a;
55         char *b;
56         char *G;
57         char *n;
58         char *h;
59 } Namedcurve;
60
61 typedef struct Finished{
62         uchar verify[SSL3FinishedLen];
63         int n;
64 } Finished;
65
66 typedef struct HandshakeHash {
67         MD5state        md5;
68         SHAstate        sha1;
69         SHA2_256state   sha2_256;
70 } HandshakeHash;
71
72 typedef struct TlsConnection{
73         TlsSec *sec;    // security management goo
74         int hand, ctl;  // record layer file descriptors
75         int erred;              // set when tlsError called
76         int (*trace)(char*fmt, ...); // for debugging
77         int version;    // protocol we are speaking
78         int verset;             // version has been set
79         int ver2hi;             // server got a version 2 hello
80         int isClient;   // is this the client or server?
81         Bytes *sid;             // SessionID
82         Bytes *cert;    // only last - no chain
83
84         Lock statelk;
85         int state;              // must be set using setstate
86
87         // input buffer for handshake messages
88         uchar recvbuf[MaxChunk];
89         uchar *rp, *ep;
90
91         // output buffer
92         uchar sendbuf[MaxChunk];
93         uchar *sendp;
94
95         uchar crandom[RandomSize];      // client random
96         uchar srandom[RandomSize];      // server random
97         int clientVersion;      // version in ClientHello
98         int cipher;
99         char *digest;   // name of digest algorithm to use
100         char *enc;              // name of encryption algorithm to use
101         int nsecret;    // amount of secret data to init keys
102
103         // for finished messages
104         HandshakeHash   handhash;
105         Finished        finished;
106 } TlsConnection;
107
108 typedef struct Msg{
109         int tag;
110         union {
111                 struct {
112                         int version;
113                         uchar   random[RandomSize];
114                         Bytes*  sid;
115                         Ints*   ciphers;
116                         Bytes*  compressors;
117                         Bytes*  extensions;
118                 } clientHello;
119                 struct {
120                         int version;
121                         uchar   random[RandomSize];
122                         Bytes*  sid;
123                         int     cipher;
124                         int     compressor;
125                         Bytes*  extensions;
126                 } serverHello;
127                 struct {
128                         int ncert;
129                         Bytes **certs;
130                 } certificate;
131                 struct {
132                         Bytes *types;
133                         int nca;
134                         Bytes **cas;
135                 } certificateRequest;
136                 struct {
137                         Bytes *key;
138                 } clientKeyExchange;
139                 struct {
140                         Bytes *dh_p;
141                         Bytes *dh_g;
142                         Bytes *dh_Ys;
143                         Bytes *dh_signature;
144                         int curve;
145                 } serverKeyExchange;
146                 struct {
147                         Bytes *signature;
148                 } certificateVerify;            
149                 Finished finished;
150         } u;
151 } Msg;
152
153 typedef struct TlsSec{
154         char *server;   // name of remote; nil for server
155         int ok; // <0 killed; == 0 in progress; >0 reusable
156         RSApub *rsapub;
157         AuthRpc *rpc;   // factotum for rsa private key
158         uchar sec[MasterSecretSize];    // master secret
159         uchar crandom[RandomSize];      // client random
160         uchar srandom[RandomSize];      // server random
161         int clientVers;         // version in ClientHello
162         int vers;                       // final version
163         // byte generation and handshake checksum
164         void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
165         void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int);
166         int nfin;
167 } TlsSec;
168
169
170 enum {
171         SSL3Version     = 0x0300,
172         TLS10Version    = 0x0301,
173         TLS11Version    = 0x0302,
174         TLS12Version    = 0x0303,
175         ProtocolVersion = TLS12Version, // maximum version we speak
176         MinProtoVersion = 0x0300,       // limits on version we accept
177         MaxProtoVersion = 0x03ff,
178 };
179
180 // handshake type
181 enum {
182         HHelloRequest,
183         HClientHello,
184         HServerHello,
185         HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
186         HCertificate = 11,
187         HServerKeyExchange,
188         HCertificateRequest,
189         HServerHelloDone,
190         HCertificateVerify,
191         HClientKeyExchange,
192         HFinished = 20,
193         HMax
194 };
195
196 // alerts
197 enum {
198         ECloseNotify = 0,
199         EUnexpectedMessage = 10,
200         EBadRecordMac = 20,
201         EDecryptionFailed = 21,
202         ERecordOverflow = 22,
203         EDecompressionFailure = 30,
204         EHandshakeFailure = 40,
205         ENoCertificate = 41,
206         EBadCertificate = 42,
207         EUnsupportedCertificate = 43,
208         ECertificateRevoked = 44,
209         ECertificateExpired = 45,
210         ECertificateUnknown = 46,
211         EIllegalParameter = 47,
212         EUnknownCa = 48,
213         EAccessDenied = 49,
214         EDecodeError = 50,
215         EDecryptError = 51,
216         EExportRestriction = 60,
217         EProtocolVersion = 70,
218         EInsufficientSecurity = 71,
219         EInternalError = 80,
220         EUserCanceled = 90,
221         ENoRenegotiation = 100,
222         EMax = 256
223 };
224
225 // cipher suites
226 enum {
227         TLS_NULL_WITH_NULL_NULL                 = 0x0000,
228         TLS_RSA_WITH_NULL_MD5                   = 0x0001,
229         TLS_RSA_WITH_NULL_SHA                   = 0x0002,
230         TLS_RSA_EXPORT_WITH_RC4_40_MD5          = 0x0003,
231         TLS_RSA_WITH_RC4_128_MD5                = 0x0004,
232         TLS_RSA_WITH_RC4_128_SHA                = 0x0005,
233         TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5      = 0X0006,
234         TLS_RSA_WITH_IDEA_CBC_SHA               = 0X0007,
235         TLS_RSA_EXPORT_WITH_DES40_CBC_SHA       = 0X0008,
236         TLS_RSA_WITH_DES_CBC_SHA                = 0X0009,
237         TLS_RSA_WITH_3DES_EDE_CBC_SHA           = 0X000A,
238         TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA    = 0X000B,
239         TLS_DH_DSS_WITH_DES_CBC_SHA             = 0X000C,
240         TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA        = 0X000D,
241         TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA    = 0X000E,
242         TLS_DH_RSA_WITH_DES_CBC_SHA             = 0X000F,
243         TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA        = 0X0010,
244         TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA   = 0X0011,
245         TLS_DHE_DSS_WITH_DES_CBC_SHA            = 0X0012,
246         TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA       = 0X0013,       // ZZZ must be implemented for tls1.0 compliance
247         TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA   = 0X0014,
248         TLS_DHE_RSA_WITH_DES_CBC_SHA            = 0X0015,
249         TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA       = 0X0016,
250         TLS_DH_anon_EXPORT_WITH_RC4_40_MD5      = 0x0017,
251         TLS_DH_anon_WITH_RC4_128_MD5            = 0x0018,
252         TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA   = 0X0019,
253         TLS_DH_anon_WITH_DES_CBC_SHA            = 0X001A,
254         TLS_DH_anon_WITH_3DES_EDE_CBC_SHA       = 0X001B,
255
256         TLS_RSA_WITH_AES_128_CBC_SHA            = 0X002f,       // aes, aka rijndael with 128 bit blocks
257         TLS_DH_DSS_WITH_AES_128_CBC_SHA         = 0X0030,
258         TLS_DH_RSA_WITH_AES_128_CBC_SHA         = 0X0031,
259         TLS_DHE_DSS_WITH_AES_128_CBC_SHA        = 0X0032,
260         TLS_DHE_RSA_WITH_AES_128_CBC_SHA        = 0X0033,
261         TLS_DH_anon_WITH_AES_128_CBC_SHA        = 0X0034,
262         TLS_RSA_WITH_AES_256_CBC_SHA            = 0X0035,
263         TLS_DH_DSS_WITH_AES_256_CBC_SHA         = 0X0036,
264         TLS_DH_RSA_WITH_AES_256_CBC_SHA         = 0X0037,
265         TLS_DHE_DSS_WITH_AES_256_CBC_SHA        = 0X0038,
266         TLS_DHE_RSA_WITH_AES_256_CBC_SHA        = 0X0039,
267         TLS_DH_anon_WITH_AES_256_CBC_SHA        = 0X003A,
268         
269         TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA      = 0xC013,
270         TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA      = 0xC014,
271         TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA  = 0xC009,
272         TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A,
273         CipherMax
274 };
275
276 // compression methods
277 enum {
278         CompressionNull = 0,
279         CompressionMax
280 };
281
282 static Algs cipherAlgs[] = {
283         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
284         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA},
285         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
286         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
287         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
288         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
289         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
290         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
291         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
292         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
293         {"rc4_128", "sha1",     2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
294         {"rc4_128", "md5",      2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
295 };
296
297 static uchar compressors[] = {
298         CompressionNull,
299 };
300
301 static Namedcurve namedcurves[] = {
302 {0x0017, "secp256r1",
303         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
304         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC",
305         "5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B",
306         "046B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C2964FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
307         "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
308         "1"}
309 };
310
311 static uchar pointformats[] = {
312         CompressionNull /* support of uncompressed point format is mandatory */
313 };
314
315 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
316 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
317 static void     msgClear(Msg *m);
318 static char* msgPrint(char *buf, int n, Msg *m);
319 static int      msgRecv(TlsConnection *c, Msg *m);
320 static int      msgSend(TlsConnection *c, Msg *m, int act);
321 static void     tlsError(TlsConnection *c, int err, char *msg, ...);
322 #pragma varargck argpos tlsError 3
323 static int setVersion(TlsConnection *c, int version);
324 static int finishedMatch(TlsConnection *c, Finished *f);
325 static void tlsConnectionFree(TlsConnection *c);
326
327 static int setAlgs(TlsConnection *c, int a);
328 static int okCipher(Ints *cv);
329 static int okCompression(Bytes *cv);
330 static int initCiphers(void);
331 static Ints* makeciphers(void);
332
333 static TlsSec*  tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
334 static int      tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
335 static TlsSec*  tlsSecInitc(int cvers, uchar *crandom);
336 static Bytes*   tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
337 static Bytes*   tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
338 static Bytes*   tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
339 static int      tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
340 static void     tlsSecOk(TlsSec *sec);
341 static void     tlsSecKill(TlsSec *sec);
342 static void     tlsSecClose(TlsSec *sec);
343 static void     setMasterSecret(TlsSec *sec, Bytes *pm);
344 static void     serverMasterSecret(TlsSec *sec, Bytes *epm);
345 static void     setSecrets(TlsSec *sec, uchar *kd, int nkd);
346 static Bytes*   clientMasterSecret(TlsSec *sec, RSApub *pub);
347 static Bytes*   pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
348 static Bytes*   pkcs1_decrypt(TlsSec *sec, Bytes *cipher);
349 static void     tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
350 static void     tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
351 static void     sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
352 static void     sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
353                         uchar *seed0, int nseed0, uchar *seed1, int nseed1);
354 static int setVers(TlsSec *sec, int version);
355
356 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
357 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
358 static void factotum_rsa_close(AuthRpc*rpc);
359
360 static void* emalloc(int);
361 static void* erealloc(void*, int);
362 static void put32(uchar *p, u32int);
363 static void put24(uchar *p, int);
364 static void put16(uchar *p, int);
365 static u32int get32(uchar *p);
366 static int get24(uchar *p);
367 static int get16(uchar *p);
368 static Bytes* newbytes(int len);
369 static Bytes* makebytes(uchar* buf, int len);
370 static Bytes* mptobytes(mpint* big);
371 static mpint* bytestomp(Bytes* bytes);
372 static void freebytes(Bytes* b);
373 static Ints* newints(int len);
374 static void freeints(Ints* b);
375
376 /* x509.c */
377 extern mpint* pkcs1padbuf(uchar *buf, int len, mpint *modulus);
378
379 //================= client/server ========================
380
381 //      push TLS onto fd, returning new (application) file descriptor
382 //              or -1 if error.
383 int
384 tlsServer(int fd, TLSconn *conn)
385 {
386         char buf[8];
387         char dname[64];
388         int n, data, ctl, hand;
389         TlsConnection *tls;
390
391         if(conn == nil)
392                 return -1;
393         ctl = open("#a/tls/clone", ORDWR);
394         if(ctl < 0)
395                 return -1;
396         n = read(ctl, buf, sizeof(buf)-1);
397         if(n < 0){
398                 close(ctl);
399                 return -1;
400         }
401         buf[n] = 0;
402         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
403         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
404         hand = open(dname, ORDWR);
405         if(hand < 0){
406                 close(ctl);
407                 return -1;
408         }
409         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
410         tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
411         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
412         data = open(dname, ORDWR);
413         close(hand);
414         close(ctl);
415         if(data < 0 || tls == nil){
416                 if(tls != nil)
417                         tlsConnectionFree(tls);
418                 return -1;
419         }
420         free(conn->cert);
421         conn->cert = 0;  // client certificates are not yet implemented
422         conn->certlen = 0;
423         conn->sessionIDlen = tls->sid->len;
424         conn->sessionID = emalloc(conn->sessionIDlen);
425         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
426         if(conn->sessionKey != nil
427         && conn->sessionType != nil
428         && strcmp(conn->sessionType, "ttls") == 0)
429                 tls->sec->prf(
430                         conn->sessionKey, conn->sessionKeylen,
431                         tls->sec->sec, MasterSecretSize,
432                         conn->sessionConst, 
433                         tls->sec->crandom, RandomSize,
434                         tls->sec->srandom, RandomSize);
435         tlsConnectionFree(tls);
436         close(fd);
437         return data;
438 }
439
440 static uchar*
441 tlsClientExtensions(TLSconn *conn, int *plen)
442 {
443         uchar *b, *p;
444         int i, n, m;
445
446         p = b = nil;
447
448         // RFC6066 - Server Name Identification
449         if(conn->serverName != nil){
450                 n = strlen(conn->serverName);
451
452                 m = p - b;
453                 b = erealloc(b, m + 2+2+2+1+2+n);
454                 p = b + m;
455
456                 put16(p, 0), p += 2;            /* Type: server_name */
457                 put16(p, 2+1+2+n), p += 2;      /* Length */
458                 put16(p, 1+2+n), p += 2;        /* Server Name list length */
459                 *p++ = 0;                       /* Server Name Type: host_name */
460                 put16(p, n), p += 2;            /* Server Name length */
461                 memmove(p, conn->serverName, n);
462                 p += n;
463         }
464
465         // ECDHE
466         if(1){
467                 m = p - b;
468                 b = erealloc(b, m + 2+2+2+nelem(namedcurves)*2 + 2+2+1+nelem(pointformats));
469                 p = b + m;
470
471                 n = nelem(namedcurves);
472                 put16(p, 0x000a), p += 2;       /* Type: elliptic_curves */
473                 put16(p, (n+1)*2), p += 2;      /* Length */
474                 put16(p, n*2), p += 2;          /* Elliptic Curves Length */
475                 for(i=0; i < n; i++){           /* Elliptic curves */
476                         put16(p, namedcurves[i].tlsid);
477                         p += 2;
478                 }
479
480                 n = nelem(pointformats);
481                 put16(p, 0x000b), p += 2;       /* Type: ec_point_formats */
482                 put16(p, n+1), p += 2;          /* Length */
483                 *p++ = n;                       /* EC point formats Length */
484                 for(i=0; i < n; i++)            /* Elliptic curves point formats */
485                         *p++ = pointformats[i];
486         }
487         
488         *plen = p - b;
489         return b;
490 }
491
492 //      push TLS onto fd, returning new (application) file descriptor
493 //              or -1 if error.
494 int
495 tlsClient(int fd, TLSconn *conn)
496 {
497         char buf[8];
498         char dname[64];
499         int n, data, ctl, hand;
500         TlsConnection *tls;
501         uchar *ext;
502
503         if(conn == nil)
504                 return -1;
505         ctl = open("#a/tls/clone", ORDWR);
506         if(ctl < 0)
507                 return -1;
508         n = read(ctl, buf, sizeof(buf)-1);
509         if(n < 0){
510                 close(ctl);
511                 return -1;
512         }
513         buf[n] = 0;
514         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
515         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
516         hand = open(dname, ORDWR);
517         if(hand < 0){
518                 close(ctl);
519                 return -1;
520         }
521         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
522         data = open(dname, ORDWR);
523         if(data < 0){
524                 close(hand);
525                 close(ctl);
526                 return -1;
527         }
528         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
529         ext = tlsClientExtensions(conn, &n);
530         tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, 
531                 ext, n, conn->trace);
532         free(ext);
533         close(hand);
534         close(ctl);
535         if(tls == nil){
536                 close(data);
537                 return -1;
538         }
539         conn->certlen = tls->cert->len;
540         conn->cert = emalloc(conn->certlen);
541         memcpy(conn->cert, tls->cert->data, conn->certlen);
542         conn->sessionIDlen = tls->sid->len;
543         conn->sessionID = emalloc(conn->sessionIDlen);
544         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
545         if(conn->sessionKey != nil
546         && conn->sessionType != nil
547         && strcmp(conn->sessionType, "ttls") == 0)
548                 tls->sec->prf(
549                         conn->sessionKey, conn->sessionKeylen,
550                         tls->sec->sec, MasterSecretSize,
551                         conn->sessionConst, 
552                         tls->sec->crandom, RandomSize,
553                         tls->sec->srandom, RandomSize);
554         tlsConnectionFree(tls);
555         close(fd);
556         return data;
557 }
558
559 static int
560 countchain(PEMChain *p)
561 {
562         int i = 0;
563
564         while (p) {
565                 i++;
566                 p = p->next;
567         }
568         return i;
569 }
570
571 static TlsConnection *
572 tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
573 {
574         TlsConnection *c;
575         Msg m;
576         Bytes *csid;
577         uchar sid[SidSize], kd[MaxKeyData];
578         char *secrets;
579         int cipher, compressor, nsid, rv, numcerts, i;
580
581         if(trace)
582                 trace("tlsServer2\n");
583         if(!initCiphers())
584                 return nil;
585         c = emalloc(sizeof(TlsConnection));
586         c->ctl = ctl;
587         c->hand = hand;
588         c->trace = trace;
589         c->version = ProtocolVersion;
590
591         memset(&m, 0, sizeof(m));
592         if(!msgRecv(c, &m)){
593                 if(trace)
594                         trace("initial msgRecv failed\n");
595                 goto Err;
596         }
597         if(m.tag != HClientHello) {
598                 tlsError(c, EUnexpectedMessage, "expected a client hello");
599                 goto Err;
600         }
601         c->clientVersion = m.u.clientHello.version;
602         if(trace)
603                 trace("ClientHello version %x\n", c->clientVersion);
604         if(setVersion(c, c->clientVersion) < 0) {
605                 tlsError(c, EIllegalParameter, "incompatible version");
606                 goto Err;
607         }
608
609         memmove(c->crandom, m.u.clientHello.random, RandomSize);
610         cipher = okCipher(m.u.clientHello.ciphers);
611         if(cipher < 0) {
612                 // reply with EInsufficientSecurity if we know that's the case
613                 if(cipher == -2)
614                         tlsError(c, EInsufficientSecurity, "cipher suites too weak");
615                 else
616                         tlsError(c, EHandshakeFailure, "no matching cipher suite");
617                 goto Err;
618         }
619         if(!setAlgs(c, cipher)){
620                 tlsError(c, EHandshakeFailure, "no matching cipher suite");
621                 goto Err;
622         }
623         compressor = okCompression(m.u.clientHello.compressors);
624         if(compressor < 0) {
625                 tlsError(c, EHandshakeFailure, "no matching compressor");
626                 goto Err;
627         }
628
629         csid = m.u.clientHello.sid;
630         if(trace)
631                 trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
632         c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
633         if(c->sec == nil){
634                 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
635                 goto Err;
636         }
637         c->sec->rpc = factotum_rsa_open(cert, certlen);
638         if(c->sec->rpc == nil){
639                 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
640                 goto Err;
641         }
642         c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
643         if(c->sec->rsapub == nil){
644                 tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
645                 goto Err;
646         }
647         msgClear(&m);
648
649         m.tag = HServerHello;
650         m.u.serverHello.version = c->version;
651         memmove(m.u.serverHello.random, c->srandom, RandomSize);
652         m.u.serverHello.cipher = cipher;
653         m.u.serverHello.compressor = compressor;
654         c->sid = makebytes(sid, nsid);
655         m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
656         if(!msgSend(c, &m, AQueue))
657                 goto Err;
658         msgClear(&m);
659
660         m.tag = HCertificate;
661         numcerts = countchain(chp);
662         m.u.certificate.ncert = 1 + numcerts;
663         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
664         m.u.certificate.certs[0] = makebytes(cert, certlen);
665         for (i = 0; i < numcerts && chp; i++, chp = chp->next)
666                 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
667         if(!msgSend(c, &m, AQueue))
668                 goto Err;
669         msgClear(&m);
670
671         m.tag = HServerHelloDone;
672         if(!msgSend(c, &m, AFlush))
673                 goto Err;
674         msgClear(&m);
675
676         if(!msgRecv(c, &m))
677                 goto Err;
678         if(m.tag != HClientKeyExchange) {
679                 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
680                 goto Err;
681         }
682         if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
683                 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
684                 goto Err;
685         }
686         setSecrets(c->sec, kd, c->nsecret);
687         if(trace)
688                 trace("tls secrets\n");
689         secrets = (char*)emalloc(2*c->nsecret);
690         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
691         rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
692         memset(secrets, 0, 2*c->nsecret);
693         free(secrets);
694         memset(kd, 0, c->nsecret);
695         if(rv < 0){
696                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
697                 goto Err;
698         }
699         msgClear(&m);
700
701         /* no CertificateVerify; skip to Finished */
702         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
703                 tlsError(c, EInternalError, "can't set finished: %r");
704                 goto Err;
705         }
706         if(!msgRecv(c, &m))
707                 goto Err;
708         if(m.tag != HFinished) {
709                 tlsError(c, EUnexpectedMessage, "expected a finished");
710                 goto Err;
711         }
712         if(!finishedMatch(c, &m.u.finished)) {
713                 tlsError(c, EHandshakeFailure, "finished verification failed");
714                 goto Err;
715         }
716         msgClear(&m);
717
718         /* change cipher spec */
719         if(fprint(c->ctl, "changecipher") < 0){
720                 tlsError(c, EInternalError, "can't enable cipher: %r");
721                 goto Err;
722         }
723
724         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
725                 tlsError(c, EInternalError, "can't set finished: %r");
726                 goto Err;
727         }
728         m.tag = HFinished;
729         m.u.finished = c->finished;
730         if(!msgSend(c, &m, AFlush))
731                 goto Err;
732         msgClear(&m);
733         if(trace)
734                 trace("tls finished\n");
735
736         if(fprint(c->ctl, "opened") < 0)
737                 goto Err;
738         tlsSecOk(c->sec);
739         return c;
740
741 Err:
742         msgClear(&m);
743         tlsConnectionFree(c);
744         return 0;
745 }
746
747 static int
748 isDHE(int tlsid)
749 {
750         switch(tlsid){
751         case TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA:
752         case TLS_DHE_DSS_WITH_DES_CBC_SHA:
753         case TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA:
754         case TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA:
755         case TLS_DHE_RSA_WITH_DES_CBC_SHA:
756         case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
757         case TLS_DHE_DSS_WITH_AES_128_CBC_SHA:
758         case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
759         case TLS_DHE_DSS_WITH_AES_256_CBC_SHA:
760         case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
761                 return 1;
762         }
763         return 0;
764 }
765
766 static int
767 isECDHE(int tlsid)
768 {
769         switch(tlsid){
770         case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
771         case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
772         case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
773         case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
774                 return 1;
775         }
776         return 0;
777 }
778
779 static Bytes*
780 tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, 
781         Bytes *p, Bytes *g, Bytes *Ys)
782 {
783         mpint *G, *P, *Y, *K;
784         Bytes *epm;
785         DHstate dh;
786
787         if(p == nil || g == nil || Ys == nil)
788                 return nil;
789
790         memmove(sec->srandom, srandom, RandomSize);
791         if(setVers(sec, vers) < 0)
792                 return nil;
793
794         epm = nil;
795         P = bytestomp(p);
796         G = bytestomp(g);
797         Y = bytestomp(Ys);
798         K = nil;
799
800         if(P == nil || G == nil || Y == nil || dh_new(&dh, P, G) == nil)
801                 goto Out;
802         epm = mptobytes(dh.y);
803         K = dh_finish(&dh, Y);
804         if(K == nil){
805                 freebytes(epm);
806                 epm = nil;
807                 goto Out;
808         }
809         setMasterSecret(sec, mptobytes(K));
810
811 Out:
812         mpfree(K);
813         mpfree(Y);
814         mpfree(G);
815         mpfree(P);
816
817         return epm;
818 }
819
820 static ECpoint*
821 bytestoec(ECdomain *dom, Bytes *bp, ECpoint *ret)
822 {
823         char *hex = "0123456789ABCDEF";
824         char *s;
825         int i;
826
827         s = emalloc(2*bp->len + 1);
828         for(i=0; i < bp->len; i++){
829                 s[2*i] = hex[bp->data[i]>>4 & 15];
830                 s[2*i+1] = hex[bp->data[i] & 15];
831         }
832         s[2*bp->len] = '\0';
833         ret = strtoec(dom, s, nil, ret);
834         free(s);
835         return ret;
836 }
837
838 static Bytes*
839 ectobytes(int type, ECpoint *p)
840 {
841         Bytes *bx, *by, *bp;
842
843         bx = mptobytes(p->x);
844         by = mptobytes(p->y);
845         bp = newbytes(bx->len + by->len + 1);
846         bp->data[0] =  type;
847         memmove(bp->data+1, bx->data, bx->len);
848         memmove(bp->data+1+bx->len, by->data, by->len);
849         freebytes(bx);
850         freebytes(by);
851         return bp;
852 }
853
854 static Bytes*
855 tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys)
856 {
857         Bytes *epm;
858         ECdomain *dom;
859         ECpoint K, *Y;
860         ECpriv *Q;
861
862         epm = nil;
863         Y = nil;
864         Q = nil;
865
866         if(Ys == nil)
867                 return nil;
868
869         memmove(sec->srandom, srandom, RandomSize);
870         if(setVers(sec, vers) < 0)
871                 return nil;
872
873         dom = ecnamedcurve(curve);
874         if(dom == nil)
875                 return nil;
876
877
878         memset(&K, 0, sizeof(K));
879         K.x = mpnew(0);
880         K.y = mpnew(0);
881
882         if(K.x == nil || K.y == nil)
883                 goto Out;
884
885         Y = betoec(dom, Ys->data, Ys->len, nil);
886         if(Y == nil)
887                 goto Out;
888
889         Q = ecgen(dom, nil);
890         if(Q == nil)
891                 goto Out;
892
893         ecmul(dom, Y, Q->d, &K);
894         setMasterSecret(sec, mptobytes(K.x));
895
896         /* 0x04 = uncompressed public key */
897         epm = ectobytes(0x04, Q);
898         
899 Out:
900         ecfreepriv(Q);
901
902         ecfreepoint(Y);
903
904         mpfree(K.x);
905         mpfree(K.y);
906
907         ecfreedomain(dom);
908
909         return epm;
910 }
911
912 static TlsConnection *
913 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
914         int (*trace)(char*fmt, ...))
915 {
916         TlsConnection *c;
917         Msg m;
918         uchar kd[MaxKeyData];
919         char *secrets;
920         int creq, dhx, rv, cipher;
921         mpint *signedMP, *paddedHashes;
922         Bytes *epm;
923
924         if(!initCiphers())
925                 return nil;
926         epm = nil;
927         c = emalloc(sizeof(TlsConnection));
928         c->version = ProtocolVersion;
929
930         // client certificate signature not implemented for TLS1.2
931         if(cert != nil && certlen > 0 && c->version >= TLS12Version)
932                 c->version = TLS11Version;
933
934         c->ctl = ctl;
935         c->hand = hand;
936         c->trace = trace;
937         c->isClient = 1;
938         c->clientVersion = c->version;
939
940         c->sec = tlsSecInitc(c->clientVersion, c->crandom);
941         if(c->sec == nil)
942                 goto Err;
943         /* client hello */
944         memset(&m, 0, sizeof(m));
945         m.tag = HClientHello;
946         m.u.clientHello.version = c->clientVersion;
947         memmove(m.u.clientHello.random, c->crandom, RandomSize);
948         m.u.clientHello.sid = makebytes(csid, ncsid);
949         m.u.clientHello.ciphers = makeciphers();
950         m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
951         m.u.clientHello.extensions = makebytes(ext, extlen);
952         if(!msgSend(c, &m, AFlush))
953                 goto Err;
954         msgClear(&m);
955
956         /* server hello */
957         if(!msgRecv(c, &m))
958                 goto Err;
959         if(m.tag != HServerHello) {
960                 tlsError(c, EUnexpectedMessage, "expected a server hello");
961                 goto Err;
962         }
963         if(setVersion(c, m.u.serverHello.version) < 0) {
964                 tlsError(c, EIllegalParameter, "incompatible version %r");
965                 goto Err;
966         }
967         memmove(c->srandom, m.u.serverHello.random, RandomSize);
968         c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
969         if(c->sid->len != 0 && c->sid->len != SidSize) {
970                 tlsError(c, EIllegalParameter, "invalid server session identifier");
971                 goto Err;
972         }
973         cipher = m.u.serverHello.cipher;
974         if(!setAlgs(c, cipher)) {
975                 tlsError(c, EIllegalParameter, "invalid cipher suite");
976                 goto Err;
977         }
978         if(m.u.serverHello.compressor != CompressionNull) {
979                 tlsError(c, EIllegalParameter, "invalid compression");
980                 goto Err;
981         }
982         msgClear(&m);
983
984         /* certificate */
985         if(!msgRecv(c, &m) || m.tag != HCertificate) {
986                 tlsError(c, EUnexpectedMessage, "expected a certificate");
987                 goto Err;
988         }
989         if(m.u.certificate.ncert < 1) {
990                 tlsError(c, EIllegalParameter, "runt certificate");
991                 goto Err;
992         }
993         c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
994         msgClear(&m);
995
996         /* server key exchange */
997         dhx = isDHE(cipher) || isECDHE(cipher);
998         if(!msgRecv(c, &m))
999                 goto Err;
1000         if(m.tag == HServerKeyExchange) {
1001                 if(!dhx){
1002                         tlsError(c, EUnexpectedMessage, "got an server key exchange");
1003                         goto Err;
1004                 }
1005                 if(isECDHE(cipher))
1006                         epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
1007                                 m.u.serverKeyExchange.curve,
1008                                 m.u.serverKeyExchange.dh_Ys);
1009                 else
1010                         epm = tlsSecDHEc(c->sec, c->srandom, c->version,
1011                                 m.u.serverKeyExchange.dh_p, 
1012                                 m.u.serverKeyExchange.dh_g,
1013                                 m.u.serverKeyExchange.dh_Ys);
1014                 if(epm == nil)
1015                         goto Badcert;
1016                 msgClear(&m);
1017                 if(!msgRecv(c, &m))
1018                         goto Err;
1019         } else if(dhx){
1020                 tlsError(c, EUnexpectedMessage, "expected server key exchange");
1021                 goto Err;
1022         }
1023
1024         /* certificate request (optional) */
1025         creq = 0;
1026         if(m.tag == HCertificateRequest) {
1027                 creq = 1;
1028                 msgClear(&m);
1029                 if(!msgRecv(c, &m))
1030                         goto Err;
1031         }
1032
1033         if(m.tag != HServerHelloDone) {
1034                 tlsError(c, EUnexpectedMessage, "expected a server hello done");
1035                 goto Err;
1036         }
1037         msgClear(&m);
1038
1039         if(!dhx)
1040                 epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
1041                         c->cert->data, c->cert->len, c->version);
1042
1043         if(epm == nil){
1044         Badcert:
1045                 tlsError(c, EBadCertificate, "bad certificate: %r");
1046                 goto Err;
1047         }
1048
1049         setSecrets(c->sec, kd, c->nsecret);
1050         secrets = (char*)emalloc(2*c->nsecret);
1051         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
1052         rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
1053         memset(secrets, 0, 2*c->nsecret);
1054         free(secrets);
1055         memset(kd, 0, c->nsecret);
1056         if(rv < 0){
1057                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
1058                 goto Err;
1059         }
1060
1061         if(creq) {
1062                 if(cert != nil && certlen > 0){
1063                         m.u.certificate.ncert = 1;
1064                         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
1065                         m.u.certificate.certs[0] = makebytes(cert, certlen);
1066                 }               
1067                 m.tag = HCertificate;
1068                 if(!msgSend(c, &m, AFlush))
1069                         goto Err;
1070                 msgClear(&m);
1071         }
1072
1073         /* client key exchange */
1074         m.tag = HClientKeyExchange;
1075         m.u.clientKeyExchange.key = epm;
1076         epm = nil;
1077         if(m.u.clientKeyExchange.key == nil) {
1078                 tlsError(c, EHandshakeFailure, "can't set secret: %r");
1079                 goto Err;
1080         }
1081          
1082         if(!msgSend(c, &m, AFlush))
1083                 goto Err;
1084         msgClear(&m);
1085
1086         /* certificate verify */
1087         if(creq && cert != nil && certlen > 0) {
1088                 uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */
1089                 HandshakeHash hsave;
1090
1091                 /* save the state for the Finish message */
1092                 hsave = c->handhash;
1093                 md5(nil, 0, hshashes, &c->handhash.md5);
1094                 sha1(nil, 0, hshashes+MD5dlen, &c->handhash.sha1);
1095                 c->handhash = hsave;
1096
1097                 c->sec->rpc = factotum_rsa_open(cert, certlen);
1098                 if(c->sec->rpc == nil){
1099                         tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
1100                         goto Err;
1101                 }
1102                 c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
1103                 if(c->sec->rsapub == nil){
1104                         tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
1105                         goto Err;
1106                 }
1107
1108                 paddedHashes = pkcs1padbuf(hshashes, MD5dlen+SHA1dlen, c->sec->rsapub->n);
1109                 signedMP = factotum_rsa_decrypt(c->sec->rpc, paddedHashes);
1110                 if(signedMP == nil){
1111                         tlsError(c, EHandshakeFailure, "factotum_rsa_decrypt: %r");
1112                         goto Err;
1113                 }
1114                 m.u.certificateVerify.signature = mptobytes(signedMP);
1115                 mpfree(signedMP);
1116
1117                 m.tag = HCertificateVerify;
1118                 if(!msgSend(c, &m, AFlush))
1119                         goto Err;
1120                 msgClear(&m);
1121         } 
1122
1123         /* change cipher spec */
1124         if(fprint(c->ctl, "changecipher") < 0){
1125                 tlsError(c, EInternalError, "can't enable cipher: %r");
1126                 goto Err;
1127         }
1128
1129         // Cipherchange must occur immediately before Finished to avoid
1130         // potential hole;  see section 4.3 of Wagner Schneier 1996.
1131         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
1132                 tlsError(c, EInternalError, "can't set finished 1: %r");
1133                 goto Err;
1134         }
1135         m.tag = HFinished;
1136         m.u.finished = c->finished;
1137         if(!msgSend(c, &m, AFlush)) {
1138                 tlsError(c, EInternalError, "can't flush after client Finished: %r");
1139                 goto Err;
1140         }
1141         msgClear(&m);
1142
1143         if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
1144                 tlsError(c, EInternalError, "can't set finished 0: %r");
1145                 goto Err;
1146         }
1147         if(!msgRecv(c, &m)) {
1148                 tlsError(c, EInternalError, "can't read server Finished: %r");
1149                 goto Err;
1150         }
1151         if(m.tag != HFinished) {
1152                 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
1153                 goto Err;
1154         }
1155
1156         if(!finishedMatch(c, &m.u.finished)) {
1157                 tlsError(c, EHandshakeFailure, "finished verification failed");
1158                 goto Err;
1159         }
1160         msgClear(&m);
1161
1162         if(fprint(c->ctl, "opened") < 0){
1163                 if(trace)
1164                         trace("unable to do final open: %r\n");
1165                 goto Err;
1166         }
1167         tlsSecOk(c->sec);
1168         return c;
1169
1170 Err:
1171         free(epm);
1172         msgClear(&m);
1173         tlsConnectionFree(c);
1174         return 0;
1175 }
1176
1177
1178 //================= message functions ========================
1179
1180 static void
1181 msgHash(TlsConnection *c, uchar *p, int n)
1182 {
1183         md5(p, n, 0, &c->handhash.md5);
1184         sha1(p, n, 0, &c->handhash.sha1);
1185         if(c->version >= TLS12Version)
1186                 sha2_256(p, n, 0, &c->handhash.sha2_256);
1187 }
1188
1189 static int
1190 msgSend(TlsConnection *c, Msg *m, int act)
1191 {
1192         uchar *p; // sendp = start of new message;  p = write pointer
1193         int nn, n, i;
1194
1195         if(c->sendp == nil)
1196                 c->sendp = c->sendbuf;
1197         p = c->sendp;
1198         if(c->trace)
1199                 c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m));
1200
1201         p[0] = m->tag;  // header - fill in size later
1202         p += 4;
1203
1204         switch(m->tag) {
1205         default:
1206                 tlsError(c, EInternalError, "can't encode a %d", m->tag);
1207                 goto Err;
1208         case HClientHello:
1209                 // version
1210                 put16(p, m->u.clientHello.version);
1211                 p += 2;
1212
1213                 // random
1214                 memmove(p, m->u.clientHello.random, RandomSize);
1215                 p += RandomSize;
1216
1217                 // sid
1218                 n = m->u.clientHello.sid->len;
1219                 assert(n < 256);
1220                 p[0] = n;
1221                 memmove(p+1, m->u.clientHello.sid->data, n);
1222                 p += n+1;
1223
1224                 n = m->u.clientHello.ciphers->len;
1225                 assert(n > 0 && n < 200);
1226                 put16(p, n*2);
1227                 p += 2;
1228                 for(i=0; i<n; i++) {
1229                         put16(p, m->u.clientHello.ciphers->data[i]);
1230                         p += 2;
1231                 }
1232
1233                 n = m->u.clientHello.compressors->len;
1234                 assert(n > 0);
1235                 p[0] = n;
1236                 memmove(p+1, m->u.clientHello.compressors->data, n);
1237                 p += n+1;
1238
1239                 if(m->u.clientHello.extensions == nil)
1240                         break;
1241                 n = m->u.clientHello.extensions->len;
1242                 if(n == 0)
1243                         break;
1244                 put16(p, n);
1245                 memmove(p+2, m->u.clientHello.extensions->data, n);
1246                 p += n+2;
1247                 break;
1248         case HServerHello:
1249                 put16(p, m->u.serverHello.version);
1250                 p += 2;
1251
1252                 // random
1253                 memmove(p, m->u.serverHello.random, RandomSize);
1254                 p += RandomSize;
1255
1256                 // sid
1257                 n = m->u.serverHello.sid->len;
1258                 assert(n < 256);
1259                 p[0] = n;
1260                 memmove(p+1, m->u.serverHello.sid->data, n);
1261                 p += n+1;
1262
1263                 put16(p, m->u.serverHello.cipher);
1264                 p += 2;
1265                 p[0] = m->u.serverHello.compressor;
1266                 p += 1;
1267
1268                 if(m->u.serverHello.extensions == nil)
1269                         break;
1270                 n = m->u.serverHello.extensions->len;
1271                 if(n == 0)
1272                         break;
1273                 put16(p, n);
1274                 memmove(p+2, m->u.serverHello.extensions->data, n);
1275                 p += n+2;
1276                 break;
1277         case HServerHelloDone:
1278                 break;
1279         case HCertificate:
1280                 nn = 0;
1281                 for(i = 0; i < m->u.certificate.ncert; i++)
1282                         nn += 3 + m->u.certificate.certs[i]->len;
1283                 if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) {
1284                         tlsError(c, EInternalError, "output buffer too small for certificate");
1285                         goto Err;
1286                 }
1287                 put24(p, nn);
1288                 p += 3;
1289                 for(i = 0; i < m->u.certificate.ncert; i++){
1290                         put24(p, m->u.certificate.certs[i]->len);
1291                         p += 3;
1292                         memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
1293                         p += m->u.certificate.certs[i]->len;
1294                 }
1295                 break;
1296         case HCertificateVerify:
1297                 put16(p, m->u.certificateVerify.signature->len);
1298                 p += 2;
1299                 memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
1300                 p += m->u.certificateVerify.signature->len;
1301                 break;  
1302         case HClientKeyExchange:
1303                 n = m->u.clientKeyExchange.key->len;
1304                 if(c->version != SSL3Version){
1305                         if(isECDHE(c->cipher))
1306                                 *p++ = n;
1307                         else
1308                                 put16(p, n), p += 2;
1309                 }
1310                 memmove(p, m->u.clientKeyExchange.key->data, n);
1311                 p += n;
1312                 break;
1313         case HFinished:
1314                 memmove(p, m->u.finished.verify, m->u.finished.n);
1315                 p += m->u.finished.n;
1316                 break;
1317         }
1318
1319         // go back and fill in size
1320         n = p - c->sendp;
1321         assert(p <= c->sendbuf + sizeof(c->sendbuf));
1322         put24(c->sendp+1, n-4);
1323
1324         // remember hash of Handshake messages
1325         if(m->tag != HHelloRequest)
1326                 msgHash(c, c->sendp, n);
1327
1328         c->sendp = p;
1329         if(act == AFlush){
1330                 c->sendp = c->sendbuf;
1331                 if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){
1332                         fprint(2, "write error: %r\n");
1333                         goto Err;
1334                 }
1335         }
1336         msgClear(m);
1337         return 1;
1338 Err:
1339         msgClear(m);
1340         return 0;
1341 }
1342
1343 static uchar*
1344 tlsReadN(TlsConnection *c, int n)
1345 {
1346         uchar *p;
1347         int nn, nr;
1348
1349         nn = c->ep - c->rp;
1350         if(nn < n){
1351                 if(c->rp != c->recvbuf){
1352                         memmove(c->recvbuf, c->rp, nn);
1353                         c->rp = c->recvbuf;
1354                         c->ep = &c->recvbuf[nn];
1355                 }
1356                 for(; nn < n; nn += nr) {
1357                         nr = read(c->hand, &c->rp[nn], n - nn);
1358                         if(nr <= 0)
1359                                 return nil;
1360                         c->ep += nr;
1361                 }
1362         }
1363         p = c->rp;
1364         c->rp += n;
1365         return p;
1366 }
1367
1368 static int
1369 msgRecv(TlsConnection *c, Msg *m)
1370 {
1371         uchar *p;
1372         int type, n, nn, i, nsid, nrandom, nciph;
1373
1374         for(;;) {
1375                 p = tlsReadN(c, 4);
1376                 if(p == nil)
1377                         return 0;
1378                 type = p[0];
1379                 n = get24(p+1);
1380
1381                 if(type != HHelloRequest)
1382                         break;
1383                 if(n != 0) {
1384                         tlsError(c, EDecodeError, "invalid hello request during handshake");
1385                         return 0;
1386                 }
1387         }
1388
1389         if(n > sizeof(c->recvbuf)) {
1390                 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf));
1391                 return 0;
1392         }
1393
1394         if(type == HSSL2ClientHello){
1395                 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
1396                         This is sent by some clients that we must interoperate
1397                         with, such as Java's JSSE and Microsoft's Internet Explorer. */
1398                 p = tlsReadN(c, n);
1399                 if(p == nil)
1400                         return 0;
1401                 msgHash(c, p, n);
1402                 m->tag = HClientHello;
1403                 if(n < 22)
1404                         goto Short;
1405                 m->u.clientHello.version = get16(p+1);
1406                 p += 3;
1407                 n -= 3;
1408                 nn = get16(p); /* cipher_spec_len */
1409                 nsid = get16(p + 2);
1410                 nrandom = get16(p + 4);
1411                 p += 6;
1412                 n -= 6;
1413                 if(nsid != 0    /* no sid's, since shouldn't restart using ssl2 header */
1414                                 || nrandom < 16 || nn % 3)
1415                         goto Err;
1416                 if(c->trace && (n - nrandom != nn))
1417                         c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1418                 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1419                 nciph = 0;
1420                 for(i = 0; i < nn; i += 3)
1421                         if(p[i] == 0)
1422                                 nciph++;
1423                 m->u.clientHello.ciphers = newints(nciph);
1424                 nciph = 0;
1425                 for(i = 0; i < nn; i += 3)
1426                         if(p[i] == 0)
1427                                 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1428                 p += nn;
1429                 m->u.clientHello.sid = makebytes(nil, 0);
1430                 if(nrandom > RandomSize)
1431                         nrandom = RandomSize;
1432                 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1433                 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1434                 m->u.clientHello.compressors = newbytes(1);
1435                 m->u.clientHello.compressors->data[0] = CompressionNull;
1436                 goto Ok;
1437         }
1438         msgHash(c, p, 4);
1439
1440         p = tlsReadN(c, n);
1441         if(p == nil)
1442                 return 0;
1443
1444         msgHash(c, p, n);
1445
1446         m->tag = type;
1447
1448         switch(type) {
1449         default:
1450                 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1451                 goto Err;
1452         case HClientHello:
1453                 if(n < 2)
1454                         goto Short;
1455                 m->u.clientHello.version = get16(p);
1456                 p += 2;
1457                 n -= 2;
1458
1459                 if(n < RandomSize)
1460                         goto Short;
1461                 memmove(m->u.clientHello.random, p, RandomSize);
1462                 p += RandomSize;
1463                 n -= RandomSize;
1464                 if(n < 1 || n < p[0]+1)
1465                         goto Short;
1466                 m->u.clientHello.sid = makebytes(p+1, p[0]);
1467                 p += m->u.clientHello.sid->len+1;
1468                 n -= m->u.clientHello.sid->len+1;
1469
1470                 if(n < 2)
1471                         goto Short;
1472                 nn = get16(p);
1473                 p += 2;
1474                 n -= 2;
1475
1476                 if((nn & 1) || n < nn || nn < 2)
1477                         goto Short;
1478                 m->u.clientHello.ciphers = newints(nn >> 1);
1479                 for(i = 0; i < nn; i += 2)
1480                         m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1481                 p += nn;
1482                 n -= nn;
1483
1484                 if(n < 1 || n < p[0]+1 || p[0] == 0)
1485                         goto Short;
1486                 nn = p[0];
1487                 m->u.clientHello.compressors = makebytes(p+1, nn);
1488                 p += nn + 1;
1489                 n -= nn + 1;
1490
1491                 if(n < 2)
1492                         break;
1493                 nn = get16(p);
1494                 if(nn > n-2)
1495                         goto Short;
1496                 m->u.clientHello.extensions = makebytes(p+2, nn);
1497                 n -= nn + 2;
1498                 break;
1499         case HServerHello:
1500                 if(n < 2)
1501                         goto Short;
1502                 m->u.serverHello.version = get16(p);
1503                 p += 2;
1504                 n -= 2;
1505
1506                 if(n < RandomSize)
1507                         goto Short;
1508                 memmove(m->u.serverHello.random, p, RandomSize);
1509                 p += RandomSize;
1510                 n -= RandomSize;
1511
1512                 if(n < 1 || n < p[0]+1)
1513                         goto Short;
1514                 m->u.serverHello.sid = makebytes(p+1, p[0]);
1515                 p += m->u.serverHello.sid->len+1;
1516                 n -= m->u.serverHello.sid->len+1;
1517
1518                 if(n < 3)
1519                         goto Short;
1520                 m->u.serverHello.cipher = get16(p);
1521                 m->u.serverHello.compressor = p[2];
1522                 p += 3;
1523                 n -= 3;
1524
1525                 if(n < 2)
1526                         break;
1527                 nn = get16(p);
1528                 if(nn > n-2)
1529                         goto Short;
1530                 m->u.serverHello.extensions = makebytes(p+2, nn);
1531                 n -= nn + 2;
1532                 break;
1533         case HCertificate:
1534                 if(n < 3)
1535                         goto Short;
1536                 nn = get24(p);
1537                 p += 3;
1538                 n -= 3;
1539                 if(nn == 0 && n > 0)
1540                         goto Short;
1541                 /* certs */
1542                 i = 0;
1543                 while(n > 0) {
1544                         if(n < 3)
1545                                 goto Short;
1546                         nn = get24(p);
1547                         p += 3;
1548                         n -= 3;
1549                         if(nn > n)
1550                                 goto Short;
1551                         m->u.certificate.ncert = i+1;
1552                         m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes*));
1553                         m->u.certificate.certs[i] = makebytes(p, nn);
1554                         p += nn;
1555                         n -= nn;
1556                         i++;
1557                 }
1558                 break;
1559         case HCertificateRequest:
1560                 if(n < 1)
1561                         goto Short;
1562                 nn = p[0];
1563                 p += 1;
1564                 n -= 1;
1565                 if(nn > n)
1566                         goto Short;
1567                 m->u.certificateRequest.types = makebytes(p, nn);
1568                 p += nn;
1569                 n -= nn;
1570                 if(n < 2)
1571                         goto Short;
1572                 nn = get16(p);
1573                 p += 2;
1574                 n -= 2;
1575                 /* nn == 0 can happen; yahoo's servers do it */
1576                 if(nn != n)
1577                         goto Short;
1578                 /* cas */
1579                 i = 0;
1580                 while(n > 0) {
1581                         if(n < 2)
1582                                 goto Short;
1583                         nn = get16(p);
1584                         p += 2;
1585                         n -= 2;
1586                         if(nn < 1 || nn > n)
1587                                 goto Short;
1588                         m->u.certificateRequest.nca = i+1;
1589                         m->u.certificateRequest.cas = erealloc(
1590                                 m->u.certificateRequest.cas, (i+1)*sizeof(Bytes*));
1591                         m->u.certificateRequest.cas[i] = makebytes(p, nn);
1592                         p += nn;
1593                         n -= nn;
1594                         i++;
1595                 }
1596                 break;
1597         case HServerHelloDone:
1598                 break;
1599         case HServerKeyExchange:
1600                 if(n < 2)
1601                         goto Short;
1602                 if(isECDHE(c->cipher)){
1603                         nn = *p;
1604                         p++, n--;
1605                         if(nn != 3 || nn > n) /* not a named curve */
1606                                 goto Short;
1607                         nn = get16(p);
1608                         p += 2, n -= 2;
1609                         m->u.serverKeyExchange.curve = nn;
1610
1611                         nn = *p++, n--;
1612                         if(nn < 1 || nn > n)
1613                                 goto Short;
1614                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1615                         p += nn, n -= nn;
1616                 }else if(isDHE(c->cipher)){
1617                         nn = get16(p);
1618                         p += 2, n -= 2;
1619                         if(nn < 1 || nn > n)
1620                                 goto Short;
1621                         m->u.serverKeyExchange.dh_p = makebytes(p, nn);
1622                         p += nn, n -= nn;
1623         
1624                         if(n < 2)
1625                                 goto Short;
1626                         nn = get16(p);
1627                         p += 2, n -= 2;
1628                         if(nn < 1 || nn > n)
1629                                 goto Short;
1630                         m->u.serverKeyExchange.dh_g = makebytes(p, nn);
1631                         p += nn, n -= nn;
1632         
1633                         if(n < 2)
1634                                 goto Short;
1635                         nn = get16(p);
1636                         p += 2, n -= 2;
1637                         if(nn < 1 || nn > n)
1638                                 goto Short;
1639                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1640                         p += nn, n -= nn;
1641                 } else {
1642                         /* should not happen */
1643                         break;
1644                 }
1645                 if(n >= 2){
1646                         if(c->version >= TLS12Version){
1647                                 /* signature hash algorithm */
1648                                 p += 2, n -= 2;
1649                                 if(n < 2)
1650                                         goto Short;
1651                         }
1652                         nn = get16(p);
1653                         p += 2, n -= 2;
1654                         if(nn > 0 && nn <= n){
1655                                 m->u.serverKeyExchange.dh_signature = makebytes(p, nn);
1656                                 n -= nn;
1657                         }
1658                 }
1659                 break;          
1660         case HClientKeyExchange:
1661                 /*
1662                  * this message depends upon the encryption selected
1663                  * assume rsa.
1664                  */
1665                 if(c->version == SSL3Version)
1666                         nn = n;
1667                 else{
1668                         if(n < 2)
1669                                 goto Short;
1670                         nn = get16(p);
1671                         p += 2;
1672                         n -= 2;
1673                 }
1674                 if(n < nn)
1675                         goto Short;
1676                 m->u.clientKeyExchange.key = makebytes(p, nn);
1677                 n -= nn;
1678                 break;
1679         case HFinished:
1680                 m->u.finished.n = c->finished.n;
1681                 if(n < m->u.finished.n)
1682                         goto Short;
1683                 memmove(m->u.finished.verify, p, m->u.finished.n);
1684                 n -= m->u.finished.n;
1685                 break;
1686         }
1687
1688         if(type != HClientHello && type != HServerHello && n != 0)
1689                 goto Short;
1690 Ok:
1691         if(c->trace){
1692                 char *buf;
1693                 buf = emalloc(8000);
1694                 c->trace("recv %s", msgPrint(buf, 8000, m));
1695                 free(buf);
1696         }
1697         return 1;
1698 Short:
1699         tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
1700 Err:
1701         msgClear(m);
1702         return 0;
1703 }
1704
1705 static void
1706 msgClear(Msg *m)
1707 {
1708         int i;
1709
1710         switch(m->tag) {
1711         default:
1712                 sysfatal("msgClear: unknown message type: %d", m->tag);
1713         case HHelloRequest:
1714                 break;
1715         case HClientHello:
1716                 freebytes(m->u.clientHello.sid);
1717                 freeints(m->u.clientHello.ciphers);
1718                 freebytes(m->u.clientHello.compressors);
1719                 freebytes(m->u.clientHello.extensions);
1720                 break;
1721         case HServerHello:
1722                 freebytes(m->u.serverHello.sid);
1723                 freebytes(m->u.serverHello.extensions);
1724                 break;
1725         case HCertificate:
1726                 for(i=0; i<m->u.certificate.ncert; i++)
1727                         freebytes(m->u.certificate.certs[i]);
1728                 free(m->u.certificate.certs);
1729                 break;
1730         case HCertificateRequest:
1731                 freebytes(m->u.certificateRequest.types);
1732                 for(i=0; i<m->u.certificateRequest.nca; i++)
1733                         freebytes(m->u.certificateRequest.cas[i]);
1734                 free(m->u.certificateRequest.cas);
1735                 break;
1736         case HCertificateVerify:
1737                 freebytes(m->u.certificateVerify.signature);
1738                 break;
1739         case HServerHelloDone:
1740                 break;
1741         case HServerKeyExchange:
1742                 freebytes(m->u.serverKeyExchange.dh_p);
1743                 freebytes(m->u.serverKeyExchange.dh_g);
1744                 freebytes(m->u.serverKeyExchange.dh_Ys);
1745                 freebytes(m->u.serverKeyExchange.dh_signature);
1746                 break;
1747         case HClientKeyExchange:
1748                 freebytes(m->u.clientKeyExchange.key);
1749                 break;
1750         case HFinished:
1751                 break;
1752         }
1753         memset(m, 0, sizeof(Msg));
1754 }
1755
1756 static char *
1757 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1758 {
1759         int i;
1760
1761         if(s0)
1762                 bs = seprint(bs, be, "%s", s0);
1763         if(b == nil)
1764                 bs = seprint(bs, be, "nil");
1765         else {
1766                 bs = seprint(bs, be, "<%d> [", b->len);
1767                 for(i=0; i<b->len; i++)
1768                         bs = seprint(bs, be, "%.2x ", b->data[i]);
1769         }
1770         bs = seprint(bs, be, "]");
1771         if(s1)
1772                 bs = seprint(bs, be, "%s", s1);
1773         return bs;
1774 }
1775
1776 static char *
1777 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1778 {
1779         int i;
1780
1781         if(s0)
1782                 bs = seprint(bs, be, "%s", s0);
1783         bs = seprint(bs, be, "[");
1784         if(b == nil)
1785                 bs = seprint(bs, be, "nil");
1786         else
1787                 for(i=0; i<b->len; i++)
1788                         bs = seprint(bs, be, "%x ", b->data[i]);
1789         bs = seprint(bs, be, "]");
1790         if(s1)
1791                 bs = seprint(bs, be, "%s", s1);
1792         return bs;
1793 }
1794
1795 static char*
1796 msgPrint(char *buf, int n, Msg *m)
1797 {
1798         int i;
1799         char *bs = buf, *be = buf+n;
1800
1801         switch(m->tag) {
1802         default:
1803                 bs = seprint(bs, be, "unknown %d\n", m->tag);
1804                 break;
1805         case HClientHello:
1806                 bs = seprint(bs, be, "ClientHello\n");
1807                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1808                 bs = seprint(bs, be, "\trandom: ");
1809                 for(i=0; i<RandomSize; i++)
1810                         bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1811                 bs = seprint(bs, be, "\n");
1812                 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1813                 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1814                 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1815                 if(m->u.clientHello.extensions != nil)
1816                         bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
1817                 break;
1818         case HServerHello:
1819                 bs = seprint(bs, be, "ServerHello\n");
1820                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1821                 bs = seprint(bs, be, "\trandom: ");
1822                 for(i=0; i<RandomSize; i++)
1823                         bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1824                 bs = seprint(bs, be, "\n");
1825                 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1826                 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1827                 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1828                 if(m->u.serverHello.extensions != nil)
1829                         bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
1830                 break;
1831         case HCertificate:
1832                 bs = seprint(bs, be, "Certificate\n");
1833                 for(i=0; i<m->u.certificate.ncert; i++)
1834                         bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1835                 break;
1836         case HCertificateRequest:
1837                 bs = seprint(bs, be, "CertificateRequest\n");
1838                 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1839                 bs = seprint(bs, be, "\tcertificateauthorities\n");
1840                 for(i=0; i<m->u.certificateRequest.nca; i++)
1841                         bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1842                 break;
1843         case HCertificateVerify:
1844                 bs = seprint(bs, be, "HCertificateVerify\n");
1845                 bs = bytesPrint(bs, be, "\tsignature: ", m->u.certificateVerify.signature,"\n");
1846                 break;  
1847         case HServerHelloDone:
1848                 bs = seprint(bs, be, "ServerHelloDone\n");
1849                 break;
1850         case HServerKeyExchange:
1851                 bs = seprint(bs, be, "HServerKeyExchange\n");
1852                 if(m->u.serverKeyExchange.curve != 0){
1853                         bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
1854                 } else {
1855                         bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
1856                         bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
1857                 }
1858                 bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
1859                 bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
1860                 break;
1861         case HClientKeyExchange:
1862                 bs = seprint(bs, be, "HClientKeyExchange\n");
1863                 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1864                 break;
1865         case HFinished:
1866                 bs = seprint(bs, be, "HFinished\n");
1867                 for(i=0; i<m->u.finished.n; i++)
1868                         bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1869                 bs = seprint(bs, be, "\n");
1870                 break;
1871         }
1872         USED(bs);
1873         return buf;
1874 }
1875
1876 static void
1877 tlsError(TlsConnection *c, int err, char *fmt, ...)
1878 {
1879         char msg[512];
1880         va_list arg;
1881
1882         va_start(arg, fmt);
1883         vseprint(msg, msg+sizeof(msg), fmt, arg);
1884         va_end(arg);
1885         if(c->trace)
1886                 c->trace("tlsError: %s\n", msg);
1887         else if(c->erred)
1888                 fprint(2, "double error: %r, %s", msg);
1889         else
1890                 werrstr("tls: local %s", msg);
1891         c->erred = 1;
1892         fprint(c->ctl, "alert %d", err);
1893 }
1894
1895 // commit to specific version number
1896 static int
1897 setVersion(TlsConnection *c, int version)
1898 {
1899         if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1900                 return -1;
1901         if(version > c->version)
1902                 version = c->version;
1903         if(version == SSL3Version) {
1904                 c->version = version;
1905                 c->finished.n = SSL3FinishedLen;
1906         }else {
1907                 c->version = version;
1908                 c->finished.n = TLSFinishedLen;
1909         }
1910         c->verset = 1;
1911         return fprint(c->ctl, "version 0x%x", version);
1912 }
1913
1914 // confirm that received Finished message matches the expected value
1915 static int
1916 finishedMatch(TlsConnection *c, Finished *f)
1917 {
1918         return constcmp(f->verify, c->finished.verify, f->n) == 0;
1919 }
1920
1921 // free memory associated with TlsConnection struct
1922 //              (but don't close the TLS channel itself)
1923 static void
1924 tlsConnectionFree(TlsConnection *c)
1925 {
1926         tlsSecClose(c->sec);
1927         freebytes(c->sid);
1928         freebytes(c->cert);
1929         memset(c, 0, sizeof(c));
1930         free(c);
1931 }
1932
1933
1934 //================= cipher choices ========================
1935
1936 static int weakCipher[CipherMax] =
1937 {
1938         1,      /* TLS_NULL_WITH_NULL_NULL */
1939         1,      /* TLS_RSA_WITH_NULL_MD5 */
1940         1,      /* TLS_RSA_WITH_NULL_SHA */
1941         1,      /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1942         0,      /* TLS_RSA_WITH_RC4_128_MD5 */
1943         0,      /* TLS_RSA_WITH_RC4_128_SHA */
1944         1,      /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1945         0,      /* TLS_RSA_WITH_IDEA_CBC_SHA */
1946         1,      /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1947         0,      /* TLS_RSA_WITH_DES_CBC_SHA */
1948         0,      /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1949         1,      /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1950         0,      /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1951         0,      /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1952         1,      /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1953         0,      /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1954         0,      /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1955         1,      /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1956         0,      /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1957         0,      /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1958         1,      /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1959         0,      /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1960         0,      /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1961         1,      /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1962         1,      /* TLS_DH_anon_WITH_RC4_128_MD5 */
1963         1,      /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1964         1,      /* TLS_DH_anon_WITH_DES_CBC_SHA */
1965         1,      /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1966 };
1967
1968 static int
1969 setAlgs(TlsConnection *c, int a)
1970 {
1971         int i;
1972
1973         for(i = 0; i < nelem(cipherAlgs); i++){
1974                 if(cipherAlgs[i].tlsid == a){
1975                         c->cipher = a;
1976                         c->enc = cipherAlgs[i].enc;
1977                         c->digest = cipherAlgs[i].digest;
1978                         c->nsecret = cipherAlgs[i].nsecret;
1979                         if(c->nsecret > MaxKeyData)
1980                                 return 0;
1981                         return 1;
1982                 }
1983         }
1984         return 0;
1985 }
1986
1987 static int
1988 okCipher(Ints *cv)
1989 {
1990         int weak, i, j, c;
1991
1992         weak = 1;
1993         for(i = 0; i < cv->len; i++) {
1994                 c = cv->data[i];
1995                 if(c >= CipherMax)
1996                         weak = 0;
1997                 else
1998                         weak &= weakCipher[c];
1999                 if(isDHE(c) || isECDHE(c))
2000                         continue;       /* TODO: not implemented for server */
2001                 for(j = 0; j < nelem(cipherAlgs); j++)
2002                         if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
2003                                 return c;
2004         }
2005         if(weak)
2006                 return -2;
2007         return -1;
2008 }
2009
2010 static int
2011 okCompression(Bytes *cv)
2012 {
2013         int i, j, c;
2014
2015         for(i = 0; i < cv->len; i++) {
2016                 c = cv->data[i];
2017                 for(j = 0; j < nelem(compressors); j++) {
2018                         if(compressors[j] == c)
2019                                 return c;
2020                 }
2021         }
2022         return -1;
2023 }
2024
2025 static Lock     ciphLock;
2026 static int      nciphers;
2027
2028 static int
2029 initCiphers(void)
2030 {
2031         enum {MaxAlgF = 1024, MaxAlgs = 10};
2032         char s[MaxAlgF], *flds[MaxAlgs];
2033         int i, j, n, ok;
2034
2035         lock(&ciphLock);
2036         if(nciphers){
2037                 unlock(&ciphLock);
2038                 return nciphers;
2039         }
2040         j = open("#a/tls/encalgs", OREAD);
2041         if(j < 0){
2042                 werrstr("can't open #a/tls/encalgs: %r");
2043                 return 0;
2044         }
2045         n = read(j, s, MaxAlgF-1);
2046         close(j);
2047         if(n <= 0){
2048                 werrstr("nothing in #a/tls/encalgs: %r");
2049                 return 0;
2050         }
2051         s[n] = 0;
2052         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2053         for(i = 0; i < nelem(cipherAlgs); i++){
2054                 ok = 0;
2055                 for(j = 0; j < n; j++){
2056                         if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
2057                                 ok = 1;
2058                                 break;
2059                         }
2060                 }
2061                 cipherAlgs[i].ok = ok;
2062         }
2063
2064         j = open("#a/tls/hashalgs", OREAD);
2065         if(j < 0){
2066                 werrstr("can't open #a/tls/hashalgs: %r");
2067                 return 0;
2068         }
2069         n = read(j, s, MaxAlgF-1);
2070         close(j);
2071         if(n <= 0){
2072                 werrstr("nothing in #a/tls/hashalgs: %r");
2073                 return 0;
2074         }
2075         s[n] = 0;
2076         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2077         for(i = 0; i < nelem(cipherAlgs); i++){
2078                 ok = 0;
2079                 for(j = 0; j < n; j++){
2080                         if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
2081                                 ok = 1;
2082                                 break;
2083                         }
2084                 }
2085                 cipherAlgs[i].ok &= ok;
2086                 if(cipherAlgs[i].ok)
2087                         nciphers++;
2088         }
2089         unlock(&ciphLock);
2090         return nciphers;
2091 }
2092
2093 static Ints*
2094 makeciphers(void)
2095 {
2096         Ints *is;
2097         int i, j;
2098
2099         is = newints(nciphers);
2100         j = 0;
2101         for(i = 0; i < nelem(cipherAlgs); i++){
2102                 if(cipherAlgs[i].ok)
2103                         is->data[j++] = cipherAlgs[i].tlsid;
2104         }
2105         return is;
2106 }
2107
2108
2109
2110 //================= security functions ========================
2111
2112 // given X.509 certificate, set up connection to factotum
2113 //      for using corresponding private key
2114 static AuthRpc*
2115 factotum_rsa_open(uchar *cert, int certlen)
2116 {
2117         int afd;
2118         char *s;
2119         mpint *pub = nil;
2120         RSApub *rsapub;
2121         AuthRpc *rpc;
2122
2123         // start talking to factotum
2124         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
2125                 return nil;
2126         if((rpc = auth_allocrpc(afd)) == nil){
2127                 close(afd);
2128                 return nil;
2129         }
2130         s = "proto=rsa service=tls role=client";
2131         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
2132                 factotum_rsa_close(rpc);
2133                 return nil;
2134         }
2135
2136         // roll factotum keyring around to match certificate
2137         rsapub = X509toRSApub(cert, certlen, nil, 0);
2138         while(1){
2139                 if(auth_rpc(rpc, "read", nil, 0) != ARok){
2140                         factotum_rsa_close(rpc);
2141                         rpc = nil;
2142                         goto done;
2143                 }
2144                 pub = strtomp(rpc->arg, nil, 16, nil);
2145                 assert(pub != nil);
2146                 if(mpcmp(pub,rsapub->n) == 0)
2147                         break;
2148         }
2149 done:
2150         mpfree(pub);
2151         rsapubfree(rsapub);
2152         return rpc;
2153 }
2154
2155 static mpint*
2156 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
2157 {
2158         char *p;
2159         int rv;
2160
2161         p = mptoa(cipher, 16, nil, 0);
2162         mpfree(cipher);
2163         if(p == nil)
2164                 return nil;
2165         rv = auth_rpc(rpc, "write", p, strlen(p));
2166         free(p);
2167         if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
2168                 return nil;
2169         return strtomp(rpc->arg, nil, 16, nil);
2170 }
2171
2172 static void
2173 factotum_rsa_close(AuthRpc*rpc)
2174 {
2175         if(!rpc)
2176                 return;
2177         close(rpc->afd);
2178         auth_freerpc(rpc);
2179 }
2180
2181 static void
2182 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2183 {
2184         uchar ai[MD5dlen], tmp[MD5dlen];
2185         int i, n;
2186         MD5state *s;
2187
2188         // generate a1
2189         s = hmac_md5(label, nlabel, key, nkey, nil, nil);
2190         s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2191         hmac_md5(seed1, nseed1, key, nkey, ai, s);
2192
2193         while(nbuf > 0) {
2194                 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
2195                 s = hmac_md5(label, nlabel, key, nkey, nil, s);
2196                 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2197                 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
2198                 n = MD5dlen;
2199                 if(n > nbuf)
2200                         n = nbuf;
2201                 for(i = 0; i < n; i++)
2202                         buf[i] ^= tmp[i];
2203                 buf += n;
2204                 nbuf -= n;
2205                 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
2206                 memmove(ai, tmp, MD5dlen);
2207         }
2208 }
2209
2210 static void
2211 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2212 {
2213         uchar ai[SHA1dlen], tmp[SHA1dlen];
2214         int i, n;
2215         SHAstate *s;
2216
2217         // generate a1
2218         s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
2219         s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2220         hmac_sha1(seed1, nseed1, key, nkey, ai, s);
2221
2222         while(nbuf > 0) {
2223                 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
2224                 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
2225                 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2226                 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
2227                 n = SHA1dlen;
2228                 if(n > nbuf)
2229                         n = nbuf;
2230                 for(i = 0; i < n; i++)
2231                         buf[i] ^= tmp[i];
2232                 buf += n;
2233                 nbuf -= n;
2234                 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
2235                 memmove(ai, tmp, SHA1dlen);
2236         }
2237 }
2238
2239 static void
2240 p_sha256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
2241 {
2242         uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
2243         SHAstate *s;
2244         int n;
2245
2246         // generate a1
2247         s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
2248         hmac_sha2_256(seed, nseed, key, nkey, ai, s);
2249
2250         while(nbuf > 0) {
2251                 s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
2252                 s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
2253                 hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
2254                 n = SHA2_256dlen;
2255                 if(n > nbuf)
2256                         n = nbuf;
2257                 memmove(buf, tmp, n);
2258                 buf += n;
2259                 nbuf -= n;
2260                 hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
2261                 memmove(ai, tmp, SHA2_256dlen);
2262         }
2263 }
2264
2265 // fill buf with md5(args)^sha1(args)
2266 static void
2267 tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2268 {
2269         int nlabel = strlen(label);
2270         int n = (nkey + 1) >> 1;
2271
2272         memset(buf, 0, nbuf);
2273         tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2274         tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2275 }
2276
2277 static void
2278 tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2279 {
2280         uchar seed[2*RandomSize];
2281
2282         assert(nseed0+nseed1 <= sizeof(seed));
2283         memmove(seed, seed0, nseed0);
2284         memmove(seed+nseed0, seed1, nseed1);
2285         p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1);
2286 }
2287
2288 /*
2289  * for setting server session id's
2290  */
2291 static Lock     sidLock;
2292 static long     maxSid = 1;
2293
2294 /* the keys are verified to have the same public components
2295  * and to function correctly with pkcs 1 encryption and decryption. */
2296 static TlsSec*
2297 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
2298 {
2299         TlsSec *sec = emalloc(sizeof(*sec));
2300
2301         USED(csid); USED(ncsid);  // ignore csid for now
2302
2303         memmove(sec->crandom, crandom, RandomSize);
2304         sec->clientVers = cvers;
2305
2306         put32(sec->srandom, time(0));
2307         genrandom(sec->srandom+4, RandomSize-4);
2308         memmove(srandom, sec->srandom, RandomSize);
2309
2310         /*
2311          * make up a unique sid: use our pid, and and incrementing id
2312          * can signal no sid by setting nssid to 0.
2313          */
2314         memset(ssid, 0, SidSize);
2315         put32(ssid, getpid());
2316         lock(&sidLock);
2317         put32(ssid+4, maxSid++);
2318         unlock(&sidLock);
2319         *nssid = SidSize;
2320         return sec;
2321 }
2322
2323 static int
2324 tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm)
2325 {
2326         if(epm != nil){
2327                 if(setVers(sec, vers) < 0)
2328                         goto Err;
2329                 serverMasterSecret(sec, epm);
2330         }else if(sec->vers != vers){
2331                 werrstr("mismatched session versions");
2332                 goto Err;
2333         }
2334         return 0;
2335 Err:
2336         sec->ok = -1;
2337         return -1;
2338 }
2339
2340 static TlsSec*
2341 tlsSecInitc(int cvers, uchar *crandom)
2342 {
2343         TlsSec *sec = emalloc(sizeof(*sec));
2344         sec->clientVers = cvers;
2345         put32(sec->crandom, time(0));
2346         genrandom(sec->crandom+4, RandomSize-4);
2347         memmove(crandom, sec->crandom, RandomSize);
2348         return sec;
2349 }
2350
2351 static Bytes*
2352 tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
2353 {
2354         RSApub *pub;
2355         Bytes *epm;
2356
2357         USED(sid);
2358         USED(nsid);
2359         
2360         memmove(sec->srandom, srandom, RandomSize);
2361         if(setVers(sec, vers) < 0)
2362                 goto Err;
2363         pub = X509toRSApub(cert, ncert, nil, 0);
2364         if(pub == nil){
2365                 werrstr("invalid x509/rsa certificate");
2366                 goto Err;
2367         }
2368         epm = clientMasterSecret(sec, pub);
2369         rsapubfree(pub);
2370         if(epm != nil)
2371                 return epm;
2372 Err:
2373         sec->ok = -1;
2374         return nil;
2375 }
2376
2377 static int
2378 tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient)
2379 {
2380         if(sec->nfin != nfin){
2381                 sec->ok = -1;
2382                 werrstr("invalid finished exchange");
2383                 return -1;
2384         }
2385         hsh.md5.malloced = 0;
2386         hsh.sha1.malloced = 0;
2387         hsh.sha2_256.malloced = 0;
2388         (*sec->setFinished)(sec, hsh, fin, isclient);
2389         return 1;
2390 }
2391
2392 static void
2393 tlsSecOk(TlsSec *sec)
2394 {
2395         if(sec->ok == 0)
2396                 sec->ok = 1;
2397 }
2398
2399 static void
2400 tlsSecKill(TlsSec *sec)
2401 {
2402         if(!sec)
2403                 return;
2404         factotum_rsa_close(sec->rpc);
2405         sec->ok = -1;
2406 }
2407
2408 static void
2409 tlsSecClose(TlsSec *sec)
2410 {
2411         if(!sec)
2412                 return;
2413         factotum_rsa_close(sec->rpc);
2414         free(sec->server);
2415         free(sec);
2416 }
2417
2418 static int
2419 setVers(TlsSec *sec, int v)
2420 {
2421         if(v == SSL3Version){
2422                 sec->setFinished = sslSetFinished;
2423                 sec->nfin = SSL3FinishedLen;
2424                 sec->prf = sslPRF;
2425         }else if(v < TLS12Version) {
2426                 sec->setFinished = tls10SetFinished;
2427                 sec->nfin = TLSFinishedLen;
2428                 sec->prf = tls10PRF;
2429         }else {
2430                 sec->setFinished = tls12SetFinished;
2431                 sec->nfin = TLSFinishedLen;
2432                 sec->prf = tls12PRF;
2433         }
2434         sec->vers = v;
2435         return 0;
2436 }
2437
2438 /*
2439  * generate secret keys from the master secret.
2440  *
2441  * different crypto selections will require different amounts
2442  * of key expansion and use of key expansion data,
2443  * but it's all generated using the same function.
2444  */
2445 static void
2446 setSecrets(TlsSec *sec, uchar *kd, int nkd)
2447 {
2448         (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
2449                         sec->srandom, RandomSize, sec->crandom, RandomSize);
2450 }
2451
2452 /*
2453  * set the master secret from the pre-master secret,
2454  * destroys premaster.
2455  */
2456 static void
2457 setMasterSecret(TlsSec *sec, Bytes *pm)
2458 {
2459         (*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
2460                         sec->crandom, RandomSize, sec->srandom, RandomSize);
2461
2462         memset(pm->data, 0, pm->len);   
2463         freebytes(pm);
2464 }
2465
2466 static void
2467 serverMasterSecret(TlsSec *sec, Bytes *epm)
2468 {
2469         Bytes *pm;
2470
2471         pm = pkcs1_decrypt(sec, epm);
2472
2473         // if the client messed up, just continue as if everything is ok,
2474         // to prevent attacks to check for correctly formatted messages.
2475         // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
2476         if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
2477                 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
2478                         sec->ok, pm, pm != nil ? get16(pm->data) : -1, sec->clientVers, epm->len);
2479                 sec->ok = -1;
2480                 freebytes(pm);
2481                 pm = newbytes(MasterSecretSize);
2482                 genrandom(pm->data, MasterSecretSize);
2483         }
2484         assert(pm->len == MasterSecretSize);
2485         setMasterSecret(sec, pm);
2486 }
2487
2488 static Bytes*
2489 clientMasterSecret(TlsSec *sec, RSApub *pub)
2490 {
2491         Bytes *pm, *epm;
2492
2493         pm = newbytes(MasterSecretSize);
2494         put16(pm->data, sec->clientVers);
2495         genrandom(pm->data+2, MasterSecretSize - 2);
2496         epm = pkcs1_encrypt(pm, pub, 2);
2497         setMasterSecret(sec, pm);
2498         return epm;
2499 }
2500
2501 static void
2502 sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2503 {
2504         DigestState *s;
2505         uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
2506         char *label;
2507
2508         if(isClient)
2509                 label = "CLNT";
2510         else
2511                 label = "SRVR";
2512
2513         md5((uchar*)label, 4, nil, &hsh.md5);
2514         md5(sec->sec, MasterSecretSize, nil, &hsh.md5);
2515         memset(pad, 0x36, 48);
2516         md5(pad, 48, nil, &hsh.md5);
2517         md5(nil, 0, h0, &hsh.md5);
2518         memset(pad, 0x5C, 48);
2519         s = md5(sec->sec, MasterSecretSize, nil, nil);
2520         s = md5(pad, 48, nil, s);
2521         md5(h0, MD5dlen, finished, s);
2522
2523         sha1((uchar*)label, 4, nil, &hsh.sha1);
2524         sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1);
2525         memset(pad, 0x36, 40);
2526         sha1(pad, 40, nil, &hsh.sha1);
2527         sha1(nil, 0, h1, &hsh.sha1);
2528         memset(pad, 0x5C, 40);
2529         s = sha1(sec->sec, MasterSecretSize, nil, nil);
2530         s = sha1(pad, 40, nil, s);
2531         sha1(h1, SHA1dlen, finished + MD5dlen, s);
2532 }
2533
2534 // fill "finished" arg with md5(args)^sha1(args)
2535 static void
2536 tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2537 {
2538         uchar h0[MD5dlen], h1[SHA1dlen];
2539         char *label;
2540
2541         // get current hash value, but allow further messages to be hashed in
2542         md5(nil, 0, h0, &hsh.md5);
2543         sha1(nil, 0, h1, &hsh.sha1);
2544
2545         if(isClient)
2546                 label = "client finished";
2547         else
2548                 label = "server finished";
2549         tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2550 }
2551
2552 static void
2553 tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
2554 {
2555         uchar seed[SHA2_256dlen];
2556         char *label;
2557
2558         // get current hash value, but allow further messages to be hashed in
2559         sha2_256(nil, 0, seed, &hsh.sha2_256);
2560
2561         if(isClient)
2562                 label = "client finished";
2563         else
2564                 label = "server finished";
2565         p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen);
2566 }
2567
2568 static void
2569 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2570 {
2571         uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2572         DigestState *s;
2573         int i, n, len;
2574
2575         USED(label);
2576         len = 1;
2577         while(nbuf > 0){
2578                 if(len > 26)
2579                         return;
2580                 for(i = 0; i < len; i++)
2581                         tmp[i] = 'A' - 1 + len;
2582                 s = sha1(tmp, len, nil, nil);
2583                 s = sha1(key, nkey, nil, s);
2584                 s = sha1(seed0, nseed0, nil, s);
2585                 sha1(seed1, nseed1, sha1dig, s);
2586                 s = md5(key, nkey, nil, nil);
2587                 md5(sha1dig, SHA1dlen, md5dig, s);
2588                 n = MD5dlen;
2589                 if(n > nbuf)
2590                         n = nbuf;
2591                 memmove(buf, md5dig, n);
2592                 buf += n;
2593                 nbuf -= n;
2594                 len++;
2595         }
2596 }
2597
2598 static mpint*
2599 bytestomp(Bytes* bytes)
2600 {
2601         return betomp(bytes->data, bytes->len, nil);
2602 }
2603
2604 /*
2605  * Convert mpint* to Bytes, putting high order byte first.
2606  */
2607 static Bytes*
2608 mptobytes(mpint* big)
2609 {
2610         Bytes* ans;
2611         int n;
2612
2613         n = (mpsignif(big)+7)/8;
2614         if(n == 0) n = 1;
2615         ans = newbytes(n);
2616         ans->len = mptobe(big, ans->data, n, nil);
2617         return ans;
2618 }
2619
2620 // Do RSA computation on block according to key, and pad
2621 // result on left with zeros to make it modlen long.
2622 static Bytes*
2623 rsacomp(Bytes* block, RSApub* key, int modlen)
2624 {
2625         mpint *x, *y;
2626         Bytes *a, *ybytes;
2627         int ylen;
2628
2629         x = bytestomp(block);
2630         y = rsaencrypt(key, x, nil);
2631         mpfree(x);
2632         ybytes = mptobytes(y);
2633         ylen = ybytes->len;
2634         mpfree(y);
2635
2636         if(ylen < modlen) {
2637                 a = newbytes(modlen);
2638                 memset(a->data, 0, modlen-ylen);
2639                 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2640                 freebytes(ybytes);
2641                 ybytes = a;
2642         }
2643         else if(ylen > modlen) {
2644                 // assume it has leading zeros (mod should make it so)
2645                 a = newbytes(modlen);
2646                 memmove(a->data, ybytes->data, modlen);
2647                 freebytes(ybytes);
2648                 ybytes = a;
2649         }
2650         return ybytes;
2651 }
2652
2653 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2654 static Bytes*
2655 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2656 {
2657         Bytes *pad, *eb, *ans;
2658         int i, dlen, padlen, modlen;
2659
2660         modlen = (mpsignif(key->n)+7)/8;
2661         dlen = data->len;
2662         if(modlen < 12 || dlen > modlen - 11)
2663                 return nil;
2664         padlen = modlen - 3 - dlen;
2665         pad = newbytes(padlen);
2666         genrandom(pad->data, padlen);
2667         for(i = 0; i < padlen; i++) {
2668                 if(blocktype == 0)
2669                         pad->data[i] = 0;
2670                 else if(blocktype == 1)
2671                         pad->data[i] = 255;
2672                 else if(pad->data[i] == 0)
2673                         pad->data[i] = 1;
2674         }
2675         eb = newbytes(modlen);
2676         eb->data[0] = 0;
2677         eb->data[1] = blocktype;
2678         memmove(eb->data+2, pad->data, padlen);
2679         eb->data[padlen+2] = 0;
2680         memmove(eb->data+padlen+3, data->data, dlen);
2681         ans = rsacomp(eb, key, modlen);
2682         freebytes(eb);
2683         freebytes(pad);
2684         return ans;
2685 }
2686
2687 // decrypt data according to PKCS#1, with given key.
2688 // expect a block type of 2.
2689 static Bytes*
2690 pkcs1_decrypt(TlsSec *sec, Bytes *cipher)
2691 {
2692         Bytes *eb, *ans = nil;
2693         int i, modlen;
2694         mpint *x, *y;
2695
2696         modlen = (mpsignif(sec->rsapub->n)+7)/8;
2697         if(cipher->len != modlen)
2698                 return nil;
2699         x = bytestomp(cipher);
2700         y = factotum_rsa_decrypt(sec->rpc, x);
2701         if(y == nil)
2702                 return nil;
2703         eb = mptobytes(y);
2704         mpfree(y);
2705         if(eb->len < modlen){ // pad on left with zeros
2706                 ans = newbytes(modlen);
2707                 memset(ans->data, 0, modlen-eb->len);
2708                 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2709                 freebytes(eb);
2710                 eb = ans;
2711         }
2712         if(eb->data[0] == 0 && eb->data[1] == 2) {
2713                 for(i = 2; i < modlen; i++)
2714                         if(eb->data[i] == 0)
2715                                 break;
2716                 if(i < modlen - 1)
2717                         ans = makebytes(eb->data+i+1, modlen-(i+1));
2718         }
2719         freebytes(eb);
2720         return ans;
2721 }
2722
2723
2724 //================= general utility functions ========================
2725
2726 static void *
2727 emalloc(int n)
2728 {
2729         void *p;
2730         if(n==0)
2731                 n=1;
2732         p = malloc(n);
2733         if(p == nil)
2734                 sysfatal("out of memory");
2735         memset(p, 0, n);
2736         setmalloctag(p, getcallerpc(&n));
2737         return p;
2738 }
2739
2740 static void *
2741 erealloc(void *ReallocP, int ReallocN)
2742 {
2743         if(ReallocN == 0)
2744                 ReallocN = 1;
2745         if(ReallocP == nil)
2746                 ReallocP = emalloc(ReallocN);
2747         else if((ReallocP = realloc(ReallocP, ReallocN)) == nil)
2748                 sysfatal("out of memory");
2749         setrealloctag(ReallocP, getcallerpc(&ReallocP));
2750         return(ReallocP);
2751 }
2752
2753 static void
2754 put32(uchar *p, u32int x)
2755 {
2756         p[0] = x>>24;
2757         p[1] = x>>16;
2758         p[2] = x>>8;
2759         p[3] = x;
2760 }
2761
2762 static void
2763 put24(uchar *p, int x)
2764 {
2765         p[0] = x>>16;
2766         p[1] = x>>8;
2767         p[2] = x;
2768 }
2769
2770 static void
2771 put16(uchar *p, int x)
2772 {
2773         p[0] = x>>8;
2774         p[1] = x;
2775 }
2776
2777 static u32int
2778 get32(uchar *p)
2779 {
2780         return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2781 }
2782
2783 static int
2784 get24(uchar *p)
2785 {
2786         return (p[0]<<16)|(p[1]<<8)|p[2];
2787 }
2788
2789 static int
2790 get16(uchar *p)
2791 {
2792         return (p[0]<<8)|p[1];
2793 }
2794
2795 #define OFFSET(x, s) offsetof(s, x)
2796
2797 static Bytes*
2798 newbytes(int len)
2799 {
2800         Bytes* ans;
2801
2802         if(len < 0)
2803                 abort();
2804         ans = (Bytes*)emalloc(OFFSET(data[0], Bytes) + len);
2805         ans->len = len;
2806         return ans;
2807 }
2808
2809 /*
2810  * newbytes(len), with data initialized from buf
2811  */
2812 static Bytes*
2813 makebytes(uchar* buf, int len)
2814 {
2815         Bytes* ans;
2816
2817         ans = newbytes(len);
2818         memmove(ans->data, buf, len);
2819         return ans;
2820 }
2821
2822 static void
2823 freebytes(Bytes* b)
2824 {
2825         free(b);
2826 }
2827
2828 /* len is number of ints */
2829 static Ints*
2830 newints(int len)
2831 {
2832         Ints* ans;
2833
2834         if(len < 0 || len > ((uint)-1>>1)/sizeof(int))
2835                 abort();
2836         ans = (Ints*)emalloc(OFFSET(data[0], Ints) + len*sizeof(int));
2837         ans->len = len;
2838         return ans;
2839 }
2840
2841 static void
2842 freeints(Ints* b)
2843 {
2844         free(b);
2845 }