]> git.lizzy.rs Git - PAKEs.git/blobdiff - srp/src/types.rs
srp: replace custom powm with modpow (#78)
[PAKEs.git] / srp / src / types.rs
index 6ae8595093341199566b2daeb3899fcbb02fe1c4..41742d53589af270b1c00bc9ff0e1a45c21ad58a 100644 (file)
@@ -1,10 +1,9 @@
 //! Additional SRP types.
-use crate::tools::powm;
 use digest::Digest;
-use num::BigUint;
+use num_bigint::BigUint;
 use std::{error, fmt};
 
-/// SRP authentification error.
+/// SRP authentication error.
 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
 pub struct SrpAuthError {
     pub(crate) description: &'static str,
@@ -12,7 +11,7 @@ pub struct SrpAuthError {
 
 impl fmt::Display for SrpAuthError {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "SRP authentification error")
+        write!(f, "SRP authentication error")
     }
 }
 
@@ -32,8 +31,8 @@ pub struct SrpGroup {
 }
 
 impl SrpGroup {
-    pub(crate) fn powm(&self, v: &BigUint) -> BigUint {
-        powm(&self.g, v, &self.n)
+    pub(crate) fn modpow(&self, v: &BigUint) -> BigUint {
+        self.g.modpow(v, &self.n)
     }
 
     /// Compute `k` with given hash function and return SRP parameters
@@ -45,9 +44,31 @@ impl SrpGroup {
         buf[l..].copy_from_slice(&g_bytes);
 
         let mut d = D::new();
-        d.input(&n);
-        d.input(&buf);
-        BigUint::from_bytes_be(&d.result())
+        d.update(&n);
+        d.update(&buf);
+        BigUint::from_bytes_be(&d.finalize().as_slice())
+    }
+
+    /// Compute `Hash(N) xor Hash(g)` with given hash function and return SRP parameters
+    pub(crate) fn compute_hash_n_xor_hash_g<D: Digest>(&self) -> Vec<u8> {
+        let n = self.n.to_bytes_be();
+        let g_bytes = self.g.to_bytes_be();
+        let mut buf = vec![0u8; n.len()];
+        let l = n.len() - g_bytes.len();
+        buf[l..].copy_from_slice(&g_bytes);
+
+        let mut d = D::new();
+        d.update(&n);
+        let h = d.finalize_reset();
+        let h_n: &[u8] = h.as_slice();
+        d.update(&buf);
+        let h = d.finalize_reset();
+        let h_g: &[u8] = h.as_slice();
+
+        h_n.iter()
+            .zip(h_g.iter())
+            .map(|(&x1, &x2)| x1 ^ x2)
+            .collect()
     }
 }