Much cleaner design with more reuse and safe state

This commit is contained in:
Sarah Jamie Lewis 2020-11-24 19:15:18 -08:00
parent 4d89c33882
commit 96ae0b1352
4 changed files with 174 additions and 171 deletions

View File

@ -1,5 +1,6 @@
use crate::applications::authentication_app::AuthenticationAppError::NotAuthenticatedError;
use crate::connections::utils::public_key_to_hostname;
use crate::connections::{Connection, InboundConnection, OutboundConnection};
use crate::connections::{Connection, ConnectionInterface, InboundConnection, OutboundConnection};
use crate::primitives::identity::Identity;
use crate::primitives::transcript::Transcript;
use ed25519_dalek::PublicKey;
@ -7,6 +8,7 @@ use integer_encoding::VarInt;
use serde::Deserialize;
use serde::Serialize;
use sha3::Digest;
use std::io::Error;
use std::sync::Arc;
use subtle::ConstantTimeEq;
@ -14,6 +16,11 @@ pub struct AuthenicationApp {
identity: Arc<Identity>,
}
#[derive(Debug)]
pub enum AuthenticationAppError {
NotAuthenticatedError,
}
struct AuthenticationSession<Direction> {
long_term_identity: Arc<Identity>,
ephemeral_identity: Identity,
@ -28,73 +35,57 @@ struct AuthenticationSession<Direction> {
conn: Connection<Direction>,
}
/// Macro for constructing the identity exchange...we need to do this for both inbound and outbound and
/// there is slightly less code complication doing this in a macro v.s. duplicating the code.
macro_rules! identity_exchange {
($conn:ident, $long_term_identity:ident, $ephemeral_identity:ident, $local_auth_message_json: ident, $remote_long_term_identity:ident, $remote_ephemeral_identity: ident, $remote_auth_message_json: ident) => {
let $ephemeral_identity = Identity::initialize_ephemeral_identity();
let auth_message = AuthMessage {
longterm_public_key: base64::encode($long_term_identity.public_key()),
ephemeral_public_key: base64::encode($ephemeral_identity.public_key().to_bytes()),
};
let $local_auth_message_json = serde_json::to_string(&auth_message).unwrap();
$conn.send(&$local_auth_message_json);
let $remote_auth_message_json = String::from_utf8($conn.expect()).unwrap();
let remote_auth_message: AuthMessage = serde_json::from_str($remote_auth_message_json.as_str()).unwrap();
let $remote_long_term_identity = PublicKey::from_bytes(base64::decode(&remote_auth_message.longterm_public_key).unwrap().as_slice()).unwrap();
let $remote_ephemeral_identity = PublicKey::from_bytes(base64::decode(&remote_auth_message.ephemeral_public_key).unwrap().as_slice()).unwrap();
};
}
impl<Direction> AuthenticationSession<Direction> {
pub fn new_outbound(mut conn: Connection<OutboundConnection>, long_term_identity: Arc<Identity>) -> AuthenticationSession<OutboundConnection> {
identity_exchange!(
conn,
let ephemeral_identity = Identity::initialize_ephemeral_identity();
let mut auth_session = AuthenticationSession {
long_term_identity,
ephemeral_identity,
local_auth_message_json,
remote_long_term_identity,
remote_ephemeral_identity,
remote_auth_message_json
);
AuthenticationSession {
long_term_identity,
ephemeral_identity,
local_auth_message_json,
remote_auth_message_json,
remote_long_term_identity,
remote_ephemeral_identity,
local_auth_message_json: "".to_string(),
remote_auth_message_json: "".to_string(),
remote_long_term_identity: Default::default(),
remote_ephemeral_identity: Default::default(),
conn,
challenge: [0; 64],
}
};
auth_session.identity_exchange();
auth_session
}
pub fn new_inbound(mut conn: Connection<InboundConnection>, long_term_identity: Arc<Identity>) -> AuthenticationSession<InboundConnection> {
identity_exchange!(
conn,
let ephemeral_identity = Identity::initialize_ephemeral_identity();
let mut auth_session = AuthenticationSession {
long_term_identity,
ephemeral_identity,
local_auth_message_json,
remote_long_term_identity,
remote_ephemeral_identity,
remote_auth_message_json
);
AuthenticationSession {
long_term_identity,
ephemeral_identity,
local_auth_message_json,
remote_auth_message_json,
remote_long_term_identity,
remote_ephemeral_identity,
local_auth_message_json: "".to_string(),
remote_auth_message_json: "".to_string(),
remote_long_term_identity: Default::default(),
remote_ephemeral_identity: Default::default(),
conn,
challenge: [0; 64],
}
};
auth_session.identity_exchange();
auth_session
}
/// exchange identities with the remote parties
/// Not that if network connections fail during this empty buffers will be returned which will
/// cause subsequent steps to fail
fn identity_exchange(&mut self) {
let auth_message = AuthMessage {
longterm_public_key: base64::encode(self.long_term_identity.public_key()),
ephemeral_public_key: base64::encode(self.ephemeral_identity.public_key().to_bytes()),
};
self.local_auth_message_json = serde_json::to_string(&auth_message).unwrap();
self.conn.send(&self.local_auth_message_json);
self.remote_auth_message_json = String::from_utf8(self.conn.expect().unwrap()).unwrap();
let remote_auth_message: AuthMessage = serde_json::from_str(self.remote_auth_message_json.as_str()).unwrap();
self.remote_long_term_identity = PublicKey::from_bytes(base64::decode(&remote_auth_message.longterm_public_key).unwrap().as_slice()).unwrap();
self.remote_ephemeral_identity = PublicKey::from_bytes(base64::decode(&remote_auth_message.ephemeral_public_key).unwrap().as_slice()).unwrap();
}
/// generate the message for the local challenge based a new transcript hash.
fn generate_challenge_message(&mut self) -> Vec<u8> {
fn generate_challenge_message(&self) -> Vec<u8> {
let mut msg = vec![];
let mut len = [0u8; 2];
((self.challenge.len() + self.long_term_identity.hostname().as_bytes().len()) as u16).encode_var(&mut len);
@ -105,22 +96,32 @@ impl<Direction> AuthenticationSession<Direction> {
}
/// check the challenge from the remote
fn check_remote_challenge(&mut self, remote_challenge: Vec<u8>) -> bool {
let mut cmp_challenge = vec![];
cmp_challenge.extend_from_slice(self.challenge.as_slice());
cmp_challenge.extend_from_slice(public_key_to_hostname(&self.remote_long_term_identity).as_bytes());
return remote_challenge.ct_eq(cmp_challenge.as_slice()).unwrap_u8() == 1;
fn check_remote_challenge(&mut self) -> Result<(), AuthenticationAppError> {
match self.conn.send_encrypted(self.generate_challenge_message()) {
Ok(()) => {
let remote_challenge = self.conn.expect_encrypted();
let mut cmp_challenge = vec![];
cmp_challenge.extend_from_slice(self.challenge.as_slice());
cmp_challenge.extend_from_slice(public_key_to_hostname(&self.remote_long_term_identity).as_bytes());
if remote_challenge.ct_eq(cmp_challenge.as_slice()).unwrap_u8() == 1 {
return Ok(());
}
self.conn.shutdown();
Err(NotAuthenticatedError)
}
Err(err) => Err(NotAuthenticatedError),
}
}
}
impl AuthenticationSession<InboundConnection> {
pub fn generate_challenge(&mut self, transcript: &mut Transcript) -> Vec<u8> {
impl AuthenticationSession<OutboundConnection> {
pub fn generate_challenge(&mut self, transcript: &mut Transcript) -> Result<(), AuthenticationAppError> {
let l2e = self.long_term_identity.edh(self.remote_ephemeral_identity);
let e2l = self.ephemeral_identity.edh(self.remote_long_term_identity);
let e2e = self.ephemeral_identity.edh(self.remote_ephemeral_identity);
let mut shared_secret = vec![];
// inbound calculates shared secret as Sha256(l2e:e2l:e2e)
// outbound calculates shared secret as Sha256(l2e:e2l:e2e)
shared_secret.extend_from_slice(l2e.as_bytes());
shared_secret.extend_from_slice(e2l.as_bytes());
shared_secret.extend_from_slice(e2e.as_bytes());
@ -133,13 +134,13 @@ impl AuthenticationSession<InboundConnection> {
transcript.add_to_transcript("outbound-challenge", self.local_auth_message_json.as_bytes());
transcript.add_to_transcript("inbound-challenge", self.remote_auth_message_json.as_bytes());
self.challenge = transcript.commit_to_transcript("3dh-auth-challenge");
let key = key.clone();
key.to_vec()
self.conn.enable_encryption(key.to_vec());
self.check_remote_challenge()
}
}
impl AuthenticationSession<OutboundConnection> {
pub fn generate_challenge(&mut self, transcript: &mut Transcript) -> Vec<u8> {
impl AuthenticationSession<InboundConnection> {
pub fn generate_challenge(&mut self, transcript: &mut Transcript) -> Result<(), AuthenticationAppError> {
let l2e = self.long_term_identity.edh(self.remote_ephemeral_identity);
let e2l = self.ephemeral_identity.edh(self.remote_long_term_identity);
let e2e = self.ephemeral_identity.edh(self.remote_ephemeral_identity);
@ -158,9 +159,8 @@ impl AuthenticationSession<OutboundConnection> {
transcript.add_to_transcript("outbound-challenge", self.remote_auth_message_json.as_bytes());
transcript.add_to_transcript("inbound-challenge", self.local_auth_message_json.as_bytes());
self.challenge = transcript.commit_to_transcript("3dh-auth-challenge");
let key = key.clone();
key.to_vec()
self.conn.enable_encryption(key.to_vec());
self.check_remote_challenge()
}
}
@ -177,31 +177,13 @@ impl AuthenicationApp {
AuthenicationApp { identity }
}
pub fn run_outbound(&mut self, conn: &mut Connection<OutboundConnection>, transcript: &mut Transcript) {
pub fn run_outbound(&mut self, conn: &mut Connection<OutboundConnection>, transcript: &mut Transcript) -> Result<(), AuthenticationAppError> {
let mut auth_session = AuthenticationSession::<OutboundConnection>::new_outbound(conn.try_clone(), self.identity.clone());
let key = auth_session.generate_challenge(transcript);
conn.enable_encryption(key);
conn.send_encrypted(auth_session.generate_challenge_message());
let remote_challenge = conn.expect_encrypted();
if auth_session.check_remote_challenge(remote_challenge) {
println!("Authenticated")
} else {
println!("Failed Authentication");
conn.shutdown();
}
auth_session.generate_challenge(transcript)
}
pub fn run_inbound(&mut self, conn: &mut Connection<InboundConnection>, transcript: &mut Transcript) {
pub fn run_inbound(&mut self, conn: &mut Connection<InboundConnection>, transcript: &mut Transcript) -> Result<(), AuthenticationAppError> {
let mut auth_session = AuthenticationSession::<InboundConnection>::new_inbound(conn.try_clone(), self.identity.clone());
let key = auth_session.generate_challenge(transcript);
conn.enable_encryption(key);
conn.send_encrypted(auth_session.generate_challenge_message());
let remote_challenge = conn.expect_encrypted();
if auth_session.check_remote_challenge(remote_challenge) {
println!("Authenticated")
} else {
println!("Failed Authentication");
conn.shutdown();
}
auth_session.generate_challenge(transcript)
}
}

View File

@ -1,7 +1,7 @@
use integer_encoding::{FixedInt, VarInt};
use secretbox::CipherType::Salsa20;
use secretbox::SecretBox;
use std::io::{Read, Write};
use std::io::{Error, Read, Write};
use std::net::{Shutdown, TcpStream};
/// Connections provides an interface for manage sets of connections on top of a particular
@ -27,6 +27,92 @@ pub struct Connection<Direction> {
key: Vec<u8>,
}
pub trait ConnectionInterface {
fn enable_encryption(&mut self, key: Vec<u8>);
fn send(&mut self, amsg: &String) -> Result<(), Error>;
fn send_encrypted(&mut self, msg: Vec<u8>) -> Result<(), Error>;
fn expect_encrypted(&mut self) -> Vec<u8>;
fn expect(&mut self) -> Result<Vec<u8>, Error>;
fn shutdown(&mut self);
}
impl<Direction> ConnectionInterface for Connection<Direction> {
fn enable_encryption(&mut self, key: Vec<u8>) {
self.key = key
}
fn send(&mut self, amsg: &String) -> Result<(), Error> {
let mut len = [0u8; 2];
(amsg.len() as u16).encode_var(&mut len);
let mut msg = vec![];
msg.extend_from_slice(len.as_slice());
msg.extend_from_slice(amsg.as_bytes());
let mut msg = msg;
while msg.len() < 8192 {
msg.push(0);
}
self.conn.write_all(msg.as_slice())
}
fn send_encrypted(&mut self, msg: Vec<u8>) -> Result<(), Error> {
let mut msg = msg;
while msg.len() < 8192 - 40 {
msg.push(0);
}
let secret_box = SecretBox::new(&self.key, Salsa20).unwrap();
let msg = secret_box.easy_seal(msg.as_slice());
self.conn.write_all(msg.as_slice())
}
fn expect_encrypted(&mut self) -> Vec<u8> {
let secret_box = SecretBox::new(&self.key, Salsa20).unwrap();
let mut result = [0u8; 8192];
match self.conn.read_exact(&mut result) {
Err(e) => {
eprintln!("{}", e.to_string());
return vec![];
}
_ => {}
}
let msg = match secret_box.easy_unseal(&result) {
Some(msg) => msg,
_ => {
return vec![];
}
};
let msg = msg.as_slice();
let _len_bytes = [0u8; 2];
let len = u16::decode_fixed(&msg[0..2]) as usize;
if len > 8192 {
eprintln!("invalid length: {}", len);
return vec![]; // lol no.
}
msg[2..len + 2].to_vec()
}
fn expect(&mut self) -> Result<Vec<u8>, Error> {
let mut msg = [0; 8192];
let result = self.conn.read_exact(&mut msg);
match result {
Err(e) => Err(e),
Ok(()) => {
// TODO why did I decide to use varints here?!?!
let len = u16::decode_var(&msg[0..2]).unwrap().0 as usize;
println!("{} [{}]", len, String::from_utf8(msg[2..len + 2].to_vec()).unwrap());
return Ok(msg[2..len + 2].to_vec());
}
}
}
fn shutdown(&mut self) {
match self.conn.shutdown(Shutdown::Both) {
_ => {} // If anything bad happens we will know soon enough...
}
}
}
impl<Direction> Connection<Direction>
where
Direction: Clone,
@ -54,79 +140,6 @@ where
key: self.key.clone(),
}
}
pub fn enable_encryption(&mut self, key: Vec<u8>) {
self.key = key
}
pub fn send(&mut self, amsg: &String) {
let mut len = [0u8; 2];
(amsg.len() as u16).encode_var(&mut len);
let mut msg = vec![];
msg.extend_from_slice(len.as_slice());
msg.extend_from_slice(amsg.as_bytes());
let mut msg = msg;
while msg.len() < 8192 {
msg.push(0);
}
match self.conn.write_all(msg.as_slice()) {
_ => {} // If anything bad happens we will know soon enough...
}
}
pub fn send_encrypted(&mut self, msg: Vec<u8>) {
let mut msg = msg;
while msg.len() < 8192 - 40 {
msg.push(0);
}
let secret_box = SecretBox::new(&self.key, Salsa20).unwrap();
let msg = secret_box.easy_seal(msg.as_slice());
match self.conn.write_all(msg.as_slice()) {
_ => {} // If anything bad happens we will know soon enough...
}
}
pub fn expect_encrypted(&mut self) -> Vec<u8> {
let secret_box = SecretBox::new(&self.key, Salsa20).unwrap();
let mut result = [0u8; 8192];
match self.conn.read_exact(&mut result) {
Err(e) => eprintln!("{}", e.to_string()),
_ => {}
}
let msg = secret_box.easy_unseal(&result).unwrap();
let msg = msg.as_slice();
let _len_bytes = [0u8; 2];
let len = u16::decode_fixed(&msg[0..2]) as usize;
if len > 8192 {
eprintln!("invalid length: {}", len);
return vec![]; // lol no.
}
msg[2..len + 2].to_vec()
}
pub fn expect(&mut self) -> Vec<u8> {
let mut msg = [0; 8192];
let result = self.conn.read_exact(&mut msg);
match result {
Err(e) => {
println!("{:?}", e);
vec![]
}
Ok(()) => {
// TODO why did I decide to use varints here?!?!
let len = u16::decode_var(&msg[0..2]).unwrap().0 as usize;
println!("{} [{}]", len, String::from_utf8(msg[2..len + 2].to_vec()).unwrap());
return msg[2..len + 2].to_vec();
}
}
}
pub fn shutdown(&mut self) {
match self.conn.shutdown(Shutdown::Both) {
_ => {} // If anything bad happens we will know soon enough...
}
}
}
pub struct Hostname {}

View File

@ -19,13 +19,13 @@ pub struct Service<ListenService> {
impl<ListenService> Service<ListenService> {
pub fn connect<F>(&mut self, hostname: &str, application: F) -> Result<(), ServiceError>
where
F: FnOnce(Connection<InboundConnection>) + Send + Clone + 'static,
F: FnOnce(Connection<OutboundConnection>) + Send + Clone + 'static,
{
let conn = Socks5Stream::connect(format!("127.0.0.1:9050"), Domain(format!("{}.onion", hostname), 9878));
match conn {
Ok(conn) => {
let application = application.clone();
spawn(move || application(Connection::<InboundConnection>::new_inbound(conn.into_inner())));
spawn(move || application(Connection::<OutboundConnection>::new_outbound(conn.into_inner())));
Ok(())
}
Err(err) => Err(ConnectionFailed(err.to_string())),
@ -43,7 +43,7 @@ impl Service<NoListenService> {
pub fn listen<F>(self, port: u16, application: F) -> Result<Service<JoinHandle<ServiceError>>, ServiceError>
where
F: FnOnce(Connection<OutboundConnection>) + Send + Clone + 'static,
F: FnOnce(Connection<InboundConnection>) + Send + Clone + 'static,
{
let jh = spawn(move || {
let listener = TcpListener::bind(format!("127.0.0.1:{}", port));
@ -53,7 +53,7 @@ impl Service<NoListenService> {
match stream {
Ok(conn) => {
let application = application.clone();
spawn(move || application(Connection::<OutboundConnection>::new_outbound(conn)));
spawn(move || application(Connection::<InboundConnection>::new_inbound(conn)));
}
Err(_) => {}
}

View File

@ -32,18 +32,26 @@ mod test {
let outbound_service = |mut conn: Connection<OutboundConnection>| {
let mut transcript = Transcript::new_transcript("tapir-transcript");
let mut auth_app = AuthenicationApp::new(outbound_identity);
auth_app.run_outbound(&mut conn, &mut transcript);
match auth_app.run_outbound(&mut conn, &mut transcript) {
Ok(()) => {
println!("Authenticated");
}
Err(err) => {
println!("Error: {:?}", err);
}
}
};
let mut service = service.listen(10029, outbound_service.clone()).unwrap_or_else(|_| panic!());
let identity = identity.clone();
let inbound_service = |mut conn: Connection<InboundConnection>| {
let mut transcript = Transcript::new_transcript("tapir-transcript");
let mut auth_app = AuthenicationApp::new(identity);
auth_app.run_inbound(&mut conn, &mut transcript);
};
match service.connect("hejcfc3pduhcymo7i3zkoxezbjvw3vdsbavyy7gfsd677l5py7ccnlyd", inbound_service.clone()) {
let mut service = service.listen(10029, inbound_service.clone()).unwrap_or_else(|_| panic!());
match service.connect("yxdappsm5v45njecytgok7j2lwbsxkcfbtuvfkfrx7l2jqf2a73n47yd", outbound_service.clone()) {
_ => {}
}
service.close();