demo of an Imp bot using rust-bert to chat
This commit is contained in:
commit
149c92593c
|
@ -0,0 +1 @@
|
|||
/target
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "rust-bert-bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
libcwtch = "0.3.2"
|
||||
imp = {path = "../imp"}
|
||||
serde_json = "1.0"
|
||||
serde = { version = "1.0.127", features = ["derive"] }
|
||||
serde_with = { version = "1.10.0" }
|
||||
rust-bert = "0.17.0"
|
|
@ -0,0 +1,17 @@
|
|||
# Rust-Bert-Bot
|
||||
|
||||
An Imp Cwtch bot demo using [rust-bert](https://github.com/guillaume-be/rust-bert) to chat
|
||||
|
||||
## Usage
|
||||
|
||||
Create a `user_config.json` with an allow list of bot peer admins with their cwtch addresses:
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
{
|
||||
"allow_list": ["xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", "yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy"]
|
||||
}
|
||||
```
|
||||
|
||||
![Rust-Bert bot in Cwtch](rust-bert.png)
|
Binary file not shown.
After Width: | Height: | Size: 86 KiB |
|
@ -0,0 +1,102 @@
|
|||
extern crate core;
|
||||
|
||||
use std::borrow::{BorrowMut};
|
||||
use std::fmt::Debug;
|
||||
use std::fs::{File};
|
||||
use std::thread;
|
||||
|
||||
use ::imp::event::Event;
|
||||
use ::imp::imp;
|
||||
use ::imp::imp::Imp;
|
||||
use ::imp::behaviour::AllowListMembers;
|
||||
use ::imp::behaviour::NewContactPolicy;
|
||||
use ::imp::behaviour::BehaviourBuilder;
|
||||
use ::imp::behaviour::Behaviour;
|
||||
use libcwtch;
|
||||
use libcwtch::CwtchLib;
|
||||
use libcwtch::structs::*;
|
||||
use serde_json;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
|
||||
|
||||
|
||||
const BOT_HOME: &str = "~/.cwtch/bots/rust_bert";
|
||||
const PASSWORD: &str = "be gay do crime";
|
||||
|
||||
const BOT_NAME: &str = "Rust-Bert";
|
||||
|
||||
const USER_CONFIG_FILE: &str = "user_config.json";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UserConfig {
|
||||
allow_list: Vec<String>
|
||||
}
|
||||
|
||||
struct RustBertBot {
|
||||
allow_list: Vec<String>,
|
||||
conversation_model : ConversationModel,
|
||||
conversation_manager: ConversationManager,
|
||||
}
|
||||
|
||||
impl RustBertBot {
|
||||
pub fn new() -> Self {
|
||||
let file = File::open(USER_CONFIG_FILE).expect(&format!("could not open allow list file: {}", USER_CONFIG_FILE));
|
||||
let config: UserConfig = serde_json::from_reader(file)
|
||||
.expect("file should be proper JSON");
|
||||
|
||||
let conversation_model = match ConversationModel::new(Default::default()) {
|
||||
Ok(cm) => cm,
|
||||
Err(err) => panic!("ConversationModel::new error: {:?}", err),
|
||||
};
|
||||
let mut conversation_manager = ConversationManager::new();
|
||||
|
||||
RustBertBot {allow_list: config.allow_list, conversation_model, conversation_manager }
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let mut rust_bert_bot = RustBertBot::new();
|
||||
let allow_list = AllowListMembers::new(rust_bert_bot.allow_list.clone(), vec!());
|
||||
let behaviour: Behaviour = BehaviourBuilder::new().name(BOT_NAME.to_string()).new_contact_policy(NewContactPolicy::AllowList(allow_list)).build();
|
||||
|
||||
let event_loop_handle = thread::spawn(move || {
|
||||
let mut bot = Imp::spawn(behaviour, PASSWORD.to_string(), BOT_HOME.to_string());
|
||||
|
||||
bot.event_loop::<RustBertBot>(rust_bert_bot.borrow_mut());
|
||||
});
|
||||
|
||||
event_loop_handle.join().expect("Error running event loop");
|
||||
}
|
||||
|
||||
impl imp::EventHandler for RustBertBot {
|
||||
fn handle(&mut self, cwtch: &dyn CwtchLib, profile_opt: Option<&Profile>, event: Event) {
|
||||
match event {
|
||||
Event::NewMessageFromPeer { conversation_id, handle, timestamp_received, message } => {
|
||||
match profile_opt {
|
||||
Some(profile) => {
|
||||
let from = profile.handle.as_str();
|
||||
let message_wrapper: Message =
|
||||
serde_json::from_str(&message).expect("Error parsing message");
|
||||
let mut message = message_wrapper.d.clone();
|
||||
|
||||
let bert_conversation_id = self.conversation_manager.create(&message);
|
||||
let output = self.conversation_model.generate_responses(&mut self.conversation_manager);
|
||||
|
||||
let response = Message {
|
||||
o: 1,
|
||||
d: output[&bert_conversation_id].to_string(),
|
||||
};
|
||||
match serde_json::to_string(&response) {
|
||||
Ok(response_json) => cwtch.send_message(&profile.handle, conversation_id, &response_json),
|
||||
Err(e) => { println!("Error parsing json response: {}", e.to_string()); "".to_string() }
|
||||
};
|
||||
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Event::ErrUnhandled { name, data } => eprintln!("unhandled event: {}!", name),
|
||||
_ => ()
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
{
|
||||
"allow_list": []
|
||||
}
|
Loading…
Reference in New Issue