]> git.lizzy.rs Git - plan9front.git/blobdiff - sys/src/libsec/port/tlshand.c
import E script from bell labs
[plan9front.git] / sys / src / libsec / port / tlshand.c
index 13baa6a628fec0dc7deadc5c53108f5dc14d031b..8151c60c4a61eccb100cc4a1452530732eb2043b 100644 (file)
@@ -63,6 +63,12 @@ typedef struct Finished{
        int n;
 } Finished;
 
+typedef struct HandshakeHash {
+       MD5state        md5;
+       SHAstate        sha1;
+       SHA2_256state   sha2_256;
+} HandshakeHash;
+
 typedef struct TlsConnection{
        TlsSec *sec;    // security management goo
        int hand, ctl;  // record layer file descriptors
@@ -95,8 +101,7 @@ typedef struct TlsConnection{
        int nsecret;    // amount of secret data to init keys
 
        // for finished messages
-       MD5state        hsmd5;  // handshake hash
-       SHAstate        hssha1; // handshake hash
+       HandshakeHash   handhash;
        Finished        finished;
 } TlsConnection;
 
@@ -157,7 +162,7 @@ typedef struct TlsSec{
        int vers;                       // final version
        // byte generation and handshake checksum
        void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
-       void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
+       void (*setFinished)(TlsSec*, HandshakeHash, uchar*, int);
        int nfin;
 } TlsSec;
 
@@ -166,7 +171,8 @@ enum {
        SSL3Version     = 0x0300,
        TLS10Version    = 0x0301,
        TLS11Version    = 0x0302,
-       ProtocolVersion = TLS11Version, // maximum version we speak
+       TLS12Version    = 0x0303,
+       ProtocolVersion = TLS12Version, // maximum version we speak
        MinProtoVersion = 0x0300,       // limits on version we accept
        MaxProtoVersion = 0x03ff,
 };
@@ -274,19 +280,18 @@ enum {
 };
 
 static Algs cipherAlgs[] = {
-       {"rc4_128", "md5",      2 * (16 + MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
-       {"rc4_128", "sha1",     2 * (16 + SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
-       {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
-       {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
-       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
-       {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
-       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
-       {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
-
-       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
-       {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
-       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA},
        {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA},
+       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA},
+       {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA},
+       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA},
+       {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_256_CBC_SHA},
+       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_DHE_RSA_WITH_AES_128_CBC_SHA},
+       {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA},
+       {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA},
+       {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA},
+       {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA},
+       {"rc4_128", "sha1",     2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA},
+       {"rc4_128", "md5",      2*(16+MD5dlen), TLS_RSA_WITH_RC4_128_MD5},
 };
 
 static uchar compressors[] = {
@@ -331,7 +336,7 @@ static TlsSec*      tlsSecInitc(int cvers, uchar *crandom);
 static Bytes*  tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers);
 static Bytes*  tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys);
 static Bytes*  tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys);
-static int     tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
+static int     tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient);
 static void    tlsSecOk(TlsSec *sec);
 static void    tlsSecKill(TlsSec *sec);
 static void    tlsSecClose(TlsSec *sec);
@@ -341,8 +346,9 @@ static void setSecrets(TlsSec *sec, uchar *kd, int nkd);
 static Bytes*  clientMasterSecret(TlsSec *sec, RSApub *pub);
 static Bytes*  pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
 static Bytes*  pkcs1_decrypt(TlsSec *sec, Bytes *cipher);
-static void    tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
-static void    sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
+static void    tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
+static void    tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
+static void    sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient);
 static void    sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
                        uchar *seed0, int nseed0, uchar *seed1, int nseed1);
 static int setVers(TlsSec *sec, int version);
@@ -693,7 +699,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, .
        msgClear(&m);
 
        /* no CertificateVerify; skip to Finished */
-       if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+       if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
                tlsError(c, EInternalError, "can't set finished: %r");
                goto Err;
        }
@@ -715,7 +721,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, .
                goto Err;
        }
 
-       if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+       if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
                tlsError(c, EInternalError, "can't set finished: %r");
                goto Err;
        }
@@ -848,99 +854,57 @@ ectobytes(int type, ECpoint *p)
 static Bytes*
 tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys)
 {
-       Namedcurve *nc, *enc;
        Bytes *epm;
-       ECdomain dom;
-       ECpoint G, K, Y;
-       ECpriv Q;
+       ECdomain *dom;
+       ECpoint K, *Y;
+       ECpriv *Q;
+
+       epm = nil;
+       Y = nil;
+       Q = nil;
 
        if(Ys == nil)
                return nil;
 
-       enc = &namedcurves[nelem(namedcurves)];
-       for(nc = namedcurves; nc != enc; nc++)
-               if(nc->tlsid == curve)
-                       break;
-
-       if(nc == enc)
-               return nil;
-               
        memmove(sec->srandom, srandom, RandomSize);
        if(setVers(sec, vers) < 0)
                return nil;
-       
-       epm = nil;
-
-       memset(&dom, 0, sizeof(dom));
-       dom.p = strtomp(nc->p, nil, 16, nil);
-       dom.a = strtomp(nc->a, nil, 16, nil);
-       dom.b = strtomp(nc->b, nil, 16, nil);
-       dom.n = strtomp(nc->n, nil, 16, nil);
-       dom.h = strtomp(nc->h, nil, 16, nil);
 
-       memset(&G, 0, sizeof(G));
-       G.x = mpnew(0);
-       G.y = mpnew(0);
+       dom = ecnamedcurve(curve);
+       if(dom == nil)
+               return nil;
 
-       memset(&Q, 0, sizeof(Q));
-       Q.x = mpnew(0);
-       Q.y = mpnew(0);
-       Q.d = mpnew(0);
 
        memset(&K, 0, sizeof(K));
        K.x = mpnew(0);
        K.y = mpnew(0);
 
-       memset(&Y, 0, sizeof(Y));
-       Y.x = mpnew(0);
-       Y.y = mpnew(0);
-
-       if(dom.p == nil || dom.a == nil || dom.b == nil || dom.n == nil || dom.h == nil)
-               goto Out;
-       if(Q.x == nil || Q.y == nil || Q.d == nil)
-               goto Out;
-       if(G.x == nil || G.y == nil)
-               goto Out;
        if(K.x == nil || K.y == nil)
                goto Out;
-       if(Y.x == nil || Y.y == nil)
-               goto Out;
 
-       dom.G = strtoec(&dom, nc->G, nil, &G);
-       if(dom.G == nil)
+       Y = betoec(dom, Ys->data, Ys->len, nil);
+       if(Y == nil)
                goto Out;
 
-       if(bytestoec(&dom, Ys, &Y) == nil)
+       Q = ecgen(dom, nil);
+       if(Q == nil)
                goto Out;
 
-       if(ecgen(&dom, &Q) == nil)
-               goto Out;
-
-       ecmul(&dom, &Y, Q.d, &K);
+       ecmul(dom, Y, Q->d, &K);
        setMasterSecret(sec, mptobytes(K.x));
 
        /* 0x04 = uncompressed public key */
-       epm = ectobytes(0x04, &Q);
+       epm = ectobytes(0x04, Q);
        
 Out:
-       mpfree(Y.x);
-       mpfree(Y.y);
+       ecfreepriv(Q);
+
+       ecfreepoint(Y);
 
        mpfree(K.x);
        mpfree(K.y);
 
-       mpfree(Q.x);
-       mpfree(Q.y);
-       mpfree(Q.d);
-
-       mpfree(G.x);
-       mpfree(G.y);
-
-       mpfree(dom.p);
-       mpfree(dom.a);
-       mpfree(dom.b);
-       mpfree(dom.n);
-       mpfree(dom.h);
+       ecfreedomain(dom);
 
        return epm;
 }
@@ -962,6 +926,11 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
        epm = nil;
        c = emalloc(sizeof(TlsConnection));
        c->version = ProtocolVersion;
+
+       // client certificate signature not implemented for TLS1.2
+       if(cert != nil && certlen > 0 && c->version >= TLS12Version)
+               c->version = TLS11Version;
+
        c->ctl = ctl;
        c->hand = hand;
        c->trace = trace;
@@ -1114,25 +1083,16 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
                goto Err;
        msgClear(&m);
 
-       /* CertificateVerify */
-       /*XXX I should only send this when it is not DH right? 
-               Also we need to know which TLS key 
-               we have to use in case there are more than one*/
-       if(cert){
-               m.tag = HCertificateVerify;
+       /* certificate verify */
+       if(creq && cert != nil && certlen > 0) {
                uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */
-               MD5state        hsmd5_save;
-               SHAstate        hssha1_save;
-       
-               /* save the state for the Finish message */
+               HandshakeHash hsave;
 
-               hsmd5_save = c->hsmd5;
-               hssha1_save = c->hssha1;
-               md5(nil, 0, hshashes, &c->hsmd5);
-               sha1(nil, 0, hshashes+MD5dlen, &c->hssha1);
-       
-               c->hsmd5 = hsmd5_save;
-               c->hssha1 = hssha1_save;
+               /* save the state for the Finish message */
+               hsave = c->handhash;
+               md5(nil, 0, hshashes, &c->handhash.md5);
+               sha1(nil, 0, hshashes+MD5dlen, &c->handhash.sha1);
+               c->handhash = hsave;
 
                c->sec->rpc = factotum_rsa_open(cert, certlen);
                if(c->sec->rpc == nil){
@@ -1154,6 +1114,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
                m.u.certificateVerify.signature = mptobytes(signedMP);
                mpfree(signedMP);
 
+               m.tag = HCertificateVerify;
                if(!msgSend(c, &m, AFlush))
                        goto Err;
                msgClear(&m);
@@ -1167,7 +1128,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
 
        // Cipherchange must occur immediately before Finished to avoid
        // potential hole;  see section 4.3 of Wagner Schneier 1996.
-       if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+       if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 1) < 0){
                tlsError(c, EInternalError, "can't set finished 1: %r");
                goto Err;
        }
@@ -1179,7 +1140,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen,
        }
        msgClear(&m);
 
-       if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+       if(tlsSecFinished(c->sec, c->handhash, c->finished.verify, c->finished.n, 0) < 0){
                tlsError(c, EInternalError, "can't set finished 0: %r");
                goto Err;
        }
@@ -1216,6 +1177,15 @@ Err:
 
 //================= message functions ========================
 
+static void
+msgHash(TlsConnection *c, uchar *p, int n)
+{
+       md5(p, n, 0, &c->handhash.md5);
+       sha1(p, n, 0, &c->handhash.sha1);
+       if(c->version >= TLS12Version)
+               sha2_256(p, n, 0, &c->handhash.sha2_256);
+}
+
 static int
 msgSend(TlsConnection *c, Msg *m, int act)
 {
@@ -1352,10 +1322,8 @@ msgSend(TlsConnection *c, Msg *m, int act)
        put24(c->sendp+1, n-4);
 
        // remember hash of Handshake messages
-       if(m->tag != HHelloRequest) {
-               md5(c->sendp, n, 0, &c->hsmd5);
-               sha1(c->sendp, n, 0, &c->hssha1);
-       }
+       if(m->tag != HHelloRequest)
+               msgHash(c, c->sendp, n);
 
        c->sendp = p;
        if(act == AFlush){
@@ -1430,8 +1398,7 @@ msgRecv(TlsConnection *c, Msg *m)
                p = tlsReadN(c, n);
                if(p == nil)
                        return 0;
-               md5(p, n, 0, &c->hsmd5);
-               sha1(p, n, 0, &c->hssha1);
+               msgHash(c, p, n);
                m->tag = HClientHello;
                if(n < 22)
                        goto Short;
@@ -1468,15 +1435,13 @@ msgRecv(TlsConnection *c, Msg *m)
                m->u.clientHello.compressors->data[0] = CompressionNull;
                goto Ok;
        }
-       md5(p, 4, 0, &c->hsmd5);
-       sha1(p, 4, 0, &c->hssha1);
+       msgHash(c, p, 4);
 
        p = tlsReadN(c, n);
        if(p == nil)
                return 0;
 
-       md5(p, n, 0, &c->hsmd5);
-       sha1(p, n, 0, &c->hssha1);
+       msgHash(c, p, n);
 
        m->tag = type;
 
@@ -1678,6 +1643,12 @@ msgRecv(TlsConnection *c, Msg *m)
                        break;
                }
                if(n >= 2){
+                       if(c->version >= TLS12Version){
+                               /* signature hash algorithm */
+                               p += 2, n -= 2;
+                               if(n < 2)
+                                       goto Short;
+                       }
                        nn = get16(p);
                        p += 2, n -= 2;
                        if(nn > 0 && nn <= n){
@@ -1944,7 +1915,7 @@ setVersion(TlsConnection *c, int version)
 static int
 finishedMatch(TlsConnection *c, Finished *f)
 {
-       return memcmp(f->verify, c->finished.verify, f->n) == 0;
+       return constcmp(f->verify, c->finished.verify, f->n) == 0;
 }
 
 // free memory associated with TlsConnection struct
@@ -2265,20 +2236,55 @@ tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, u
        }
 }
 
+static void
+p_sha256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
+{
+       uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
+       SHAstate *s;
+       int n;
+
+       // generate a1
+       s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
+       hmac_sha2_256(seed, nseed, key, nkey, ai, s);
+
+       while(nbuf > 0) {
+               s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
+               s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
+               hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
+               n = SHA2_256dlen;
+               if(n > nbuf)
+                       n = nbuf;
+               memmove(buf, tmp, n);
+               buf += n;
+               nbuf -= n;
+               hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
+               memmove(ai, tmp, SHA2_256dlen);
+       }
+}
+
 // fill buf with md5(args)^sha1(args)
 static void
-tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
+tls10PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
 {
-       int i;
        int nlabel = strlen(label);
        int n = (nkey + 1) >> 1;
 
-       for(i = 0; i < nbuf; i++)
-               buf[i] = 0;
+       memset(buf, 0, nbuf);
        tlsPmd5(buf, nbuf, key, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
        tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
 }
 
+static void
+tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
+{
+       uchar seed[2*RandomSize];
+
+       assert(nseed0+nseed1 <= sizeof(seed));
+       memmove(seed, seed0, nseed0);
+       memmove(seed+nseed0, seed1, nseed1);
+       p_sha256(buf, nbuf, key, nkey, (uchar*)label, strlen(label), seed, nseed0+nseed1);
+}
+
 /*
  * for setting server session id's
  */
@@ -2369,16 +2375,17 @@ Err:
 }
 
 static int
-tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
+tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient)
 {
        if(sec->nfin != nfin){
                sec->ok = -1;
                werrstr("invalid finished exchange");
                return -1;
        }
-       md5.malloced = 0;
-       sha1.malloced = 0;
-       (*sec->setFinished)(sec, md5, sha1, fin, isclient);
+       hsh.md5.malloced = 0;
+       hsh.sha1.malloced = 0;
+       hsh.sha2_256.malloced = 0;
+       (*sec->setFinished)(sec, hsh, fin, isclient);
        return 1;
 }
 
@@ -2415,10 +2422,14 @@ setVers(TlsSec *sec, int v)
                sec->setFinished = sslSetFinished;
                sec->nfin = SSL3FinishedLen;
                sec->prf = sslPRF;
-       }else{
-               sec->setFinished = tlsSetFinished;
+       }else if(v < TLS12Version) {
+               sec->setFinished = tls10SetFinished;
                sec->nfin = TLSFinishedLen;
-               sec->prf = tlsPRF;
+               sec->prf = tls10PRF;
+       }else {
+               sec->setFinished = tls12SetFinished;
+               sec->nfin = TLSFinishedLen;
+               sec->prf = tls12PRF;
        }
        sec->vers = v;
        return 0;
@@ -2488,7 +2499,7 @@ clientMasterSecret(TlsSec *sec, RSApub *pub)
 }
 
 static void
-sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+sslSetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
 {
        DigestState *s;
        uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
@@ -2499,21 +2510,21 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in
        else
                label = "SRVR";
 
-       md5((uchar*)label, 4, nil, &hsmd5);
-       md5(sec->sec, MasterSecretSize, nil, &hsmd5);
+       md5((uchar*)label, 4, nil, &hsh.md5);
+       md5(sec->sec, MasterSecretSize, nil, &hsh.md5);
        memset(pad, 0x36, 48);
-       md5(pad, 48, nil, &hsmd5);
-       md5(nil, 0, h0, &hsmd5);
+       md5(pad, 48, nil, &hsh.md5);
+       md5(nil, 0, h0, &hsh.md5);
        memset(pad, 0x5C, 48);
        s = md5(sec->sec, MasterSecretSize, nil, nil);
        s = md5(pad, 48, nil, s);
        md5(h0, MD5dlen, finished, s);
 
-       sha1((uchar*)label, 4, nil, &hssha1);
-       sha1(sec->sec, MasterSecretSize, nil, &hssha1);
+       sha1((uchar*)label, 4, nil, &hsh.sha1);
+       sha1(sec->sec, MasterSecretSize, nil, &hsh.sha1);
        memset(pad, 0x36, 40);
-       sha1(pad, 40, nil, &hssha1);
-       sha1(nil, 0, h1, &hssha1);
+       sha1(pad, 40, nil, &hsh.sha1);
+       sha1(nil, 0, h1, &hsh.sha1);
        memset(pad, 0x5C, 40);
        s = sha1(sec->sec, MasterSecretSize, nil, nil);
        s = sha1(pad, 40, nil, s);
@@ -2522,27 +2533,43 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in
 
 // fill "finished" arg with md5(args)^sha1(args)
 static void
-tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+tls10SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
 {
        uchar h0[MD5dlen], h1[SHA1dlen];
        char *label;
 
        // get current hash value, but allow further messages to be hashed in
-       md5(nil, 0, h0, &hsmd5);
-       sha1(nil, 0, h1, &hssha1);
+       md5(nil, 0, h0, &hsh.md5);
+       sha1(nil, 0, h1, &hsh.sha1);
+
+       if(isClient)
+               label = "client finished";
+       else
+               label = "server finished";
+       tls10PRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
+}
+
+static void
+tls12SetFinished(TlsSec *sec, HandshakeHash hsh, uchar *finished, int isClient)
+{
+       uchar seed[SHA2_256dlen];
+       char *label;
+
+       // get current hash value, but allow further messages to be hashed in
+       sha2_256(nil, 0, seed, &hsh.sha2_256);
 
        if(isClient)
                label = "client finished";
        else
                label = "server finished";
-       tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
+       p_sha256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), seed, SHA2_256dlen);
 }
 
 static void
 sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
 {
-       DigestState *s;
        uchar sha1dig[SHA1dlen], md5dig[MD5dlen], tmp[26];
+       DigestState *s;
        int i, n, len;
 
        USED(label);