feat: 重构工具和协议模块,添加工具注册和会话管理逻辑,优化消息处理
This commit is contained in:
parent
af7860f2fd
commit
8f27bd2735
@ -4,8 +4,8 @@ use std::sync::Arc;
|
||||
use tokio::sync::{RwLock, mpsc};
|
||||
|
||||
use crate::bus::{MessageBus, OutboundMessage};
|
||||
use crate::gateway::ws_adapter::ws_outbound_from_outbound_message;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::protocol::ws_adapter::ws_outbound_from_outbound_message;
|
||||
|
||||
use super::base::{Channel, ChannelError};
|
||||
|
||||
|
||||
@ -1 +1,2 @@
|
||||
pub mod messages;
|
||||
pub mod tools;
|
||||
|
||||
15
src/domain/tools.rs
Normal file
15
src/domain/tools.rs
Normal file
@ -0,0 +1,15 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
@ -15,7 +15,6 @@ pub mod session_factory;
|
||||
pub mod session_pool;
|
||||
pub mod tool_registry_factory;
|
||||
pub mod ws;
|
||||
pub mod ws_adapter;
|
||||
|
||||
use axum::{Router, routing};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
pub mod ws_adapter;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@ -4,7 +4,8 @@ use crate::bus::OutboundMessage;
|
||||
use crate::bus::message::OutboundEventKind;
|
||||
#[cfg(test)]
|
||||
use crate::bus::message::{ToolMessageState, format_tool_call_content};
|
||||
use crate::protocol::WsOutbound;
|
||||
|
||||
use super::WsOutbound;
|
||||
|
||||
const TOOL_PENDING_RESUME_HINT: &str = "完成外部操作后,直接发一条继续消息即可。";
|
||||
|
||||
@ -7,9 +7,8 @@ pub use self::openai::OpenAIProvider;
|
||||
|
||||
use crate::config::LLMProviderConfig;
|
||||
pub use crate::domain::messages::ToolCall;
|
||||
pub use traits::{
|
||||
ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Tool, ToolFunction, Usage,
|
||||
};
|
||||
pub use crate::domain::tools::{Tool, ToolFunction};
|
||||
pub use traits::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Message, Usage};
|
||||
|
||||
pub fn create_provider(config: LLMProviderConfig) -> Result<Box<dyn LLMProvider>, ProviderError> {
|
||||
match config.provider_type.as_str() {
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
use crate::domain::messages::{ContentBlock, ToolCall};
|
||||
use crate::domain::tools::Tool;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -77,20 +78,6 @@ impl Message {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub messages: Vec<Message>,
|
||||
|
||||
@ -6,7 +6,7 @@ use std::path::{Path, PathBuf};
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::config::SkillsConfig;
|
||||
use crate::providers::{Tool, ToolFunction};
|
||||
use crate::domain::tools::{Tool, ToolFunction};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Skill {
|
||||
|
||||
9
src/storage/error.rs
Normal file
9
src/storage/error.rs
Normal file
@ -0,0 +1,9 @@
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum StorageError {
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] rusqlite::Error),
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
}
|
||||
@ -3,193 +3,23 @@ use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use rusqlite::{Connection, OptionalExtension, params};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::bus::ChatMessage;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillEventRecord {
|
||||
pub id: String,
|
||||
pub session_id: Option<String>,
|
||||
pub event_type: String,
|
||||
pub skill_name: Option<String>,
|
||||
pub payload: serde_json::Value,
|
||||
pub created_at: i64,
|
||||
}
|
||||
pub mod error;
|
||||
pub mod records;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum StorageError {
|
||||
#[error("database error: {0}")]
|
||||
Database(#[from] rusqlite::Error),
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionRecord {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub channel_name: String,
|
||||
pub chat_id: String,
|
||||
pub summary: Option<String>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
pub last_active_at: i64,
|
||||
pub archived_at: Option<i64>,
|
||||
pub deleted_at: Option<i64>,
|
||||
pub message_count: i64,
|
||||
pub reset_cutoff_seq: i64,
|
||||
pub user_turn_count: i64,
|
||||
pub agent_prompt_reinjection_count: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryRecord {
|
||||
pub id: String,
|
||||
pub scope_kind: String,
|
||||
pub scope_key: String,
|
||||
pub namespace: String,
|
||||
pub memory_key: String,
|
||||
pub content: String,
|
||||
pub source_type: String,
|
||||
pub source_session_id: Option<String>,
|
||||
pub source_message_id: Option<String>,
|
||||
pub source_message_seq: Option<i64>,
|
||||
pub source_channel_name: Option<String>,
|
||||
pub source_chat_id: Option<String>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryUpsert {
|
||||
pub scope_kind: String,
|
||||
pub scope_key: String,
|
||||
pub namespace: String,
|
||||
pub memory_key: String,
|
||||
pub content: String,
|
||||
pub source_type: String,
|
||||
pub source_session_id: Option<String>,
|
||||
pub source_message_id: Option<String>,
|
||||
pub source_message_seq: Option<i64>,
|
||||
pub source_channel_name: Option<String>,
|
||||
pub source_chat_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SchedulerJobState {
|
||||
Scheduled,
|
||||
Running,
|
||||
Paused,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl SchedulerJobState {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
SchedulerJobState::Scheduled => "scheduled",
|
||||
SchedulerJobState::Running => "running",
|
||||
SchedulerJobState::Paused => "paused",
|
||||
SchedulerJobState::Completed => "completed",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(value: &str) -> Option<Self> {
|
||||
match value {
|
||||
"scheduled" => Some(Self::Scheduled),
|
||||
"running" => Some(Self::Running),
|
||||
"paused" => Some(Self::Paused),
|
||||
"completed" => Some(Self::Completed),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SchedulerJobStatus {
|
||||
Ok,
|
||||
Error,
|
||||
Skipped,
|
||||
}
|
||||
|
||||
impl SchedulerJobStatus {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
SchedulerJobStatus::Ok => "ok",
|
||||
SchedulerJobStatus::Error => "error",
|
||||
SchedulerJobStatus::Skipped => "skipped",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(value: &str) -> Option<Self> {
|
||||
match value {
|
||||
"ok" => Some(Self::Ok),
|
||||
"error" => Some(Self::Error),
|
||||
"skipped" => Some(Self::Skipped),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SchedulerJobState {
|
||||
fn default() -> Self {
|
||||
Self::Scheduled
|
||||
}
|
||||
}
|
||||
pub use error::StorageError;
|
||||
pub use records::{
|
||||
MemoryRecord, MemoryUpsert, SchedulerJobRecord, SchedulerJobState, SchedulerJobStatus,
|
||||
SchedulerJobUpsert, SessionRecord, SkillEventRecord,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionStore {
|
||||
conn: Arc<Mutex<Connection>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SchedulerJobRecord {
|
||||
pub id: String,
|
||||
pub kind: String,
|
||||
pub schedule: serde_json::Value,
|
||||
pub interval_secs: i64,
|
||||
pub startup_delay_secs: i64,
|
||||
pub target: serde_json::Value,
|
||||
pub payload: serde_json::Value,
|
||||
pub enabled: bool,
|
||||
pub state: SchedulerJobState,
|
||||
pub last_status: Option<SchedulerJobStatus>,
|
||||
pub last_error: Option<String>,
|
||||
pub run_count: i64,
|
||||
pub max_runs: Option<i64>,
|
||||
pub last_fired_at: Option<i64>,
|
||||
pub next_fire_at: Option<i64>,
|
||||
pub paused_at: Option<i64>,
|
||||
pub completed_at: Option<i64>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SchedulerJobUpsert {
|
||||
pub id: String,
|
||||
pub kind: String,
|
||||
pub schedule: serde_json::Value,
|
||||
pub interval_secs: i64,
|
||||
pub startup_delay_secs: i64,
|
||||
pub target: serde_json::Value,
|
||||
pub payload: serde_json::Value,
|
||||
pub enabled: bool,
|
||||
pub state: SchedulerJobState,
|
||||
pub last_status: Option<SchedulerJobStatus>,
|
||||
pub last_error: Option<String>,
|
||||
pub run_count: i64,
|
||||
pub max_runs: Option<i64>,
|
||||
pub last_fired_at: Option<i64>,
|
||||
pub next_fire_at: Option<i64>,
|
||||
pub paused_at: Option<i64>,
|
||||
pub completed_at: Option<i64>,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
#[cfg(test)]
|
||||
pub fn new() -> Result<Self, StorageError> {
|
||||
|
||||
169
src/storage/records.rs
Normal file
169
src/storage/records.rs
Normal file
@ -0,0 +1,169 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillEventRecord {
|
||||
pub id: String,
|
||||
pub session_id: Option<String>,
|
||||
pub event_type: String,
|
||||
pub skill_name: Option<String>,
|
||||
pub payload: serde_json::Value,
|
||||
pub created_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionRecord {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub channel_name: String,
|
||||
pub chat_id: String,
|
||||
pub summary: Option<String>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
pub last_active_at: i64,
|
||||
pub archived_at: Option<i64>,
|
||||
pub deleted_at: Option<i64>,
|
||||
pub message_count: i64,
|
||||
pub reset_cutoff_seq: i64,
|
||||
pub user_turn_count: i64,
|
||||
pub agent_prompt_reinjection_count: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryRecord {
|
||||
pub id: String,
|
||||
pub scope_kind: String,
|
||||
pub scope_key: String,
|
||||
pub namespace: String,
|
||||
pub memory_key: String,
|
||||
pub content: String,
|
||||
pub source_type: String,
|
||||
pub source_session_id: Option<String>,
|
||||
pub source_message_id: Option<String>,
|
||||
pub source_message_seq: Option<i64>,
|
||||
pub source_channel_name: Option<String>,
|
||||
pub source_chat_id: Option<String>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryUpsert {
|
||||
pub scope_kind: String,
|
||||
pub scope_key: String,
|
||||
pub namespace: String,
|
||||
pub memory_key: String,
|
||||
pub content: String,
|
||||
pub source_type: String,
|
||||
pub source_session_id: Option<String>,
|
||||
pub source_message_id: Option<String>,
|
||||
pub source_message_seq: Option<i64>,
|
||||
pub source_channel_name: Option<String>,
|
||||
pub source_chat_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SchedulerJobState {
|
||||
Scheduled,
|
||||
Running,
|
||||
Paused,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl SchedulerJobState {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
SchedulerJobState::Scheduled => "scheduled",
|
||||
SchedulerJobState::Running => "running",
|
||||
SchedulerJobState::Paused => "paused",
|
||||
SchedulerJobState::Completed => "completed",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(value: &str) -> Option<Self> {
|
||||
match value {
|
||||
"scheduled" => Some(Self::Scheduled),
|
||||
"running" => Some(Self::Running),
|
||||
"paused" => Some(Self::Paused),
|
||||
"completed" => Some(Self::Completed),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SchedulerJobStatus {
|
||||
Ok,
|
||||
Error,
|
||||
Skipped,
|
||||
}
|
||||
|
||||
impl SchedulerJobStatus {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
SchedulerJobStatus::Ok => "ok",
|
||||
SchedulerJobStatus::Error => "error",
|
||||
SchedulerJobStatus::Skipped => "skipped",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(value: &str) -> Option<Self> {
|
||||
match value {
|
||||
"ok" => Some(Self::Ok),
|
||||
"error" => Some(Self::Error),
|
||||
"skipped" => Some(Self::Skipped),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SchedulerJobState {
|
||||
fn default() -> Self {
|
||||
Self::Scheduled
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SchedulerJobRecord {
|
||||
pub id: String,
|
||||
pub kind: String,
|
||||
pub schedule: serde_json::Value,
|
||||
pub interval_secs: i64,
|
||||
pub startup_delay_secs: i64,
|
||||
pub target: serde_json::Value,
|
||||
pub payload: serde_json::Value,
|
||||
pub enabled: bool,
|
||||
pub state: SchedulerJobState,
|
||||
pub last_status: Option<SchedulerJobStatus>,
|
||||
pub last_error: Option<String>,
|
||||
pub run_count: i64,
|
||||
pub max_runs: Option<i64>,
|
||||
pub last_fired_at: Option<i64>,
|
||||
pub next_fire_at: Option<i64>,
|
||||
pub paused_at: Option<i64>,
|
||||
pub completed_at: Option<i64>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SchedulerJobUpsert {
|
||||
pub id: String,
|
||||
pub kind: String,
|
||||
pub schedule: serde_json::Value,
|
||||
pub interval_secs: i64,
|
||||
pub startup_delay_secs: i64,
|
||||
pub target: serde_json::Value,
|
||||
pub payload: serde_json::Value,
|
||||
pub enabled: bool,
|
||||
pub state: SchedulerJobState,
|
||||
pub last_status: Option<SchedulerJobStatus>,
|
||||
pub last_error: Option<String>,
|
||||
pub run_count: i64,
|
||||
pub max_runs: Option<i64>,
|
||||
pub last_fired_at: Option<i64>,
|
||||
pub next_fire_at: Option<i64>,
|
||||
pub paused_at: Option<i64>,
|
||||
pub completed_at: Option<i64>,
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::providers::{Tool, ToolFunction};
|
||||
use crate::domain::tools::{Tool, ToolFunction};
|
||||
|
||||
use super::traits::Tool as ToolTrait;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user