]> git.lizzy.rs Git - PAKEs.git/blob - spake2/src/lib.rs
update spake2 deps (sha2-0.8, hkdf-0.7)
[PAKEs.git] / spake2 / src / lib.rs
1 #![doc(html_logo_url = "https://raw.githubusercontent.com/RustCrypto/meta/master/logo_small.png")]
2 #![deny(warnings)]
3 #![forbid(unsafe_code)]
4
5 extern crate curve25519_dalek;
6 extern crate hex;
7 extern crate hkdf;
8 extern crate num_bigint;
9 extern crate rand;
10 extern crate sha2;
11
12 use curve25519_dalek::constants::ED25519_BASEPOINT_POINT;
13 use curve25519_dalek::edwards::CompressedEdwardsY;
14 use curve25519_dalek::edwards::EdwardsPoint as c2_Element;
15 use curve25519_dalek::scalar::Scalar as c2_Scalar;
16
17 use hkdf::Hkdf;
18 use rand::{CryptoRng, OsRng, Rng};
19 use sha2::{Digest, Sha256};
20 use std::fmt;
21 use std::ops::Deref;
22
23 /* "newtype pattern": it's a Vec<u8>, but only used for a specific argument
24  * type, to distinguish between ones that are meant as passwords, and ones
25  * that are meant as identity strings */
26
27 #[derive(PartialEq, Eq, Clone)]
28 pub struct Password(Vec<u8>);
29 impl Password {
30     pub fn new(p: &[u8]) -> Password {
31         Password(p.to_vec())
32     }
33 }
34 impl Deref for Password {
35     type Target = Vec<u8>;
36     fn deref(&self) -> &Vec<u8> {
37         &self.0
38     }
39 }
40
41 #[derive(PartialEq, Eq, Clone)]
42 pub struct Identity(Vec<u8>);
43 impl Deref for Identity {
44     type Target = Vec<u8>;
45     fn deref(&self) -> &Vec<u8> {
46         &self.0
47     }
48 }
49 impl Identity {
50     pub fn new(p: &[u8]) -> Identity {
51         Identity(p.to_vec())
52     }
53 }
54
55 #[derive(Debug, PartialEq, Eq)]
56 pub enum ErrorType {
57     BadSide,
58     WrongLength,
59     CorruptMessage,
60 }
61
62 #[derive(Debug, PartialEq, Eq)]
63 pub struct SPAKEErr {
64     pub kind: ErrorType,
65 }
66
67 pub trait Group {
68     type Scalar;
69     type Element;
70     //type Element: Add<Output=Self::Element>
71     //    + Mul<Self::Scalar, Output=Self::Element>;
72     // const element_length: usize; // in unstable, or u8
73     //type ElementBytes : Index<usize, Output=u8>+IndexMut<usize>; // later
74     type TranscriptHash;
75     fn const_m() -> Self::Element;
76     fn const_n() -> Self::Element;
77     fn const_s() -> Self::Element;
78     fn hash_to_scalar(s: &[u8]) -> Self::Scalar;
79     fn random_scalar<T>(cspring: &mut T) -> Self::Scalar
80     where
81         T: Rng + CryptoRng;
82     fn scalar_neg(s: &Self::Scalar) -> Self::Scalar;
83     fn element_to_bytes(e: &Self::Element) -> Vec<u8>;
84     fn bytes_to_element(b: &[u8]) -> Option<Self::Element>;
85     fn element_length() -> usize;
86     fn basepoint_mult(s: &Self::Scalar) -> Self::Element;
87     fn scalarmult(e: &Self::Element, s: &Self::Scalar) -> Self::Element;
88     fn add(a: &Self::Element, b: &Self::Element) -> Self::Element;
89 }
90
91 #[derive(Debug, PartialEq, Eq)]
92 pub struct Ed25519Group;
93
94 impl Group for Ed25519Group {
95     type Scalar = c2_Scalar;
96     type Element = c2_Element;
97     //type ElementBytes = Vec<u8>;
98     //type ElementBytes = [u8; 32];
99     //type ScalarBytes
100     type TranscriptHash = Sha256;
101
102     fn const_m() -> c2_Element {
103         // python -c "import binascii, spake2; b=binascii.hexlify(spake2.ParamsEd25519.M.to_bytes()); print(', '.join(['0x'+b[i:i+2] for i in range(0,len(b),2)]))"
104         // 15cfd18e385952982b6a8f8c7854963b58e34388c8e6dae891db756481a02312
105         CompressedEdwardsY([
106             0x15, 0xcf, 0xd1, 0x8e, 0x38, 0x59, 0x52, 0x98, 0x2b, 0x6a, 0x8f, 0x8c, 0x78, 0x54,
107             0x96, 0x3b, 0x58, 0xe3, 0x43, 0x88, 0xc8, 0xe6, 0xda, 0xe8, 0x91, 0xdb, 0x75, 0x64,
108             0x81, 0xa0, 0x23, 0x12,
109         ]).decompress()
110         .unwrap()
111     }
112
113     fn const_n() -> c2_Element {
114         // python -c "import binascii, spake2; b=binascii.hexlify(spake2.ParamsEd25519.N.to_bytes()); print(', '.join(['0x'+b[i:i+2] for i in range(0,len(b),2)]))"
115         // f04f2e7eb734b2a8f8b472eaf9c3c632576ac64aea650b496a8a20ff00e583c3
116         CompressedEdwardsY([
117             0xf0, 0x4f, 0x2e, 0x7e, 0xb7, 0x34, 0xb2, 0xa8, 0xf8, 0xb4, 0x72, 0xea, 0xf9, 0xc3,
118             0xc6, 0x32, 0x57, 0x6a, 0xc6, 0x4a, 0xea, 0x65, 0x0b, 0x49, 0x6a, 0x8a, 0x20, 0xff,
119             0x00, 0xe5, 0x83, 0xc3,
120         ]).decompress()
121         .unwrap()
122     }
123
124     fn const_s() -> c2_Element {
125         // python -c "import binascii, spake2; b=binascii.hexlify(spake2.ParamsEd25519.S.to_bytes()); print(', '.join(['0x'+b[i:i+2] for i in range(0,len(b),2)]))"
126         // 6f00dae87c1be1a73b5922ef431cd8f57879569c222d22b1cd71e8546ab8e6f1
127         CompressedEdwardsY([
128             0x6f, 0x00, 0xda, 0xe8, 0x7c, 0x1b, 0xe1, 0xa7, 0x3b, 0x59, 0x22, 0xef, 0x43, 0x1c,
129             0xd8, 0xf5, 0x78, 0x79, 0x56, 0x9c, 0x22, 0x2d, 0x22, 0xb1, 0xcd, 0x71, 0xe8, 0x54,
130             0x6a, 0xb8, 0xe6, 0xf1,
131         ]).decompress()
132         .unwrap()
133     }
134
135     fn hash_to_scalar(s: &[u8]) -> c2_Scalar {
136         ed25519_hash_to_scalar(s)
137     }
138     fn random_scalar<T>(cspring: &mut T) -> c2_Scalar
139     where
140         T: Rng + CryptoRng,
141     {
142         c2_Scalar::random(cspring)
143     }
144     fn scalar_neg(s: &c2_Scalar) -> c2_Scalar {
145         -s
146     }
147     fn element_to_bytes(s: &c2_Element) -> Vec<u8> {
148         s.compress().as_bytes().to_vec()
149     }
150     fn element_length() -> usize {
151         32
152     }
153     fn bytes_to_element(b: &[u8]) -> Option<c2_Element> {
154         if b.len() != 32 {
155             return None;
156         }
157         //let mut bytes: [u8; 32] =
158         let mut bytes = [0u8; 32];
159         bytes.copy_from_slice(b);
160         let cey = CompressedEdwardsY(bytes);
161         // CompressedEdwardsY::new(b)
162         cey.decompress()
163     }
164
165     fn basepoint_mult(s: &c2_Scalar) -> c2_Element {
166         //c2_Element::basepoint_mult(s)
167         ED25519_BASEPOINT_POINT * s
168     }
169     fn scalarmult(e: &c2_Element, s: &c2_Scalar) -> c2_Element {
170         e * s
171         //e.scalar_mult(s)
172     }
173     fn add(a: &c2_Element, b: &c2_Element) -> c2_Element {
174         a + b
175         //a.add(b)
176     }
177 }
178
179 fn ed25519_hash_to_scalar(s: &[u8]) -> c2_Scalar {
180     //c2_Scalar::hash_from_bytes::<Sha512>(&s)
181     // spake2.py does:
182     //  h = HKDF(salt=b"", ikm=s, hash=SHA256, info=b"SPAKE2 pw", len=32+16)
183     //  i = int(h, 16)
184     //  i % q
185
186     let mut okm = [0u8; 32 + 16];
187     Hkdf::<Sha256>::extract(Some(b""), s)
188         .expand(b"SPAKE2 pw", &mut okm)
189         .unwrap();
190     //println!("expanded:   {}{}", "................................", okm.iter().to_hex()); // ok
191
192     let mut reducible = [0u8; 64]; // little-endian
193     for (i, x) in okm.iter().enumerate().take(32 + 16) {
194         reducible[32 + 16 - 1 - i] = *x;
195     }
196     //println!("reducible:  {}", reducible.iter().to_hex());
197     c2_Scalar::from_bytes_mod_order_wide(&reducible)
198     //let reduced = c2_Scalar::reduce(&reducible);
199     //println!("reduced:    {}", reduced.as_bytes().to_hex());
200     //println!("done");
201     //reduced
202 }
203
204 fn ed25519_hash_ab(
205     password_vec: &[u8],
206     id_a: &[u8],
207     id_b: &[u8],
208     first_msg: &[u8],
209     second_msg: &[u8],
210     key_bytes: &[u8],
211 ) -> Vec<u8> {
212     assert_eq!(first_msg.len(), 32);
213     assert_eq!(second_msg.len(), 32);
214     // the transcript is fixed-length, made up of 6 32-byte values:
215     // byte 0-31   : sha256(pw)
216     // byte 32-63  : sha256(idA)
217     // byte 64-95  : sha256(idB)
218     // byte 96-127 : X_msg
219     // byte 128-159: Y_msg
220     // byte 160-191: K_bytes
221     let mut transcript = [0u8; 6 * 32];
222
223     let mut pw_hash = Sha256::new();
224     pw_hash.input(password_vec);
225     transcript[0..32].copy_from_slice(&pw_hash.result());
226
227     let mut ida_hash = Sha256::new();
228     ida_hash.input(id_a);
229     transcript[32..64].copy_from_slice(&ida_hash.result());
230
231     let mut idb_hash = Sha256::new();
232     idb_hash.input(id_b);
233     transcript[64..96].copy_from_slice(&idb_hash.result());
234
235     transcript[96..128].copy_from_slice(first_msg);
236     transcript[128..160].copy_from_slice(second_msg);
237     transcript[160..192].copy_from_slice(key_bytes);
238
239     //println!("transcript: {:?}", transcript.iter().to_hex());
240
241     //let mut hash = G::TranscriptHash::default();
242     let mut hash = Sha256::new();
243     hash.input(transcript.to_vec());
244     hash.result().to_vec()
245 }
246
247 fn ed25519_hash_symmetric(
248     password_vec: &[u8],
249     id_s: &[u8],
250     msg_u: &[u8],
251     msg_v: &[u8],
252     key_bytes: &[u8],
253 ) -> Vec<u8> {
254     assert_eq!(msg_u.len(), 32);
255     assert_eq!(msg_v.len(), 32);
256     // # since we don't know which side is which, we must sort the messages
257     // first_msg, second_msg = sorted([msg1, msg2])
258     // transcript = b"".join([sha256(pw).digest(),
259     //                        sha256(idSymmetric).digest(),
260     //                        first_msg, second_msg, K_bytes])
261
262     // the transcript is fixed-length, made up of 5 32-byte values:
263     // byte 0-31   : sha256(pw)
264     // byte 32-63  : sha256(idSymmetric)
265     // byte 64-95  : X_msg
266     // byte 96-127 : Y_msg
267     // byte 128-159: K_bytes
268     let mut transcript = [0u8; 5 * 32];
269
270     let mut pw_hash = Sha256::new();
271     pw_hash.input(password_vec);
272     transcript[0..32].copy_from_slice(&pw_hash.result());
273
274     let mut ids_hash = Sha256::new();
275     ids_hash.input(id_s);
276     transcript[32..64].copy_from_slice(&ids_hash.result());
277
278     if msg_u < msg_v {
279         transcript[64..96].copy_from_slice(msg_u);
280         transcript[96..128].copy_from_slice(msg_v);
281     } else {
282         transcript[64..96].copy_from_slice(msg_v);
283         transcript[96..128].copy_from_slice(msg_u);
284     }
285     transcript[128..160].copy_from_slice(key_bytes);
286
287     let mut hash = Sha256::new();
288     hash.input(transcript.to_vec());
289     hash.result().to_vec()
290 }
291
292 /* "session type pattern" */
293
294 #[derive(Debug, PartialEq, Eq)]
295 enum Side {
296     A,
297     B,
298     Symmetric,
299 }
300
301 // we implement a custom Debug below, to avoid revealing secrets in a dump
302 #[derive(PartialEq, Eq)]
303 pub struct SPAKE2<G: Group> {
304     //where &G::Scalar: Neg {
305     side: Side,
306     xy_scalar: G::Scalar,
307     password_vec: Vec<u8>,
308     id_a: Vec<u8>,
309     id_b: Vec<u8>,
310     id_s: Vec<u8>,
311     msg1: Vec<u8>,
312     password_scalar: G::Scalar,
313 }
314
315 impl<G: Group> SPAKE2<G> {
316     fn start_internal(
317         side: Side,
318         password: &Password,
319         id_a: &Identity,
320         id_b: &Identity,
321         id_s: &Identity,
322         xy_scalar: G::Scalar,
323     ) -> (SPAKE2<G>, Vec<u8>) {
324         //let password_scalar: G::Scalar = hash_to_scalar::<G::Scalar>(password);
325         let password_scalar: G::Scalar = G::hash_to_scalar(&password);
326
327         // a: X = B*x + M*pw
328         // b: Y = B*y + N*pw
329         // sym: X = B*x * S*pw
330         let blinding = match side {
331             Side::A => G::const_m(),
332             Side::B => G::const_n(),
333             Side::Symmetric => G::const_s(),
334         };
335         let m1: G::Element = G::add(
336             &G::basepoint_mult(&xy_scalar),
337             &G::scalarmult(&blinding, &password_scalar),
338         );
339         //let m1: G::Element = &G::basepoint_mult(&x) + &(blinding * &password_scalar);
340         let msg1: Vec<u8> = G::element_to_bytes(&m1);
341         let mut password_vec = Vec::new();
342         password_vec.extend_from_slice(&password);
343         let mut id_a_copy = Vec::new();
344         id_a_copy.extend_from_slice(&id_a);
345         let mut id_b_copy = Vec::new();
346         id_b_copy.extend_from_slice(&id_b);
347         let mut id_s_copy = Vec::new();
348         id_s_copy.extend_from_slice(&id_s);
349
350         let mut msg_and_side = Vec::new();
351         msg_and_side.push(match side {
352             Side::A => 0x41,         // 'A'
353             Side::B => 0x42,         // 'B'
354             Side::Symmetric => 0x53, // 'S'
355         });
356         msg_and_side.extend_from_slice(&msg1);
357
358         (
359             SPAKE2 {
360                 side,
361                 xy_scalar,
362                 password_vec, // string
363                 id_a: id_a_copy,
364                 id_b: id_b_copy,
365                 id_s: id_s_copy,
366                 msg1: msg1.clone(),
367                 password_scalar, // scalar
368             },
369             msg_and_side,
370         )
371     }
372
373     fn start_a_internal(
374         password: &Password,
375         id_a: &Identity,
376         id_b: &Identity,
377         xy_scalar: G::Scalar,
378     ) -> (SPAKE2<G>, Vec<u8>) {
379         Self::start_internal(
380             Side::A,
381             &password,
382             &id_a,
383             &id_b,
384             &Identity::new(b""),
385             xy_scalar,
386         )
387     }
388
389     fn start_b_internal(
390         password: &Password,
391         id_a: &Identity,
392         id_b: &Identity,
393         xy_scalar: G::Scalar,
394     ) -> (SPAKE2<G>, Vec<u8>) {
395         Self::start_internal(
396             Side::B,
397             &password,
398             &id_a,
399             &id_b,
400             &Identity::new(b""),
401             xy_scalar,
402         )
403     }
404
405     fn start_symmetric_internal(
406         password: &Password,
407         id_s: &Identity,
408         xy_scalar: G::Scalar,
409     ) -> (SPAKE2<G>, Vec<u8>) {
410         Self::start_internal(
411             Side::Symmetric,
412             &password,
413             &Identity::new(b""),
414             &Identity::new(b""),
415             &id_s,
416             xy_scalar,
417         )
418     }
419
420     pub fn start_a(password: &Password, id_a: &Identity, id_b: &Identity) -> (SPAKE2<G>, Vec<u8>) {
421         let mut cspring: OsRng = OsRng::new().unwrap();
422         let xy_scalar: G::Scalar = G::random_scalar(&mut cspring);
423         Self::start_a_internal(&password, &id_a, &id_b, xy_scalar)
424     }
425
426     pub fn start_b(password: &Password, id_a: &Identity, id_b: &Identity) -> (SPAKE2<G>, Vec<u8>) {
427         let mut cspring: OsRng = OsRng::new().unwrap();
428         let xy_scalar: G::Scalar = G::random_scalar(&mut cspring);
429         Self::start_b_internal(&password, &id_a, &id_b, xy_scalar)
430     }
431
432     pub fn start_symmetric(password: &Password, id_s: &Identity) -> (SPAKE2<G>, Vec<u8>) {
433         let mut cspring: OsRng = OsRng::new().unwrap();
434         let xy_scalar: G::Scalar = G::random_scalar(&mut cspring);
435         Self::start_symmetric_internal(&password, &id_s, xy_scalar)
436     }
437
438     pub fn finish(self, msg2: &[u8]) -> Result<Vec<u8>, SPAKEErr> {
439         if msg2.len() != 1 + G::element_length() {
440             return Err(SPAKEErr {
441                 kind: ErrorType::WrongLength,
442             });
443         }
444         let msg_side = msg2[0];
445
446         match self.side {
447             Side::A => match msg_side {
448                 0x42 => (), // 'B'
449                 _ => {
450                     return Err(SPAKEErr {
451                         kind: ErrorType::BadSide,
452                     })
453                 }
454             },
455             Side::B => match msg_side {
456                 0x41 => (), // 'A'
457                 _ => {
458                     return Err(SPAKEErr {
459                         kind: ErrorType::BadSide,
460                     })
461                 }
462             },
463             Side::Symmetric => match msg_side {
464                 0x53 => (), // 'S'
465                 _ => {
466                     return Err(SPAKEErr {
467                         kind: ErrorType::BadSide,
468                     })
469                 }
470             },
471         }
472
473         let msg2_element = match G::bytes_to_element(&msg2[1..]) {
474             Some(x) => x,
475             None => {
476                 return Err(SPAKEErr {
477                     kind: ErrorType::CorruptMessage,
478                 })
479             }
480         };
481
482         // a: K = (Y+N*(-pw))*x
483         // b: K = (X+M*(-pw))*y
484         let unblinding = match self.side {
485             Side::A => G::const_n(),
486             Side::B => G::const_m(),
487             Side::Symmetric => G::const_s(),
488         };
489         let tmp1 = G::scalarmult(&unblinding, &G::scalar_neg(&self.password_scalar));
490         let tmp2 = G::add(&msg2_element, &tmp1);
491         let key_element = G::scalarmult(&tmp2, &self.xy_scalar);
492         let key_bytes = G::element_to_bytes(&key_element);
493
494         // key = H(H(pw) + H(idA) + H(idB) + X + Y + K)
495         //transcript = b"".join([sha256(pw).digest(),
496         //                       sha256(idA).digest(), sha256(idB).digest(),
497         //                       X_msg, Y_msg, K_bytes])
498         //key = sha256(transcript).digest()
499         // note that both sides must use the same order
500
501         Ok(match self.side {
502             Side::A => ed25519_hash_ab(
503                 &self.password_vec,
504                 &self.id_a,
505                 &self.id_b,
506                 self.msg1.as_slice(),
507                 &msg2[1..],
508                 &key_bytes,
509             ),
510             Side::B => ed25519_hash_ab(
511                 &self.password_vec,
512                 &self.id_a,
513                 &self.id_b,
514                 &msg2[1..],
515                 self.msg1.as_slice(),
516                 &key_bytes,
517             ),
518             Side::Symmetric => ed25519_hash_symmetric(
519                 &self.password_vec,
520                 &self.id_s,
521                 &self.msg1,
522                 &msg2[1..],
523                 &key_bytes,
524             ),
525         })
526     }
527 }
528
529 fn maybe_utf8(s: &[u8]) -> String {
530     match String::from_utf8(s.to_vec()) {
531         Ok(m) => format!("(s={})", m),
532         Err(_) => format!("(hex={})", hex::encode(s)),
533     }
534 }
535
536 impl<G: Group> fmt::Debug for SPAKE2<G> {
537     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
538         write!(
539             f,
540             "SPAKE2(G=?, side={:?}, idA={}, idB={}, idS={})",
541             self.side,
542             maybe_utf8(&self.id_a),
543             maybe_utf8(&self.id_b),
544             maybe_utf8(&self.id_s)
545         )
546     }
547 }
548
549 #[cfg(test)]
550 mod tests;