From 98eb7bea3d697e7324c63cb33db255ab45acc1e8 Mon Sep 17 00:00:00 2001 From: xiaoski Date: Mon, 4 May 2026 00:32:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E8=B7=A8session=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E5=8F=91=E9=80=81=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AGENTS.md | 29 ++-- Cargo.toml | 2 +- src/agent/agent_loop.rs | 12 +- src/agent/system_prompt.rs | 43 ++++++ src/bus/message.rs | 42 ++++++ src/bus/mod.rs | 2 +- src/channels/manager.rs | 13 ++ src/gateway/mod.rs | 31 +++-- src/logging.rs | 4 + src/providers/anthropic.rs | 65 +++++++-- src/session/session.rs | 263 ++++++++++++++++++++++++++----------- src/storage/message.rs | 1 + src/storage/mod.rs | 18 ++- src/tools/mod.rs | 28 +++- src/tools/registry.rs | 32 +++-- src/tools/send_message.rs | 171 ++++++++++++++++++++++++ src/tools/traits.rs | 13 ++ 17 files changed, 636 insertions(+), 133 deletions(-) create mode 100644 src/tools/send_message.rs diff --git a/AGENTS.md b/AGENTS.md index 017d557..d063b68 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,9 +1,5 @@ # PicoBot -## Maintenance - -- **Update this file on any architectural change** — module boundaries, data flow, key constraints, or build/test commands must be reflected here - ## Build & Run - `cargo build` — build the binary @@ -12,18 +8,20 @@ ## Config -- Config file: `~/.picobot/config.json` or `./config.json` (fallback order) -- `.env` is loaded and env var placeholders `` are substituted into config +- Config file: `~/.picobot/config.json` or `./config.json` (fallback order, see `src/config/mod.rs:213`) +- `.env` is loaded manually (not via dotenv crate); env var placeholders `` in config JSON are substituted - Config example: `config.example.json` ## Tests -- `cargo test --lib` — run unit tests (FAILS: `src/session/session.rs:657` missing `workspace_dir` field in test helper) -- `cargo test --test test_integration -- --ignored` — run integration tests (requires `tests/test.env` with API keys) +- `cargo test --lib` — run unit tests (runs all `#[test]` in `src/`) +- `cargo test --test test_integration -- --ignored` — run integration tests (also `test_tool_calling`, `test_request_format`) +- **All** integration tests require `tests/test.env` with real API keys; copy from `tests/test.env.example` and fill in keys +- Integration tests are `#[ignore]` by default; use `-- --ignored` to run them ## Reference -- `reference/` — third-party reference implementations (nanobot, Mini-Agent, zeroclaw); not part of this project; use for similar functionality patterns +- `reference/` — third-party reference implementations (nanobot, Mini-Agent, zeroclaw); not part of this project; do not modify ## Architecture @@ -51,8 +49,13 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM | `session` | Conversation session lifecycle, dialog operations | `SessionManager`, `Session` | | `agent` | LLM call loop, tool execution, context compression | `AgentLoop` | | `providers` | LLM API clients (OpenAI-compatible, Anthropic) | `LLMProvider` trait, factory `create_provider()` | -| `tools` | Agent tools (bash, file operations, http, web, get_skill) | `ToolRegistry`, `Tool` trait | +| `tools` | Agent tools (bash, file ops, http, web, get_skill) | `ToolRegistry`, `Tool` trait | | `skills` | Skills loading, management, and prompt building | `SkillsLoader`, `Skill` | +| `storage` | SQLite persistence for sessions and messages | `Storage`, `SessionMeta`, `MessageMeta` | +| `observability` | Observer pattern for agent/tool telemetry events | `Observer` trait, `ObserverEvent`, `MultiObserver` | +| `protocol` | WebSocket protocol message types | `WsInbound`, `WsOutbound`, `SessionSummary` | +| `config` | Config loading, env substitution, path resolution | `Config`, `LLMProviderConfig` | +| `logging` | Tracing initialization with file rotation | `init_logging()`, `init_logging_console_only()` | ### Functional Boundaries @@ -68,9 +71,7 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM ### Key Constraints - Gateway **changes working directory** to workspace on startup (`src/gateway/mod.rs:31`) +- Session/message persistence uses SQLite via `sqlx`; DB stored in workspace as `.picobot_sessions.db` by default - `ChannelManager` owns the `MessageBus` and all channel instances - `OutboundDispatcher` routes outbound messages to the correct channel via `ChannelManager` - -## Known Issues - -- (No known issues at this time) +- Config `.env` loading uses `unsafe { env::set_var(...) }` — don't refactor to safer patterns without understanding side effects diff --git a/Cargo.toml b/Cargo.toml index b4c6b0d..ae4fd4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ clap = { version = "4", features = ["derive"] } dirs = "6.0.0" prost = "0.14" tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "local-time"] } tracing-appender = "0.2" anyhow = "1.0" mime_guess = "2.0" diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 8bcfa10..340fdb8 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -390,8 +390,16 @@ impl AgentLoop { }); } - // Execute tool calls - tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools"); + // Execute tool calls — log tool names and args before execution + { + let tools_info: Vec = response.tool_calls.iter() + .map(|tc| { + let args = serde_json::to_string(&tc.arguments).unwrap_or_default(); + format!("{}:{}", tc.name, args) + }) + .collect(); + tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools"); + } // Add assistant message with tool calls let assistant_message = ChatMessage::assistant_with_tool_calls( diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs index 97403ce..f1288bf 100644 --- a/src/agent/system_prompt.rs +++ b/src/agent/system_prompt.rs @@ -45,6 +45,7 @@ impl SystemPromptBuilder { Box::new(UserProfileSection), Box::new(DateTimeSection), Box::new(RuntimeSection), + Box::new(CrossChannelSection), ], } } @@ -233,6 +234,48 @@ impl PromptSection for DateTimeSection { } } +/// Cross-channel messaging and system notification guidance for LLM. +pub struct CrossChannelSection; + +impl PromptSection for CrossChannelSection { + fn name(&self) -> &str { + "cross_channel" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + r#"## 关于跨渠道消息和系统通知 + +当前对话中可能出现带有 `source` 标记的消息,这些消息不是用户直接输入: + +### 系统通知(source.kind = "system_notification") +来自机器人内部系统(如定时任务、后台任务)的通知。 +- `system_name`: 发出通知的系统名称 +- `task_id`: 关联的任务 ID + +### 跨渠道消息(source.kind = "cross_channel") +来自其他渠道的消息被写入当前对话。 +- `from_channel`: 来源渠道(如 "feishu") +- `from_user_id`: 来源用户 ID + +### send_message 工具 + +使用 `send_message` 向其他渠道发送消息。参数: +- `target_chat_id`: 目标会话ID,支持两种格式: + 1. `:` — 发送到该聊天下最新活跃的会话,若没有活跃会话则自动创建 + 2. `::` — 发送到指定会话,若会话已过期则自动激活 +- `content`: 要发送的消息内容 +- `origin`(可选): 消息来源标识,不填则自动使用当前会话的完整 session_id + +跨渠道消息到达目标会话时,内容前会带有 `[message from X to Y]` 标记, +表示该消息的来源和目标。目标会话的 LLM 应将此理解为来自其他渠道/会话的消息。 + +### 处理建议 +- 系统通知:可以提及但不建议以此为由改变对话主题 +- 跨渠道消息:当用户提及相关事务时可关联这些消息"# + .to_string() + } +} + /// Runtime environment information. pub struct RuntimeSection; diff --git a/src/bus/message.rs b/src/bus/message.rs index 880bbe6..be65d55 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -73,6 +73,28 @@ pub struct ChatMessage { pub tool_name: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum SourceKind { + #[serde(rename = "system_notification")] + SystemNotification, + #[serde(rename = "cross_channel")] + CrossChannel, + #[serde(rename = "external_trigger")] + ExternalTrigger, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageSource { + pub kind: SourceKind, + pub from_channel: Option, + pub from_session: Option, + pub from_user_id: Option, + pub system_name: Option, + pub task_id: Option, } impl ChatMessage { @@ -86,6 +108,7 @@ impl ChatMessage { tool_call_id: None, tool_name: None, tool_calls: None, + source: None, } } @@ -99,6 +122,7 @@ impl ChatMessage { tool_call_id: None, tool_name: None, tool_calls: None, + source: None, } } @@ -112,6 +136,7 @@ impl ChatMessage { tool_call_id: None, tool_name: None, tool_calls: None, + source: None, } } @@ -125,6 +150,21 @@ impl ChatMessage { tool_call_id: None, tool_name: None, tool_calls: Some(tool_calls), + source: None, + } + } + + pub fn assistant_with_source(content: impl Into, source: MessageSource) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + role: "assistant".to_string(), + content: content.into(), + media_refs: Vec::new(), + timestamp: current_timestamp(), + tool_call_id: None, + tool_name: None, + tool_calls: None, + source: Some(source), } } @@ -138,6 +178,7 @@ impl ChatMessage { tool_call_id: None, tool_name: None, tool_calls: None, + source: None, } } @@ -151,6 +192,7 @@ impl ChatMessage { tool_call_id: Some(tool_call_id.into()), tool_name: Some(tool_name.into()), tool_calls: None, + source: None, } } } diff --git a/src/bus/mod.rs b/src/bus/mod.rs index 7c2de5e..70be932 100644 --- a/src/bus/mod.rs +++ b/src/bus/mod.rs @@ -2,7 +2,7 @@ pub mod dispatcher; pub mod message; pub use dispatcher::OutboundDispatcher; -pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, OutboundMessage}; +pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; diff --git a/src/channels/manager.rs b/src/channels/manager.rs index b144dd3..1e1c200 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -24,6 +24,14 @@ impl ChannelManager { } } + pub fn with_bus(cli_chat_channel: Arc, bus: Arc) -> Self { + Self { + channels: Arc::new(RwLock::new(HashMap::new())), + cli_chat_channel, + bus, + } + } + /// Get a reference to the MessageBus pub fn bus(&self) -> Arc { self.bus.clone() @@ -99,6 +107,11 @@ impl ChannelManager { self.channels.read().await.get(name).cloned() } + /// Get list of registered channel names + pub async fn list_channel_names(&self) -> Vec { + self.channels.read().await.keys().cloned().collect() + } + /// Dispatch an outbound message to the appropriate channel pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> { let channel_name = &msg.channel; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index b9cea73..8da35b6 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use axum::{routing, Router}; use tokio::net::TcpListener; -use crate::bus::{ControlMessage, OutboundDispatcher}; +use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher}; use crate::channels::{ChannelManager, CliChatChannel}; use crate::channels::base::{Channel, ChannelError}; use crate::config::{Config, expand_path, ensure_workspace_dir}; @@ -15,7 +15,7 @@ use crate::session::SessionManager; pub struct GatewayState { pub config: Config, pub workspace_dir: std::path::PathBuf, - pub session_manager: SessionManager, + pub session_manager: Arc, pub channel_manager: ChannelManager, } @@ -53,21 +53,32 @@ impl GatewayState { ); tracing::info!("Session storage: {}", db_path.display()); - let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone())?; + // Create MessageBus first (shared by SessionManager and ChannelManager) + let bus = MessageBus::new(100); + + // Create SessionManager with bus injection + let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone(), bus.clone())?; + let session_manager = Arc::new(session_manager); // Start background cleanup task (default 60 minutes) let cleanup_interval = config.gateway.cleanup_interval_minutes.unwrap_or(60); - Arc::new(session_manager.clone()).start_cleanup_task(cleanup_interval); + session_manager.clone().start_cleanup_task(cleanup_interval); tracing::info!("Session cleanup task started (interval: {} min)", cleanup_interval); - // Create CLI Chat Channel first (needed for ChannelManager) + // Create ChannelManager and init channels let cli_chat_channel = Arc::new(CliChatChannel::new()); - let channel_manager = ChannelManager::new(cli_chat_channel); + let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus); + channel_manager.init(&config, workspace_path.clone()).await + .map_err(|e| format!("Failed to init channels: {}", e))?; + + // Register send_message tool with available channel names + let available_channels = channel_manager.list_channel_names().await; + session_manager.register_outbound_tool(available_channels); Ok(Self { config, workspace_dir: workspace_path, - session_manager, + session_manager: session_manager.clone(), channel_manager, }) } @@ -231,11 +242,7 @@ pub async fn run(host: Option, port: Option) -> Result<(), Box, + model: Option, + #[serde(default)] content: Vec, - usage: AnthropicUsage, + #[serde(default)] + usage: Option, } #[derive(Deserialize)] @@ -138,7 +140,9 @@ enum AnthropicContent { #[derive(Deserialize)] struct AnthropicUsage { + #[serde(default)] input_tokens: u32, + #[serde(default)] output_tokens: u32, } @@ -167,9 +171,28 @@ impl LLMProvider for AnthropicProvider { messages: request .messages .iter() - .map(|m| AnthropicMessage { - role: m.role.clone(), - content: convert_content_blocks(&m.content), + .map(|m| { + let role = if m.role == "tool" { + // Anthropic uses "user" role for tool result messages + "user".to_string() + } else { + m.role.clone() + }; + let content = if let Some(ref tc_id) = m.tool_call_id { + // Tool result: wrap as tool_result content block + let output = m.content.iter() + .filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None }) + .collect::>() + .join(""); + vec![serde_json::json!({ + "type": "tool_result", + "tool_use_id": tc_id, + "content": output, + })] + } else { + convert_content_blocks(&m.content) + }; + AnthropicMessage { role, content } }) .collect(), max_tokens, @@ -191,7 +214,24 @@ impl LLMProvider for AnthropicProvider { let resp = req_builder.json(&body).send().await?; - let anthropic_resp: AnthropicResponse = resp.json().await?; + let status = resp.status(); + let body_text = resp.text().await?; + + if !status.is_success() { + let error_msg = serde_json::from_str::(&body_text) + .ok() + .and_then(|v| { + v.get("error") + .and_then(|e| e.get("message")) + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| body_text.clone()); + return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into()); + } + + let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text) + .map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?; let mut content = String::new(); let mut tool_calls = Vec::new(); @@ -218,15 +258,14 @@ impl LLMProvider for AnthropicProvider { } Ok(ChatCompletionResponse { - id: anthropic_resp.id, - model: anthropic_resp.model, + id: anthropic_resp.id.unwrap_or_default(), + model: anthropic_resp.model.unwrap_or_default(), content, tool_calls, usage: Usage { - prompt_tokens: anthropic_resp.usage.input_tokens, - completion_tokens: anthropic_resp.usage.output_tokens, - total_tokens: anthropic_resp.usage.input_tokens - + anthropic_resp.usage.output_tokens, + prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0), + completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), + total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0), }, }) } diff --git a/src/session/session.rs b/src/session/session.rs index da1f822..d5ed0f8 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -5,7 +5,7 @@ use std::time::{Duration, Instant}; use tokio::sync::{Mutex, mpsc}; use uuid::Uuid; -use crate::bus::ChatMessage; +use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; use crate::storage::{Storage, StorageError}; use std::sync::Arc as StdArc; @@ -26,10 +26,10 @@ use crate::providers::{create_provider, LLMProvider}; use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID}; use crate::session::events::DialogInfo; use crate::skills::SkillsLoader; -use crate::tools::{ - BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, - GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool, -}; +use crate::tools::{ToolRegistry, create_default_tools}; +use crate::bus::MessageBus; +use crate::tools::OutboundMessenger; +use crate::tools::SendMessageTool; /// Generate a short ID (8 characters) from a UUID fn short_id() -> String { @@ -133,6 +133,7 @@ impl Session { tool_call_id: m.tool_call_id, tool_name: m.tool_name, tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()), + source: m.source.and_then(|s| serde_json::from_str(&s).ok()), } }).collect(); @@ -190,6 +191,7 @@ impl Session { tool_call_id: message.tool_call_id.clone(), tool_name: message.tool_name.clone(), tool_calls: message.tool_calls.as_ref().map(|tc| serde_json::to_string(tc).unwrap_or_default()), + source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()), created_at: now, }; storage.append_message_with_retry(&self.id.to_string(), &msg_meta).await?; @@ -547,6 +549,8 @@ pub struct SessionManager { tools: Arc, skills_loader: Arc, storage: Arc, + bus: Arc, + current_source_session: Arc>>, } struct SessionManagerInner { @@ -558,23 +562,7 @@ struct SessionManagerInner { current_sessions: HashMap, } -fn create_default_tools(skills_loader: Arc) -> ToolRegistry { - let mut registry = ToolRegistry::new(); - registry.register(CalculatorTool::new()); - registry.register(FileReadTool::new()); - registry.register(FileWriteTool::new()); - registry.register(FileEditTool::new()); - registry.register(BashTool::new()); - registry.register(HttpRequestTool::new( - vec!["*".to_string()], - 1_000_000, - 30, - false, - )); - registry.register(WebFetchTool::new(50_000, 30)); - registry.register(GetSkillTool::new(skills_loader)); - registry -} + /// 斜杠命令定义 #[derive(Debug, Clone)] @@ -649,6 +637,7 @@ impl SessionManager { session_ttl_hours: u64, provider_config: LLMProviderConfig, storage: Arc, + bus: Arc, ) -> Result { let skills_loader = SkillsLoader::new(); skills_loader.load_skills(); @@ -667,9 +656,17 @@ impl SessionManager { tools, skills_loader, storage, + bus, + current_source_session: Arc::new(Mutex::new(None)), }) } + /// Register the send_message tool (requires self in Arc) + pub fn register_outbound_tool(self: &Arc, available_channels: Vec) { + let messenger: Arc = self.clone(); + self.tools.register(SendMessageTool::new(messenger, available_channels)); + } + pub fn tools(&self) -> Arc { self.tools.clone() } @@ -1047,65 +1044,111 @@ impl SessionManager { Err(AgentError::Other("clear_dialog_history not available".to_string())) } + /// Get or activate a specific session by its full UnifiedSessionId. + /// Returns an error if the session does not exist in storage. + /// If the session was expired from memory but still in storage, + /// it will be restored (reactivated). + pub async fn get_or_activate_session( + &self, + unified_id: &UnifiedSessionId, + ) -> Result>, AgentError> { + let session_id_str = unified_id.to_string(); + match self.storage.get_session(&session_id_str).await { + Ok(_) => self.get_or_create_session(unified_id).await, + Err(StorageError::NotFound(_)) => { + Err(AgentError::Other(format!("session not found: {}", unified_id))) + } + Err(e) => Err(AgentError::Other(format!("storage error: {}", e))), + } + } + + async fn resolve_dialog_id( + &self, + channel: &str, + chat_id: &str, + ) -> Result { + let chat_scope = format!("{}:{}", channel, chat_id); + let current_id = { + self.inner.lock().await.current_sessions.get(&chat_scope).cloned() + }; + + if let Some(ref current_id) = current_id { + match self.storage.get_session(current_id).await { + Ok(_) => { + let parts: Vec<&str> = current_id.split(':').collect(); + if parts.len() == 3 { + return Ok(UnifiedSessionId::new(channel, chat_id, parts[2])); + } + } + Err(_) => {} + } + } + + let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64; + match self.storage.find_active_session(channel, chat_id, ttl_millis).await { + Ok(Some(meta)) => Ok(UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)), + _ => { + let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?; + Ok(new_id) + } + } + } + + /// Send a system notification (no LLM triggered). + /// + /// Flow: + /// 1. Resolve target session (resolve_dialog_id) + /// 2. Write assistant message with source tag to history + /// 3. Publish OutboundMessage via bus to target channel + pub async fn send_notification( + &self, + channel: &str, + chat_id: &str, + content: &str, + system_name: &str, + task_id: Option<&str>, + ) -> Result<(), AgentError> { + let unified_id = self.resolve_dialog_id(channel, chat_id).await?; + let session = self.get_or_create_session(&unified_id).await?; + { + let mut guard = session.lock().await; + let source = MessageSource { + kind: SourceKind::SystemNotification, + from_channel: None, + from_session: None, + from_user_id: None, + system_name: Some(system_name.to_string()), + task_id: task_id.map(|s| s.to_string()), + }; + let msg = ChatMessage::assistant_with_source(content, source); + guard.add_message(msg, true).await + .map_err(|e| AgentError::Other(format!("persist error: {}", e)))?; + } + + let outbound = OutboundMessage { + channel: channel.to_string(), + chat_id: chat_id.to_string(), + content: content.to_string(), + reply_to: None, + media: vec![], + metadata: HashMap::new(), + }; + self.bus.publish_outbound(outbound).await + .map_err(|e| AgentError::Other(format!("bus publish error: {}", e)))?; + + Ok(()) + } + pub async fn handle_message( &self, channel: &str, _sender_id: &str, chat_id: &str, content: &str, - media: Vec, + media: Vec, ) -> Result { - // Channel messages never carry dialog_id — routing is entirely via current_sessions - let unified_id = { - let chat_scope = format!("{}:{}", channel, chat_id); - let current_session_id = { - let inner = self.inner.lock().await; - inner.current_sessions.get(&chat_scope).cloned() - }; - if let Some(current_id) = current_session_id { - // Verify current session still exists in Storage - match self.storage.get_session(¤t_id).await { - Ok(_) => { - // Current session still valid, extract dialog_id - let parts: Vec<&str> = current_id.split(':').collect(); - if parts.len() == 3 { - UnifiedSessionId::new(channel, chat_id, parts[2]) - } else { - // Malformed, fallback to find or create - let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64; - match self.storage.find_active_session(channel, chat_id, ttl_millis).await { - Ok(Some(m)) => UnifiedSessionId::new(channel, chat_id, &m.dialog_id), - _ => { - let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?; - new_id - } - } - } - } - Err(_) => { - // Current session no longer exists, create new - let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?; - new_id - } - } - } else { - // No current session tracked, find active or create new - let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64; - tracing::debug!(channel, chat_id, ttl_millis, "No current_sessions entry, searching Storage for active session"); - match self.storage.find_active_session(channel, chat_id, ttl_millis).await { - Ok(Some(meta)) => { - tracing::debug!(session_id = %meta.id, dialog_id = %meta.dialog_id, last_active_at = %meta.last_active_at, "Found active session in Storage"); - UnifiedSessionId::new(channel, chat_id, &meta.dialog_id) - } - Ok(None) | Err(_) => { - tracing::debug!("No active session found in Storage, creating new session"); - // Create new session - let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?; - new_id - } - } - } - }; + let unified_id = self.resolve_dialog_id(channel, chat_id).await?; + *self.current_source_session.lock().await = Some(unified_id.to_string()); tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id"); let session = self.get_or_create_session(&unified_id).await?; @@ -1121,9 +1164,11 @@ impl SessionManager { match result { Ok((_new_session_id, response)) => { + *self.current_source_session.lock().await = None; return Ok(HandleResult::CommandOutput(response)); } Err(e) => { + *self.current_source_session.lock().await = None; return Ok(HandleResult::CommandOutput(e.to_string())); } } @@ -1183,6 +1228,8 @@ impl SessionManager { "Agent response received" ); + *self.current_source_session.lock().await = None; + Ok(HandleResult::AgentResponse(response)) } @@ -1203,6 +1250,74 @@ impl SessionManager { } } +#[async_trait::async_trait] +impl OutboundMessenger for SessionManager { + async fn send_message( + &self, + channel: &str, + chat_id: &str, + dialog_id: Option<&str>, + content: &str, + mut source: MessageSource, + ) -> Result<(), String> { + // Fill origin from current source session if not provided + if source.from_session.is_none() { + source.from_session = self.current_source_session.lock().await.clone(); + } + + let (target_sid, session) = if let Some(did) = dialog_id { + let sid = UnifiedSessionId::new(channel, chat_id, did); + let session = self.get_or_activate_session(&sid).await + .map_err(|e| e.to_string())?; + (sid, session) + } else { + let sid = self.resolve_dialog_id(channel, chat_id).await + .map_err(|e| e.to_string())?; + let session = self.get_or_create_session(&sid).await + .map_err(|e| e.to_string())?; + (sid, session) + }; + + // Build message prefix: [message from to ] + let target_id = target_sid.to_string(); + let origin = source.from_session.as_deref().unwrap_or("unknown"); + let origin_id = source.from_session.clone(); + let prefix = format!("[message from {} to {}] ", origin, target_id); + let marked_content = format!("{}\n{}", prefix, content); + + // Write source-tagged assistant message to target session history + { + let mut guard = session.lock().await; + let msg = ChatMessage::assistant_with_source(marked_content.clone(), source); + guard.add_message(msg, true).await + .map_err(|e| e.to_string())?; + } + + // Restore active dialog if source and target share channel:chat_id but differ in dialog_id + if let Some(ref origin_id) = origin_id { + let parts: Vec<&str> = origin_id.split(':').collect(); + if parts.len() == 3 && parts[0] == channel && parts[1] == chat_id && parts[2] != target_sid.dialog_id { + let scope = format!("{}:{}", channel, chat_id); + self.inner.lock().await.current_sessions.insert(scope, origin_id.clone()); + } + } + + // Publish OutboundMessage via bus to target channel + let outbound = OutboundMessage { + channel: channel.to_string(), + chat_id: chat_id.to_string(), + content: marked_content, + reply_to: None, + media: vec![], + metadata: HashMap::new(), + }; + self.bus.publish_outbound(outbound).await + .map_err(|e| e.to_string())?; + + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/storage/message.rs b/src/storage/message.rs index 75ea03b..f9f7790 100644 --- a/src/storage/message.rs +++ b/src/storage/message.rs @@ -11,5 +11,6 @@ pub struct MessageMeta { pub tool_call_id: Option, pub tool_name: Option, pub tool_calls: Option, + pub source: Option, pub created_at: i64, } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index a5fbd51..94f8680 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -66,6 +66,7 @@ impl Storage { tool_call_id TEXT, tool_name TEXT, tool_calls TEXT, + source TEXT, created_at INTEGER NOT NULL, FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE ) @@ -83,6 +84,14 @@ impl Storage { .execute(&self.pool) .await?; + // Migration: add source column if upgrading from older schema + sqlx::query( + r#"ALTER TABLE messages ADD COLUMN source TEXT"#, + ) + .execute(&self.pool) + .await + .ok(); + Ok(()) } @@ -260,8 +269,8 @@ impl Storage { pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result { sqlx::query( r#" - INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(&msg.id) @@ -273,6 +282,7 @@ impl Storage { .bind(&msg.tool_call_id) .bind(&msg.tool_name) .bind(&msg.tool_calls) + .bind(&msg.source) .bind(msg.created_at) .execute(self.pool()) .await?; @@ -300,7 +310,7 @@ impl Storage { ) -> Result, StorageError> { let rows = sqlx::query( r#" - SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at + SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at FROM messages WHERE session_id = ? AND seq >= ? ORDER BY seq ASC @@ -323,6 +333,7 @@ impl Storage { tool_call_id: row.get("tool_call_id"), tool_name: row.get("tool_name"), tool_calls: row.get("tool_calls"), + source: row.get("source"), created_at: row.get("created_at"), }) .collect()) @@ -486,6 +497,7 @@ mod tests { tool_call_id: None, tool_name: None, tool_calls: None, + source: None, created_at: 1000, }; diff --git a/src/tools/mod.rs b/src/tools/mod.rs index ab1c361..11a37be 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -7,6 +7,7 @@ pub mod get_skill; pub mod http_request; pub mod registry; pub mod schema; +pub mod send_message; pub mod traits; pub mod web_fetch; @@ -19,5 +20,30 @@ pub use get_skill::GetSkillTool; pub use http_request::HttpRequestTool; pub use registry::ToolRegistry; pub use schema::{CleaningStrategy, SchemaCleanr}; -pub use traits::{Tool, ToolResult}; +pub use send_message::SendMessageTool; +pub use traits::{OutboundMessenger, Tool, ToolResult}; pub use web_fetch::WebFetchTool; + +use std::sync::Arc; +use crate::skills::SkillsLoader; + +/// Create the base tool registry (without send_message). +/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()` +/// once the available channel names are known. +pub fn create_default_tools(skills_loader: Arc) -> ToolRegistry { + let registry = ToolRegistry::new(); + registry.register(CalculatorTool::new()); + registry.register(FileReadTool::new()); + registry.register(FileWriteTool::new()); + registry.register(FileEditTool::new()); + registry.register(BashTool::new()); + registry.register(HttpRequestTool::new( + vec!["*".to_string()], + 1_000_000, + 30, + false, + )); + registry.register(WebFetchTool::new(50_000, 30)); + registry.register(GetSkillTool::new(skills_loader)); + registry +} diff --git a/src/tools/registry.rs b/src/tools/registry.rs index cccfc5d..3b2cc53 100644 --- a/src/tools/registry.rs +++ b/src/tools/registry.rs @@ -1,36 +1,39 @@ use std::collections::HashMap; +use std::sync::{Arc, Mutex}; use crate::providers::{Tool, ToolFunction}; use super::traits::Tool as ToolTrait; pub struct ToolRegistry { - tools: HashMap>, + tools: Mutex>>, } impl ToolRegistry { pub fn new() -> Self { Self { - tools: HashMap::new(), + tools: Mutex::new(HashMap::new()), } } - pub fn register(&mut self, tool: T) { - self.tools.insert(tool.name().to_string(), Box::new(tool)); + pub fn register(&self, tool: T) { + self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool)); } - pub fn get(&self, name: &str) -> Option<&Box> { - self.tools.get(name) + pub fn get(&self, name: &str) -> Option> { + self.tools.lock().unwrap().get(name).cloned() } /// Get all registered tools. /// Used for concurrent tool execution when we need to look up tools by name. - pub fn get_all(&self) -> Vec<&Box> { - self.tools.values().collect() + pub fn get_all(&self) -> Vec> { + self.tools.lock().unwrap().values().cloned().collect() } pub fn get_definitions(&self) -> Vec { self.tools + .lock() + .unwrap() .values() .map(|tool| Tool { tool_type: "function".to_string(), @@ -44,15 +47,20 @@ impl ToolRegistry { } pub fn has_tools(&self) -> bool { - !self.tools.is_empty() + !self.tools.lock().unwrap().is_empty() } pub fn tool_names(&self) -> Vec { - self.tools.keys().cloned().collect() + self.tools.lock().unwrap().keys().cloned().collect() } - pub fn iter(&self) -> impl Iterator)> { - self.tools.iter() + pub fn iter(&self) -> Vec<(String, Arc)> { + self.tools + .lock() + .unwrap() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect() } } diff --git a/src/tools/send_message.rs b/src/tools/send_message.rs new file mode 100644 index 0000000..91647e0 --- /dev/null +++ b/src/tools/send_message.rs @@ -0,0 +1,171 @@ +use std::sync::Arc; +use std::collections::HashSet; + +use async_trait::async_trait; + +use crate::bus::{MessageSource, SourceKind}; + +use super::traits::{OutboundMessenger, Tool, ToolResult}; + +pub struct SendMessageTool { + messenger: Arc, + available_channels: HashSet, +} + +impl SendMessageTool { + pub fn new(messenger: Arc, available_channels: Vec) -> Self { + Self { + messenger, + available_channels: available_channels.into_iter().collect(), + } + } +} + +/// Parse target_chat_id into (channel, chat_id, optional dialog_id). +/// Accepts two formats: +/// - Two-part: `:` → sends to latest active session for that chat +/// - Three-part: `::` → sends to specific session +fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String> { + let parts: Vec<&str> = raw.split(':').collect(); + match parts.len() { + 2 => { + if parts[0].is_empty() || parts[1].is_empty() { + Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw)) + } else { + Ok((parts[0], parts[1], None)) + } + } + 3 => { + if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() { + Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw)) + } else { + Ok((parts[0], parts[1], Some(parts[2]))) + } + } + _ => Err(format!( + "Invalid target_chat_id format '{}'. Expected : or ::", + raw + )), + } +} + +#[async_trait] +impl Tool for SendMessageTool { + fn name(&self) -> &str { + "send_message" + } + + fn description(&self) -> &str { + "向指定渠道的会话发送消息。用于在用户请求下向其他渠道发送内容。\ +target_chat_id 支持两种格式::(发送到该聊天下最新活跃会话)\ +或 ::(发送到指定会话,过期则自动激活)" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "target_chat_id": { + "type": "string", + "description": "目标会话ID。支持两种格式: 1) : 发送到该聊天下最新活跃会话, 无则自动创建; 2) :: 发送到指定会话, 过期则自动激活。channel 可选值: feishu, cli_chat" + }, + "content": { + "type": "string", + "description": "要发送的消息内容" + }, + "origin": { + "type": "string", + "description": "可选。消息来源标识。不填则自动使用当前会话的完整 session_id (::)" + } + }, + "required": ["target_chat_id", "content"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let raw_id = args["target_chat_id"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("missing target_chat_id"))?; + let content = args["content"] + .as_str() + .ok_or_else(|| anyhow::anyhow!("missing content"))?; + + // 1. Parse target_chat_id + let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id) + .map_err(|e| anyhow::anyhow!(e))?; + + // 2. Validate channel + if !self.available_channels.contains(channel) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Channel '{}' is not available. Available channels: {}", + channel, + self.available_channels.iter().cloned().collect::>().join(", ") + )), + }); + } + + let from_session = args["origin"].as_str().map(|s| s.to_string()); + + let source = MessageSource { + kind: SourceKind::CrossChannel, + from_channel: Some("tool".to_string()), + from_session, + from_user_id: None, + system_name: None, + task_id: None, + }; + + // 3. Send via messenger + match self.messenger + .send_message(channel, chat_id, dialog_id, content, source) + .await + { + Ok(()) => Ok(ToolResult { + success: true, + output: "消息已发送".to_string(), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(e), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_target_chat_id_two_part() { + let (ch, cid, did) = parse_target_chat_id("feishu:oc_abc123").unwrap(); + assert_eq!(ch, "feishu"); + assert_eq!(cid, "oc_abc123"); + assert!(did.is_none()); + } + + #[test] + fn test_parse_target_chat_id_three_part() { + let (ch, cid, did) = parse_target_chat_id("feishu:oc_abc123:dialog1").unwrap(); + assert_eq!(ch, "feishu"); + assert_eq!(cid, "oc_abc123"); + assert_eq!(did, Some("dialog1")); + } + + #[test] + fn test_parse_target_chat_id_invalid_one_part() { + assert!(parse_target_chat_id("feishu").is_err()); + } + + #[test] + fn test_parse_target_chat_id_empty_parts() { + assert!(parse_target_chat_id("feishu:").is_err()); + assert!(parse_target_chat_id(":chat_id").is_err()); + assert!(parse_target_chat_id("feishu::dialog").is_err()); + } +} diff --git a/src/tools/traits.rs b/src/tools/traits.rs index f3ffdc4..6ac5770 100644 --- a/src/tools/traits.rs +++ b/src/tools/traits.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use crate::bus::MessageSource; #[derive(Debug, Clone)] pub struct ToolResult { @@ -29,3 +30,15 @@ pub trait Tool: Send + Sync + 'static { false } } + +#[async_trait] +pub trait OutboundMessenger: Send + Sync { + async fn send_message( + &self, + channel: &str, + chat_id: &str, + dialog_id: Option<&str>, + content: &str, + source: MessageSource, + ) -> Result<(), String>; +}