fuzzytags-sim/src/server.rs

153 lines
6.6 KiB
Rust

use crate::oracle::Oracle;
use crate::probability::binomial::{at_least_with_replacement, nchoosek};
use fuzzytags::{DetectionKey, Tag, TaggingKey};
use hashbrown::{HashMap, HashSet};
use itertools::Itertools;
use rayon::iter::IntoParallelRefIterator;
use rayon::iter::ParallelIterator;
use std::sync::mpsc::channel;
use std::sync::Arc;
use std::time::Duration;
use tracing::event;
use tracing::span;
use tracing::Level;
pub struct SimulatedServer {
keybase: Vec<(DetectionKey<24>, TaggingKey<24>, String)>,
detection_key_cache: Vec<DetectionKey<24>>,
messages: Vec<(Tag<24>, TaggingKey<24>)>,
sender_tags: HashMap<String, String>,
sender_count: HashMap<String, f64>,
tags_to_keys_cache: HashMap<String, HashSet<String>>,
keys_to_tags_cache: HashMap<String, HashSet<String>>,
oracle: Oracle,
}
struct Event(String, String, String, f64);
pub struct RoundStatistics {
pub num_registered_parties: usize,
pub num_messages: usize,
}
#[derive(Debug)]
pub struct PartyStatistics {
pub ideal_rate: f64,
pub expected_messages: f64,
pub observed_messages: usize,
pub observed_rate: f64,
pub observed_skew_messages: f64,
pub observed_skew: f64,
pub trivial_breaks: usize,
}
impl SimulatedServer {
pub fn new() -> SimulatedServer {
SimulatedServer {
keybase: vec![],
messages: vec![],
detection_key_cache: vec![],
sender_tags: HashMap::new(),
sender_count: HashMap::new(),
tags_to_keys_cache: HashMap::new(),
keys_to_tags_cache: HashMap::new(),
oracle: Oracle::new(),
}
}
pub fn register_key(&mut self, detection_key: &DetectionKey<24>, tagging_key: &TaggingKey<24>) {
self.keybase.push((detection_key.clone(), tagging_key.clone(), tagging_key.id()));
self.keys_to_tags_cache.insert(tagging_key.id(), HashSet::new());
self.oracle.register_party(tagging_key.id());
}
pub fn finalize(&mut self) {
self.detection_key_cache = self.keybase.iter().map(|(d, t, _)| d.clone()).collect();
}
pub fn add_message(&mut self, tag: Tag<24>, sender_tagging_key: &TaggingKey<24>) {
let tag_id = tag.to_string();
let sender_id = sender_tagging_key.id();
self.messages.push((tag.clone(), sender_tagging_key.clone()));
self.tags_to_keys_cache.insert(tag_id.clone(), HashSet::new());
self.sender_tags.insert(tag_id.clone(), sender_tagging_key.id());
let count = match self.sender_count.get(sender_tagging_key.id().as_str()) {
Some(count) => *count + 1.0,
_ => 1.0,
};
self.sender_count.insert(sender_tagging_key.id(), count);
let results = DetectionKey::test_tag_bulk(&self.detection_key_cache, &tag);
for index in results {
let detection_key = &self.keybase[index].0;
let receiver_tagging_id = &self.keybase[index].2;
event!(Level::TRACE, "Matched detection key for {key} to tag {tag} ", key = receiver_tagging_id.clone(), tag = tag_id.clone());
self.tags_to_keys_cache.get_mut(tag_id.as_str()).unwrap().insert(receiver_tagging_id.clone());
self.keys_to_tags_cache.get_mut(receiver_tagging_id.as_str()).unwrap().insert(tag_id.clone());
self.oracle.add_event(sender_id.clone(), receiver_tagging_id.clone(), None, detection_key.false_positive_probability());
}
}
pub fn statistics(&mut self) -> (Oracle, RoundStatistics, HashMap<String, PartyStatistics>) {
let mut party_stats = HashMap::new();
let round_stats = RoundStatistics {
num_messages: self.messages.len(),
num_registered_parties: self.keybase.len(),
};
for (recipient_index, (party, pub_key, id)) in self.keybase.iter().enumerate() {
let matched = self.keys_to_tags_cache[pub_key.id().as_str()].clone();
let observed_messages = matched.len();
let ideal_rate = party.false_positive_probability();
let expected_messages = ideal_rate * (round_stats.num_messages as f64);
let observed_rate = (observed_messages as f64) / (self.messages.len() as f64);
let observed_skew_messages = (observed_messages as f64) - expected_messages;
let observed_skew = (observed_messages as f64) / expected_messages;
let mut trivial_breaks = 0;
let mut num_times_matched_with = HashMap::new();
for tag in matched.iter() {
let sender = self.sender_tags[tag].clone();
let num = match num_times_matched_with.get(sender.as_str()) {
Some(num) => *num + 1.0,
_ => 1.0,
};
num_times_matched_with.insert(sender, num);
if self.tags_to_keys_cache[tag.to_string().as_str()].len() == 1 {
trivial_breaks += 1;
}
}
for (sender, count) in num_times_matched_with.iter() {
let expected_matched_count = (ideal_rate * self.sender_count[sender.as_str()]);
let actual_matched_count = (*count);
let prob = at_least_with_replacement(actual_matched_count as u64, self.sender_count[sender.as_str()] as u64, ideal_rate).to_f64();
let diff = f64::abs(actual_matched_count - expected_matched_count);
// these numbers are arbitrary, but the point is probability only works if your sample is big enough...
if diff > 4.0 && actual_matched_count > expected_matched_count && prob < 0.0001 {
let sender_index = self.keybase.iter().find_position(|(d, t, _)| t.id() == *sender).unwrap().0;
event!(Level::INFO, "Found Anomalous Relationship Between {sender}({sender_index}) and {recipient}({recipient_index}) {falsepositiverate}, {detected} >> {expected} | Probability of Event: {prob:.9}", sender=sender, sender_index=sender_index, recipient=pub_key.id(), recipient_index=recipient_index, falsepositiverate=ideal_rate, detected=actual_matched_count, expected=expected_matched_count, prob=prob);
self.oracle.add_suspect(&sender, &id, 1.0 - prob);
}
}
let p_stats = PartyStatistics {
ideal_rate,
expected_messages,
observed_messages,
observed_rate,
observed_skew_messages,
observed_skew,
trivial_breaks,
};
party_stats.insert(pub_key.id(), p_stats);
}
(self.oracle.clone(), round_stats, party_stats)
}
}