]> git.lizzy.rs Git - plan9front.git/blob - sys/src/libsec/port/tlshand.c
libsec/tlshand: implement client side ECDHE (many thanks to pr!)
[plan9front.git] / sys / src / libsec / port / tlshand.c
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include <auth.h>
5 #include <mp.h>
6 #include <libsec.h>
7
8 // The main groups of functions are:
9 //              client/server - main handshake protocol definition
10 //              message functions - formating handshake messages
11 //              cipher choices - catalog of digest and encrypt algorithms
12 //              security functions - PKCS#1, sslHMAC, session keygen
13 //              general utility functions - malloc, serialization
14 // The handshake protocol builds on the TLS/SSL3 record layer protocol,
15 // which is implemented in kernel device #a.  See also /lib/rfc/rfc2246.
16
17 enum {
18         TLSFinishedLen = 12,
19         SSL3FinishedLen = MD5dlen+SHA1dlen,
20         MaxKeyData = 136,       // amount of secret we may need
21         MaxChunk = 1<<14,
22         RandomSize = 32,
23         SidSize = 32,
24         MasterSecretSize = 48,
25         AQueue = 0,
26         AFlush = 1,
27 };
28
29 typedef struct TlsSec TlsSec;
30
31 typedef struct Bytes{
32         int len;
33         uchar data[1];  // [len]
34 } Bytes;
35
36 typedef struct Ints{
37         int len;
38         int data[1];  // [len]
39 } Ints;
40
41 typedef struct Algs{
42         char *enc;
43         char *digest;
44         int nsecret;
45         int tlsid;
46         int ok;
47 } Algs;
48
49 typedef struct Namedcurve{
50         int tlsid;
51         char *name;
52
53         char *p;
54         char *a;
55         char *b;
56         char *G;
57         char *n;
58         char *h;
59 } Namedcurve;
60
61 typedef struct Finished{
62         uchar verify[SSL3FinishedLen];
63         int n;
64 } Finished;
65
66 typedef struct TlsConnection{
67         TlsSec *sec;    // security management goo
68         int hand, ctl;  // record layer file descriptors
69         int erred;              // set when tlsError called
70         int (*trace)(char*fmt, ...); // for debugging
71         int version;    // protocol we are speaking
72         int verset;             // version has been set
73         int ver2hi;             // server got a version 2 hello
74         int isClient;   // is this the client or server?
75         Bytes *sid;             // SessionID
76         Bytes *cert;    // only last - no chain
77
78         Lock statelk;
79         int state;              // must be set using setstate
80
81         // input buffer for handshake messages
82         uchar recvbuf[MaxChunk];
83         uchar *rp, *ep;
84
85         // output buffer
86         uchar sendbuf[MaxChunk];
87         uchar *sendp;
88
89         uchar crandom[RandomSize];      // client random
90         uchar srandom[RandomSize];      // server random
91         int clientVersion;      // version in ClientHello
92         int cipher;
93         char *digest;   // name of digest algorithm to use
94         char *enc;              // name of encryption algorithm to use
95         int nsecret;    // amount of secret data to init keys
96
97         // for finished messages
98         MD5state        hsmd5;  // handshake hash
99         SHAstate        hssha1; // handshake hash
100         Finished        finished;
101 } TlsConnection;
102
103 typedef struct Msg{
104         int tag;
105         union {
106                 struct {
107                         int version;
108                         uchar   random[RandomSize];
109                         Bytes*  sid;
110                         Ints*   ciphers;
111                         Bytes*  compressors;
112                         Bytes*  extensions;
113                 } clientHello;
114                 struct {
115                         int version;
116                         uchar   random[RandomSize];
117                         Bytes*  sid;
118                         int     cipher;
119                         int     compressor;
120                         Bytes*  extensions;
121                 } serverHello;
122                 struct {
123                         int ncert;
124                         Bytes **certs;
125                 } certificate;
126                 struct {
127                         Bytes *types;
128                         int nca;
129                         Bytes **cas;
130                 } certificateRequest;
131                 struct {
132                         Bytes *key;
133                 } clientKeyExchange;
134                 struct {
135                         Bytes *dh_p;
136                         Bytes *dh_g;
137                         Bytes *dh_Ys;
138                         Bytes *dh_signature;
139                         int curve;
140                 } serverKeyExchange;
141                 struct {
142                         Bytes *signature;
143                 } certificateVerify;            
144                 Finished finished;
145         } u;
146 } Msg;
147
148 typedef struct TlsSec{
149         char *server;   // name of remote; nil for server
150         int ok; // <0 killed; == 0 in progress; >0 reusable
151         RSApub *rsapub;
152         AuthRpc *rpc;   // factotum for rsa private key
153         uchar sec[MasterSecretSize];    // master secret
154         uchar crandom[RandomSize];      // client random
155         uchar srandom[RandomSize];      // server random
156         int clientVers;         // version in ClientHello
157         int vers;                       // final version
158         // byte generation and handshake checksum
159         void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
160         void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
161         int nfin;
162 } TlsSec;
163
164
165 enum {
166         TLSVersion = 0x0301,
167         SSL3Version = 0x0300,
168         ProtocolVersion = 0x0301,       // maximum version we speak
169         MinProtoVersion = 0x0300,       // limits on version we accept
170         MaxProtoVersion = 0x03ff,
171 };
172
173 // handshake type
174 enum {
175         HHelloRequest,
176         HClientHello,
177         HServerHello,
178         HSSL2ClientHello = 9,  /* local convention;  see devtls.c */
179         HCertificate = 11,
180         HServerKeyExchange,
181         HCertificateRequest,
182         HServerHelloDone,
183         HCertificateVerify,
184         HClientKeyExchange,
185         HFinished = 20,
186         HMax
187 };
188
189 // alerts
190 enum {
191         ECloseNotify = 0,
192         EUnexpectedMessage = 10,
193         EBadRecordMac = 20,
194         EDecryptionFailed = 21,
195         ERecordOverflow = 22,
196         EDecompressionFailure = 30,
197         EHandshakeFailure = 40,
198         ENoCertificate = 41,
199         EBadCertificate = 42,
200         EUnsupportedCertificate = 43,
201         ECertificateRevoked = 44,
202         ECertificateExpired = 45,
203         ECertificateUnknown = 46,
204         EIllegalParameter = 47,
205         EUnknownCa = 48,
206         EAccessDenied = 49,
207         EDecodeError = 50,
208         EDecryptError = 51,
209         EExportRestriction = 60,
210         EProtocolVersion = 70,
211         EInsufficientSecurity = 71,
212         EInternalError = 80,
213         EUserCanceled = 90,
214         ENoRenegotiation = 100,
215         EMax = 256
216 };
217
218 // cipher suites
219 enum {
220         TLS_NULL_WITH_NULL_NULL                 = 0x0000,
221         TLS_RSA_WITH_NULL_MD5                   = 0x0001,
222         TLS_RSA_WITH_NULL_SHA                   = 0x0002,
223         TLS_RSA_EXPORT_WITH_RC4_40_MD5          = 0x0003,
224         TLS_RSA_WITH_RC4_128_MD5                = 0x0004,
225         TLS_RSA_WITH_RC4_128_SHA                = 0x0005,
226         TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5      = 0X0006,
227         TLS_RSA_WITH_IDEA_CBC_SHA               = 0X0007,
228         TLS_RSA_EXPORT_WITH_DES40_CBC_SHA       = 0X0008,
229         TLS_RSA_WITH_DES_CBC_SHA                = 0X0009,
230         TLS_RSA_WITH_3DES_EDE_CBC_SHA           = 0X000A,
231         TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA    = 0X000B,
232         TLS_DH_DSS_WITH_DES_CBC_SHA             = 0X000C,
233         TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA        = 0X000D,
234         TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA    = 0X000E,
235         TLS_DH_RSA_WITH_DES_CBC_SHA             = 0X000F,
236         TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA        = 0X0010,
237         TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA   = 0X0011,
238         TLS_DHE_DSS_WITH_DES_CBC_SHA            = 0X0012,
239         TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA       = 0X0013,       // ZZZ must be implemented for tls1.0 compliance
240         TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA   = 0X0014,
241         TLS_DHE_RSA_WITH_DES_CBC_SHA            = 0X0015,
242         TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA       = 0X0016,
243         TLS_DH_anon_EXPORT_WITH_RC4_40_MD5      = 0x0017,
244         TLS_DH_anon_WITH_RC4_128_MD5            = 0x0018,
245         TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA   = 0X0019,
246         TLS_DH_anon_WITH_DES_CBC_SHA            = 0X001A,
247         TLS_DH_anon_WITH_3DES_EDE_CBC_SHA       = 0X001B,
248
249         TLS_RSA_WITH_AES_128_CBC_SHA            = 0X002f,       // aes, aka rijndael with 128 bit blocks
250         TLS_DH_DSS_WITH_AES_128_CBC_SHA         = 0X0030,
251         TLS_DH_RSA_WITH_AES_128_CBC_SHA         = 0X0031,
252         TLS_DHE_DSS_WITH_AES_128_CBC_SHA        = 0X0032,
253         TLS_DHE_RSA_WITH_AES_128_CBC_SHA        = 0X0033,
254         TLS_DH_anon_WITH_AES_128_CBC_SHA        = 0X0034,
255         TLS_RSA_WITH_AES_256_CBC_SHA            = 0X0035,
256         TLS_DH_DSS_WITH_AES_256_CBC_SHA         = 0X0036,
257         TLS_DH_RSA_WITH_AES_256_CBC_SHA         = 0X0037,
258         TLS_DHE_DSS_WITH_AES_256_CBC_SHA        = 0X0038,
259         TLS_DHE_RSA_WITH_AES_256_CBC_SHA        = 0X0039,
260         TLS_DH_anon_WITH_AES_256_CBC_SHA        = 0X003A,
261         
262         TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA      = 0xC013,
263         TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA      = 0xC014,
264         TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA  = 0xC009,
265         TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A,
266         CipherMax
267 };
268
269 // compression methods
270 enum {
271         CompressionNull = 0,
272         CompressionMax
273 };
274
275 static Algs cipherAlgs[] = {
276         {"rc4_128", "md5",      2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
277         {"rc4_128", "sha1",     2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
278         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
279         {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
280         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
281         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
282         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
283         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
284
285         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
286         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
287         {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA},
288         {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
289 };
290
291 static uchar compressors[] = {
292         CompressionNull,
293 };
294
295 static Namedcurve namedcurves[] = {
296 {0x0017, "secp256r1",
297         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
298         "FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC",
299         "5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B",
300         "046B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C2964FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
301         "FFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551",
302         "1"}
303 };
304
305 static uchar pointformats[] = {
306         CompressionNull /* support of uncompressed point format is mandatory */
307 };
308
309 static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain);
310 static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...));
311 static void     msgClear(Msg *m);
312 static char* msgPrint(char *buf, int n, Msg *m);
313 static int      msgRecv(TlsConnection *c, Msg *m);
314 static int      msgSend(TlsConnection *c, Msg *m, int act);
315 static void     tlsError(TlsConnection *c, int err, char *msg, ...);
316 #pragma varargck argpos tlsError 3
317 static int setVersion(TlsConnection *c, int version);
318 static int finishedMatch(TlsConnection *c, Finished *f);
319 static void tlsConnectionFree(TlsConnection *c);
320
321 static int setAlgs(TlsConnection *c, int a);
322 static int okCipher(Ints *cv);
323 static int okCompression(Bytes *cv);
324 static int initCiphers(void);
325 static Ints* makeciphers(void);
326
327 static TlsSec*  tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom);
328 static int      tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm);
329 static TlsSec*  tlsSecInitc(int cvers, uchar *crandom);
330 static Bytes*   tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
331 static Bytes*   tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
332 static Bytes*   tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
333 static int      tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
334 static void     tlsSecOk(TlsSec *sec);
335 static void     tlsSecKill(TlsSec *sec);
336 static void     tlsSecClose(TlsSec *sec);
337 static void     setMasterSecret(TlsSec *sec, Bytes *pm);
338 static void     serverMasterSecret(TlsSec *sec, Bytes *epm);
339 static void     setSecrets(TlsSec *sec, uchar *kd, int nkd);
340 static Bytes*   clientMasterSecret(TlsSec *sec, RSApub *pub);
341 static Bytes*   pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
342 static Bytes*   pkcs1_decrypt(TlsSec *sec, Bytes *cipher);
343 static void     tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
344 static void     sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
345 static void     sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
346                         uchar *seed0, int nseed0, uchar *seed1, int nseed1);
347 static int setVers(TlsSec *sec, int version);
348
349 static AuthRpc* factotum_rsa_open(uchar *cert, int certlen);
350 static mpint* factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher);
351 static void factotum_rsa_close(AuthRpc*rpc);
352
353 static void* emalloc(int);
354 static void* erealloc(void*, int);
355 static void put32(uchar *p, u32int);
356 static void put24(uchar *p, int);
357 static void put16(uchar *p, int);
358 static u32int get32(uchar *p);
359 static int get24(uchar *p);
360 static int get16(uchar *p);
361 static Bytes* newbytes(int len);
362 static Bytes* makebytes(uchar* buf, int len);
363 static Bytes* mptobytes(mpint* big);
364 static mpint* bytestomp(Bytes* bytes);
365 static void freebytes(Bytes* b);
366 static Ints* newints(int len);
367 static void freeints(Ints* b);
368
369 /* x509.c */
370 extern mpint* pkcs1padbuf(uchar *buf, int len, mpint *modulus);
371
372 //================= client/server ========================
373
374 //      push TLS onto fd, returning new (application) file descriptor
375 //              or -1 if error.
376 int
377 tlsServer(int fd, TLSconn *conn)
378 {
379         char buf[8];
380         char dname[64];
381         int n, data, ctl, hand;
382         TlsConnection *tls;
383
384         if(conn == nil)
385                 return -1;
386         ctl = open("#a/tls/clone", ORDWR);
387         if(ctl < 0)
388                 return -1;
389         n = read(ctl, buf, sizeof(buf)-1);
390         if(n < 0){
391                 close(ctl);
392                 return -1;
393         }
394         buf[n] = 0;
395         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
396         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
397         hand = open(dname, ORDWR);
398         if(hand < 0){
399                 close(ctl);
400                 return -1;
401         }
402         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
403         tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain);
404         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
405         data = open(dname, ORDWR);
406         close(hand);
407         close(ctl);
408         if(data < 0 || tls == nil){
409                 if(tls != nil)
410                         tlsConnectionFree(tls);
411                 return -1;
412         }
413         free(conn->cert);
414         conn->cert = 0;  // client certificates are not yet implemented
415         conn->certlen = 0;
416         conn->sessionIDlen = tls->sid->len;
417         conn->sessionID = emalloc(conn->sessionIDlen);
418         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
419         if(conn->sessionKey != nil
420         && conn->sessionType != nil
421         && strcmp(conn->sessionType, "ttls") == 0)
422                 tls->sec->prf(
423                         conn->sessionKey, conn->sessionKeylen,
424                         tls->sec->sec, MasterSecretSize,
425                         conn->sessionConst, 
426                         tls->sec->crandom, RandomSize,
427                         tls->sec->srandom, RandomSize);
428         tlsConnectionFree(tls);
429         close(fd);
430         return data;
431 }
432
433 static uchar*
434 tlsClientExtensions(TLSconn *conn, int *plen)
435 {
436         uchar *b, *p;
437         int i, n, m;
438
439         p = b = nil;
440
441         // RFC6066 - Server Name Identification
442         if(conn->serverName != nil){
443                 n = strlen(conn->serverName);
444
445                 m = p - b;
446                 b = erealloc(b, m + 2+2+2+1+2+n);
447                 p = b + m;
448
449                 put16(p, 0), p += 2;            /* Type: server_name */
450                 put16(p, 2+1+2+n), p += 2;      /* Length */
451                 put16(p, 1+2+n), p += 2;        /* Server Name list length */
452                 *p++ = 0;                       /* Server Name Type: host_name */
453                 put16(p, n), p += 2;            /* Server Name length */
454                 memmove(p, conn->serverName, n);
455                 p += n;
456         }
457
458         // ECDHE
459         if(1){
460                 m = p - b;
461                 b = erealloc(b, m + 2+2+2+nelem(namedcurves)*2 + 2+2+1+nelem(pointformats));
462                 p = b + m;
463
464                 n = nelem(namedcurves);
465                 put16(p, 0x000a), p += 2;       /* Type: elliptic_curves */
466                 put16(p, (n+1)*2), p += 2;      /* Length */
467                 put16(p, n*2), p += 2;          /* Elliptic Curves Length */
468                 for(i=0; i < n; i++){           /* Elliptic curves */
469                         put16(p, namedcurves[i].tlsid);
470                         p += 2;
471                 }
472
473                 n = nelem(pointformats);
474                 put16(p, 0x000b), p += 2;       /* Type: ec_point_formats */
475                 put16(p, n+1), p += 2;          /* Length */
476                 *p++ = n;                       /* EC point formats Length */
477                 for(i=0; i < n; i++)            /* Elliptic curves point formats */
478                         *p++ = pointformats[i];
479         }
480         
481         *plen = p - b;
482         return b;
483 }
484
485 //      push TLS onto fd, returning new (application) file descriptor
486 //              or -1 if error.
487 int
488 tlsClient(int fd, TLSconn *conn)
489 {
490         char buf[8];
491         char dname[64];
492         int n, data, ctl, hand;
493         TlsConnection *tls;
494         uchar *ext;
495
496         if(conn == nil)
497                 return -1;
498         ctl = open("#a/tls/clone", ORDWR);
499         if(ctl < 0)
500                 return -1;
501         n = read(ctl, buf, sizeof(buf)-1);
502         if(n < 0){
503                 close(ctl);
504                 return -1;
505         }
506         buf[n] = 0;
507         snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf);
508         snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf);
509         hand = open(dname, ORDWR);
510         if(hand < 0){
511                 close(ctl);
512                 return -1;
513         }
514         snprint(dname, sizeof(dname), "#a/tls/%s/data", buf);
515         data = open(dname, ORDWR);
516         if(data < 0){
517                 close(hand);
518                 close(ctl);
519                 return -1;
520         }
521         fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion);
522         ext = tlsClientExtensions(conn, &n);
523         tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, 
524                 ext, n, conn->trace);
525         free(ext);
526         close(hand);
527         close(ctl);
528         if(tls == nil){
529                 close(data);
530                 return -1;
531         }
532         conn->certlen = tls->cert->len;
533         conn->cert = emalloc(conn->certlen);
534         memcpy(conn->cert, tls->cert->data, conn->certlen);
535         conn->sessionIDlen = tls->sid->len;
536         conn->sessionID = emalloc(conn->sessionIDlen);
537         memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen);
538         if(conn->sessionKey != nil
539         && conn->sessionType != nil
540         && strcmp(conn->sessionType, "ttls") == 0)
541                 tls->sec->prf(
542                         conn->sessionKey, conn->sessionKeylen,
543                         tls->sec->sec, MasterSecretSize,
544                         conn->sessionConst, 
545                         tls->sec->crandom, RandomSize,
546                         tls->sec->srandom, RandomSize);
547         tlsConnectionFree(tls);
548         close(fd);
549         return data;
550 }
551
552 static int
553 countchain(PEMChain *p)
554 {
555         int i = 0;
556
557         while (p) {
558                 i++;
559                 p = p->next;
560         }
561         return i;
562 }
563
564 static TlsConnection *
565 tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp)
566 {
567         TlsConnection *c;
568         Msg m;
569         Bytes *csid;
570         uchar sid[SidSize], kd[MaxKeyData];
571         char *secrets;
572         int cipher, compressor, nsid, rv, numcerts, i;
573
574         if(trace)
575                 trace("tlsServer2\n");
576         if(!initCiphers())
577                 return nil;
578         c = emalloc(sizeof(TlsConnection));
579         c->ctl = ctl;
580         c->hand = hand;
581         c->trace = trace;
582         c->version = ProtocolVersion;
583
584         memset(&m, 0, sizeof(m));
585         if(!msgRecv(c, &m)){
586                 if(trace)
587                         trace("initial msgRecv failed\n");
588                 goto Err;
589         }
590         if(m.tag != HClientHello) {
591                 tlsError(c, EUnexpectedMessage, "expected a client hello");
592                 goto Err;
593         }
594         c->clientVersion = m.u.clientHello.version;
595         if(trace)
596                 trace("ClientHello version %x\n", c->clientVersion);
597         if(setVersion(c, m.u.clientHello.version) < 0) {
598                 tlsError(c, EIllegalParameter, "incompatible version");
599                 goto Err;
600         }
601
602         memmove(c->crandom, m.u.clientHello.random, RandomSize);
603         cipher = okCipher(m.u.clientHello.ciphers);
604         if(cipher < 0) {
605                 // reply with EInsufficientSecurity if we know that's the case
606                 if(cipher == -2)
607                         tlsError(c, EInsufficientSecurity, "cipher suites too weak");
608                 else
609                         tlsError(c, EHandshakeFailure, "no matching cipher suite");
610                 goto Err;
611         }
612         if(!setAlgs(c, cipher)){
613                 tlsError(c, EHandshakeFailure, "no matching cipher suite");
614                 goto Err;
615         }
616         compressor = okCompression(m.u.clientHello.compressors);
617         if(compressor < 0) {
618                 tlsError(c, EHandshakeFailure, "no matching compressor");
619                 goto Err;
620         }
621
622         csid = m.u.clientHello.sid;
623         if(trace)
624                 trace("  cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len);
625         c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom);
626         if(c->sec == nil){
627                 tlsError(c, EHandshakeFailure, "can't initialize security: %r");
628                 goto Err;
629         }
630         c->sec->rpc = factotum_rsa_open(cert, certlen);
631         if(c->sec->rpc == nil){
632                 tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
633                 goto Err;
634         }
635         c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
636         if(c->sec->rsapub == nil){
637                 tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
638                 goto Err;
639         }
640         msgClear(&m);
641
642         m.tag = HServerHello;
643         m.u.serverHello.version = c->version;
644         memmove(m.u.serverHello.random, c->srandom, RandomSize);
645         m.u.serverHello.cipher = cipher;
646         m.u.serverHello.compressor = compressor;
647         c->sid = makebytes(sid, nsid);
648         m.u.serverHello.sid = makebytes(c->sid->data, c->sid->len);
649         if(!msgSend(c, &m, AQueue))
650                 goto Err;
651         msgClear(&m);
652
653         m.tag = HCertificate;
654         numcerts = countchain(chp);
655         m.u.certificate.ncert = 1 + numcerts;
656         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
657         m.u.certificate.certs[0] = makebytes(cert, certlen);
658         for (i = 0; i < numcerts && chp; i++, chp = chp->next)
659                 m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen);
660         if(!msgSend(c, &m, AQueue))
661                 goto Err;
662         msgClear(&m);
663
664         m.tag = HServerHelloDone;
665         if(!msgSend(c, &m, AFlush))
666                 goto Err;
667         msgClear(&m);
668
669         if(!msgRecv(c, &m))
670                 goto Err;
671         if(m.tag != HClientKeyExchange) {
672                 tlsError(c, EUnexpectedMessage, "expected a client key exchange");
673                 goto Err;
674         }
675         if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){
676                 tlsError(c, EHandshakeFailure, "couldn't set secrets: %r");
677                 goto Err;
678         }
679         setSecrets(c->sec, kd, c->nsecret);
680         if(trace)
681                 trace("tls secrets\n");
682         secrets = (char*)emalloc(2*c->nsecret);
683         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
684         rv = fprint(c->ctl, "secret %s %s 0 %s", c->digest, c->enc, secrets);
685         memset(secrets, 0, 2*c->nsecret);
686         free(secrets);
687         memset(kd, 0, c->nsecret);
688         if(rv < 0){
689                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
690                 goto Err;
691         }
692         msgClear(&m);
693
694         /* no CertificateVerify; skip to Finished */
695         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
696                 tlsError(c, EInternalError, "can't set finished: %r");
697                 goto Err;
698         }
699         if(!msgRecv(c, &m))
700                 goto Err;
701         if(m.tag != HFinished) {
702                 tlsError(c, EUnexpectedMessage, "expected a finished");
703                 goto Err;
704         }
705         if(!finishedMatch(c, &m.u.finished)) {
706                 tlsError(c, EHandshakeFailure, "finished verification failed");
707                 goto Err;
708         }
709         msgClear(&m);
710
711         /* change cipher spec */
712         if(fprint(c->ctl, "changecipher") < 0){
713                 tlsError(c, EInternalError, "can't enable cipher: %r");
714                 goto Err;
715         }
716
717         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
718                 tlsError(c, EInternalError, "can't set finished: %r");
719                 goto Err;
720         }
721         m.tag = HFinished;
722         m.u.finished = c->finished;
723         if(!msgSend(c, &m, AFlush))
724                 goto Err;
725         msgClear(&m);
726         if(trace)
727                 trace("tls finished\n");
728
729         if(fprint(c->ctl, "opened") < 0)
730                 goto Err;
731         tlsSecOk(c->sec);
732         return c;
733
734 Err:
735         msgClear(&m);
736         tlsConnectionFree(c);
737         return 0;
738 }
739
740 static int
741 isDHE(int tlsid)
742 {
743         switch(tlsid){
744         case TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA:
745         case TLS_DHE_DSS_WITH_DES_CBC_SHA:
746         case TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA:
747         case TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA:
748         case TLS_DHE_RSA_WITH_DES_CBC_SHA:
749         case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA:
750         case TLS_DHE_DSS_WITH_AES_128_CBC_SHA:
751         case TLS_DHE_RSA_WITH_AES_128_CBC_SHA:
752         case TLS_DHE_DSS_WITH_AES_256_CBC_SHA:
753         case TLS_DHE_RSA_WITH_AES_256_CBC_SHA:
754                 return 1;
755         }
756         return 0;
757 }
758
759 static int
760 isECDHE(int tlsid)
761 {
762         switch(tlsid){
763         case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
764         case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
765         case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
766         case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
767                 return 1;
768         }
769         return 0;
770 }
771
772 static Bytes*
773 tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, 
774         Bytes *p, Bytes *g, Bytes *Ys)
775 {
776         mpint *G, *P, *Y, *K;
777         Bytes *epm;
778         DHstate dh;
779
780         if(p == nil || g == nil || Ys == nil)
781                 return nil;
782
783         memmove(sec->srandom, srandom, RandomSize);
784         if(setVers(sec, vers) < 0)
785                 return nil;
786
787         epm = nil;
788         P = bytestomp(p);
789         G = bytestomp(g);
790         Y = bytestomp(Ys);
791         K = nil;
792
793         if(P == nil || G == nil || Y == nil || dh_new(&dh, P, G) == nil)
794                 goto Out;
795         epm = mptobytes(dh.y);
796         K = dh_finish(&dh, Y);
797         if(K == nil){
798                 freebytes(epm);
799                 epm = nil;
800                 goto Out;
801         }
802         setMasterSecret(sec, mptobytes(K));
803
804 Out:
805         mpfree(K);
806         mpfree(Y);
807         mpfree(G);
808         mpfree(P);
809
810         return epm;
811 }
812
813 static ECpoint*
814 bytestoec(ECdomain *dom, Bytes *bp, ECpoint *ret)
815 {
816         char *hex = "0123456789ABCDEF";
817         char *s;
818         int i;
819
820         s = emalloc(2*bp->len + 1);
821         for(i=0; i < bp->len; i++){
822                 s[2*i] = hex[bp->data[i]>>4 & 15];
823                 s[2*i+1] = hex[bp->data[i] & 15];
824         }
825         s[2*bp->len] = '\0';
826         ret = strtoec(dom, s, nil, ret);
827         free(s);
828         return ret;
829 }
830
831 static Bytes*
832 ectobytes(int type, ECpoint *p)
833 {
834         Bytes *bx, *by, *bp;
835
836         bx = mptobytes(p->x);
837         by = mptobytes(p->y);
838         bp = newbytes(bx->len + by->len + 1);
839         bp->data[0] =  type;
840         memmove(bp->data+1, bx->data, bx->len);
841         memmove(bp->data+1+bx->len, by->data, by->len);
842         freebytes(bx);
843         freebytes(by);
844         return bp;
845 }
846
847 static Bytes*
848 tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys)
849 {
850         Namedcurve *nc, *enc;
851         Bytes *epm;
852         ECdomain dom;
853         ECpoint G, K, Y;
854         ECpriv Q;
855
856         if(Ys == nil)
857                 return nil;
858
859         enc = &namedcurves[nelem(namedcurves)];
860         for(nc = namedcurves; nc != enc; nc++)
861                 if(nc->tlsid == curve)
862                         break;
863
864         if(nc == enc)
865                 return nil;
866                 
867         memmove(sec->srandom, srandom, RandomSize);
868         if(setVers(sec, vers) < 0)
869                 return nil;
870         
871         epm = nil;
872
873         memset(&dom, 0, sizeof(dom));
874         dom.p = strtomp(nc->p, nil, 16, nil);
875         dom.a = strtomp(nc->a, nil, 16, nil);
876         dom.b = strtomp(nc->b, nil, 16, nil);
877         dom.n = strtomp(nc->n, nil, 16, nil);
878         dom.h = strtomp(nc->h, nil, 16, nil);
879
880         memset(&G, 0, sizeof(G));
881         G.x = mpnew(0);
882         G.y = mpnew(0);
883
884         memset(&Q, 0, sizeof(Q));
885         Q.x = mpnew(0);
886         Q.y = mpnew(0);
887         Q.d = mpnew(0);
888
889         memset(&K, 0, sizeof(K));
890         K.x = mpnew(0);
891         K.y = mpnew(0);
892
893         memset(&Y, 0, sizeof(Y));
894         Y.x = mpnew(0);
895         Y.y = mpnew(0);
896
897         if(dom.p == nil || dom.a == nil || dom.b == nil || dom.n == nil || dom.h == nil)
898                 goto Out;
899         if(Q.x == nil || Q.y == nil || Q.d == nil)
900                 goto Out;
901         if(G.x == nil || G.y == nil)
902                 goto Out;
903         if(K.x == nil || K.y == nil)
904                 goto Out;
905         if(Y.x == nil || Y.y == nil)
906                 goto Out;
907
908         dom.G = strtoec(&dom, nc->G, nil, &G);
909         if(dom.G == nil)
910                 goto Out;
911
912         if(bytestoec(&dom, Ys, &Y) == nil)
913                 goto Out;
914
915         if(ecgen(&dom, &Q) == nil)
916                 goto Out;
917
918         ecmul(&dom, &Y, Q.d, &K);
919         setMasterSecret(sec, mptobytes(K.x));
920
921         /* 0x04 = uncompressed public key */
922         epm = ectobytes(0x04, &Q);
923         
924 Out:
925         mpfree(Y.x);
926         mpfree(Y.y);
927
928         mpfree(K.x);
929         mpfree(K.y);
930
931         mpfree(Q.x);
932         mpfree(Q.y);
933         mpfree(Q.d);
934
935         mpfree(G.x);
936         mpfree(G.y);
937
938         mpfree(dom.p);
939         mpfree(dom.a);
940         mpfree(dom.b);
941         mpfree(dom.n);
942         mpfree(dom.h);
943
944         return epm;
945 }
946
947 static TlsConnection *
948 tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen,
949         int (*trace)(char*fmt, ...))
950 {
951         TlsConnection *c;
952         Msg m;
953         uchar kd[MaxKeyData];
954         char *secrets;
955         int creq, dhx, rv, cipher;
956         mpint *signedMP, *paddedHashes;
957         Bytes *epm;
958
959         if(!initCiphers())
960                 return nil;
961         epm = nil;
962         c = emalloc(sizeof(TlsConnection));
963         c->version = ProtocolVersion;
964         c->ctl = ctl;
965         c->hand = hand;
966         c->trace = trace;
967         c->isClient = 1;
968         c->clientVersion = c->version;
969
970         c->sec = tlsSecInitc(c->clientVersion, c->crandom);
971         if(c->sec == nil)
972                 goto Err;
973
974         /* client hello */
975         memset(&m, 0, sizeof(m));
976         m.tag = HClientHello;
977         m.u.clientHello.version = c->clientVersion;
978         memmove(m.u.clientHello.random, c->crandom, RandomSize);
979         m.u.clientHello.sid = makebytes(csid, ncsid);
980         m.u.clientHello.ciphers = makeciphers();
981         m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors));
982         m.u.clientHello.extensions = makebytes(ext, extlen);
983         if(!msgSend(c, &m, AFlush))
984                 goto Err;
985         msgClear(&m);
986
987         /* server hello */
988         if(!msgRecv(c, &m))
989                 goto Err;
990         if(m.tag != HServerHello) {
991                 tlsError(c, EUnexpectedMessage, "expected a server hello");
992                 goto Err;
993         }
994         if(setVersion(c, m.u.serverHello.version) < 0) {
995                 tlsError(c, EIllegalParameter, "incompatible version %r");
996                 goto Err;
997         }
998         memmove(c->srandom, m.u.serverHello.random, RandomSize);
999         c->sid = makebytes(m.u.serverHello.sid->data, m.u.serverHello.sid->len);
1000         if(c->sid->len != 0 && c->sid->len != SidSize) {
1001                 tlsError(c, EIllegalParameter, "invalid server session identifier");
1002                 goto Err;
1003         }
1004         cipher = m.u.serverHello.cipher;
1005         if(!setAlgs(c, cipher)) {
1006                 tlsError(c, EIllegalParameter, "invalid cipher suite");
1007                 goto Err;
1008         }
1009         if(m.u.serverHello.compressor != CompressionNull) {
1010                 tlsError(c, EIllegalParameter, "invalid compression");
1011                 goto Err;
1012         }
1013         msgClear(&m);
1014
1015         /* certificate */
1016         if(!msgRecv(c, &m) || m.tag != HCertificate) {
1017                 tlsError(c, EUnexpectedMessage, "expected a certificate");
1018                 goto Err;
1019         }
1020         if(m.u.certificate.ncert < 1) {
1021                 tlsError(c, EIllegalParameter, "runt certificate");
1022                 goto Err;
1023         }
1024         c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len);
1025         msgClear(&m);
1026
1027         /* server key exchange */
1028         dhx = isDHE(cipher) || isECDHE(cipher);
1029         if(!msgRecv(c, &m))
1030                 goto Err;
1031         if(m.tag == HServerKeyExchange) {
1032                 if(!dhx){
1033                         tlsError(c, EUnexpectedMessage, "got an server key exchange");
1034                         goto Err;
1035                 }
1036                 if(isECDHE(cipher))
1037                         epm = tlsSecECDHEc(c->sec, c->srandom, c->version,
1038                                 m.u.serverKeyExchange.curve,
1039                                 m.u.serverKeyExchange.dh_Ys);
1040                 else
1041                         epm = tlsSecDHEc(c->sec, c->srandom, c->version,
1042                                 m.u.serverKeyExchange.dh_p, 
1043                                 m.u.serverKeyExchange.dh_g,
1044                                 m.u.serverKeyExchange.dh_Ys);
1045                 if(epm == nil)
1046                         goto Badcert;
1047                 msgClear(&m);
1048                 if(!msgRecv(c, &m))
1049                         goto Err;
1050         } else if(dhx){
1051                 tlsError(c, EUnexpectedMessage, "expected server key exchange");
1052                 goto Err;
1053         }
1054
1055         /* certificate request (optional) */
1056         creq = 0;
1057         if(m.tag == HCertificateRequest) {
1058                 creq = 1;
1059                 msgClear(&m);
1060                 if(!msgRecv(c, &m))
1061                         goto Err;
1062         }
1063
1064         if(m.tag != HServerHelloDone) {
1065                 tlsError(c, EUnexpectedMessage, "expected a server hello done");
1066                 goto Err;
1067         }
1068         msgClear(&m);
1069
1070         if(!dhx)
1071                 epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom,
1072                         c->cert->data, c->cert->len, c->version);
1073
1074         if(epm == nil){
1075         Badcert:
1076                 tlsError(c, EBadCertificate, "bad certificate: %r");
1077                 goto Err;
1078         }
1079
1080         setSecrets(c->sec, kd, c->nsecret);
1081         secrets = (char*)emalloc(2*c->nsecret);
1082         enc64(secrets, 2*c->nsecret, kd, c->nsecret);
1083         rv = fprint(c->ctl, "secret %s %s 1 %s", c->digest, c->enc, secrets);
1084         memset(secrets, 0, 2*c->nsecret);
1085         free(secrets);
1086         memset(kd, 0, c->nsecret);
1087         if(rv < 0){
1088                 tlsError(c, EHandshakeFailure, "can't set keys: %r");
1089                 goto Err;
1090         }
1091
1092         if(creq) {
1093                 if(cert != nil && certlen > 0){
1094                         m.u.certificate.ncert = 1;
1095                         m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*));
1096                         m.u.certificate.certs[0] = makebytes(cert, certlen);
1097                 }               
1098                 m.tag = HCertificate;
1099                 if(!msgSend(c, &m, AFlush))
1100                         goto Err;
1101                 msgClear(&m);
1102         }
1103
1104         /* client key exchange */
1105         m.tag = HClientKeyExchange;
1106         m.u.clientKeyExchange.key = epm;
1107         epm = nil;
1108         if(m.u.clientKeyExchange.key == nil) {
1109                 tlsError(c, EHandshakeFailure, "can't set secret: %r");
1110                 goto Err;
1111         }
1112          
1113         if(!msgSend(c, &m, AFlush))
1114                 goto Err;
1115         msgClear(&m);
1116
1117         /* CertificateVerify */
1118         /*XXX I should only send this when it is not DH right? 
1119                 Also we need to know which TLS key 
1120                 we have to use in case there are more than one*/
1121         if(cert){
1122                 m.tag = HCertificateVerify;
1123                 uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */
1124                 MD5state        hsmd5_save;
1125                 SHAstate        hssha1_save;
1126         
1127                 /* save the state for the Finish message */
1128
1129                 hsmd5_save = c->hsmd5;
1130                 hssha1_save = c->hssha1;
1131                 md5(nil, 0, hshashes, &c->hsmd5);
1132                 sha1(nil, 0, hshashes+MD5dlen, &c->hssha1);
1133         
1134                 c->hsmd5 = hsmd5_save;
1135                 c->hssha1 = hssha1_save;
1136
1137                 c->sec->rpc = factotum_rsa_open(cert, certlen);
1138                 if(c->sec->rpc == nil){
1139                         tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r");
1140                         goto Err;
1141                 }
1142                 c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0);
1143                 if(c->sec->rsapub == nil){
1144                         tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate");
1145                         goto Err;
1146                 }
1147
1148                 paddedHashes = pkcs1padbuf(hshashes, MD5dlen+SHA1dlen, c->sec->rsapub->n);
1149                 signedMP = factotum_rsa_decrypt(c->sec->rpc, paddedHashes);
1150                 if(signedMP == nil){
1151                         tlsError(c, EHandshakeFailure, "factotum_rsa_decrypt: %r");
1152                         goto Err;
1153                 }
1154                 m.u.certificateVerify.signature = mptobytes(signedMP);
1155                 mpfree(signedMP);
1156
1157                 if(!msgSend(c, &m, AFlush))
1158                         goto Err;
1159                 msgClear(&m);
1160         } 
1161
1162         /* change cipher spec */
1163         if(fprint(c->ctl, "changecipher") < 0){
1164                 tlsError(c, EInternalError, "can't enable cipher: %r");
1165                 goto Err;
1166         }
1167
1168         // Cipherchange must occur immediately before Finished to avoid
1169         // potential hole;  see section 4.3 of Wagner Schneier 1996.
1170         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
1171                 tlsError(c, EInternalError, "can't set finished 1: %r");
1172                 goto Err;
1173         }
1174         m.tag = HFinished;
1175         m.u.finished = c->finished;
1176         if(!msgSend(c, &m, AFlush)) {
1177                 tlsError(c, EInternalError, "can't flush after client Finished: %r");
1178                 goto Err;
1179         }
1180         msgClear(&m);
1181
1182         if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
1183                 tlsError(c, EInternalError, "can't set finished 0: %r");
1184                 goto Err;
1185         }
1186         if(!msgRecv(c, &m)) {
1187                 tlsError(c, EInternalError, "can't read server Finished: %r");
1188                 goto Err;
1189         }
1190         if(m.tag != HFinished) {
1191                 tlsError(c, EUnexpectedMessage, "expected a Finished msg from server");
1192                 goto Err;
1193         }
1194
1195         if(!finishedMatch(c, &m.u.finished)) {
1196                 tlsError(c, EHandshakeFailure, "finished verification failed");
1197                 goto Err;
1198         }
1199         msgClear(&m);
1200
1201         if(fprint(c->ctl, "opened") < 0){
1202                 if(trace)
1203                         trace("unable to do final open: %r\n");
1204                 goto Err;
1205         }
1206         tlsSecOk(c->sec);
1207         return c;
1208
1209 Err:
1210         free(epm);
1211         msgClear(&m);
1212         tlsConnectionFree(c);
1213         return 0;
1214 }
1215
1216
1217 //================= message functions ========================
1218
1219 static int
1220 msgSend(TlsConnection *c, Msg *m, int act)
1221 {
1222         uchar *p; // sendp = start of new message;  p = write pointer
1223         int nn, n, i;
1224
1225         if(c->sendp == nil)
1226                 c->sendp = c->sendbuf;
1227         p = c->sendp;
1228         if(c->trace)
1229                 c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m));
1230
1231         p[0] = m->tag;  // header - fill in size later
1232         p += 4;
1233
1234         switch(m->tag) {
1235         default:
1236                 tlsError(c, EInternalError, "can't encode a %d", m->tag);
1237                 goto Err;
1238         case HClientHello:
1239                 // version
1240                 put16(p, m->u.clientHello.version);
1241                 p += 2;
1242
1243                 // random
1244                 memmove(p, m->u.clientHello.random, RandomSize);
1245                 p += RandomSize;
1246
1247                 // sid
1248                 n = m->u.clientHello.sid->len;
1249                 assert(n < 256);
1250                 p[0] = n;
1251                 memmove(p+1, m->u.clientHello.sid->data, n);
1252                 p += n+1;
1253
1254                 n = m->u.clientHello.ciphers->len;
1255                 assert(n > 0 && n < 200);
1256                 put16(p, n*2);
1257                 p += 2;
1258                 for(i=0; i<n; i++) {
1259                         put16(p, m->u.clientHello.ciphers->data[i]);
1260                         p += 2;
1261                 }
1262
1263                 n = m->u.clientHello.compressors->len;
1264                 assert(n > 0);
1265                 p[0] = n;
1266                 memmove(p+1, m->u.clientHello.compressors->data, n);
1267                 p += n+1;
1268
1269                 if(m->u.clientHello.extensions == nil)
1270                         break;
1271                 n = m->u.clientHello.extensions->len;
1272                 if(n == 0)
1273                         break;
1274                 put16(p, n);
1275                 memmove(p+2, m->u.clientHello.extensions->data, n);
1276                 p += n+2;
1277                 break;
1278         case HServerHello:
1279                 put16(p, m->u.serverHello.version);
1280                 p += 2;
1281
1282                 // random
1283                 memmove(p, m->u.serverHello.random, RandomSize);
1284                 p += RandomSize;
1285
1286                 // sid
1287                 n = m->u.serverHello.sid->len;
1288                 assert(n < 256);
1289                 p[0] = n;
1290                 memmove(p+1, m->u.serverHello.sid->data, n);
1291                 p += n+1;
1292
1293                 put16(p, m->u.serverHello.cipher);
1294                 p += 2;
1295                 p[0] = m->u.serverHello.compressor;
1296                 p += 1;
1297
1298                 if(m->u.serverHello.extensions == nil)
1299                         break;
1300                 n = m->u.serverHello.extensions->len;
1301                 if(n == 0)
1302                         break;
1303                 put16(p, n);
1304                 memmove(p+2, m->u.serverHello.extensions->data, n);
1305                 p += n+2;
1306                 break;
1307         case HServerHelloDone:
1308                 break;
1309         case HCertificate:
1310                 nn = 0;
1311                 for(i = 0; i < m->u.certificate.ncert; i++)
1312                         nn += 3 + m->u.certificate.certs[i]->len;
1313                 if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) {
1314                         tlsError(c, EInternalError, "output buffer too small for certificate");
1315                         goto Err;
1316                 }
1317                 put24(p, nn);
1318                 p += 3;
1319                 for(i = 0; i < m->u.certificate.ncert; i++){
1320                         put24(p, m->u.certificate.certs[i]->len);
1321                         p += 3;
1322                         memmove(p, m->u.certificate.certs[i]->data, m->u.certificate.certs[i]->len);
1323                         p += m->u.certificate.certs[i]->len;
1324                 }
1325                 break;
1326         case HCertificateVerify:
1327                 put16(p, m->u.certificateVerify.signature->len);
1328                 p += 2;
1329                 memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len);
1330                 p += m->u.certificateVerify.signature->len;
1331                 break;  
1332         case HClientKeyExchange:
1333                 n = m->u.clientKeyExchange.key->len;
1334                 if(c->version != SSL3Version){
1335                         if(isECDHE(c->cipher))
1336                                 *p++ = n;
1337                         else
1338                                 put16(p, n), p += 2;
1339                 }
1340                 memmove(p, m->u.clientKeyExchange.key->data, n);
1341                 p += n;
1342                 break;
1343         case HFinished:
1344                 memmove(p, m->u.finished.verify, m->u.finished.n);
1345                 p += m->u.finished.n;
1346                 break;
1347         }
1348
1349         // go back and fill in size
1350         n = p - c->sendp;
1351         assert(p <= c->sendbuf + sizeof(c->sendbuf));
1352         put24(c->sendp+1, n-4);
1353
1354         // remember hash of Handshake messages
1355         if(m->tag != HHelloRequest) {
1356                 md5(c->sendp, n, 0, &c->hsmd5);
1357                 sha1(c->sendp, n, 0, &c->hssha1);
1358         }
1359
1360         c->sendp = p;
1361         if(act == AFlush){
1362                 c->sendp = c->sendbuf;
1363                 if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){
1364                         fprint(2, "write error: %r\n");
1365                         goto Err;
1366                 }
1367         }
1368         msgClear(m);
1369         return 1;
1370 Err:
1371         msgClear(m);
1372         return 0;
1373 }
1374
1375 static uchar*
1376 tlsReadN(TlsConnection *c, int n)
1377 {
1378         uchar *p;
1379         int nn, nr;
1380
1381         nn = c->ep - c->rp;
1382         if(nn < n){
1383                 if(c->rp != c->recvbuf){
1384                         memmove(c->recvbuf, c->rp, nn);
1385                         c->rp = c->recvbuf;
1386                         c->ep = &c->recvbuf[nn];
1387                 }
1388                 for(; nn < n; nn += nr) {
1389                         nr = read(c->hand, &c->rp[nn], n - nn);
1390                         if(nr <= 0)
1391                                 return nil;
1392                         c->ep += nr;
1393                 }
1394         }
1395         p = c->rp;
1396         c->rp += n;
1397         return p;
1398 }
1399
1400 static int
1401 msgRecv(TlsConnection *c, Msg *m)
1402 {
1403         uchar *p;
1404         int type, n, nn, i, nsid, nrandom, nciph;
1405
1406         for(;;) {
1407                 p = tlsReadN(c, 4);
1408                 if(p == nil)
1409                         return 0;
1410                 type = p[0];
1411                 n = get24(p+1);
1412
1413                 if(type != HHelloRequest)
1414                         break;
1415                 if(n != 0) {
1416                         tlsError(c, EDecodeError, "invalid hello request during handshake");
1417                         return 0;
1418                 }
1419         }
1420
1421         if(n > sizeof(c->recvbuf)) {
1422                 tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf));
1423                 return 0;
1424         }
1425
1426         if(type == HSSL2ClientHello){
1427                 /* Cope with an SSL3 ClientHello expressed in SSL2 record format.
1428                         This is sent by some clients that we must interoperate
1429                         with, such as Java's JSSE and Microsoft's Internet Explorer. */
1430                 p = tlsReadN(c, n);
1431                 if(p == nil)
1432                         return 0;
1433                 md5(p, n, 0, &c->hsmd5);
1434                 sha1(p, n, 0, &c->hssha1);
1435                 m->tag = HClientHello;
1436                 if(n < 22)
1437                         goto Short;
1438                 m->u.clientHello.version = get16(p+1);
1439                 p += 3;
1440                 n -= 3;
1441                 nn = get16(p); /* cipher_spec_len */
1442                 nsid = get16(p + 2);
1443                 nrandom = get16(p + 4);
1444                 p += 6;
1445                 n -= 6;
1446                 if(nsid != 0    /* no sid's, since shouldn't restart using ssl2 header */
1447                                 || nrandom < 16 || nn % 3)
1448                         goto Err;
1449                 if(c->trace && (n - nrandom != nn))
1450                         c->trace("n-nrandom!=nn: n=%d nrandom=%d nn=%d\n", n, nrandom, nn);
1451                 /* ignore ssl2 ciphers and look for {0x00, ssl3 cipher} */
1452                 nciph = 0;
1453                 for(i = 0; i < nn; i += 3)
1454                         if(p[i] == 0)
1455                                 nciph++;
1456                 m->u.clientHello.ciphers = newints(nciph);
1457                 nciph = 0;
1458                 for(i = 0; i < nn; i += 3)
1459                         if(p[i] == 0)
1460                                 m->u.clientHello.ciphers->data[nciph++] = get16(&p[i + 1]);
1461                 p += nn;
1462                 m->u.clientHello.sid = makebytes(nil, 0);
1463                 if(nrandom > RandomSize)
1464                         nrandom = RandomSize;
1465                 memset(m->u.clientHello.random, 0, RandomSize - nrandom);
1466                 memmove(&m->u.clientHello.random[RandomSize - nrandom], p, nrandom);
1467                 m->u.clientHello.compressors = newbytes(1);
1468                 m->u.clientHello.compressors->data[0] = CompressionNull;
1469                 goto Ok;
1470         }
1471         md5(p, 4, 0, &c->hsmd5);
1472         sha1(p, 4, 0, &c->hssha1);
1473
1474         p = tlsReadN(c, n);
1475         if(p == nil)
1476                 return 0;
1477
1478         md5(p, n, 0, &c->hsmd5);
1479         sha1(p, n, 0, &c->hssha1);
1480
1481         m->tag = type;
1482
1483         switch(type) {
1484         default:
1485                 tlsError(c, EUnexpectedMessage, "can't decode a %d", type);
1486                 goto Err;
1487         case HClientHello:
1488                 if(n < 2)
1489                         goto Short;
1490                 m->u.clientHello.version = get16(p);
1491                 p += 2;
1492                 n -= 2;
1493
1494                 if(n < RandomSize)
1495                         goto Short;
1496                 memmove(m->u.clientHello.random, p, RandomSize);
1497                 p += RandomSize;
1498                 n -= RandomSize;
1499                 if(n < 1 || n < p[0]+1)
1500                         goto Short;
1501                 m->u.clientHello.sid = makebytes(p+1, p[0]);
1502                 p += m->u.clientHello.sid->len+1;
1503                 n -= m->u.clientHello.sid->len+1;
1504
1505                 if(n < 2)
1506                         goto Short;
1507                 nn = get16(p);
1508                 p += 2;
1509                 n -= 2;
1510
1511                 if((nn & 1) || n < nn || nn < 2)
1512                         goto Short;
1513                 m->u.clientHello.ciphers = newints(nn >> 1);
1514                 for(i = 0; i < nn; i += 2)
1515                         m->u.clientHello.ciphers->data[i >> 1] = get16(&p[i]);
1516                 p += nn;
1517                 n -= nn;
1518
1519                 if(n < 1 || n < p[0]+1 || p[0] == 0)
1520                         goto Short;
1521                 nn = p[0];
1522                 m->u.clientHello.compressors = makebytes(p+1, nn);
1523                 p += nn + 1;
1524                 n -= nn + 1;
1525
1526                 if(n < 2)
1527                         break;
1528                 nn = get16(p);
1529                 if(nn > n-2)
1530                         goto Short;
1531                 m->u.clientHello.extensions = makebytes(p+2, nn);
1532                 n -= nn + 2;
1533                 break;
1534         case HServerHello:
1535                 if(n < 2)
1536                         goto Short;
1537                 m->u.serverHello.version = get16(p);
1538                 p += 2;
1539                 n -= 2;
1540
1541                 if(n < RandomSize)
1542                         goto Short;
1543                 memmove(m->u.serverHello.random, p, RandomSize);
1544                 p += RandomSize;
1545                 n -= RandomSize;
1546
1547                 if(n < 1 || n < p[0]+1)
1548                         goto Short;
1549                 m->u.serverHello.sid = makebytes(p+1, p[0]);
1550                 p += m->u.serverHello.sid->len+1;
1551                 n -= m->u.serverHello.sid->len+1;
1552
1553                 if(n < 3)
1554                         goto Short;
1555                 m->u.serverHello.cipher = get16(p);
1556                 m->u.serverHello.compressor = p[2];
1557                 p += 3;
1558                 n -= 3;
1559
1560                 if(n < 2)
1561                         break;
1562                 nn = get16(p);
1563                 if(nn > n-2)
1564                         goto Short;
1565                 m->u.serverHello.extensions = makebytes(p+2, nn);
1566                 n -= nn + 2;
1567                 break;
1568         case HCertificate:
1569                 if(n < 3)
1570                         goto Short;
1571                 nn = get24(p);
1572                 p += 3;
1573                 n -= 3;
1574                 if(nn == 0 && n > 0)
1575                         goto Short;
1576                 /* certs */
1577                 i = 0;
1578                 while(n > 0) {
1579                         if(n < 3)
1580                                 goto Short;
1581                         nn = get24(p);
1582                         p += 3;
1583                         n -= 3;
1584                         if(nn > n)
1585                                 goto Short;
1586                         m->u.certificate.ncert = i+1;
1587                         m->u.certificate.certs = erealloc(m->u.certificate.certs, (i+1)*sizeof(Bytes*));
1588                         m->u.certificate.certs[i] = makebytes(p, nn);
1589                         p += nn;
1590                         n -= nn;
1591                         i++;
1592                 }
1593                 break;
1594         case HCertificateRequest:
1595                 if(n < 1)
1596                         goto Short;
1597                 nn = p[0];
1598                 p += 1;
1599                 n -= 1;
1600                 if(nn > n)
1601                         goto Short;
1602                 m->u.certificateRequest.types = makebytes(p, nn);
1603                 p += nn;
1604                 n -= nn;
1605                 if(n < 2)
1606                         goto Short;
1607                 nn = get16(p);
1608                 p += 2;
1609                 n -= 2;
1610                 /* nn == 0 can happen; yahoo's servers do it */
1611                 if(nn != n)
1612                         goto Short;
1613                 /* cas */
1614                 i = 0;
1615                 while(n > 0) {
1616                         if(n < 2)
1617                                 goto Short;
1618                         nn = get16(p);
1619                         p += 2;
1620                         n -= 2;
1621                         if(nn < 1 || nn > n)
1622                                 goto Short;
1623                         m->u.certificateRequest.nca = i+1;
1624                         m->u.certificateRequest.cas = erealloc(
1625                                 m->u.certificateRequest.cas, (i+1)*sizeof(Bytes*));
1626                         m->u.certificateRequest.cas[i] = makebytes(p, nn);
1627                         p += nn;
1628                         n -= nn;
1629                         i++;
1630                 }
1631                 break;
1632         case HServerHelloDone:
1633                 break;
1634         case HServerKeyExchange:
1635                 if(n < 2)
1636                         goto Short;
1637                 if(isECDHE(c->cipher)){
1638                         nn = *p;
1639                         p++, n--;
1640                         if(nn != 3 || nn > n) /* not a named curve */
1641                                 goto Short;
1642                         nn = get16(p);
1643                         p += 2, n -= 2;
1644                         m->u.serverKeyExchange.curve = nn;
1645
1646                         nn = *p++, n--;
1647                         if(nn < 1 || nn > n)
1648                                 goto Short;
1649                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1650                         p += nn, n -= nn;
1651                 }else if(isDHE(c->cipher)){
1652                         nn = get16(p);
1653                         p += 2, n -= 2;
1654                         if(nn < 1 || nn > n)
1655                                 goto Short;
1656                         m->u.serverKeyExchange.dh_p = makebytes(p, nn);
1657                         p += nn, n -= nn;
1658         
1659                         if(n < 2)
1660                                 goto Short;
1661                         nn = get16(p);
1662                         p += 2, n -= 2;
1663                         if(nn < 1 || nn > n)
1664                                 goto Short;
1665                         m->u.serverKeyExchange.dh_g = makebytes(p, nn);
1666                         p += nn, n -= nn;
1667         
1668                         if(n < 2)
1669                                 goto Short;
1670                         nn = get16(p);
1671                         p += 2, n -= 2;
1672                         if(nn < 1 || nn > n)
1673                                 goto Short;
1674                         m->u.serverKeyExchange.dh_Ys = makebytes(p, nn);
1675                         p += nn, n -= nn;
1676                 } else {
1677                         /* should not happen */
1678                         break;
1679                 }
1680                 if(n >= 2){
1681                         nn = get16(p);
1682                         p += 2, n -= 2;
1683                         if(nn > 0 && nn <= n){
1684                                 m->u.serverKeyExchange.dh_signature = makebytes(p, nn);
1685                                 n -= nn;
1686                         }
1687                 }
1688                 break;          
1689         case HClientKeyExchange:
1690                 /*
1691                  * this message depends upon the encryption selected
1692                  * assume rsa.
1693                  */
1694                 if(c->version == SSL3Version)
1695                         nn = n;
1696                 else{
1697                         if(n < 2)
1698                                 goto Short;
1699                         nn = get16(p);
1700                         p += 2;
1701                         n -= 2;
1702                 }
1703                 if(n < nn)
1704                         goto Short;
1705                 m->u.clientKeyExchange.key = makebytes(p, nn);
1706                 n -= nn;
1707                 break;
1708         case HFinished:
1709                 m->u.finished.n = c->finished.n;
1710                 if(n < m->u.finished.n)
1711                         goto Short;
1712                 memmove(m->u.finished.verify, p, m->u.finished.n);
1713                 n -= m->u.finished.n;
1714                 break;
1715         }
1716
1717         if(type != HClientHello && type != HServerHello && n != 0)
1718                 goto Short;
1719 Ok:
1720         if(c->trace){
1721                 char *buf;
1722                 buf = emalloc(8000);
1723                 c->trace("recv %s", msgPrint(buf, 8000, m));
1724                 free(buf);
1725         }
1726         return 1;
1727 Short:
1728         tlsError(c, EDecodeError, "handshake message (%d) has invalid length", type);
1729 Err:
1730         msgClear(m);
1731         return 0;
1732 }
1733
1734 static void
1735 msgClear(Msg *m)
1736 {
1737         int i;
1738
1739         switch(m->tag) {
1740         default:
1741                 sysfatal("msgClear: unknown message type: %d", m->tag);
1742         case HHelloRequest:
1743                 break;
1744         case HClientHello:
1745                 freebytes(m->u.clientHello.sid);
1746                 freeints(m->u.clientHello.ciphers);
1747                 freebytes(m->u.clientHello.compressors);
1748                 freebytes(m->u.clientHello.extensions);
1749                 break;
1750         case HServerHello:
1751                 freebytes(m->u.serverHello.sid);
1752                 freebytes(m->u.serverHello.extensions);
1753                 break;
1754         case HCertificate:
1755                 for(i=0; i<m->u.certificate.ncert; i++)
1756                         freebytes(m->u.certificate.certs[i]);
1757                 free(m->u.certificate.certs);
1758                 break;
1759         case HCertificateRequest:
1760                 freebytes(m->u.certificateRequest.types);
1761                 for(i=0; i<m->u.certificateRequest.nca; i++)
1762                         freebytes(m->u.certificateRequest.cas[i]);
1763                 free(m->u.certificateRequest.cas);
1764                 break;
1765         case HCertificateVerify:
1766                 freebytes(m->u.certificateVerify.signature);
1767                 break;
1768         case HServerHelloDone:
1769                 break;
1770         case HServerKeyExchange:
1771                 freebytes(m->u.serverKeyExchange.dh_p);
1772                 freebytes(m->u.serverKeyExchange.dh_g);
1773                 freebytes(m->u.serverKeyExchange.dh_Ys);
1774                 freebytes(m->u.serverKeyExchange.dh_signature);
1775                 break;
1776         case HClientKeyExchange:
1777                 freebytes(m->u.clientKeyExchange.key);
1778                 break;
1779         case HFinished:
1780                 break;
1781         }
1782         memset(m, 0, sizeof(Msg));
1783 }
1784
1785 static char *
1786 bytesPrint(char *bs, char *be, char *s0, Bytes *b, char *s1)
1787 {
1788         int i;
1789
1790         if(s0)
1791                 bs = seprint(bs, be, "%s", s0);
1792         if(b == nil)
1793                 bs = seprint(bs, be, "nil");
1794         else {
1795                 bs = seprint(bs, be, "<%d> [", b->len);
1796                 for(i=0; i<b->len; i++)
1797                         bs = seprint(bs, be, "%.2x ", b->data[i]);
1798         }
1799         bs = seprint(bs, be, "]");
1800         if(s1)
1801                 bs = seprint(bs, be, "%s", s1);
1802         return bs;
1803 }
1804
1805 static char *
1806 intsPrint(char *bs, char *be, char *s0, Ints *b, char *s1)
1807 {
1808         int i;
1809
1810         if(s0)
1811                 bs = seprint(bs, be, "%s", s0);
1812         bs = seprint(bs, be, "[");
1813         if(b == nil)
1814                 bs = seprint(bs, be, "nil");
1815         else
1816                 for(i=0; i<b->len; i++)
1817                         bs = seprint(bs, be, "%x ", b->data[i]);
1818         bs = seprint(bs, be, "]");
1819         if(s1)
1820                 bs = seprint(bs, be, "%s", s1);
1821         return bs;
1822 }
1823
1824 static char*
1825 msgPrint(char *buf, int n, Msg *m)
1826 {
1827         int i;
1828         char *bs = buf, *be = buf+n;
1829
1830         switch(m->tag) {
1831         default:
1832                 bs = seprint(bs, be, "unknown %d\n", m->tag);
1833                 break;
1834         case HClientHello:
1835                 bs = seprint(bs, be, "ClientHello\n");
1836                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.clientHello.version);
1837                 bs = seprint(bs, be, "\trandom: ");
1838                 for(i=0; i<RandomSize; i++)
1839                         bs = seprint(bs, be, "%.2x", m->u.clientHello.random[i]);
1840                 bs = seprint(bs, be, "\n");
1841                 bs = bytesPrint(bs, be, "\tsid: ", m->u.clientHello.sid, "\n");
1842                 bs = intsPrint(bs, be, "\tciphers: ", m->u.clientHello.ciphers, "\n");
1843                 bs = bytesPrint(bs, be, "\tcompressors: ", m->u.clientHello.compressors, "\n");
1844                 if(m->u.clientHello.extensions != nil)
1845                         bs = bytesPrint(bs, be, "\textensions: ", m->u.clientHello.extensions, "\n");
1846                 break;
1847         case HServerHello:
1848                 bs = seprint(bs, be, "ServerHello\n");
1849                 bs = seprint(bs, be, "\tversion: %.4x\n", m->u.serverHello.version);
1850                 bs = seprint(bs, be, "\trandom: ");
1851                 for(i=0; i<RandomSize; i++)
1852                         bs = seprint(bs, be, "%.2x", m->u.serverHello.random[i]);
1853                 bs = seprint(bs, be, "\n");
1854                 bs = bytesPrint(bs, be, "\tsid: ", m->u.serverHello.sid, "\n");
1855                 bs = seprint(bs, be, "\tcipher: %.4x\n", m->u.serverHello.cipher);
1856                 bs = seprint(bs, be, "\tcompressor: %.2x\n", m->u.serverHello.compressor);
1857                 if(m->u.serverHello.extensions != nil)
1858                         bs = bytesPrint(bs, be, "\textensions: ", m->u.serverHello.extensions, "\n");
1859                 break;
1860         case HCertificate:
1861                 bs = seprint(bs, be, "Certificate\n");
1862                 for(i=0; i<m->u.certificate.ncert; i++)
1863                         bs = bytesPrint(bs, be, "\t", m->u.certificate.certs[i], "\n");
1864                 break;
1865         case HCertificateRequest:
1866                 bs = seprint(bs, be, "CertificateRequest\n");
1867                 bs = bytesPrint(bs, be, "\ttypes: ", m->u.certificateRequest.types, "\n");
1868                 bs = seprint(bs, be, "\tcertificateauthorities\n");
1869                 for(i=0; i<m->u.certificateRequest.nca; i++)
1870                         bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n");
1871                 break;
1872         case HCertificateVerify:
1873                 bs = seprint(bs, be, "HCertificateVerify\n");
1874                 bs = bytesPrint(bs, be, "\tsignature: ", m->u.certificateVerify.signature,"\n");
1875                 break;  
1876         case HServerHelloDone:
1877                 bs = seprint(bs, be, "ServerHelloDone\n");
1878                 break;
1879         case HServerKeyExchange:
1880                 bs = seprint(bs, be, "HServerKeyExchange\n");
1881                 if(m->u.serverKeyExchange.curve != 0){
1882                         bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve);
1883                 } else {
1884                         bs = bytesPrint(bs, be, "\tdh_p: ", m->u.serverKeyExchange.dh_p, "\n");
1885                         bs = bytesPrint(bs, be, "\tdh_g: ", m->u.serverKeyExchange.dh_g, "\n");
1886                 }
1887                 bs = bytesPrint(bs, be, "\tdh_Ys: ", m->u.serverKeyExchange.dh_Ys, "\n");
1888                 bs = bytesPrint(bs, be, "\tdh_signature: ", m->u.serverKeyExchange.dh_signature, "\n");
1889                 break;
1890         case HClientKeyExchange:
1891                 bs = seprint(bs, be, "HClientKeyExchange\n");
1892                 bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n");
1893                 break;
1894         case HFinished:
1895                 bs = seprint(bs, be, "HFinished\n");
1896                 for(i=0; i<m->u.finished.n; i++)
1897                         bs = seprint(bs, be, "%.2x", m->u.finished.verify[i]);
1898                 bs = seprint(bs, be, "\n");
1899                 break;
1900         }
1901         USED(bs);
1902         return buf;
1903 }
1904
1905 static void
1906 tlsError(TlsConnection *c, int err, char *fmt, ...)
1907 {
1908         char msg[512];
1909         va_list arg;
1910
1911         va_start(arg, fmt);
1912         vseprint(msg, msg+sizeof(msg), fmt, arg);
1913         va_end(arg);
1914         if(c->trace)
1915                 c->trace("tlsError: %s\n", msg);
1916         else if(c->erred)
1917                 fprint(2, "double error: %r, %s", msg);
1918         else
1919                 werrstr("tls: local %s", msg);
1920         c->erred = 1;
1921         fprint(c->ctl, "alert %d", err);
1922 }
1923
1924 // commit to specific version number
1925 static int
1926 setVersion(TlsConnection *c, int version)
1927 {
1928         if(c->verset || version > MaxProtoVersion || version < MinProtoVersion)
1929                 return -1;
1930         if(version > c->version)
1931                 version = c->version;
1932         if(version == SSL3Version) {
1933                 c->version = version;
1934                 c->finished.n = SSL3FinishedLen;
1935         }else if(version == TLSVersion){
1936                 c->version = version;
1937                 c->finished.n = TLSFinishedLen;
1938         }else
1939                 return -1;
1940         c->verset = 1;
1941         return fprint(c->ctl, "version 0x%x", version);
1942 }
1943
1944 // confirm that received Finished message matches the expected value
1945 static int
1946 finishedMatch(TlsConnection *c, Finished *f)
1947 {
1948         return memcmp(f->verify, c->finished.verify, f->n) == 0;
1949 }
1950
1951 // free memory associated with TlsConnection struct
1952 //              (but don't close the TLS channel itself)
1953 static void
1954 tlsConnectionFree(TlsConnection *c)
1955 {
1956         tlsSecClose(c->sec);
1957         freebytes(c->sid);
1958         freebytes(c->cert);
1959         memset(c, 0, sizeof(c));
1960         free(c);
1961 }
1962
1963
1964 //================= cipher choices ========================
1965
1966 static int weakCipher[CipherMax] =
1967 {
1968         1,      /* TLS_NULL_WITH_NULL_NULL */
1969         1,      /* TLS_RSA_WITH_NULL_MD5 */
1970         1,      /* TLS_RSA_WITH_NULL_SHA */
1971         1,      /* TLS_RSA_EXPORT_WITH_RC4_40_MD5 */
1972         0,      /* TLS_RSA_WITH_RC4_128_MD5 */
1973         0,      /* TLS_RSA_WITH_RC4_128_SHA */
1974         1,      /* TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 */
1975         0,      /* TLS_RSA_WITH_IDEA_CBC_SHA */
1976         1,      /* TLS_RSA_EXPORT_WITH_DES40_CBC_SHA */
1977         0,      /* TLS_RSA_WITH_DES_CBC_SHA */
1978         0,      /* TLS_RSA_WITH_3DES_EDE_CBC_SHA */
1979         1,      /* TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA */
1980         0,      /* TLS_DH_DSS_WITH_DES_CBC_SHA */
1981         0,      /* TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA */
1982         1,      /* TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA */
1983         0,      /* TLS_DH_RSA_WITH_DES_CBC_SHA */
1984         0,      /* TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA */
1985         1,      /* TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA */
1986         0,      /* TLS_DHE_DSS_WITH_DES_CBC_SHA */
1987         0,      /* TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA */
1988         1,      /* TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA */
1989         0,      /* TLS_DHE_RSA_WITH_DES_CBC_SHA */
1990         0,      /* TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA */
1991         1,      /* TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 */
1992         1,      /* TLS_DH_anon_WITH_RC4_128_MD5 */
1993         1,      /* TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA */
1994         1,      /* TLS_DH_anon_WITH_DES_CBC_SHA */
1995         1,      /* TLS_DH_anon_WITH_3DES_EDE_CBC_SHA */
1996 };
1997
1998 static int
1999 setAlgs(TlsConnection *c, int a)
2000 {
2001         int i;
2002
2003         for(i = 0; i < nelem(cipherAlgs); i++){
2004                 if(cipherAlgs[i].tlsid == a){
2005                         c->cipher = a;
2006                         c->enc = cipherAlgs[i].enc;
2007                         c->digest = cipherAlgs[i].digest;
2008                         c->nsecret = cipherAlgs[i].nsecret;
2009                         if(c->nsecret > MaxKeyData)
2010                                 return 0;
2011                         return 1;
2012                 }
2013         }
2014         return 0;
2015 }
2016
2017 static int
2018 okCipher(Ints *cv)
2019 {
2020         int weak, i, j, c;
2021
2022         weak = 1;
2023         for(i = 0; i < cv->len; i++) {
2024                 c = cv->data[i];
2025                 if(c >= CipherMax)
2026                         weak = 0;
2027                 else
2028                         weak &= weakCipher[c];
2029                 if(isDHE(c) || isECDHE(c))
2030                         continue;       /* TODO: not implemented for server */
2031                 for(j = 0; j < nelem(cipherAlgs); j++)
2032                         if(cipherAlgs[j].ok && cipherAlgs[j].tlsid == c)
2033                                 return c;
2034         }
2035         if(weak)
2036                 return -2;
2037         return -1;
2038 }
2039
2040 static int
2041 okCompression(Bytes *cv)
2042 {
2043         int i, j, c;
2044
2045         for(i = 0; i < cv->len; i++) {
2046                 c = cv->data[i];
2047                 for(j = 0; j < nelem(compressors); j++) {
2048                         if(compressors[j] == c)
2049                                 return c;
2050                 }
2051         }
2052         return -1;
2053 }
2054
2055 static Lock     ciphLock;
2056 static int      nciphers;
2057
2058 static int
2059 initCiphers(void)
2060 {
2061         enum {MaxAlgF = 1024, MaxAlgs = 10};
2062         char s[MaxAlgF], *flds[MaxAlgs];
2063         int i, j, n, ok;
2064
2065         lock(&ciphLock);
2066         if(nciphers){
2067                 unlock(&ciphLock);
2068                 return nciphers;
2069         }
2070         j = open("#a/tls/encalgs", OREAD);
2071         if(j < 0){
2072                 werrstr("can't open #a/tls/encalgs: %r");
2073                 return 0;
2074         }
2075         n = read(j, s, MaxAlgF-1);
2076         close(j);
2077         if(n <= 0){
2078                 werrstr("nothing in #a/tls/encalgs: %r");
2079                 return 0;
2080         }
2081         s[n] = 0;
2082         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2083         for(i = 0; i < nelem(cipherAlgs); i++){
2084                 ok = 0;
2085                 for(j = 0; j < n; j++){
2086                         if(strcmp(cipherAlgs[i].enc, flds[j]) == 0){
2087                                 ok = 1;
2088                                 break;
2089                         }
2090                 }
2091                 cipherAlgs[i].ok = ok;
2092         }
2093
2094         j = open("#a/tls/hashalgs", OREAD);
2095         if(j < 0){
2096                 werrstr("can't open #a/tls/hashalgs: %r");
2097                 return 0;
2098         }
2099         n = read(j, s, MaxAlgF-1);
2100         close(j);
2101         if(n <= 0){
2102                 werrstr("nothing in #a/tls/hashalgs: %r");
2103                 return 0;
2104         }
2105         s[n] = 0;
2106         n = getfields(s, flds, MaxAlgs, 1, " \t\r\n");
2107         for(i = 0; i < nelem(cipherAlgs); i++){
2108                 ok = 0;
2109                 for(j = 0; j < n; j++){
2110                         if(strcmp(cipherAlgs[i].digest, flds[j]) == 0){
2111                                 ok = 1;
2112                                 break;
2113                         }
2114                 }
2115                 cipherAlgs[i].ok &= ok;
2116                 if(cipherAlgs[i].ok)
2117                         nciphers++;
2118         }
2119         unlock(&ciphLock);
2120         return nciphers;
2121 }
2122
2123 static Ints*
2124 makeciphers(void)
2125 {
2126         Ints *is;
2127         int i, j;
2128
2129         is = newints(nciphers);
2130         j = 0;
2131         for(i = 0; i < nelem(cipherAlgs); i++){
2132                 if(cipherAlgs[i].ok)
2133                         is->data[j++] = cipherAlgs[i].tlsid;
2134         }
2135         return is;
2136 }
2137
2138
2139
2140 //================= security functions ========================
2141
2142 // given X.509 certificate, set up connection to factotum
2143 //      for using corresponding private key
2144 static AuthRpc*
2145 factotum_rsa_open(uchar *cert, int certlen)
2146 {
2147         int afd;
2148         char *s;
2149         mpint *pub = nil;
2150         RSApub *rsapub;
2151         AuthRpc *rpc;
2152
2153         // start talking to factotum
2154         if((afd = open("/mnt/factotum/rpc", ORDWR)) < 0)
2155                 return nil;
2156         if((rpc = auth_allocrpc(afd)) == nil){
2157                 close(afd);
2158                 return nil;
2159         }
2160         s = "proto=rsa service=tls role=client";
2161         if(auth_rpc(rpc, "start", s, strlen(s)) != ARok){
2162                 factotum_rsa_close(rpc);
2163                 return nil;
2164         }
2165
2166         // roll factotum keyring around to match certificate
2167         rsapub = X509toRSApub(cert, certlen, nil, 0);
2168         while(1){
2169                 if(auth_rpc(rpc, "read", nil, 0) != ARok){
2170                         factotum_rsa_close(rpc);
2171                         rpc = nil;
2172                         goto done;
2173                 }
2174                 pub = strtomp(rpc->arg, nil, 16, nil);
2175                 assert(pub != nil);
2176                 if(mpcmp(pub,rsapub->n) == 0)
2177                         break;
2178         }
2179 done:
2180         mpfree(pub);
2181         rsapubfree(rsapub);
2182         return rpc;
2183 }
2184
2185 static mpint*
2186 factotum_rsa_decrypt(AuthRpc *rpc, mpint *cipher)
2187 {
2188         char *p;
2189         int rv;
2190
2191         p = mptoa(cipher, 16, nil, 0);
2192         mpfree(cipher);
2193         if(p == nil)
2194                 return nil;
2195         rv = auth_rpc(rpc, "write", p, strlen(p));
2196         free(p);
2197         if(rv != ARok || auth_rpc(rpc, "read", nil, 0) != ARok)
2198                 return nil;
2199         return strtomp(rpc->arg, nil, 16, nil);
2200 }
2201
2202 static void
2203 factotum_rsa_close(AuthRpc*rpc)
2204 {
2205         if(!rpc)
2206                 return;
2207         close(rpc->afd);
2208         auth_freerpc(rpc);
2209 }
2210
2211 static void
2212 tlsPmd5(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2213 {
2214         uchar ai[MD5dlen], tmp[MD5dlen];
2215         int i, n;
2216         MD5state *s;
2217
2218         // generate a1
2219         s = hmac_md5(label, nlabel, key, nkey, nil, nil);
2220         s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2221         hmac_md5(seed1, nseed1, key, nkey, ai, s);
2222
2223         while(nbuf > 0) {
2224                 s = hmac_md5(ai, MD5dlen, key, nkey, nil, nil);
2225                 s = hmac_md5(label, nlabel, key, nkey, nil, s);
2226                 s = hmac_md5(seed0, nseed0, key, nkey, nil, s);
2227                 hmac_md5(seed1, nseed1, key, nkey, tmp, s);
2228                 n = MD5dlen;
2229                 if(n > nbuf)
2230                         n = nbuf;
2231                 for(i = 0; i < n; i++)
2232                         buf[i] ^= tmp[i];
2233                 buf += n;
2234                 nbuf -= n;
2235                 hmac_md5(ai, MD5dlen, key, nkey, tmp, nil);
2236                 memmove(ai, tmp, MD5dlen);
2237         }
2238 }
2239
2240 static void
2241 tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2242 {
2243         uchar ai[SHA1dlen], tmp[SHA1dlen];
2244         int i, n;
2245         SHAstate *s;
2246
2247         // generate a1
2248         s = hmac_sha1(label, nlabel, key, nkey, nil, nil);
2249         s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2250         hmac_sha1(seed1, nseed1, key, nkey, ai, s);
2251
2252         while(nbuf > 0) {
2253                 s = hmac_sha1(ai, SHA1dlen, key, nkey, nil, nil);
2254                 s = hmac_sha1(label, nlabel, key, nkey, nil, s);
2255                 s = hmac_sha1(seed0, nseed0, key, nkey, nil, s);
2256                 hmac_sha1(seed1, nseed1, key, nkey, tmp, s);
2257                 n = SHA1dlen;
2258                 if(n > nbuf)
2259                         n = nbuf;
2260                 for(i = 0; i < n; i++)
2261                         buf[i] ^= tmp[i];
2262                 buf += n;
2263                 nbuf -= n;
2264                 hmac_sha1(ai, SHA1dlen, key, nkey, tmp, nil);
2265                 memmove(ai, tmp, SHA1dlen);
2266         }
2267 }
2268
2269 // fill buf with md5(args)^sha1(args)
2270 static void
2271 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2272 {
2273         int i;
2274         int nlabel = strlen(label);
2275         int n = (nkey + 1) >> 1;
2276
2277         for(i = 0; i < nbuf; i++)
2278                 buf[i] = 0;
2279         tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2280         tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
2281 }
2282
2283 /*
2284  * for setting server session id's
2285  */
2286 static Lock     sidLock;
2287 static long     maxSid = 1;
2288
2289 /* the keys are verified to have the same public components
2290  * and to function correctly with pkcs 1 encryption and decryption. */
2291 static TlsSec*
2292 tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom)
2293 {
2294         TlsSec *sec = emalloc(sizeof(*sec));
2295
2296         USED(csid); USED(ncsid);  // ignore csid for now
2297
2298         memmove(sec->crandom, crandom, RandomSize);
2299         sec->clientVers = cvers;
2300
2301         put32(sec->srandom, time(0));
2302         genrandom(sec->srandom+4, RandomSize-4);
2303         memmove(srandom, sec->srandom, RandomSize);
2304
2305         /*
2306          * make up a unique sid: use our pid, and and incrementing id
2307          * can signal no sid by setting nssid to 0.
2308          */
2309         memset(ssid, 0, SidSize);
2310         put32(ssid, getpid());
2311         lock(&sidLock);
2312         put32(ssid+4, maxSid++);
2313         unlock(&sidLock);
2314         *nssid = SidSize;
2315         return sec;
2316 }
2317
2318 static int
2319 tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm)
2320 {
2321         if(epm != nil){
2322                 if(setVers(sec, vers) < 0)
2323                         goto Err;
2324                 serverMasterSecret(sec, epm);
2325         }else if(sec->vers != vers){
2326                 werrstr("mismatched session versions");
2327                 goto Err;
2328         }
2329         return 0;
2330 Err:
2331         sec->ok = -1;
2332         return -1;
2333 }
2334
2335 static TlsSec*
2336 tlsSecInitc(int cvers, uchar *crandom)
2337 {
2338         TlsSec *sec = emalloc(sizeof(*sec));
2339         sec->clientVers = cvers;
2340         put32(sec->crandom, time(0));
2341         genrandom(sec->crandom+4, RandomSize-4);
2342         memmove(crandom, sec->crandom, RandomSize);
2343         return sec;
2344 }
2345
2346 static Bytes*
2347 tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers)
2348 {
2349         RSApub *pub;
2350         Bytes *epm;
2351
2352         USED(sid);
2353         USED(nsid);
2354         
2355         memmove(sec->srandom, srandom, RandomSize);
2356         if(setVers(sec, vers) < 0)
2357                 goto Err;
2358         pub = X509toRSApub(cert, ncert, nil, 0);
2359         if(pub == nil){
2360                 werrstr("invalid x509/rsa certificate");
2361                 goto Err;
2362         }
2363         epm = clientMasterSecret(sec, pub);
2364         rsapubfree(pub);
2365         if(epm != nil)
2366                 return epm;
2367 Err:
2368         sec->ok = -1;
2369         return nil;
2370 }
2371
2372 static int
2373 tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
2374 {
2375         if(sec->nfin != nfin){
2376                 sec->ok = -1;
2377                 werrstr("invalid finished exchange");
2378                 return -1;
2379         }
2380         md5.malloced = 0;
2381         sha1.malloced = 0;
2382         (*sec->setFinished)(sec, md5, sha1, fin, isclient);
2383         return 1;
2384 }
2385
2386 static void
2387 tlsSecOk(TlsSec *sec)
2388 {
2389         if(sec->ok == 0)
2390                 sec->ok = 1;
2391 }
2392
2393 static void
2394 tlsSecKill(TlsSec *sec)
2395 {
2396         if(!sec)
2397                 return;
2398         factotum_rsa_close(sec->rpc);
2399         sec->ok = -1;
2400 }
2401
2402 static void
2403 tlsSecClose(TlsSec *sec)
2404 {
2405         if(!sec)
2406                 return;
2407         factotum_rsa_close(sec->rpc);
2408         free(sec->server);
2409         free(sec);
2410 }
2411
2412 static int
2413 setVers(TlsSec *sec, int v)
2414 {
2415         if(v == SSL3Version){
2416                 sec->setFinished = sslSetFinished;
2417                 sec->nfin = SSL3FinishedLen;
2418                 sec->prf = sslPRF;
2419         }else if(v == TLSVersion){
2420                 sec->setFinished = tlsSetFinished;
2421                 sec->nfin = TLSFinishedLen;
2422                 sec->prf = tlsPRF;
2423         }else{
2424                 werrstr("invalid version");
2425                 return -1;
2426         }
2427         sec->vers = v;
2428         return 0;
2429 }
2430
2431 /*
2432  * generate secret keys from the master secret.
2433  *
2434  * different crypto selections will require different amounts
2435  * of key expansion and use of key expansion data,
2436  * but it's all generated using the same function.
2437  */
2438 static void
2439 setSecrets(TlsSec *sec, uchar *kd, int nkd)
2440 {
2441         (*sec->prf)(kd, nkd, sec->sec, MasterSecretSize, "key expansion",
2442                         sec->srandom, RandomSize, sec->crandom, RandomSize);
2443 }
2444
2445 /*
2446  * set the master secret from the pre-master secret,
2447  * destroys premaster.
2448  */
2449 static void
2450 setMasterSecret(TlsSec *sec, Bytes *pm)
2451 {
2452         (*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret",
2453                         sec->crandom, RandomSize, sec->srandom, RandomSize);
2454
2455         memset(pm->data, 0, pm->len);   
2456         freebytes(pm);
2457 }
2458
2459 static void
2460 serverMasterSecret(TlsSec *sec, Bytes *epm)
2461 {
2462         Bytes *pm;
2463
2464         pm = pkcs1_decrypt(sec, epm);
2465
2466         // if the client messed up, just continue as if everything is ok,
2467         // to prevent attacks to check for correctly formatted messages.
2468         // Hence the fprint(2,) can't be replaced by tlsError(), which sends an Alert msg to the client.
2469         if(sec->ok < 0 || pm == nil || get16(pm->data) != sec->clientVers){
2470                 fprint(2, "serverMasterSecret failed ok=%d pm=%p pmvers=%x cvers=%x nepm=%d\n",
2471                         sec->ok, pm, pm != nil ? get16(pm->data) : -1, sec->clientVers, epm->len);
2472                 sec->ok = -1;
2473                 freebytes(pm);
2474                 pm = newbytes(MasterSecretSize);
2475                 genrandom(pm->data, MasterSecretSize);
2476         }
2477         assert(pm->len == MasterSecretSize);
2478         setMasterSecret(sec, pm);
2479 }
2480
2481 static Bytes*
2482 clientMasterSecret(TlsSec *sec, RSApub *pub)
2483 {
2484         Bytes *pm, *epm;
2485
2486         pm = newbytes(MasterSecretSize);
2487         put16(pm->data, sec->clientVers);
2488         genrandom(pm->data+2, MasterSecretSize - 2);
2489         epm = pkcs1_encrypt(pm, pub, 2);
2490         setMasterSecret(sec, pm);
2491         return epm;
2492 }
2493
2494 static void
2495 sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
2496 {
2497         DigestState *s;
2498         uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
2499         char *label;
2500
2501         if(isClient)
2502                 label = "CLNT";
2503         else
2504                 label = "SRVR";
2505
2506         md5((uchar*)label, 4, nil, &hsmd5);
2507         md5(sec->sec, MasterSecretSize, nil, &hsmd5);
2508         memset(pad, 0x36, 48);
2509         md5(pad, 48, nil, &hsmd5);
2510         md5(nil, 0, h0, &hsmd5);
2511         memset(pad, 0x5C, 48);
2512         s = md5(sec->sec, MasterSecretSize, nil, nil);
2513         s = md5(pad, 48, nil, s);
2514         md5(h0, MD5dlen, finished, s);
2515
2516         sha1((uchar*)label, 4, nil, &hssha1);
2517         sha1(sec->sec, MasterSecretSize, nil, &hssha1);
2518         memset(pad, 0x36, 40);
2519         sha1(pad, 40, nil, &hssha1);
2520         sha1(nil, 0, h1, &hssha1);
2521         memset(pad, 0x5C, 40);
2522         s = sha1(sec->sec, MasterSecretSize, nil, nil);
2523         s = sha1(pad, 40, nil, s);
2524         sha1(h1, SHA1dlen, finished + MD5dlen, s);
2525 }
2526
2527 // fill "finished" arg with md5(args)^sha1(args)
2528 static void
2529 tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
2530 {
2531         uchar h0[MD5dlen], h1[SHA1dlen];
2532         char *label;
2533
2534         // get current hash value, but allow further messages to be hashed in
2535         md5(nil, 0, h0, &hsmd5);
2536         sha1(nil, 0, h1, &hssha1);
2537
2538         if(isClient)
2539                 label = "client finished";
2540         else
2541                 label = "server finished";
2542         tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
2543 }
2544
2545 static void
2546 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
2547 {
2548         DigestState *s;
2549         uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
2550         int i, n, len;
2551
2552         USED(label);
2553         len = 1;
2554         while(nbuf > 0){
2555                 if(len > 26)
2556                         return;
2557                 for(i = 0; i < len; i++)
2558                         tmp[i] = 'A' - 1 + len;
2559                 s = sha1(tmp, len, nil, nil);
2560                 s = sha1(key, nkey, nil, s);
2561                 s = sha1(seed0, nseed0, nil, s);
2562                 sha1(seed1, nseed1, sha1dig, s);
2563                 s = md5(key, nkey, nil, nil);
2564                 md5(sha1dig, SHA1dlen, md5dig, s);
2565                 n = MD5dlen;
2566                 if(n > nbuf)
2567                         n = nbuf;
2568                 memmove(buf, md5dig, n);
2569                 buf += n;
2570                 nbuf -= n;
2571                 len++;
2572         }
2573 }
2574
2575 static mpint*
2576 bytestomp(Bytes* bytes)
2577 {
2578         return betomp(bytes->data, bytes->len, nil);
2579 }
2580
2581 /*
2582  * Convert mpint* to Bytes, putting high order byte first.
2583  */
2584 static Bytes*
2585 mptobytes(mpint* big)
2586 {
2587         Bytes* ans;
2588         int n;
2589
2590         n = (mpsignif(big)+7)/8;
2591         if(n == 0) n = 1;
2592         ans = newbytes(n);
2593         ans->len = mptobe(big, ans->data, n, nil);
2594         return ans;
2595 }
2596
2597 // Do RSA computation on block according to key, and pad
2598 // result on left with zeros to make it modlen long.
2599 static Bytes*
2600 rsacomp(Bytes* block, RSApub* key, int modlen)
2601 {
2602         mpint *x, *y;
2603         Bytes *a, *ybytes;
2604         int ylen;
2605
2606         x = bytestomp(block);
2607         y = rsaencrypt(key, x, nil);
2608         mpfree(x);
2609         ybytes = mptobytes(y);
2610         ylen = ybytes->len;
2611         mpfree(y);
2612
2613         if(ylen < modlen) {
2614                 a = newbytes(modlen);
2615                 memset(a->data, 0, modlen-ylen);
2616                 memmove(a->data+modlen-ylen, ybytes->data, ylen);
2617                 freebytes(ybytes);
2618                 ybytes = a;
2619         }
2620         else if(ylen > modlen) {
2621                 // assume it has leading zeros (mod should make it so)
2622                 a = newbytes(modlen);
2623                 memmove(a->data, ybytes->data, modlen);
2624                 freebytes(ybytes);
2625                 ybytes = a;
2626         }
2627         return ybytes;
2628 }
2629
2630 // encrypt data according to PKCS#1, /lib/rfc/rfc2437 9.1.2.1
2631 static Bytes*
2632 pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype)
2633 {
2634         Bytes *pad, *eb, *ans;
2635         int i, dlen, padlen, modlen;
2636
2637         modlen = (mpsignif(key->n)+7)/8;
2638         dlen = data->len;
2639         if(modlen < 12 || dlen > modlen - 11)
2640                 return nil;
2641         padlen = modlen - 3 - dlen;
2642         pad = newbytes(padlen);
2643         genrandom(pad->data, padlen);
2644         for(i = 0; i < padlen; i++) {
2645                 if(blocktype == 0)
2646                         pad->data[i] = 0;
2647                 else if(blocktype == 1)
2648                         pad->data[i] = 255;
2649                 else if(pad->data[i] == 0)
2650                         pad->data[i] = 1;
2651         }
2652         eb = newbytes(modlen);
2653         eb->data[0] = 0;
2654         eb->data[1] = blocktype;
2655         memmove(eb->data+2, pad->data, padlen);
2656         eb->data[padlen+2] = 0;
2657         memmove(eb->data+padlen+3, data->data, dlen);
2658         ans = rsacomp(eb, key, modlen);
2659         freebytes(eb);
2660         freebytes(pad);
2661         return ans;
2662 }
2663
2664 // decrypt data according to PKCS#1, with given key.
2665 // expect a block type of 2.
2666 static Bytes*
2667 pkcs1_decrypt(TlsSec *sec, Bytes *cipher)
2668 {
2669         Bytes *eb, *ans = nil;
2670         int i, modlen;
2671         mpint *x, *y;
2672
2673         modlen = (mpsignif(sec->rsapub->n)+7)/8;
2674         if(cipher->len != modlen)
2675                 return nil;
2676         x = bytestomp(cipher);
2677         y = factotum_rsa_decrypt(sec->rpc, x);
2678         if(y == nil)
2679                 return nil;
2680         eb = mptobytes(y);
2681         mpfree(y);
2682         if(eb->len < modlen){ // pad on left with zeros
2683                 ans = newbytes(modlen);
2684                 memset(ans->data, 0, modlen-eb->len);
2685                 memmove(ans->data+modlen-eb->len, eb->data, eb->len);
2686                 freebytes(eb);
2687                 eb = ans;
2688         }
2689         if(eb->data[0] == 0 && eb->data[1] == 2) {
2690                 for(i = 2; i < modlen; i++)
2691                         if(eb->data[i] == 0)
2692                                 break;
2693                 if(i < modlen - 1)
2694                         ans = makebytes(eb->data+i+1, modlen-(i+1));
2695         }
2696         freebytes(eb);
2697         return ans;
2698 }
2699
2700
2701 //================= general utility functions ========================
2702
2703 static void *
2704 emalloc(int n)
2705 {
2706         void *p;
2707         if(n==0)
2708                 n=1;
2709         p = malloc(n);
2710         if(p == nil)
2711                 sysfatal("out of memory");
2712         memset(p, 0, n);
2713         setmalloctag(p, getcallerpc(&n));
2714         return p;
2715 }
2716
2717 static void *
2718 erealloc(void *ReallocP, int ReallocN)
2719 {
2720         if(ReallocN == 0)
2721                 ReallocN = 1;
2722         if(ReallocP == nil)
2723                 ReallocP = emalloc(ReallocN);
2724         else if((ReallocP = realloc(ReallocP, ReallocN)) == nil)
2725                 sysfatal("out of memory");
2726         setrealloctag(ReallocP, getcallerpc(&ReallocP));
2727         return(ReallocP);
2728 }
2729
2730 static void
2731 put32(uchar *p, u32int x)
2732 {
2733         p[0] = x>>24;
2734         p[1] = x>>16;
2735         p[2] = x>>8;
2736         p[3] = x;
2737 }
2738
2739 static void
2740 put24(uchar *p, int x)
2741 {
2742         p[0] = x>>16;
2743         p[1] = x>>8;
2744         p[2] = x;
2745 }
2746
2747 static void
2748 put16(uchar *p, int x)
2749 {
2750         p[0] = x>>8;
2751         p[1] = x;
2752 }
2753
2754 static u32int
2755 get32(uchar *p)
2756 {
2757         return (p[0]<<24)|(p[1]<<16)|(p[2]<<8)|p[3];
2758 }
2759
2760 static int
2761 get24(uchar *p)
2762 {
2763         return (p[0]<<16)|(p[1]<<8)|p[2];
2764 }
2765
2766 static int
2767 get16(uchar *p)
2768 {
2769         return (p[0]<<8)|p[1];
2770 }
2771
2772 #define OFFSET(x, s) offsetof(s, x)
2773
2774 static Bytes*
2775 newbytes(int len)
2776 {
2777         Bytes* ans;
2778
2779         if(len < 0)
2780                 abort();
2781         ans = (Bytes*)emalloc(OFFSET(data[0], Bytes) + len);
2782         ans->len = len;
2783         return ans;
2784 }
2785
2786 /*
2787  * newbytes(len), with data initialized from buf
2788  */
2789 static Bytes*
2790 makebytes(uchar* buf, int len)
2791 {
2792         Bytes* ans;
2793
2794         ans = newbytes(len);
2795         memmove(ans->data, buf, len);
2796         return ans;
2797 }
2798
2799 static void
2800 freebytes(Bytes* b)
2801 {
2802         free(b);
2803 }
2804
2805 /* len is number of ints */
2806 static Ints*
2807 newints(int len)
2808 {
2809         Ints* ans;
2810
2811         if(len < 0 || len > ((uint)-1>>1)/sizeof(int))
2812                 abort();
2813         ans = (Ints*)emalloc(OFFSET(data[0], Ints) + len*sizeof(int));
2814         ans->len = len;
2815         return ans;
2816 }
2817
2818 static void
2819 freeints(Ints* b)
2820 {
2821         free(b);
2822 }