From 33f2fa9a88278c28e6aefbf8c52aa61176a8b9e3 Mon Sep 17 00:00:00 2001 From: Sarah Jamie Lewis Date: Wed, 19 May 2021 23:31:14 -0700 Subject: [PATCH] Performance improvement during testing. By splitting up the hash function H(U, x_iU, W) we can precompute the application of U and W, and only apply x_iU during the main detection key loop. Saving about 9.6% of the overall time. --- benches/fuzzy_tags_benches.rs | 8 +- rustfmt.toml | 11 ++- src/lib.rs | 152 +++++++++++++++++++++++----------- 3 files changed, 118 insertions(+), 53 deletions(-) diff --git a/benches/fuzzy_tags_benches.rs b/benches/fuzzy_tags_benches.rs index c61933d..01c28b6 100644 --- a/benches/fuzzy_tags_benches.rs +++ b/benches/fuzzy_tags_benches.rs @@ -9,7 +9,9 @@ fn benchmark_generate_tag(c: &mut Criterion) { let secret_key = RootSecret::<24>::generate(); for p in [5, 10, 15].iter() { let public_key = secret_key.tagging_key(); - group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| b.iter(|| public_key.generate_tag())); + group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| { + b.iter(|| public_key.generate_tag()) + }); } } @@ -22,7 +24,9 @@ fn benchmark_test_tag(c: &mut Criterion) { for p in [5, 10, 15].iter() { let tag = secret_key.tagging_key().generate_tag(); let detection_key = secret_key.extract_detection_key(*p); - group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| b.iter(|| detection_key.test_tag(&tag))); + group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| { + b.iter(|| detection_key.test_tag(&tag)) + }); } } diff --git a/rustfmt.toml b/rustfmt.toml index ad894cd..3d075c7 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -2,8 +2,15 @@ max_width = 200 hard_tabs = false tab_spaces = 4 newline_style = "Auto" -use_small_heuristics = "Default" indent_style = "Block" +use_small_heuristics = "Default" +fn_call_width = 60 +attr_fn_like_width = 70 +struct_lit_width = 18 +struct_variant_width = 35 +array_width = 60 +chain_width = 60 +single_line_if_else_max_width = 50 wrap_comments = false format_code_in_doc_comments = false comment_width = 80 @@ -54,7 +61,7 @@ use_field_init_shorthand = false force_explicit_abi = true condense_wildcard_suffixes = false color = "Auto" -required_version = "1.4.34" +required_version = "1.4.37" unstable_features = false disable_all_formatting = false skip_children = false diff --git a/src/lib.rs b/src/lib.rs index 815c3a9..ae27be8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ use curve25519_dalek::scalar::Scalar; use curve25519_dalek::traits::MultiscalarMul; use rand::rngs::OsRng; use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; -use sha3::Sha3_512; +use sha3::{Sha3_256, Sha3_512}; use std::convert::TryFrom; use std::fmt; use std::fmt::{Display, Formatter}; @@ -32,6 +32,13 @@ use rayon::iter::ParallelIterator; #[cfg(feature = "bulk_verify")] use std::sync::mpsc::channel; +#[cfg(feature = "encrypt_to_tag")] +use secretbox::CipherType::Salsa20; +#[cfg(feature = "encrypt_to_tag")] +use secretbox::SecretBox; +#[cfg(feature = "encrypt_to_tag")] +use std::string::FromUtf8Error; + /// A tag is a probabilistic cryptographic structure. When constructed for a given `TaggingKey` /// it will pass the `DetectionKey::test_tag` 100% of the time. For other tagging keys /// it will pass the test with probability `GAMMA` related to the security parameter of the system. @@ -81,7 +88,10 @@ impl<'de, const GAMMA: u8> Deserialize<'de> for Tag<{ GAMMA }> { { let mut bytes = vec![]; for i in 0..64 { - bytes.push(seq.next_element()?.ok_or(serde::de::Error::invalid_length(i, &"expected at least 64 bytes"))?); + bytes.push(seq.next_element()?.ok_or(serde::de::Error::invalid_length( + i, + &"expected at least 64 bytes", + ))?); } loop { match seq.next_element().unwrap_or(None) { @@ -157,7 +167,10 @@ impl Tag<{ GAMMA }> { }; let mut ciphertexts = BitVec::from_bytes(ciphertext); ciphertexts.truncate(GAMMA as usize); - return match (CompressedRistretto::from_slice(u_bytes).decompress(), Scalar::from_canonical_bytes(y_bytes_fixed)) { + return match ( + CompressedRistretto::from_slice(u_bytes).decompress(), + Scalar::from_canonical_bytes(y_bytes_fixed), + ) { (Some(u), Some(y)) => Some(Tag { u, y, ciphertexts }), _ => None, }; @@ -178,6 +191,12 @@ impl Display for Tag<{ GAMMA }> { } } +/// PrecomputeH is an encapsulation around the precomputation of the H function which +/// significantly speeds up testing. We define it for some additional type safety (to +/// prevent us from passing an uninitialized hash function to post_h +#[derive(Clone)] +struct PrecomputeH(Sha3_256); + /// The complete secret. Can't directly be used for testing. Instead you will need to generate /// a DetectionKey using `extract_detection_key` #[derive(Serialize, Deserialize)] @@ -203,7 +222,7 @@ impl RootSecret<{ GAMMA }> { let sk_i = Scalar::random(&mut rng); secret.push(sk_i); } - RootSecret:: { secret: secret } + RootSecret:: { secret } } /// extract a detection key for a given false positive (p = 2^-n) @@ -238,13 +257,28 @@ impl RootSecret<{ GAMMA }> { TaggingKey:: { 0: tagging_key } } + /// precompute the first part of h + fn pre_h(u: RistrettoPoint, w: RistrettoPoint) -> PrecomputeH { + let mut hash = sha3::Sha3_256::new(); + hash.update(&[GAMMA]); + hash.update(u.compress().as_bytes()); + hash.update(w.compress().as_bytes()); + return PrecomputeH(hash); + } + + /// compute the rest of h from a precomputed hash + fn post_h(mut hash: PrecomputeH, h: RistrettoPoint) -> u8 { + hash.0.update(h.compress().as_bytes()); + return hash.0.finalize().as_slice()[0] & 0x01; + } + /// a hash function that takes 3 ristretto points as a parameter and outputs 0 or 1. fn h(u: RistrettoPoint, h: RistrettoPoint, w: RistrettoPoint) -> u8 { let mut hash = sha3::Sha3_256::new(); hash.update(&[GAMMA]); hash.update(u.compress().as_bytes()); - hash.update(h.compress().as_bytes()); hash.update(w.compress().as_bytes()); + hash.update(h.compress().as_bytes()); return hash.finalize().as_slice()[0] & 0x01; } @@ -338,30 +372,28 @@ impl DetectionKey<{ GAMMA }> { // See below for a full explanation as to the reason for this: let w = RistrettoPoint::multiscalar_mul(&[m, tag.y], &[g, tag.u]); + let pre_h = RootSecret::::pre_h(tag.u, w); + // for each secret part... - let mut result = true; - for (i, x_i) in self.0.iter().enumerate() { + let mut result = 0; + for (x_i, c_i) in self.0.iter().zip(&tag.ciphertexts) { // re-derive the key from the tag - let k_i = RootSecret::::h(tag.u, tag.u.mul(x_i), w); + let k_i = RootSecret::::post_h(pre_h.clone(), tag.u.mul(x_i)); // calculate the "original" plaintext - let c_i = match tag.ciphertexts.get(i) { - Some(true) => 0x01, - Some(false) => 0x00, - _ => 0x00, - // we've run out of ciphertext, it doesn't really matter what we put here, the rest of the test will fail - // since the security of k_i is modelled as a random oracle, (k_i ^ 0) should also be random - }; - - let b_i = k_i ^ c_i; - - if b_i != 1 { + let b_i = k_i ^ (c_i as u8); + // short circuit + if b_i != 0x01 { return false; } // assert that the plaintext is all 1's - result = result & (b_i == 1); + result += 1; } - return result; + // Assert that number of sequential ones is equal to the length of the detection key + // If it isn't it indicates that the tag ciphertext is shorter than the verification key, + // Given the checks on deserialization that should never happen, but we throw in a check + // here anyway for defense in depth. + return result == self.0.len(); } /// A bulk testing function that takes in an vector of detection keys and returns a vector @@ -406,38 +438,41 @@ impl DetectionKey<{ GAMMA }> { // for each secret part... let mut results: Vec = vec![]; - detection_keys.par_iter().enumerate().for_each_with(tx.clone(), |tx, (index, detection_key)| { - let mut result = true; - for (i, x_i) in detection_key.0.iter().enumerate() { - // re-derive the key from the tag - let k_i = RootSecret::::h(tag.u, tag.u.mul(x_i), w); + detection_keys + .par_iter() + .enumerate() + .for_each_with(tx.clone(), |tx, (index, detection_key)| { + let mut result = true; + for (i, x_i) in detection_key.0.iter().enumerate() { + // re-derive the key from the tag + let k_i = RootSecret::::h(tag.u, tag.u.mul(x_i), w); - // calculate the "original" plaintext - let c_i = match tag.ciphertexts.get(i) { - Some(true) => 0x01, - Some(false) => 0x00, - _ => 0x00, - // we've run out of ciphertext, it doesn't really matter what we put here, the rest of the test will fail - // since the security of k_i is modelled as a random oracle, (k_i ^ 0) should also be random - }; + // calculate the "original" plaintext + let c_i = match tag.ciphertexts.get(i) { + Some(true) => 0x01, + Some(false) => 0x00, + _ => 0x00, + // we've run out of ciphertext, it doesn't really matter what we put here, the rest of the test will fail + // since the security of k_i is modelled as a random oracle, (k_i ^ 0) should also be random + }; - let b_i = k_i ^ c_i; + let b_i = k_i ^ c_i; - if b_i != 1 { - result = false; - break; + if b_i != 1 { + result = false; + break; + } + // assert that the plaintext is all 1's + result = result & (b_i == 1); } - // assert that the plaintext is all 1's - result = result & (b_i == 1); - } - if result { - match tx.send(index) { - _ => { - // TODO...surface this error... + if result { + match tx.send(index) { + _ => { + // TODO...surface this error... + } } } - } - }); + }); std::mem::drop(tx); loop { @@ -586,10 +621,11 @@ impl TaggingKey<{ GAMMA }> { #[cfg(test)] mod tests { - use crate::{DetectionKey, RootSecret, Tag}; + use crate::{RootSecret, Tag}; use bit_vec::BitVec; use curve25519_dalek::ristretto::RistrettoPoint; use curve25519_dalek::scalar::Scalar; + use rand::rngs::OsRng; #[test] fn test_compression() { @@ -674,7 +710,10 @@ mod tests { let tagging_keys: Vec> = secrets.iter().map(|x| x.tagging_key()).collect(); // it takes ~15 minutes on a standard desktop to find a length=24 match for 2 parties, so for testing let's keep things light let entangled_tag = TaggingKey::generate_entangled_tag(tagging_keys, 16); - let detection_keys = secrets.iter().map(|x| x.extract_detection_key(16)).collect(); + let detection_keys = secrets + .iter() + .map(|x| x.extract_detection_key(16)) + .collect(); let results = DetectionKey::test_tag_bulk(&detection_keys, &entangled_tag); assert_eq!(results.len(), 2); @@ -710,6 +749,21 @@ mod tests { tag } + #[test] + fn assert_h_and_pre_post_h() { + let mut rng = OsRng::default(); + + for _ in 0..100 { + let a = RistrettoPoint::random(&mut rng); + let b = RistrettoPoint::random(&mut rng); + let c = RistrettoPoint::random(&mut rng); + assert_eq!( + RootSecret::<24>::post_h(RootSecret::<24>::pre_h(a, b), c), + RootSecret::<24>::h(a, c, b) + ); + } + } + #[test] // Thanks to Lee Bousfield who noticed an all zeros or all ones tag would // validate against a tagging key with 50% probability, allowing universal