rust-bert-bot/src/main.rs

102 lines
3.6 KiB
Rust

extern crate core;
use std::borrow::{BorrowMut};
use std::fmt::Debug;
use std::fs::{File};
use std::thread;
use ::imp::event::Event;
use ::imp::imp;
use ::imp::imp::Imp;
use ::imp::behaviour::AllowListMembers;
use ::imp::behaviour::NewContactPolicy;
use ::imp::behaviour::BehaviourBuilder;
use ::imp::behaviour::Behaviour;
use libcwtch;
use libcwtch::CwtchLib;
use libcwtch::structs::*;
use serde_json;
use serde::{Deserialize, Serialize};
use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
const BOT_HOME: &str = "~/.cwtch/bots/rust_bert";
const PASSWORD: &str = "be gay do crime";
const BOT_NAME: &str = "Rust-Bert";
const USER_CONFIG_FILE: &str = "user_config.json";
#[derive(Debug, Serialize, Deserialize)]
struct UserConfig {
allow_list: Vec<String>
}
struct RustBertBot {
allow_list: Vec<String>,
conversation_model : ConversationModel,
conversation_manager: ConversationManager,
}
impl RustBertBot {
pub fn new() -> Self {
let file = File::open(USER_CONFIG_FILE).expect(&format!("could not open allow list file: {}", USER_CONFIG_FILE));
let config: UserConfig = serde_json::from_reader(file)
.expect("file should be proper JSON");
let conversation_model = match ConversationModel::new(Default::default()) {
Ok(cm) => cm,
Err(err) => panic!("ConversationModel::new error: {:?}", err),
};
let mut conversation_manager = ConversationManager::new();
RustBertBot {allow_list: config.allow_list, conversation_model, conversation_manager }
}
}
fn main() {
let mut rust_bert_bot = RustBertBot::new();
let allow_list = AllowListMembers::new(rust_bert_bot.allow_list.clone(), vec!());
let behaviour: Behaviour = BehaviourBuilder::new().name(BOT_NAME.to_string()).new_contact_policy(NewContactPolicy::AllowList(allow_list)).build();
let event_loop_handle = thread::spawn(move || {
let mut bot = Imp::spawn(behaviour, PASSWORD.to_string(), BOT_HOME.to_string());
bot.event_loop::<RustBertBot>(rust_bert_bot.borrow_mut());
});
event_loop_handle.join().expect("Error running event loop");
}
impl imp::EventHandler for RustBertBot {
fn handle(&mut self, cwtch: &dyn CwtchLib, profile_opt: Option<&Profile>, event: Event) {
match event {
Event::NewMessageFromPeer { conversation_id, handle, timestamp_received, message } => {
match profile_opt {
Some(profile) => {
let from = profile.handle.as_str();
let message_wrapper: Message =
serde_json::from_str(&message).expect("Error parsing message");
let mut message = message_wrapper.d.clone();
let bert_conversation_id = self.conversation_manager.create(&message);
let output = self.conversation_model.generate_responses(&mut self.conversation_manager);
let response = Message {
o: 1,
d: output[&bert_conversation_id].to_string(),
};
match serde_json::to_string(&response) {
Ok(response_json) => cwtch.send_message(&profile.handle, conversation_id, &response_json),
Err(e) => { println!("Error parsing json response: {}", e.to_string()); "".to_string() }
};
}
_ => {}
}
}
Event::ErrUnhandled { name, data } => eprintln!("unhandled event: {}!", name),
_ => ()
};
}
}