]> git.lizzy.rs Git - PAKEs.git/blobdiff - src/spake2.rs
fix incorrect tests
[PAKEs.git] / src / spake2.rs
index 4c24258dd302a01abfbb2b4b053780e3edd2e37e..11f4f6f8e1cfecc93aa2412837c8aad7ed99cc80 100644 (file)
@@ -5,10 +5,16 @@ use curve25519_dalek::curve::ExtendedPoint as c2_Element;
 use curve25519_dalek::constants::ED25519_BASEPOINT;
 use curve25519_dalek::curve::CompressedEdwardsY;
 use rand::{Rng, OsRng};
-use sha2::{Sha256, Sha512, Digest};
+//use sha2::{Sha256, Sha512, Digest};
+use crypto::sha2::Sha256;
+use crypto::digest::Digest;
+use crypto::hkdf;
+use num_bigint::BigUint;
+
+use hex::ToHex;
 
 #[derive(Debug)]
-pub struct SPAKEErr;
+pub struct SPAKEErr ( String );
 
 pub trait Group {
     type Scalar;
@@ -20,11 +26,13 @@ pub trait Group {
     type TranscriptHash;
     fn const_m() -> Self::Element;
     fn const_n() -> Self::Element;
+    fn const_s() -> Self::Element;
     fn hash_to_scalar(s: &[u8]) -> Self::Scalar;
     fn random_scalar<T: Rng>(cspring: &mut T) -> Self::Scalar;
     fn scalar_neg(s: &Self::Scalar) -> Self::Scalar;
     fn element_to_bytes(e: &Self::Element) -> Vec<u8>;
     fn bytes_to_element(b: &[u8]) -> Option<Self::Element>;
+    fn element_length() -> usize;
     fn basepoint_mult(s: &Self::Scalar) -> Self::Element;
     fn scalarmult(e: &Self::Element, s: &Self::Scalar) -> Self::Element;
     fn add(a: &Self::Element, b: &Self::Element) -> Self::Element;
@@ -61,8 +69,41 @@ impl Group for Ed25519Group {
 
     }
 
+    fn const_s() -> c2_Element {
+        // python -c "import binascii, spake2; b=binascii.hexlify(spake2.ParamsEd25519.S.to_bytes()); print(', '.join(['0x'+b[i:i+2] for i in range(0,len(b),2)]))"
+        // 6f00dae87c1be1a73b5922ef431cd8f57879569c222d22b1cd71e8546ab8e6f1
+        CompressedEdwardsY([
+            0x6f, 0x00, 0xda, 0xe8, 0x7c, 0x1b, 0xe1, 0xa7, 0x3b, 0x59, 0x22,
+            0xef, 0x43, 0x1c, 0xd8, 0xf5, 0x78, 0x79, 0x56, 0x9c, 0x22, 0x2d,
+            0x22, 0xb1, 0xcd, 0x71, 0xe8, 0x54, 0x6a, 0xb8, 0xe6, 0xf1,
+        ]).decompress().unwrap()
+
+    }
+
     fn hash_to_scalar(s: &[u8]) -> c2_Scalar {
-        c2_Scalar::hash_from_bytes::<Sha512>(&s)
+        //c2_Scalar::hash_from_bytes::<Sha512>(&s)
+        // spake2.py does:
+        //  h = HKDF(salt=b"", ikm=s, hash=SHA256, info=b"SPAKE2 pw", len=32+16)
+        //  i = int(h, 16)
+        //  i % q
+
+        let mut prk = [0u8; 32];
+        let digest = Sha256::new();
+        hkdf::hkdf_extract(digest, b"", s, &mut prk);
+        let mut okm = [0u8; 32+16];
+        hkdf::hkdf_expand(digest, &prk, b"SPAKE2 pw", &mut okm);
+        //okm[32+16-2] = 1;
+        println!("expanded:   {}{}", "................................", okm.iter().to_hex()); // ok
+
+        let mut reducible = [0u8; 64]; // little-endian
+        for i in 0..32+16 {
+            reducible[32+16-1-i] = okm[i];
+        }
+        println!("reducible:  {}", reducible.iter().to_hex());
+        let reduced = c2_Scalar::reduce(&reducible);
+        println!("reduced:    {}", reduced.as_bytes().to_hex());
+        println!("done");
+        reduced
     }
     fn random_scalar<T: Rng>(cspring: &mut T) -> c2_Scalar {
         c2_Scalar::random(cspring)
@@ -73,6 +114,9 @@ impl Group for Ed25519Group {
     fn element_to_bytes(s: &c2_Element) -> Vec<u8> {
         s.compress_edwards().as_bytes().to_vec()
     }
+    fn element_length() -> usize {
+        32
+    }
     fn bytes_to_element(b: &[u8]) -> Option<c2_Element> {
         if b.len() != 32 { return None; }
         //let mut bytes: [u8; 32] =
@@ -97,35 +141,50 @@ impl Group for Ed25519Group {
     }
 }
 
+fn decimal_to_scalar(d: &[u8]) -> c2_Scalar {
+    let bytes = BigUint::parse_bytes(d, 10).unwrap().to_bytes_le();
+    assert_eq!(bytes.len(), 32);
+    let mut s = c2_Scalar([0u8; 32]);
+    s.0.copy_from_slice(&bytes);
+    s
+}
+
 
 /* "session type pattern" */
 
+enum Side {
+    A,
+    B,
+    Symmetric,
+}
 pub struct SPAKE2<G: Group> { //where &G::Scalar: Neg {
-    i_am_a: bool,
+    side: Side,
     xy_scalar: G::Scalar,
     password_vec: Vec<u8>,
     id_a: Vec<u8>,
     id_b: Vec<u8>,
+    id_s: Vec<u8>,
     msg1: Vec<u8>,
     password_scalar: G::Scalar,
 }
 
 impl<G: Group> SPAKE2<G> {
-    fn start_internal<T: Rng>(i_am_a: bool,
-                            password: &[u8], id_a: &[u8], id_b: &[u8],
-                            rng: &mut T)
-                    -> (SPAKE2<G>, Vec<u8>) {
+    fn start_internal(side: Side,
+                      password: &[u8],
+                      id_a: &[u8], id_b: &[u8], id_s: &[u8],
+                      xy_scalar: G::Scalar) -> (SPAKE2<G>, Vec<u8>) {
         //let password_scalar: G::Scalar = hash_to_scalar::<G::Scalar>(password);
         let password_scalar: G::Scalar = G::hash_to_scalar(password);
-        let xy: G::Scalar = G::random_scalar(rng);
 
         // a: X = B*x + M*pw
         // b: Y = B*y + N*pw
-        let blinding = match i_am_a {
-            true => G::const_m(),
-            false => G::const_n(),
+        // sym: X = B*x * S*pw
+        let blinding = match side {
+            Side::A => G::const_m(),
+            Side::B => G::const_n(),
+            Side::Symmetric => G::const_s(),
         };
-        let m1: G::Element = G::add(&G::basepoint_mult(&xy),
+        let m1: G::Element = G::add(&G::basepoint_mult(&xy_scalar),
                                     &G::scalarmult(&blinding, &password_scalar));
         //let m1: G::Element = &G::basepoint_mult(&x) + &(blinding * &password_scalar);
         let msg1: Vec<u8> = G::element_to_bytes(&m1);
@@ -135,36 +194,101 @@ impl<G: Group> SPAKE2<G> {
         id_a_copy.extend_from_slice(id_a);
         let mut id_b_copy = Vec::new();
         id_b_copy.extend_from_slice(id_b);
+        let mut id_s_copy = Vec::new();
+        id_s_copy.extend_from_slice(id_s);
+
+        let mut msg_and_side = Vec::new();
+        msg_and_side.push(match side {
+            Side::A => 0x41, // 'A'
+            Side::B => 0x42, // 'B'
+            Side::Symmetric => 0x53, // 'S'
+        });
+        msg_and_side.extend_from_slice(&msg1);
+
         (SPAKE2 {
-            i_am_a: i_am_a,
-            xy_scalar: xy,
+            side: side,
+            xy_scalar: xy_scalar,
             password_vec: password_vec, // string
             id_a: id_a_copy,
             id_b: id_b_copy,
+            id_s: id_s_copy,
             msg1: msg1.clone(),
             password_scalar: password_scalar, // scalar
-        }, msg1)
+        }, msg_and_side)
+    }
+
+    fn start_a_internal(password: &[u8], id_a: &[u8], id_b: &[u8],
+                        xy_scalar: G::Scalar) -> (SPAKE2<G>, Vec<u8>) {
+        Self::start_internal(Side::A,
+                             password, id_a, id_b, b"", xy_scalar)
     }
 
+    fn start_b_internal(password: &[u8], id_a: &[u8], id_b: &[u8],
+                        xy_scalar: G::Scalar) -> (SPAKE2<G>, Vec<u8>) {
+        Self::start_internal(Side::B,
+                             password, id_a, id_b, b"", xy_scalar)
+    }
+
+    fn start_symmetric_internal(password: &[u8], id_s: &[u8],
+                                xy_scalar: G::Scalar) -> (SPAKE2<G>, Vec<u8>) {
+        Self::start_internal(Side::Symmetric,
+                             password, b"", b"", id_s, xy_scalar)
+    }
+
+
     pub fn start_a(password: &[u8], id_a: &[u8], id_b: &[u8])
                -> (SPAKE2<G>, Vec<u8>) {
         let mut cspring: OsRng = OsRng::new().unwrap();
-        Self::start_internal(true, password, id_a, id_b, &mut cspring)
+        let xy_scalar: G::Scalar = G::random_scalar(&mut cspring);
+        Self::start_a_internal(password, id_a, id_b, xy_scalar)
     }
 
     pub fn start_b(password: &[u8], id_a: &[u8], id_b: &[u8])
                -> (SPAKE2<G>, Vec<u8>) {
         let mut cspring: OsRng = OsRng::new().unwrap();
-        Self::start_internal(false, password, id_a, id_b, &mut cspring)
+        let xy_scalar: G::Scalar = G::random_scalar(&mut cspring);
+        Self::start_b_internal(password, id_a, id_b, xy_scalar)
+    }
+
+    pub fn start_symmetric(password: &[u8], id_s: &[u8])
+               -> (SPAKE2<G>, Vec<u8>) {
+        let mut cspring: OsRng = OsRng::new().unwrap();
+        let xy_scalar: G::Scalar = G::random_scalar(&mut cspring);
+        Self::start_symmetric_internal(password, id_s, xy_scalar)
     }
 
     pub fn finish(self, msg2: &[u8]) -> Result<Vec<u8>, SPAKEErr> {
+        if msg2.len() != 1 + G::element_length() {
+            return Err(SPAKEErr(String::from("inbound message is the wrong length")))
+        }
+        let msg_side = msg2[0];
+
+        match self.side {
+            Side::A => match msg_side {
+                0x42 => (), // 'B'
+                _ => return Err(SPAKEErr(String::from("bad side"))),
+            },
+            Side::B => match msg_side {
+                0x41 => (), // 'A'
+                _ => return Err(SPAKEErr(String::from("bad side"))),
+            },
+            Side::Symmetric => match msg_side {
+                0x53 => (), // 'S'
+                _ => return Err(SPAKEErr(String::from("bad side"))),
+            },
+        }
+
+        let msg2_element = match G::bytes_to_element(&msg2[1..]) {
+            Some(x) => x,
+            None => {return Err(SPAKEErr(String::from("message corrupted")))},
+        };
+
         // a: K = (Y+N*(-pw))*x
         // b: K = (X+M*(-pw))*y
-        let msg2_element = G::bytes_to_element(msg2).unwrap();
-        let unblinding = match self.i_am_a {
-            true => G::const_n(),
-            false => G::const_m(),
+        let unblinding = match self.side {
+            Side::A => G::const_n(),
+            Side::B => G::const_m(),
+            Side::Symmetric => G::const_s(),
         };
         let tmp1 = G::scalarmult(&unblinding,
                                  &G::scalar_neg(&self.password_scalar));
@@ -177,41 +301,213 @@ impl<G: Group> SPAKE2<G> {
         //                       X_msg, Y_msg, K_bytes])
         //key = sha256(transcript).digest()
         // note that both sides must use the same order
-        let mut transcript = Vec::<u8>::new();
+
+        Ok(match self.side {
+            Side::A => self.hash_ab(self.msg1.as_slice(), &msg2[1..], &key_element),
+            Side::B => self.hash_ab(&msg2[1..], self.msg1.as_slice(), &key_element),
+            Side::Symmetric => self.hash_symmetric(&msg2[1..], &key_element),
+        })
+    }
+
+    fn hash_ab(&self, first_msg: &[u8], second_msg: &[u8],
+               key_element: &G::Element) -> Vec<u8> {
+        assert_eq!(first_msg.len(), 32);
+        assert_eq!(second_msg.len(), 32);
+        // the transcript is fixed-length, made up of 6 32-byte values:
+        // byte 0-31   : sha256(pw)
+        // byte 32-63  : sha256(idA)
+        // byte 64-95  : sha256(idB)
+        // byte 96-127 : X_msg
+        // byte 128-159: Y_msg
+        // byte 160-191: K_bytes
+        let mut transcript = [0u8; 6*32];
 
         let mut pw_hash = Sha256::new();
         pw_hash.input(&self.password_vec);
-        transcript.extend_from_slice(pw_hash.result().as_slice());
+        pw_hash.result(&mut transcript[0..32]);
 
         let mut ida_hash = Sha256::new();
         ida_hash.input(&self.id_a);
-        transcript.extend_from_slice(ida_hash.result().as_slice());
+        ida_hash.result(&mut transcript[32..64]);
 
         let mut idb_hash = Sha256::new();
         idb_hash.input(&self.id_b);
-        transcript.extend_from_slice(idb_hash.result().as_slice());
+        idb_hash.result(&mut transcript[64..96]);
 
-        transcript.extend_from_slice(match self.i_am_a {
-            true => self.msg1.as_slice(),
-            false => msg2,
-        });
-        transcript.extend_from_slice(match self.i_am_a {
-            true => msg2,
-            false => self.msg1.as_slice(),
-        });
+        transcript[96..128].copy_from_slice(first_msg);
+        transcript[128..160].copy_from_slice(second_msg);
 
         let k_bytes = G::element_to_bytes(&key_element);
-        transcript.extend_from_slice(k_bytes.as_slice());
+        transcript[160..192].copy_from_slice(k_bytes.as_slice());
 
         //let mut hash = G::TranscriptHash::default();
-        let mut hash = Sha256::default();
-        hash.input(transcript.as_slice());
+        let mut hash = Sha256::new();
+        hash.input(&transcript);
+        let mut out = [0u8; 32];
+        hash.result(&mut out);
+        out.to_vec()
+    }
+
+    fn hash_symmetric(&self, msg2: &[u8], key_element: &G::Element) -> Vec<u8> {
+        assert_eq!(msg2.len(), 32);
+        // # since we don't know which side is which, we must sort the messages
+        // first_msg, second_msg = sorted([msg1, msg2])
+        // transcript = b"".join([sha256(pw).digest(),
+        //                        sha256(idSymmetric).digest(),
+        //                        first_msg, second_msg, K_bytes])
+
+        // the transcript is fixed-length, made up of 5 32-byte values:
+        // byte 0-31   : sha256(pw)
+        // byte 32-63  : sha256(idSymmetric)
+        // byte 64-95  : X_msg
+        // byte 96-127 : Y_msg
+        // byte 128-159: K_bytes
+        let mut transcript = [0u8; 5*32];
+
+        let mut pw_hash = Sha256::new();
+        pw_hash.input(&self.password_vec);
+        pw_hash.result(&mut transcript[0..32]);
 
-        Ok(hash.result().to_vec())
+        let mut ids_hash = Sha256::new();
+        ids_hash.input(&self.id_s);
+        ids_hash.result(&mut transcript[32..64]);
+
+        let msg_u = self.msg1.as_slice();
+        let msg_v = msg2;
+        if msg_u < msg_v {
+            transcript[64..96].copy_from_slice(&msg_u);
+            transcript[96..128].copy_from_slice(msg_v);
+        } else {
+            transcript[64..96].copy_from_slice(msg_v);
+            transcript[96..128].copy_from_slice(&msg_u);
+        }
+
+        let k_bytes = G::element_to_bytes(&key_element);
+        transcript[128..160].copy_from_slice(k_bytes.as_slice());
+
+        let mut hash = Sha256::new();
+        hash.input(&transcript);
+        let mut out = [0u8; 32];
+        hash.result(&mut out);
+        out.to_vec()
     }
 }
 
 
 #[cfg(test)]
 mod test {
+    /* This compares results against the python compatibility tests:
+    spake2.test.test_compat.SPAKE2.test_asymmetric . The python test passes a
+    deterministic RNG (used only for tests, of course) into the per-Group
+    "random_scalar()" function, which results in some particular scalar.
+     */
+    use curve25519_dalek::scalar::Scalar;
+    use curve25519_dalek::constants::ED25519_BASEPOINT;
+    use spake2::{SPAKE2, Ed25519Group};
+    use hex::ToHex;
+    use super::*;
+
+    // the python tests show the long-integer form of scalars. the rust code
+    // wants an array of bytes (little-endian). Make sure the way we convert
+    // things works correctly.
+
+    #[test]
+    fn test_convert() {
+        let t1_decimal = b"2238329342913194256032495932344128051776374960164957527413114840482143558222";
+        let t1_scalar = decimal_to_scalar(t1_decimal);
+        let expected: Scalar = Scalar(
+            [0x4e, 0x5a, 0xb4, 0x34, 0x5d, 0x47, 0x08, 0x84,
+             0x59, 0x13, 0xb4, 0x64, 0x1b, 0xc2, 0x7d, 0x52,
+             0x52, 0xa5, 0x85, 0x10, 0x1b, 0xcc, 0x42, 0x44,
+             0xd4, 0x49, 0xf4, 0xa8, 0x79, 0xd9, 0xf2, 0x04]);
+        assert_eq!(t1_scalar, expected);
+        //println!("t1_scalar is {:?}", t1_scalar);
+    }
+
+    #[test]
+    fn test_serialize_basepoint() {
+        // make sure elements are serialized same as the python library
+        let exp = "5866666666666666666666666666666666666666666666666666666666666666";
+        let base_vec = ED25519_BASEPOINT.compress_edwards().as_bytes().to_vec();
+        let base_hex = base_vec.to_hex();
+        println!("exp: {:?}", exp);
+        println!("got: {:?}", base_hex);
+        assert_eq!(exp, base_hex);
+    }
+
+    #[test]
+    fn test_password_to_scalar() {
+        let password = b"password";
+        let expected_pw_scalar = decimal_to_scalar(b"3515301705789368674385125653994241092664323519848410154015274772661223168839");
+        let pw_scalar = Ed25519Group::hash_to_scalar(password);
+        println!("exp: {:?}", expected_pw_scalar.as_bytes().to_hex());
+        println!("got: {:?}", pw_scalar.as_bytes().to_hex());
+        assert_eq!(&pw_scalar, &expected_pw_scalar);
+    }
+
+    #[test]
+    fn test_sizes() {
+        let (s1, msg1) = SPAKE2::<Ed25519Group>::start_a(b"password", b"idA",
+                                                         b"idB");
+        assert_eq!(msg1.len(), 1+32);
+        let (s2, msg2) = SPAKE2::<Ed25519Group>::start_b(b"password", b"idA",
+                                                         b"idB");
+        assert_eq!(msg2.len(), 1+32);
+        let key1 = s1.finish(&msg2).unwrap();
+        let key2 = s2.finish(&msg1).unwrap();
+        assert_eq!(key1.len(), 32);
+        assert_eq!(key2.len(), 32);
+
+        let (s1, msg1) = SPAKE2::<Ed25519Group>::start_symmetric(b"password",
+                                                                 b"idS");
+        assert_eq!(msg1.len(), 1+32);
+        let (s2, msg2) = SPAKE2::<Ed25519Group>::start_symmetric(b"password",
+                                                                 b"idS");
+        assert_eq!(msg2.len(), 1+32);
+        let key1 = s1.finish(&msg2).unwrap();
+        let key2 = s2.finish(&msg1).unwrap();
+        assert_eq!(key1.len(), 32);
+        assert_eq!(key2.len(), 32);
+    }
+
+    #[test]
+    fn test_asymmetric() {
+        let scalar_a = decimal_to_scalar(b"2611694063369306139794446498317402240796898290761098242657700742213257926693");
+        let scalar_b = decimal_to_scalar(b"7002393159576182977806091886122272758628412261510164356026361256515836884383");
+        let expected_pw_scalar = decimal_to_scalar(b"3515301705789368674385125653994241092664323519848410154015274772661223168839");
+
+        println!("scalar_a is {}", scalar_a.as_bytes().to_hex());
+
+        let (s1, msg1) = SPAKE2::<Ed25519Group>::start_a_internal(
+            b"password", b"idA", b"idB", scalar_a);
+        let expected_msg1 = "416fc960df73c9cf8ed7198b0c9534e2e96a5984bfc5edc023fd24dacf371f2af9";
+
+        println!();
+        println!("xys1: {:?}", s1.xy_scalar.as_bytes().to_hex());
+        println!();
+        println!("pws1: {:?}", s1.password_scalar.as_bytes().to_hex());
+        println!("exp : {:?}", expected_pw_scalar.as_bytes().to_hex());
+        println!();
+        println!("msg1: {:?}", msg1.to_hex());
+        println!("exp : {:?}", expected_msg1);
+        println!();
+
+        assert_eq!(expected_pw_scalar.as_bytes().to_hex(),
+                   s1.password_scalar.as_bytes().to_hex());
+        assert_eq!(msg1.to_hex(), expected_msg1);
+
+        let (s2, msg2) = SPAKE2::<Ed25519Group>::start_b_internal(
+            b"password", b"idA", b"idB", scalar_b);
+        assert_eq!(expected_pw_scalar, s2.password_scalar);
+        assert_eq!(msg2.to_hex(),
+                   "42354e97b88406922b1df4bea1d7870f17aed3dba7c720b313edae315b00959309");
+
+        let key1 = s1.finish(&msg2).unwrap();
+        let key2 = s2.finish(&msg1).unwrap();
+        assert_eq!(key1, key2);
+        assert_eq!(key1.to_hex(),
+                   "a480bca13fa04464bb644f10e340125e96c9494f7399fef7c2bda67eb0fdf06d");
+    }
+
+
 }