Performance improvement during testing.

By splitting up the hash function H(U, x_iU, W) we can
precompute the application of U and W, and only apply x_iU during
the main detection key loop. Saving about 9.6% of the overall time.
This commit is contained in:
Sarah Jamie Lewis 2021-05-19 23:31:14 -07:00
parent 931ca3050a
commit 33f2fa9a88
3 changed files with 118 additions and 53 deletions

View File

@ -9,7 +9,9 @@ fn benchmark_generate_tag(c: &mut Criterion) {
let secret_key = RootSecret::<24>::generate(); let secret_key = RootSecret::<24>::generate();
for p in [5, 10, 15].iter() { for p in [5, 10, 15].iter() {
let public_key = secret_key.tagging_key(); let public_key = secret_key.tagging_key();
group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| b.iter(|| public_key.generate_tag())); group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| {
b.iter(|| public_key.generate_tag())
});
} }
} }
@ -22,7 +24,9 @@ fn benchmark_test_tag(c: &mut Criterion) {
for p in [5, 10, 15].iter() { for p in [5, 10, 15].iter() {
let tag = secret_key.tagging_key().generate_tag(); let tag = secret_key.tagging_key().generate_tag();
let detection_key = secret_key.extract_detection_key(*p); let detection_key = secret_key.extract_detection_key(*p);
group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| b.iter(|| detection_key.test_tag(&tag))); group.bench_with_input(BenchmarkId::from_parameter(p), p, |b, _gamma| {
b.iter(|| detection_key.test_tag(&tag))
});
} }
} }

View File

@ -2,8 +2,15 @@ max_width = 200
hard_tabs = false hard_tabs = false
tab_spaces = 4 tab_spaces = 4
newline_style = "Auto" newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block" indent_style = "Block"
use_small_heuristics = "Default"
fn_call_width = 60
attr_fn_like_width = 70
struct_lit_width = 18
struct_variant_width = 35
array_width = 60
chain_width = 60
single_line_if_else_max_width = 50
wrap_comments = false wrap_comments = false
format_code_in_doc_comments = false format_code_in_doc_comments = false
comment_width = 80 comment_width = 80
@ -54,7 +61,7 @@ use_field_init_shorthand = false
force_explicit_abi = true force_explicit_abi = true
condense_wildcard_suffixes = false condense_wildcard_suffixes = false
color = "Auto" color = "Auto"
required_version = "1.4.34" required_version = "1.4.37"
unstable_features = false unstable_features = false
disable_all_formatting = false disable_all_formatting = false
skip_children = false skip_children = false

View File

@ -12,7 +12,7 @@ use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::MultiscalarMul; use curve25519_dalek::traits::MultiscalarMul;
use rand::rngs::OsRng; use rand::rngs::OsRng;
use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};
use sha3::Sha3_512; use sha3::{Sha3_256, Sha3_512};
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt; use std::fmt;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
@ -32,6 +32,13 @@ use rayon::iter::ParallelIterator;
#[cfg(feature = "bulk_verify")] #[cfg(feature = "bulk_verify")]
use std::sync::mpsc::channel; use std::sync::mpsc::channel;
#[cfg(feature = "encrypt_to_tag")]
use secretbox::CipherType::Salsa20;
#[cfg(feature = "encrypt_to_tag")]
use secretbox::SecretBox;
#[cfg(feature = "encrypt_to_tag")]
use std::string::FromUtf8Error;
/// A tag is a probabilistic cryptographic structure. When constructed for a given `TaggingKey` /// A tag is a probabilistic cryptographic structure. When constructed for a given `TaggingKey`
/// it will pass the `DetectionKey::test_tag` 100% of the time. For other tagging keys /// it will pass the `DetectionKey::test_tag` 100% of the time. For other tagging keys
/// it will pass the test with probability `GAMMA` related to the security parameter of the system. /// it will pass the test with probability `GAMMA` related to the security parameter of the system.
@ -81,7 +88,10 @@ impl<'de, const GAMMA: u8> Deserialize<'de> for Tag<{ GAMMA }> {
{ {
let mut bytes = vec![]; let mut bytes = vec![];
for i in 0..64 { for i in 0..64 {
bytes.push(seq.next_element()?.ok_or(serde::de::Error::invalid_length(i, &"expected at least 64 bytes"))?); bytes.push(seq.next_element()?.ok_or(serde::de::Error::invalid_length(
i,
&"expected at least 64 bytes",
))?);
} }
loop { loop {
match seq.next_element().unwrap_or(None) { match seq.next_element().unwrap_or(None) {
@ -157,7 +167,10 @@ impl<const GAMMA: u8> Tag<{ GAMMA }> {
}; };
let mut ciphertexts = BitVec::from_bytes(ciphertext); let mut ciphertexts = BitVec::from_bytes(ciphertext);
ciphertexts.truncate(GAMMA as usize); ciphertexts.truncate(GAMMA as usize);
return match (CompressedRistretto::from_slice(u_bytes).decompress(), Scalar::from_canonical_bytes(y_bytes_fixed)) { return match (
CompressedRistretto::from_slice(u_bytes).decompress(),
Scalar::from_canonical_bytes(y_bytes_fixed),
) {
(Some(u), Some(y)) => Some(Tag { u, y, ciphertexts }), (Some(u), Some(y)) => Some(Tag { u, y, ciphertexts }),
_ => None, _ => None,
}; };
@ -178,6 +191,12 @@ impl<const GAMMA: u8> Display for Tag<{ GAMMA }> {
} }
} }
/// PrecomputeH is an encapsulation around the precomputation of the H function which
/// significantly speeds up testing. We define it for some additional type safety (to
/// prevent us from passing an uninitialized hash function to post_h
#[derive(Clone)]
struct PrecomputeH(Sha3_256);
/// The complete secret. Can't directly be used for testing. Instead you will need to generate /// The complete secret. Can't directly be used for testing. Instead you will need to generate
/// a DetectionKey using `extract_detection_key` /// a DetectionKey using `extract_detection_key`
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
@ -203,7 +222,7 @@ impl<const GAMMA: u8> RootSecret<{ GAMMA }> {
let sk_i = Scalar::random(&mut rng); let sk_i = Scalar::random(&mut rng);
secret.push(sk_i); secret.push(sk_i);
} }
RootSecret::<GAMMA> { secret: secret } RootSecret::<GAMMA> { secret }
} }
/// extract a detection key for a given false positive (p = 2^-n) /// extract a detection key for a given false positive (p = 2^-n)
@ -238,13 +257,28 @@ impl<const GAMMA: u8> RootSecret<{ GAMMA }> {
TaggingKey::<GAMMA> { 0: tagging_key } TaggingKey::<GAMMA> { 0: tagging_key }
} }
/// precompute the first part of h
fn pre_h(u: RistrettoPoint, w: RistrettoPoint) -> PrecomputeH {
let mut hash = sha3::Sha3_256::new();
hash.update(&[GAMMA]);
hash.update(u.compress().as_bytes());
hash.update(w.compress().as_bytes());
return PrecomputeH(hash);
}
/// compute the rest of h from a precomputed hash
fn post_h(mut hash: PrecomputeH, h: RistrettoPoint) -> u8 {
hash.0.update(h.compress().as_bytes());
return hash.0.finalize().as_slice()[0] & 0x01;
}
/// a hash function that takes 3 ristretto points as a parameter and outputs 0 or 1. /// a hash function that takes 3 ristretto points as a parameter and outputs 0 or 1.
fn h(u: RistrettoPoint, h: RistrettoPoint, w: RistrettoPoint) -> u8 { fn h(u: RistrettoPoint, h: RistrettoPoint, w: RistrettoPoint) -> u8 {
let mut hash = sha3::Sha3_256::new(); let mut hash = sha3::Sha3_256::new();
hash.update(&[GAMMA]); hash.update(&[GAMMA]);
hash.update(u.compress().as_bytes()); hash.update(u.compress().as_bytes());
hash.update(h.compress().as_bytes());
hash.update(w.compress().as_bytes()); hash.update(w.compress().as_bytes());
hash.update(h.compress().as_bytes());
return hash.finalize().as_slice()[0] & 0x01; return hash.finalize().as_slice()[0] & 0x01;
} }
@ -338,30 +372,28 @@ impl<const GAMMA: u8> DetectionKey<{ GAMMA }> {
// See below for a full explanation as to the reason for this: // See below for a full explanation as to the reason for this:
let w = RistrettoPoint::multiscalar_mul(&[m, tag.y], &[g, tag.u]); let w = RistrettoPoint::multiscalar_mul(&[m, tag.y], &[g, tag.u]);
let pre_h = RootSecret::<GAMMA>::pre_h(tag.u, w);
// for each secret part... // for each secret part...
let mut result = true; let mut result = 0;
for (i, x_i) in self.0.iter().enumerate() { for (x_i, c_i) in self.0.iter().zip(&tag.ciphertexts) {
// re-derive the key from the tag // re-derive the key from the tag
let k_i = RootSecret::<GAMMA>::h(tag.u, tag.u.mul(x_i), w); let k_i = RootSecret::<GAMMA>::post_h(pre_h.clone(), tag.u.mul(x_i));
// calculate the "original" plaintext // calculate the "original" plaintext
let c_i = match tag.ciphertexts.get(i) { let b_i = k_i ^ (c_i as u8);
Some(true) => 0x01, // short circuit
Some(false) => 0x00, if b_i != 0x01 {
_ => 0x00,
// we've run out of ciphertext, it doesn't really matter what we put here, the rest of the test will fail
// since the security of k_i is modelled as a random oracle, (k_i ^ 0) should also be random
};
let b_i = k_i ^ c_i;
if b_i != 1 {
return false; return false;
} }
// assert that the plaintext is all 1's // assert that the plaintext is all 1's
result = result & (b_i == 1); result += 1;
} }
return result; // Assert that number of sequential ones is equal to the length of the detection key
// If it isn't it indicates that the tag ciphertext is shorter than the verification key,
// Given the checks on deserialization that should never happen, but we throw in a check
// here anyway for defense in depth.
return result == self.0.len();
} }
/// A bulk testing function that takes in an vector of detection keys and returns a vector /// A bulk testing function that takes in an vector of detection keys and returns a vector
@ -406,38 +438,41 @@ impl<const GAMMA: u8> DetectionKey<{ GAMMA }> {
// for each secret part... // for each secret part...
let mut results: Vec<usize> = vec![]; let mut results: Vec<usize> = vec![];
detection_keys.par_iter().enumerate().for_each_with(tx.clone(), |tx, (index, detection_key)| { detection_keys
let mut result = true; .par_iter()
for (i, x_i) in detection_key.0.iter().enumerate() { .enumerate()
// re-derive the key from the tag .for_each_with(tx.clone(), |tx, (index, detection_key)| {
let k_i = RootSecret::<GAMMA>::h(tag.u, tag.u.mul(x_i), w); let mut result = true;
for (i, x_i) in detection_key.0.iter().enumerate() {
// re-derive the key from the tag
let k_i = RootSecret::<GAMMA>::h(tag.u, tag.u.mul(x_i), w);
// calculate the "original" plaintext // calculate the "original" plaintext
let c_i = match tag.ciphertexts.get(i) { let c_i = match tag.ciphertexts.get(i) {
Some(true) => 0x01, Some(true) => 0x01,
Some(false) => 0x00, Some(false) => 0x00,
_ => 0x00, _ => 0x00,
// we've run out of ciphertext, it doesn't really matter what we put here, the rest of the test will fail // we've run out of ciphertext, it doesn't really matter what we put here, the rest of the test will fail
// since the security of k_i is modelled as a random oracle, (k_i ^ 0) should also be random // since the security of k_i is modelled as a random oracle, (k_i ^ 0) should also be random
}; };
let b_i = k_i ^ c_i; let b_i = k_i ^ c_i;
if b_i != 1 { if b_i != 1 {
result = false; result = false;
break; break;
}
// assert that the plaintext is all 1's
result = result & (b_i == 1);
} }
// assert that the plaintext is all 1's if result {
result = result & (b_i == 1); match tx.send(index) {
} _ => {
if result { // TODO...surface this error...
match tx.send(index) { }
_ => {
// TODO...surface this error...
} }
} }
} });
});
std::mem::drop(tx); std::mem::drop(tx);
loop { loop {
@ -586,10 +621,11 @@ impl<const GAMMA: u8> TaggingKey<{ GAMMA }> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::{DetectionKey, RootSecret, Tag}; use crate::{RootSecret, Tag};
use bit_vec::BitVec; use bit_vec::BitVec;
use curve25519_dalek::ristretto::RistrettoPoint; use curve25519_dalek::ristretto::RistrettoPoint;
use curve25519_dalek::scalar::Scalar; use curve25519_dalek::scalar::Scalar;
use rand::rngs::OsRng;
#[test] #[test]
fn test_compression() { fn test_compression() {
@ -674,7 +710,10 @@ mod tests {
let tagging_keys: Vec<TaggingKey<24>> = secrets.iter().map(|x| x.tagging_key()).collect(); let tagging_keys: Vec<TaggingKey<24>> = secrets.iter().map(|x| x.tagging_key()).collect();
// it takes ~15 minutes on a standard desktop to find a length=24 match for 2 parties, so for testing let's keep things light // it takes ~15 minutes on a standard desktop to find a length=24 match for 2 parties, so for testing let's keep things light
let entangled_tag = TaggingKey::generate_entangled_tag(tagging_keys, 16); let entangled_tag = TaggingKey::generate_entangled_tag(tagging_keys, 16);
let detection_keys = secrets.iter().map(|x| x.extract_detection_key(16)).collect(); let detection_keys = secrets
.iter()
.map(|x| x.extract_detection_key(16))
.collect();
let results = DetectionKey::test_tag_bulk(&detection_keys, &entangled_tag); let results = DetectionKey::test_tag_bulk(&detection_keys, &entangled_tag);
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
@ -710,6 +749,21 @@ mod tests {
tag tag
} }
#[test]
fn assert_h_and_pre_post_h() {
let mut rng = OsRng::default();
for _ in 0..100 {
let a = RistrettoPoint::random(&mut rng);
let b = RistrettoPoint::random(&mut rng);
let c = RistrettoPoint::random(&mut rng);
assert_eq!(
RootSecret::<24>::post_h(RootSecret::<24>::pre_h(a, b), c),
RootSecret::<24>::h(a, c, b)
);
}
}
#[test] #[test]
// Thanks to Lee Bousfield who noticed an all zeros or all ones tag would // Thanks to Lee Bousfield who noticed an all zeros or all ones tag would
// validate against a tagging key with 50% probability, allowing universal // validate against a tagging key with 50% probability, allowing universal