102 lines
3.6 KiB
Rust
102 lines
3.6 KiB
Rust
extern crate core;
|
|
|
|
use std::borrow::{BorrowMut};
|
|
use std::fmt::Debug;
|
|
use std::fs::{File};
|
|
use std::thread;
|
|
|
|
use ::imp::event::Event;
|
|
use ::imp::imp;
|
|
use ::imp::imp::Imp;
|
|
use ::imp::behaviour::AllowListMembers;
|
|
use ::imp::behaviour::NewContactPolicy;
|
|
use ::imp::behaviour::BehaviourBuilder;
|
|
use ::imp::behaviour::Behaviour;
|
|
use libcwtch;
|
|
use libcwtch::CwtchLib;
|
|
use libcwtch::structs::*;
|
|
use serde_json;
|
|
use serde::{Deserialize, Serialize};
|
|
use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
|
|
|
|
|
|
const BOT_HOME: &str = "~/.cwtch/bots/rust_bert";
|
|
const PASSWORD: &str = "be gay do crime";
|
|
|
|
const BOT_NAME: &str = "Rust-Bert";
|
|
|
|
const USER_CONFIG_FILE: &str = "user_config.json";
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct UserConfig {
|
|
allow_list: Vec<String>
|
|
}
|
|
|
|
struct RustBertBot {
|
|
allow_list: Vec<String>,
|
|
conversation_model : ConversationModel,
|
|
conversation_manager: ConversationManager,
|
|
}
|
|
|
|
impl RustBertBot {
|
|
pub fn new() -> Self {
|
|
let file = File::open(USER_CONFIG_FILE).expect(&format!("could not open allow list file: {}", USER_CONFIG_FILE));
|
|
let config: UserConfig = serde_json::from_reader(file)
|
|
.expect("file should be proper JSON");
|
|
|
|
let conversation_model = match ConversationModel::new(Default::default()) {
|
|
Ok(cm) => cm,
|
|
Err(err) => panic!("ConversationModel::new error: {:?}", err),
|
|
};
|
|
let mut conversation_manager = ConversationManager::new();
|
|
|
|
RustBertBot {allow_list: config.allow_list, conversation_model, conversation_manager }
|
|
}
|
|
}
|
|
|
|
fn main() {
|
|
let mut rust_bert_bot = RustBertBot::new();
|
|
let allow_list = AllowListMembers::new(rust_bert_bot.allow_list.clone(), vec!());
|
|
let behaviour: Behaviour = BehaviourBuilder::new().name(BOT_NAME.to_string()).new_contact_policy(NewContactPolicy::AllowList(allow_list)).build();
|
|
|
|
let event_loop_handle = thread::spawn(move || {
|
|
let mut bot = Imp::spawn(behaviour, PASSWORD.to_string(), BOT_HOME.to_string());
|
|
|
|
bot.event_loop::<RustBertBot>(rust_bert_bot.borrow_mut());
|
|
});
|
|
|
|
event_loop_handle.join().expect("Error running event loop");
|
|
}
|
|
|
|
impl imp::EventHandler for RustBertBot {
|
|
fn handle(&mut self, cwtch: &dyn CwtchLib, profile_opt: Option<&Profile>, event: Event) {
|
|
match event {
|
|
Event::NewMessageFromPeer { conversation_id, handle, timestamp_received, message } => {
|
|
match profile_opt {
|
|
Some(profile) => {
|
|
let from = profile.handle.as_str();
|
|
let message_wrapper: Message =
|
|
serde_json::from_str(&message).expect("Error parsing message");
|
|
let mut message = message_wrapper.d.clone();
|
|
|
|
let bert_conversation_id = self.conversation_manager.create(&message);
|
|
let output = self.conversation_model.generate_responses(&mut self.conversation_manager);
|
|
|
|
let response = Message {
|
|
o: 1,
|
|
d: output[&bert_conversation_id].to_string(),
|
|
};
|
|
match serde_json::to_string(&response) {
|
|
Ok(response_json) => cwtch.send_message(&profile.handle, conversation_id, &response_json),
|
|
Err(e) => { println!("Error parsing json response: {}", e.to_string()); "".to_string() }
|
|
};
|
|
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
Event::ErrUnhandled { name, data } => eprintln!("unhandled event: {}!", name),
|
|
_ => ()
|
|
};
|
|
}
|
|
} |