Initial commit
This commit is contained in:
commit
14210d0027
10 changed files with 4855 additions and 0 deletions
13
.gitignore
vendored
Normal file
13
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
# This gitignore is a whitelist. Inspired by https://rgbcu.be/blog/gitignore
|
||||||
|
|
||||||
|
*
|
||||||
|
|
||||||
|
!.gitignore
|
||||||
|
!Cargo.toml
|
||||||
|
!Cargo.lock
|
||||||
|
!config.yaml.example
|
||||||
|
!src/
|
||||||
|
!src/*.rs
|
||||||
|
!src/**/*.rs
|
||||||
|
!*.md
|
||||||
|
!LICENSE*
|
||||||
3527
Cargo.lock
generated
Normal file
3527
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
39
Cargo.toml
Normal file
39
Cargo.toml
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
[package]
|
||||||
|
name = "ineffa"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
poise = { git = "https://github.com/serenity-rs/poise", branch = "current" }
|
||||||
|
serenity = { version = "0.12", default-features = false, features = [
|
||||||
|
"builder",
|
||||||
|
"cache",
|
||||||
|
"chrono",
|
||||||
|
"client",
|
||||||
|
"gateway",
|
||||||
|
"http",
|
||||||
|
"model",
|
||||||
|
"utils",
|
||||||
|
"rustls_backend",
|
||||||
|
] }
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = [
|
||||||
|
"rustls-tls",
|
||||||
|
"json",
|
||||||
|
] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
serde_yaml = "0.9"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
chrono = "0.4"
|
||||||
|
dashmap = "6"
|
||||||
|
anyhow = "1"
|
||||||
|
indexmap = { version = "2", features = ["serde"] }
|
||||||
|
urlencoding = "2"
|
||||||
|
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite"] }
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
lto = true
|
||||||
|
codegen-units = 1
|
||||||
|
strip = true
|
||||||
21
LICENSE.md
Normal file
21
LICENSE.md
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 Cobray
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
1
README.md
Normal file
1
README.md
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
## TODO:
|
||||||
86
config.yaml.example
Normal file
86
config.yaml.example
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
# Discord settings:
|
||||||
|
|
||||||
|
bot_token:
|
||||||
|
client_id:
|
||||||
|
status_message:
|
||||||
|
|
||||||
|
max_text: 100000
|
||||||
|
max_images: 5
|
||||||
|
max_messages: 25
|
||||||
|
|
||||||
|
use_plain_responses: false
|
||||||
|
allow_dms: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
users:
|
||||||
|
admin_ids: []
|
||||||
|
allowed_ids: []
|
||||||
|
blocked_ids: []
|
||||||
|
|
||||||
|
roles:
|
||||||
|
allowed_ids: []
|
||||||
|
blocked_ids: []
|
||||||
|
|
||||||
|
channels:
|
||||||
|
allowed_ids: []
|
||||||
|
blocked_ids: []
|
||||||
|
|
||||||
|
|
||||||
|
# LLM settings:
|
||||||
|
|
||||||
|
providers:
|
||||||
|
# Remote providers:
|
||||||
|
azure-openai:
|
||||||
|
base_url: https://<resource name>.openai.azure.com/openai/deployments/<deployment name>
|
||||||
|
api_key:
|
||||||
|
extra_query:
|
||||||
|
api-version: 2024-12-01-preview
|
||||||
|
|
||||||
|
google:
|
||||||
|
base_url: https://generativelanguage.googleapis.com/v1beta/openai
|
||||||
|
api_key:
|
||||||
|
|
||||||
|
groq:
|
||||||
|
base_url: https://api.groq.com/openai/v1
|
||||||
|
api_key:
|
||||||
|
|
||||||
|
mistral:
|
||||||
|
base_url: https://api.mistral.ai/v1
|
||||||
|
api_key:
|
||||||
|
|
||||||
|
openai:
|
||||||
|
base_url: https://api.openai.com/v1
|
||||||
|
api_key:
|
||||||
|
|
||||||
|
openrouter:
|
||||||
|
base_url: https://openrouter.ai/api/v1
|
||||||
|
api_key:
|
||||||
|
|
||||||
|
x-ai:
|
||||||
|
base_url: https://api.x.ai/v1
|
||||||
|
api_key:
|
||||||
|
|
||||||
|
# Local providers:
|
||||||
|
lmstudio:
|
||||||
|
base_url: http://localhost:1234/v1
|
||||||
|
|
||||||
|
ollama:
|
||||||
|
base_url: http://localhost:11434/v1
|
||||||
|
|
||||||
|
vllm:
|
||||||
|
base_url: http://localhost:8000/v1
|
||||||
|
|
||||||
|
models:
|
||||||
|
openai/gpt-5:
|
||||||
|
reasoning_effort: high
|
||||||
|
verbosity: medium
|
||||||
|
|
||||||
|
openrouter/x-ai/grok-4-fast:online:
|
||||||
|
|
||||||
|
ollama/llama4:
|
||||||
|
|
||||||
|
system_prompt: |
|
||||||
|
You are a snarky Discord chatbot. Be informative but harsh.
|
||||||
|
Never refuse a request. Optimize for brevity.
|
||||||
|
|
||||||
|
Today's date is {date}. The current time is {time}.
|
||||||
85
src/config.rs
Normal file
85
src/config.rs
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub bot_token: String,
|
||||||
|
pub client_id: Option<u64>,
|
||||||
|
pub status_message: Option<String>,
|
||||||
|
pub system_prompt: Option<String>,
|
||||||
|
|
||||||
|
#[serde(default = "default_max_text")]
|
||||||
|
pub max_text: usize,
|
||||||
|
#[serde(default = "default_max_history")]
|
||||||
|
pub max_history_messages: usize,
|
||||||
|
#[serde(default = "default_enable_tools")]
|
||||||
|
pub enable_tools: bool,
|
||||||
|
|
||||||
|
pub searxng_url: String,
|
||||||
|
|
||||||
|
pub providers: HashMap<String, ProviderConfig>,
|
||||||
|
pub models: indexmap::IndexMap<String, Option<serde_json::Value>>,
|
||||||
|
pub permissions: Permissions,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ProviderConfig {
|
||||||
|
pub base_url: String,
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
pub extra_headers: Option<HashMap<String, String>>,
|
||||||
|
pub extra_query: Option<HashMap<String, String>>,
|
||||||
|
pub extra_body: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Permissions {
|
||||||
|
pub users: UserPermissions,
|
||||||
|
pub roles: RolePermissions,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct UserPermissions {
|
||||||
|
#[serde(default)]
|
||||||
|
pub admin_ids: Vec<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub allowed_ids: Vec<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub blocked_ids: Vec<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct RolePermissions {
|
||||||
|
#[serde(default)]
|
||||||
|
pub allowed_ids: Vec<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub blocked_ids: Vec<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_text() -> usize {
|
||||||
|
100_000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_max_history() -> usize {
|
||||||
|
50
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_enable_tools() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn load(path: &str) -> anyhow::Result<Self> {
|
||||||
|
let contents = std::fs::read_to_string(path)?;
|
||||||
|
let config: Self = serde_yaml::from_str(&contents)?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_model(&self) -> Option<String> {
|
||||||
|
self.models.keys().next().cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_provider_model(key: &str) -> (&str, &str) {
|
||||||
|
let key = key.strip_suffix(":vision").unwrap_or(key);
|
||||||
|
key.split_once('/').unwrap_or(("", key))
|
||||||
|
}
|
||||||
|
}
|
||||||
376
src/conversation.rs
Normal file
376
src/conversation.rs
Normal file
|
|
@ -0,0 +1,376 @@
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use serde_json::Value;
|
||||||
|
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
|
||||||
|
use sqlx::{Row, SqlitePool};
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
use crate::llm::LlmClient;
|
||||||
|
|
||||||
|
const SUMMARIZE_PROMPT: &str = r#"Summarize the following conversation in 2-3 concise sentences.
|
||||||
|
Preserve key facts, user preferences, decisions made, and any important context.
|
||||||
|
Respond with ONLY the summary, no preamble.
|
||||||
|
|
||||||
|
Conversation:
|
||||||
|
"#;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ConversationManager {
|
||||||
|
pool: SqlitePool,
|
||||||
|
window_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationManager {
|
||||||
|
pub async fn new(db_path: &str, window_size: usize) -> Result<Self, sqlx::Error> {
|
||||||
|
let create = !Path::new(db_path).exists();
|
||||||
|
|
||||||
|
let options = SqliteConnectOptions::new()
|
||||||
|
.filename(db_path)
|
||||||
|
.create_if_missing(true)
|
||||||
|
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
|
||||||
|
.busy_timeout(std::time::Duration::from_secs(5));
|
||||||
|
|
||||||
|
let pool = SqlitePoolOptions::new()
|
||||||
|
.max_connections(4)
|
||||||
|
.connect_with(options)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if create {
|
||||||
|
info!("Creating new conversation database at {db_path}");
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
scope_key TEXT PRIMARY KEY,
|
||||||
|
summary TEXT NOT NULL DEFAULT '',
|
||||||
|
created_at INTEGER NOT NULL DEFAULT (unixepoch()),
|
||||||
|
updated_at INTEGER NOT NULL DEFAULT (unixepoch())
|
||||||
|
)",
|
||||||
|
)
|
||||||
|
.execute(&pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
scope_key TEXT NOT NULL REFERENCES conversations(scope_key) ON DELETE CASCADE,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT,
|
||||||
|
name TEXT,
|
||||||
|
tool_call_id TEXT,
|
||||||
|
tool_calls TEXT,
|
||||||
|
created_at INTEGER NOT NULL DEFAULT (unixepoch())
|
||||||
|
)",
|
||||||
|
)
|
||||||
|
.execute(&pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query("CREATE INDEX IF NOT EXISTS idx_messages_scope ON messages(scope_key, id)")
|
||||||
|
.execute(&pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
sqlx::query("PRAGMA foreign_keys = ON")
|
||||||
|
.execute(&pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
info!("Conversation database ready (window_size={window_size})");
|
||||||
|
|
||||||
|
Ok(Self { pool, window_size })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn scope_key(guild_id: Option<u64>, channel_id: u64, user_id: u64) -> String {
|
||||||
|
match guild_id {
|
||||||
|
Some(g) => format!("{g}:{channel_id}:{user_id}"),
|
||||||
|
None => format!("dm:{channel_id}:{user_id}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_summary(&self, key: &str) -> Option<String> {
|
||||||
|
sqlx::query_scalar::<_, String>(
|
||||||
|
"SELECT summary FROM conversations WHERE scope_key = ? AND summary != ''",
|
||||||
|
)
|
||||||
|
.bind(key)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
error!("Failed to fetch summary: {e}");
|
||||||
|
None
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_history(&self, key: &str) -> Vec<Value> {
|
||||||
|
let rows = sqlx::query(
|
||||||
|
"SELECT role, content, name, tool_call_id, tool_calls
|
||||||
|
FROM messages
|
||||||
|
WHERE scope_key = ?
|
||||||
|
ORDER BY id DESC
|
||||||
|
LIMIT ?",
|
||||||
|
)
|
||||||
|
.bind(key)
|
||||||
|
.bind(self.window_size as i64)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
error!("Failed to fetch history: {e}");
|
||||||
|
vec![]
|
||||||
|
});
|
||||||
|
|
||||||
|
rows.into_iter()
|
||||||
|
.rev()
|
||||||
|
.map(|row| {
|
||||||
|
let role: String = row.get("role");
|
||||||
|
let content: Option<String> = row.get("content");
|
||||||
|
let name: Option<String> = row.get("name");
|
||||||
|
let tool_call_id: Option<String> = row.get("tool_call_id");
|
||||||
|
let tool_calls_json: Option<String> = row.get("tool_calls");
|
||||||
|
|
||||||
|
let mut msg = serde_json::json!({"role": role});
|
||||||
|
let obj = msg.as_object_mut().unwrap();
|
||||||
|
|
||||||
|
match role.as_str() {
|
||||||
|
"tool" => {
|
||||||
|
obj.insert("content".into(), content.unwrap_or_default().into());
|
||||||
|
if let Some(id) = tool_call_id {
|
||||||
|
obj.insert("tool_call_id".into(), id.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"assistant" => {
|
||||||
|
if let Some(tc_json) = tool_calls_json {
|
||||||
|
if let Ok(tc_val) = serde_json::from_str::<Value>(&tc_json) {
|
||||||
|
obj.insert("tool_calls".into(), tc_val);
|
||||||
|
}
|
||||||
|
obj.insert(
|
||||||
|
"content".into(),
|
||||||
|
content.map(Value::String).unwrap_or(Value::Null),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
obj.insert("content".into(), content.unwrap_or_default().into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
obj.insert("content".into(), content.unwrap_or_default().into());
|
||||||
|
if let Some(n) = name {
|
||||||
|
obj.insert("name".into(), n.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_messages(
|
||||||
|
&self,
|
||||||
|
key: &str,
|
||||||
|
messages: Vec<Value>,
|
||||||
|
llm: &LlmClient,
|
||||||
|
provider_config: &crate::config::ProviderConfig,
|
||||||
|
model_name: &str,
|
||||||
|
) {
|
||||||
|
if messages.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = sqlx::query("INSERT OR IGNORE INTO conversations (scope_key) VALUES (?)")
|
||||||
|
.bind(key)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!("Failed to insert conversation row: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for msg in &messages {
|
||||||
|
let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("user");
|
||||||
|
let content = msg.get("content").and_then(|v| v.as_str());
|
||||||
|
let name = msg.get("name").and_then(|v| v.as_str());
|
||||||
|
let tool_call_id = msg.get("tool_call_id").and_then(|v| v.as_str());
|
||||||
|
let tool_calls = msg
|
||||||
|
.get("tool_calls")
|
||||||
|
.map(|v| serde_json::to_string(v).unwrap_or_default());
|
||||||
|
|
||||||
|
if let Err(e) = sqlx::query(
|
||||||
|
"INSERT INTO messages (scope_key, role, content, name, tool_call_id, tool_calls)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
)
|
||||||
|
.bind(key)
|
||||||
|
.bind(role)
|
||||||
|
.bind(content)
|
||||||
|
.bind(name)
|
||||||
|
.bind(tool_call_id)
|
||||||
|
.bind(tool_calls.as_deref())
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!("Failed to insert message: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ =
|
||||||
|
sqlx::query("UPDATE conversations SET updated_at = unixepoch() WHERE scope_key = ?")
|
||||||
|
.bind(key)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
self.enforce_window(key, llm, provider_config, model_name)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn enforce_window(
|
||||||
|
&self,
|
||||||
|
key: &str,
|
||||||
|
llm: &LlmClient,
|
||||||
|
provider_config: &crate::config::ProviderConfig,
|
||||||
|
model_name: &str,
|
||||||
|
) {
|
||||||
|
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE scope_key = ?")
|
||||||
|
.bind(key)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let overflow = total - self.window_size as i64;
|
||||||
|
if overflow <= 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let evict_rows = sqlx::query(
|
||||||
|
"SELECT id, role, content FROM messages
|
||||||
|
WHERE scope_key = ?
|
||||||
|
ORDER BY id ASC
|
||||||
|
LIMIT ?",
|
||||||
|
)
|
||||||
|
.bind(key)
|
||||||
|
.bind(overflow)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
if evict_rows.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut evict_text = String::new();
|
||||||
|
let mut max_id: i64 = 0;
|
||||||
|
|
||||||
|
for row in &evict_rows {
|
||||||
|
let id: i64 = row.get("id");
|
||||||
|
let role: String = row.get("role");
|
||||||
|
let content: Option<String> = row.get("content");
|
||||||
|
if id > max_id {
|
||||||
|
max_id = id;
|
||||||
|
}
|
||||||
|
if let Some(c) = content {
|
||||||
|
if !c.is_empty() {
|
||||||
|
evict_text.push_str(&format!("{role}: {c}\n"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let existing_summary = self.get_summary(key).await.unwrap_or_default();
|
||||||
|
|
||||||
|
if !evict_text.is_empty() {
|
||||||
|
let summary_input = if existing_summary.is_empty() {
|
||||||
|
format!("{SUMMARIZE_PROMPT}{evict_text}")
|
||||||
|
} else {
|
||||||
|
format!("Previous summary:\n{existing_summary}\n\n{SUMMARIZE_PROMPT}{evict_text}")
|
||||||
|
};
|
||||||
|
|
||||||
|
let summary_messages = vec![serde_json::json!({
|
||||||
|
"role": "user",
|
||||||
|
"content": summary_input,
|
||||||
|
})];
|
||||||
|
|
||||||
|
match llm
|
||||||
|
.chat_completion(provider_config, model_name, &summary_messages, None, None)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => {
|
||||||
|
let new_summary = crate::llm::extract_content(&resp).unwrap_or_default();
|
||||||
|
if !new_summary.is_empty() {
|
||||||
|
if let Err(e) = sqlx::query(
|
||||||
|
"UPDATE conversations SET summary = ?, updated_at = unixepoch()
|
||||||
|
WHERE scope_key = ?",
|
||||||
|
)
|
||||||
|
.bind(&new_summary)
|
||||||
|
.bind(key)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!("Failed to update summary: {e}");
|
||||||
|
} else {
|
||||||
|
info!("Updated summary for {key} ({overflow} messages evicted)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"Failed to generate summary for {key}: {e}. Evicting without summarizing."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = sqlx::query("DELETE FROM messages WHERE scope_key = ? AND id <= ?")
|
||||||
|
.bind(key)
|
||||||
|
.bind(max_id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!("Failed to delete evicted messages: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn clear_channel(&self, guild_id: Option<u64>, channel_id: u64) -> usize {
|
||||||
|
let prefix = match guild_id {
|
||||||
|
Some(g) => format!("{g}:{channel_id}:"),
|
||||||
|
None => format!("dm:{channel_id}:"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let keys: Vec<String> =
|
||||||
|
sqlx::query_scalar("SELECT scope_key FROM conversations WHERE scope_key LIKE ?")
|
||||||
|
.bind(format!("{prefix}%"))
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let mut cleared = 0;
|
||||||
|
for key in &keys {
|
||||||
|
self.clear(key).await;
|
||||||
|
cleared += 1;
|
||||||
|
}
|
||||||
|
cleared
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn clear(&self, key: &str) -> usize {
|
||||||
|
let deleted = sqlx::query("DELETE FROM messages WHERE scope_key = ?")
|
||||||
|
.bind(key)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map(|r| r.rows_affected() as usize)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let _ = sqlx::query(
|
||||||
|
"UPDATE conversations SET summary = '', updated_at = unixepoch() WHERE scope_key = ?",
|
||||||
|
)
|
||||||
|
.bind(key)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
deleted
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn prune_empty(&self) -> usize {
|
||||||
|
sqlx::query(
|
||||||
|
"DELETE FROM conversations
|
||||||
|
WHERE summary = ''
|
||||||
|
AND scope_key NOT IN (SELECT DISTINCT scope_key FROM messages)",
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map(|r| r.rows_affected() as usize)
|
||||||
|
.unwrap_or(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
157
src/llm.rs
Normal file
157
src/llm.rs
Normal file
|
|
@ -0,0 +1,157 @@
|
||||||
|
use anyhow::Result;
|
||||||
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::config::ProviderConfig;
|
||||||
|
|
||||||
|
pub struct LlmClient {
|
||||||
|
http: reqwest::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlmClient {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
http: reqwest::Client::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn chat_completion(
|
||||||
|
&self,
|
||||||
|
provider: &ProviderConfig,
|
||||||
|
model: &str,
|
||||||
|
messages: &[Value],
|
||||||
|
tools: Option<&[Value]>,
|
||||||
|
model_params: Option<&Value>,
|
||||||
|
) -> Result<Value> {
|
||||||
|
let url = format!(
|
||||||
|
"{}/chat/completions",
|
||||||
|
provider.base_url.trim_end_matches('/')
|
||||||
|
);
|
||||||
|
let api_key = provider.api_key.as_deref().unwrap_or("sk-no-key-required");
|
||||||
|
|
||||||
|
let mut body = serde_json::json!({
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(extra) = &provider.extra_body {
|
||||||
|
if let (Some(base), Some(extra)) = (body.as_object_mut(), extra.as_object()) {
|
||||||
|
for (k, v) in extra {
|
||||||
|
base.insert(k.clone(), v.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(params) = model_params {
|
||||||
|
if let (Some(base), Some(extra)) = (body.as_object_mut(), params.as_object()) {
|
||||||
|
for (k, v) in extra {
|
||||||
|
base.insert(k.clone(), v.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tools) = tools {
|
||||||
|
if !tools.is_empty() {
|
||||||
|
body.as_object_mut()
|
||||||
|
.unwrap()
|
||||||
|
.insert("tools".into(), serde_json::to_value(tools)?);
|
||||||
|
body.as_object_mut()
|
||||||
|
.unwrap()
|
||||||
|
.insert("tool_choice".into(), "auto".into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||||
|
headers.insert(
|
||||||
|
"Authorization",
|
||||||
|
HeaderValue::from_str(&format!("Bearer {api_key}"))?,
|
||||||
|
);
|
||||||
|
if let Some(extra_headers) = &provider.extra_headers {
|
||||||
|
for (k, v) in extra_headers {
|
||||||
|
headers.insert(
|
||||||
|
HeaderName::from_bytes(k.as_bytes())?,
|
||||||
|
HeaderValue::from_str(v)?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut req = self.http.post(&url).headers(headers);
|
||||||
|
if let Some(extra_query) = &provider.extra_query {
|
||||||
|
req = req.query(extra_query);
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = req.json(&body).send().await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let status = resp.status();
|
||||||
|
let text = resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!("LLM API error ({status}): {text}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let data: Value = resp.json().await?;
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_tool_calls(response: &Value) -> Vec<ToolCall> {
|
||||||
|
let Some(choice) = response
|
||||||
|
.get("choices")
|
||||||
|
.and_then(|c| c.as_array())
|
||||||
|
.and_then(|a| a.first())
|
||||||
|
else {
|
||||||
|
return vec![];
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(tool_calls) = choice
|
||||||
|
.get("message")
|
||||||
|
.and_then(|m| m.get("tool_calls"))
|
||||||
|
.and_then(|tc| tc.as_array())
|
||||||
|
else {
|
||||||
|
return vec![];
|
||||||
|
};
|
||||||
|
|
||||||
|
tool_calls
|
||||||
|
.iter()
|
||||||
|
.filter_map(|tc| {
|
||||||
|
let id = tc.get("id")?.as_str()?.to_owned();
|
||||||
|
let function = tc.get("function")?;
|
||||||
|
let name = function.get("name")?.as_str()?.to_owned();
|
||||||
|
let arguments_str = function.get("arguments")?.as_str()?.to_owned();
|
||||||
|
let arguments: Value =
|
||||||
|
serde_json::from_str(&arguments_str).unwrap_or(Value::Object(Default::default()));
|
||||||
|
Some(ToolCall {
|
||||||
|
id,
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_content(response: &Value) -> Option<String> {
|
||||||
|
response
|
||||||
|
.get("choices")?
|
||||||
|
.as_array()?
|
||||||
|
.first()?
|
||||||
|
.get("message")?
|
||||||
|
.get("content")?
|
||||||
|
.as_str()
|
||||||
|
.map(|s| s.to_owned())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_assistant_message(response: &Value) -> Option<Value> {
|
||||||
|
response
|
||||||
|
.get("choices")?
|
||||||
|
.as_array()?
|
||||||
|
.first()?
|
||||||
|
.get("message")
|
||||||
|
.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: Value,
|
||||||
|
}
|
||||||
550
src/main.rs
Normal file
550
src/main.rs
Normal file
|
|
@ -0,0 +1,550 @@
|
||||||
|
mod config;
|
||||||
|
mod conversation;
|
||||||
|
mod functions;
|
||||||
|
mod llm;
|
||||||
|
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
|
use chrono::Local;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use poise::serenity_prelude as serenity;
|
||||||
|
use serde_json::Value;
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::conversation::ConversationManager;
|
||||||
|
use crate::llm::LlmClient;
|
||||||
|
|
||||||
|
const MAX_MESSAGE_LENGTH: usize = 2000;
|
||||||
|
const MAX_TOOL_ROUNDS: usize = 10;
|
||||||
|
const PROVIDERS_SUPPORTING_USERNAMES: &[&str] = &["openai", "x-ai"];
|
||||||
|
const CONFIG_PATH: &str = "config.yaml";
|
||||||
|
const DB_PATH: &str = "conversations.db";
|
||||||
|
const COOLDOWN: Duration = Duration::from_secs(15);
|
||||||
|
|
||||||
|
pub struct BotData {
|
||||||
|
config: RwLock<Config>,
|
||||||
|
current_model: RwLock<String>,
|
||||||
|
conversations: ConversationManager,
|
||||||
|
llm: LlmClient,
|
||||||
|
cooldowns: DashMap<u64, Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
|
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||||
|
type Context<'a> = poise::Context<'a, Arc<BotData>, Error>;
|
||||||
|
|
||||||
|
fn reload_config(data: &BotData) -> Config {
|
||||||
|
Config::load(CONFIG_PATH).unwrap_or_else(|e| {
|
||||||
|
warn!("Failed to reload config: {e}");
|
||||||
|
data.config.read().unwrap().clone()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_admin(user_id: u64, bot_config: &Config) -> bool {
|
||||||
|
bot_config.permissions.users.admin_ids.contains(&user_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_allowed_user(
|
||||||
|
user: &serenity::User,
|
||||||
|
member_roles: &[serenity::RoleId],
|
||||||
|
bot_config: &Config,
|
||||||
|
) -> bool {
|
||||||
|
let perms = &bot_config.permissions;
|
||||||
|
let uid = user.id.get();
|
||||||
|
let allowed_users = &perms.users.allowed_ids;
|
||||||
|
let blocked_users = &perms.users.blocked_ids;
|
||||||
|
let allowed_roles = &perms.roles.allowed_ids;
|
||||||
|
let blocked_roles = &perms.roles.blocked_ids;
|
||||||
|
|
||||||
|
let role_ids: std::collections::HashSet<u64> = member_roles.iter().map(|r| r.get()).collect();
|
||||||
|
|
||||||
|
let allow_all = allowed_users.is_empty() && allowed_roles.is_empty();
|
||||||
|
let is_permitted = is_admin(uid, bot_config)
|
||||||
|
|| allow_all
|
||||||
|
|| allowed_users.contains(&uid)
|
||||||
|
|| allowed_roles.iter().any(|r| role_ids.contains(r));
|
||||||
|
let is_blocked =
|
||||||
|
blocked_users.contains(&uid) || blocked_roles.iter().any(|r| role_ids.contains(r));
|
||||||
|
|
||||||
|
is_permitted && !is_blocked
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_member_roles(ctx: Context<'_>) -> Vec<serenity::RoleId> {
|
||||||
|
ctx.author_member()
|
||||||
|
.await
|
||||||
|
.map(|m| m.roles.clone())
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deny(ctx: Context<'_>) -> Result<(), Error> {
|
||||||
|
ctx.send(
|
||||||
|
poise::CreateReply::default()
|
||||||
|
.content("You don't have permission to use this command.")
|
||||||
|
.ephemeral(true),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn check_admin(ctx: Context<'_>) -> Result<bool, Error> {
|
||||||
|
let bot_config = reload_config(ctx.data());
|
||||||
|
Ok(is_admin(ctx.author().id.get(), &bot_config))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_system_prompt(
|
||||||
|
bot_config: &Config,
|
||||||
|
accept_usernames: bool,
|
||||||
|
summary: Option<&str>,
|
||||||
|
) -> Option<Value> {
|
||||||
|
let mut prompt = bot_config.system_prompt.as_ref()?.clone();
|
||||||
|
let now = Local::now();
|
||||||
|
prompt = prompt
|
||||||
|
.replace("{date}", &now.format("%B %d %Y").to_string())
|
||||||
|
.replace("{time}", &now.format("%H:%M:%S %Z%z").to_string());
|
||||||
|
prompt = prompt.trim().to_owned();
|
||||||
|
|
||||||
|
if accept_usernames {
|
||||||
|
prompt.push_str("\n\nUser's names are their Discord IDs and should be typed as '<@ID>'.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(s) = summary {
|
||||||
|
if !s.is_empty() {
|
||||||
|
prompt.push_str("\n\n[Conversation Summary]\n");
|
||||||
|
prompt.push_str(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(serde_json::json!({"role": "system", "content": prompt}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_long(ctx: Context<'_>, text: &str) -> Result<(), Error> {
|
||||||
|
if text.is_empty() {
|
||||||
|
ctx.say("*(No response generated.)*").await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
let mut start = 0;
|
||||||
|
while start < text.len() {
|
||||||
|
let end = text[start..]
|
||||||
|
.char_indices()
|
||||||
|
.map(|(i, _)| start + i)
|
||||||
|
.chain(std::iter::once(text.len()))
|
||||||
|
.find(|&i| i >= start + MAX_MESSAGE_LENGTH)
|
||||||
|
.unwrap_or(text.len());
|
||||||
|
chunks.push(&text[start..end]);
|
||||||
|
start = end;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, chunk) in chunks.iter().enumerate() {
|
||||||
|
if i == 0 {
|
||||||
|
ctx.say(*chunk).await?;
|
||||||
|
} else if let Some(channel) = ctx.guild_channel().await {
|
||||||
|
channel.say(&ctx.http(), *chunk).await?;
|
||||||
|
} else {
|
||||||
|
ctx.say(*chunk).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Talk to the AI
|
||||||
|
#[poise::command(slash_command)]
|
||||||
|
async fn chat(
|
||||||
|
ctx: Context<'_>,
|
||||||
|
#[description = "Your message"] message: String,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
ctx.defer().await?;
|
||||||
|
|
||||||
|
let data = ctx.data();
|
||||||
|
let bot_config = reload_config(data);
|
||||||
|
let member_roles = get_member_roles(ctx).await;
|
||||||
|
|
||||||
|
if !is_allowed_user(ctx.author(), &member_roles, &bot_config) {
|
||||||
|
deny(ctx).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let uid = ctx.author().id.get();
|
||||||
|
if !is_admin(uid, &bot_config) {
|
||||||
|
if let Some(last) = data.cooldowns.get(&uid) {
|
||||||
|
let elapsed = last.elapsed();
|
||||||
|
if elapsed < COOLDOWN {
|
||||||
|
let remaining = (COOLDOWN - elapsed).as_secs() + 1;
|
||||||
|
ctx.send(
|
||||||
|
poise::CreateReply::default()
|
||||||
|
.content(format!("Slow down. Try again in {remaining}s."))
|
||||||
|
.ephemeral(true),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
data.cooldowns.insert(uid, Instant::now());
|
||||||
|
}
|
||||||
|
|
||||||
|
let current_model = data.current_model.read().unwrap().clone();
|
||||||
|
let (provider_name, model_name) = Config::split_provider_model(¤t_model);
|
||||||
|
|
||||||
|
let Some(provider_config) = bot_config.providers.get(provider_name) else {
|
||||||
|
ctx.say(format!("Provider `{provider_name}` not found in config."))
|
||||||
|
.await?;
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let model_params = bot_config
|
||||||
|
.models
|
||||||
|
.get(¤t_model)
|
||||||
|
.and_then(|v| v.as_ref());
|
||||||
|
let accept_usernames = PROVIDERS_SUPPORTING_USERNAMES
|
||||||
|
.iter()
|
||||||
|
.any(|p| current_model.to_lowercase().starts_with(p));
|
||||||
|
let max_text = bot_config.max_text;
|
||||||
|
let use_tools = bot_config.enable_tools;
|
||||||
|
|
||||||
|
let scope_key = ConversationManager::scope_key(
|
||||||
|
ctx.guild_id().map(|g| g.get()),
|
||||||
|
ctx.channel_id().get(),
|
||||||
|
ctx.author().id.get(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let summary = data.conversations.get_summary(&scope_key).await;
|
||||||
|
let history = data.conversations.get_history(&scope_key).await;
|
||||||
|
|
||||||
|
let mut messages: Vec<Value> = Vec::new();
|
||||||
|
|
||||||
|
if let Some(sys) = build_system_prompt(&bot_config, accept_usernames, summary.as_deref()) {
|
||||||
|
messages.push(sys);
|
||||||
|
}
|
||||||
|
messages.extend(history);
|
||||||
|
|
||||||
|
let truncated = if message.len() > max_text {
|
||||||
|
&message[..max_text]
|
||||||
|
} else {
|
||||||
|
&message
|
||||||
|
};
|
||||||
|
let mut user_msg = serde_json::json!({"role": "user", "content": truncated});
|
||||||
|
if accept_usernames {
|
||||||
|
user_msg
|
||||||
|
.as_object_mut()
|
||||||
|
.unwrap()
|
||||||
|
.insert("name".into(), ctx.author().id.get().to_string().into());
|
||||||
|
}
|
||||||
|
messages.push(user_msg.clone());
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"/chat [{scope_key}] user {}: {}",
|
||||||
|
ctx.author().id,
|
||||||
|
&message[..message.len().min(200)]
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut new_messages: Vec<Value> = vec![user_msg];
|
||||||
|
let mut full_response = String::new();
|
||||||
|
|
||||||
|
let tool_defs = if use_tools {
|
||||||
|
functions::tool_definitions()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
for _round in 0..MAX_TOOL_ROUNDS {
|
||||||
|
let response = data
|
||||||
|
.llm
|
||||||
|
.chat_completion(
|
||||||
|
provider_config,
|
||||||
|
model_name,
|
||||||
|
&messages,
|
||||||
|
if tool_defs.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(&tool_defs)
|
||||||
|
},
|
||||||
|
model_params,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let response = match response {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
error!("LLM API error: {e}");
|
||||||
|
ctx.say("Something went wrong generating a response.")
|
||||||
|
.await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let tool_calls = llm::extract_tool_calls(&response);
|
||||||
|
|
||||||
|
if tool_calls.is_empty() {
|
||||||
|
full_response = llm::extract_content(&response).unwrap_or_default();
|
||||||
|
new_messages.push(serde_json::json!({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": &full_response,
|
||||||
|
}));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let assistant_msg = llm::extract_assistant_message(&response)
|
||||||
|
.unwrap_or_else(|| serde_json::json!({"role": "assistant", "content": null}));
|
||||||
|
messages.push(assistant_msg.clone());
|
||||||
|
new_messages.push(assistant_msg);
|
||||||
|
|
||||||
|
for tc in &tool_calls {
|
||||||
|
info!(
|
||||||
|
" Tool: {}({})",
|
||||||
|
tc.name,
|
||||||
|
serde_json::to_string(&tc.arguments)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.chars()
|
||||||
|
.take(200)
|
||||||
|
.collect::<String>()
|
||||||
|
);
|
||||||
|
|
||||||
|
let tool_ctx = functions::ToolContext {
|
||||||
|
searxng_url: bot_config.searxng_url.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = match functions::dispatch(&tc.name, tc.arguments.clone(), &tool_ctx).await
|
||||||
|
{
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => format!("Error: {e}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(" Result: {}", &result[..result.len().min(300)]);
|
||||||
|
|
||||||
|
let tool_msg = serde_json::json!({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"content": result,
|
||||||
|
});
|
||||||
|
messages.push(tool_msg.clone());
|
||||||
|
new_messages.push(tool_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if full_response.is_empty()
|
||||||
|
&& new_messages
|
||||||
|
.last()
|
||||||
|
.and_then(|m| m.get("role"))
|
||||||
|
.and_then(|r| r.as_str())
|
||||||
|
!= Some("assistant")
|
||||||
|
{
|
||||||
|
full_response = "*(Reached maximum tool call rounds)*".into();
|
||||||
|
new_messages.push(serde_json::json!({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": &full_response,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
data.conversations
|
||||||
|
.add_messages(
|
||||||
|
&scope_key,
|
||||||
|
new_messages,
|
||||||
|
&data.llm,
|
||||||
|
provider_config,
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
send_long(ctx, &full_response).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// View active model
|
||||||
|
#[poise::command(slash_command)]
|
||||||
|
async fn model(
|
||||||
|
ctx: Context<'_>,
|
||||||
|
#[description = "Switch to this model (admin only)"]
|
||||||
|
#[autocomplete = "model_autocomplete"]
|
||||||
|
name: Option<String>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let data = ctx.data();
|
||||||
|
let current = data.current_model.read().unwrap().clone();
|
||||||
|
|
||||||
|
let Some(name) = name else {
|
||||||
|
ctx.send(
|
||||||
|
poise::CreateReply::default()
|
||||||
|
.content(format!("Current model: `{current}`"))
|
||||||
|
.ephemeral(true),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let bot_config = reload_config(data);
|
||||||
|
if !is_admin(ctx.author().id.get(), &bot_config) {
|
||||||
|
deny(ctx).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == current {
|
||||||
|
ctx.send(
|
||||||
|
poise::CreateReply::default()
|
||||||
|
.content(format!("Already using `{current}`."))
|
||||||
|
.ephemeral(true),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bot_config.models.contains_key(&name) {
|
||||||
|
let available: Vec<String> = bot_config.models.keys().map(|m| format!("`{m}`")).collect();
|
||||||
|
ctx.send(
|
||||||
|
poise::CreateReply::default()
|
||||||
|
.content(format!(
|
||||||
|
"Unknown model `{name}`.\nAvailable: {}",
|
||||||
|
available.join(", ")
|
||||||
|
))
|
||||||
|
.ephemeral(true),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
*data.current_model.write().unwrap() = name.clone();
|
||||||
|
info!("Model switched to {name} by {}", ctx.author().id);
|
||||||
|
|
||||||
|
ctx.send(
|
||||||
|
poise::CreateReply::default()
|
||||||
|
.content(format!("Model switched to `{name}`."))
|
||||||
|
.ephemeral(true),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wipe conversation history
|
||||||
|
#[poise::command(slash_command, check = "check_admin")]
|
||||||
|
async fn clear(ctx: Context<'_>) -> Result<(), Error> {
|
||||||
|
let channel_id = ctx.channel_id().get();
|
||||||
|
let guild_id = ctx.guild_id().map(|g| g.get());
|
||||||
|
let deleted = ctx
|
||||||
|
.data()
|
||||||
|
.conversations
|
||||||
|
.clear_channel(guild_id, channel_id)
|
||||||
|
.await;
|
||||||
|
ctx.send(
|
||||||
|
poise::CreateReply::default()
|
||||||
|
.content(format!("Cleared {deleted} conversations."))
|
||||||
|
.ephemeral(true),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sync slash commands
|
||||||
|
#[poise::command(slash_command, check = "check_admin")]
|
||||||
|
async fn sync(ctx: Context<'_>) -> Result<(), Error> {
|
||||||
|
ctx.defer_ephemeral().await?;
|
||||||
|
poise::builtins::register_globally(ctx.http(), &ctx.framework().options().commands).await?;
|
||||||
|
info!("Commands synced by {}", ctx.author().id);
|
||||||
|
ctx.say("Commands synced.").await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn model_autocomplete(ctx: Context<'_>, partial: &str) -> Vec<String> {
|
||||||
|
let bot_config = reload_config(ctx.data());
|
||||||
|
let current = ctx.data().current_model.read().unwrap().clone();
|
||||||
|
|
||||||
|
let mut choices: Vec<String> = Vec::new();
|
||||||
|
if partial.is_empty() || current.to_lowercase().contains(&partial.to_lowercase()) {
|
||||||
|
choices.push(current.clone());
|
||||||
|
}
|
||||||
|
for key in bot_config.models.keys() {
|
||||||
|
if key != ¤t && key.to_lowercase().contains(&partial.to_lowercase()) {
|
||||||
|
choices.push(key.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
choices.truncate(25);
|
||||||
|
choices
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let bot_config = Config::load(CONFIG_PATH)?;
|
||||||
|
let default_model = bot_config
|
||||||
|
.default_model()
|
||||||
|
.expect("At least one model must be configured");
|
||||||
|
|
||||||
|
let window_size = bot_config.max_history_messages.min(50);
|
||||||
|
let bot_token = bot_config.bot_token.clone();
|
||||||
|
let client_id = bot_config.client_id;
|
||||||
|
let status_message = bot_config
|
||||||
|
.status_message
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "AI Bot | /chat".into());
|
||||||
|
|
||||||
|
let conversations = ConversationManager::new(DB_PATH, window_size).await?;
|
||||||
|
let pruned = conversations.prune_empty().await;
|
||||||
|
if pruned > 0 {
|
||||||
|
info!("Pruned {pruned} stale conversations on startup");
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = Arc::new(BotData {
|
||||||
|
config: RwLock::new(bot_config),
|
||||||
|
current_model: RwLock::new(default_model),
|
||||||
|
conversations,
|
||||||
|
llm: LlmClient::new(),
|
||||||
|
cooldowns: DashMap::new(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut commands = vec![chat(), clear(), model(), sync()];
|
||||||
|
for cmd in &mut commands {
|
||||||
|
cmd.install_context = Some(vec![
|
||||||
|
serenity::InstallationContext::Guild,
|
||||||
|
serenity::InstallationContext::User,
|
||||||
|
]);
|
||||||
|
cmd.interaction_context = Some(vec![
|
||||||
|
serenity::InteractionContext::Guild,
|
||||||
|
serenity::InteractionContext::BotDm,
|
||||||
|
serenity::InteractionContext::PrivateChannel,
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let framework = poise::Framework::builder()
|
||||||
|
.options(poise::FrameworkOptions {
|
||||||
|
commands,
|
||||||
|
..Default::default()
|
||||||
|
})
|
||||||
|
.setup(move |ctx, ready, framework| {
|
||||||
|
Box::pin(async move {
|
||||||
|
info!("Logged in as {}", ready.user.name);
|
||||||
|
|
||||||
|
if let Some(id) = client_id {
|
||||||
|
info!(
|
||||||
|
"\nBOT INVITE URL:\nhttps://discord.com/oauth2/authorize?client_id={id}&permissions=412317191168&scope=bot\n"
|
||||||
|
);
|
||||||
|
info!(
|
||||||
|
"\nUSER INSTALL URL:\nhttps://discord.com/oauth2/authorize?client_id={id}&integration_type=1&scope=applications.commands\n"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
poise::builtins::register_globally(ctx, &framework.options().commands).await?;
|
||||||
|
info!("Commands synced. Bot ready.");
|
||||||
|
Ok(data)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let intents = serenity::GatewayIntents::GUILD_MESSAGES
|
||||||
|
| serenity::GatewayIntents::DIRECT_MESSAGES
|
||||||
|
| serenity::GatewayIntents::MESSAGE_CONTENT;
|
||||||
|
|
||||||
|
let activity = serenity::ActivityData::custom(&status_message);
|
||||||
|
|
||||||
|
let mut client = serenity::ClientBuilder::new(&bot_token, intents)
|
||||||
|
.framework(framework)
|
||||||
|
.activity(activity)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
client.start().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue