diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 340fdb8..4e9c856 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -226,6 +226,7 @@ pub struct AgentLoop { max_iterations: usize, workspace_dir: PathBuf, model_name: String, + notify_tx: Option>, } #[derive(Debug, Clone)] @@ -247,6 +248,7 @@ impl AgentLoop { provider: Arc::from(provider), tools: Arc::new(ToolRegistry::new()), observer: None, + notify_tx: None, max_iterations, workspace_dir, model_name, @@ -265,6 +267,7 @@ impl AgentLoop { provider: Arc::from(provider), tools, observer: None, + notify_tx: None, max_iterations, workspace_dir, model_name, @@ -277,6 +280,7 @@ impl AgentLoop { provider, tools: Arc::new(ToolRegistry::new()), observer: None, + notify_tx: None, max_iterations, workspace_dir, model_name, @@ -295,6 +299,7 @@ impl AgentLoop { provider, tools, observer: None, + notify_tx: None, max_iterations, workspace_dir, model_name, @@ -313,6 +318,11 @@ impl AgentLoop { self } + pub fn with_notify(mut self, tx: tokio::sync::mpsc::UnboundedSender) -> Self { + self.notify_tx = Some(tx); + self + } + pub fn tools(&self) -> &Arc { &self.tools } @@ -390,12 +400,16 @@ impl AgentLoop { }); } - // Execute tool calls — log tool names and args before execution + // Execute tool calls — log and notify immediately { let tools_info: Vec = response.tool_calls.iter() .map(|tc| { let args = serde_json::to_string(&tc.arguments).unwrap_or_default(); - format!("{}:{}", tc.name, args) + let s = format!("{}:{}", tc.name, args); + if let Some(ref tx) = self.notify_tx { + let _ = tx.send(format!("调用工具 {}", s)); + } + s }) .collect(); tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools"); diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs index f1288bf..4a2b976 100644 --- a/src/agent/system_prompt.rs +++ b/src/agent/system_prompt.rs @@ -38,7 +38,6 @@ impl SystemPromptBuilder { Self { sections: vec![ Box::new(ToolHonestySection), - Box::new(NoToolNarrationSection), Box::new(YourTaskSection), Box::new(SafetySection), Box::new(WorkspaceSection), @@ -82,30 +81,11 @@ impl PromptSection for ToolHonestySection { } fn build(&self, _ctx: &PromptContext<'_>) -> String { - "## CRITICAL: Tool Honesty + "## 关键规则:工具诚实性 -- NEVER fabricate, invent, or guess tool results. If a tool returns empty results, say \"No results found.\" -- If a tool call fails, report the error - never make up data to fill the gap. -- When unsure whether a tool call succeeded, ask the user rather than guessing." - .to_string() - } -} - -/// Critical rule: never narrate tool usage. -pub struct NoToolNarrationSection; - -impl PromptSection for NoToolNarrationSection { - fn name(&self) -> &str { - "no_narration" - } - - fn build(&self, _ctx: &PromptContext<'_>) -> String { - "## CRITICAL: No Tool Narration - -NEVER narrate, announce, describe, or explain your tool usage to the user. -Do NOT say things like \"Let me check...\", \"I will use bash to...\", \"I'll fetch that for you\", \"Searching now...\", or similar. -The user must ONLY see the final answer. Tool calls are invisible infrastructure - never reference them. -If you catch yourself starting a sentence about what tool you are about to use or just used, DELETE it and give the answer directly." +- 绝对不要编造、虚构或猜测工具结果。如果工具返回空结果,说\"没有找到结果\"。 +- 如果工具调用失败,报告错误——绝不要编造数据来填补空白。 +- 当不确定工具调用是否成功时,询问用户而不是猜测。" .to_string() } } @@ -123,7 +103,7 @@ impl PromptSection for ToolsSection { return String::new(); } - let mut output = String::from("## Tools\n\nYou have access to the following tools:\n\n"); + let mut output = String::from("## 工具\n\n你可以使用以下工具:\n\n"); for (name, tool) in ctx.tools.iter() { let _ = writeln!(output, "- **{}**: {}", name, tool.description()); } @@ -140,11 +120,11 @@ impl PromptSection for YourTaskSection { } fn build(&self, _ctx: &PromptContext<'_>) -> String { - "## Your Task + "## 你的任务 -When the user sends a message, ACT on it. Use the tools to fulfill their request. -Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions. -Instead: use tools directly when needed, and give the final answer when done." +当用户发送消息时,立即行动。使用工具来完成他们的请求。 +不要:总结此配置、描述你的能力、用元评论回复、或输出逐步指令。 +而是:在需要时直接使用工具,完成后给出最终答案。" .to_string() } } @@ -158,13 +138,13 @@ impl PromptSection for SafetySection { } fn build(&self, _ctx: &PromptContext<'_>) -> String { - "## Safety + "## 安全规则 -- Do not exfiltrate private data. -- Do not run destructive commands without asking. -- Do not bypass oversight or approval mechanisms. -- Prefer safe operations over risky ones. -- When in doubt, ask before acting externally." +- 不要泄露隐私数据。 +- 未经询问不要执行破坏性命令。 +- 不要绕过监督或审批机制。 +- 优先选择安全操作而非风险操作。 +- 不确定时,在外部操作前先询问。" .to_string() } } @@ -184,7 +164,7 @@ impl PromptSection for WorkspaceSection { .canonicalize() .unwrap_or_else(|_| ctx.workspace_dir.to_path_buf()); format!( - "## Workspace\n\nWorking directory: `{}`\n\n### File Storage Guidelines\n\n- **Generated files**: Store all generated files (code, documents, artifacts) in the workspace directory or its subdirectories.\n- **Downloaded files**: Save downloaded files to the workspace directory, organized by task.\n- **One task, one folder**: Create a dedicated subfolder for each task or project (e.g., `task_2024_01_01/`).\n- **Temporary files**: If files are only needed during processing and won't be kept, use `/tmp/` or create a temp folder (e.g., `/tmp/picobot_task_xxx/`) instead of cluttering the workspace.\n\n### Working Directory Structure\n\nThe workspace is your home base for this session. Keep it organized by creating subdirectories for different tasks.", + "## 工作目录\n\n工作目录:`{}`\n\n### 文件存储规范\n\n- **生成的文件**:将所有生成的文件(代码、文档、制品)存放在工作目录或其子目录中。\n- **下载的文件**:将下载的文件保存到工作目录,按任务整理。\n- **一个任务一个文件夹**:为每个任务或项目创建专用的子文件夹(如 `task_2024_01_01/`)。\n- **临时文件**:如果文件仅在处理期间需要且不保留,使用 `/tmp/` 或创建临时文件夹(如 `/tmp/picobot_task_xxx/`),以免弄乱工作目录。\n\n### 目录结构\n\n工作目录是你在本会话中的操作大本营。通过为不同任务创建子目录来保持整洁。", abs_path.display() ) } @@ -199,7 +179,7 @@ impl PromptSection for UserProfileSection { } fn build(&self, _ctx: &PromptContext<'_>) -> String { - let mut output = String::from("## User Profile\n\n"); + let mut output = String::from("## 用户配置\n\n"); // Load USER.md from ~/.picobot/USER.md if let Some(user_config_dir) = get_user_config_dir() { @@ -227,7 +207,7 @@ impl PromptSection for DateTimeSection { fn build(&self, _ctx: &PromptContext<'_>) -> String { let now = chrono::Local::now(); format!( - "## Current Date & Time\n\n{} ({})", + "## 当前日期与时间\n\n{} ({})", now.format("%Y-%m-%d %H:%M:%S"), now.format("%Z") ) @@ -289,7 +269,7 @@ impl PromptSection for RuntimeSection { .map(|h| h.to_string_lossy().to_string()) .unwrap_or_else(|_| "unknown".to_string()); format!( - "## Runtime\n\nHost: {} | OS: {} | Model: {}", + "## 运行环境\n\n主机: {} | 操作系统: {} | 模型: {}", host, std::env::consts::OS, ctx.model_name @@ -321,7 +301,7 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option Result<(), ChannelError> { let clients = self.clients.lock().await.clone(); for client in clients { - let outbound = WsOutbound::AssistantResponse { - id: short_id(), - content: msg.content.clone(), - role: "assistant".to_string(), + let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") { + WsOutbound::SystemNotification { + content: msg.content.clone(), + } + } else { + WsOutbound::AssistantResponse { + id: short_id(), + content: msg.content.clone(), + role: "assistant".to_string(), + } }; let _ = client.sender.send(outbound).await; } diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index cadf9fd..cfbd6c1 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -6,13 +6,8 @@ use std::collections::HashMap; use crate::bus::message::ContentBlock; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use super::traits::Usage; - -fn serialize_content_blocks(blocks: &[serde_json::Value], serializer: S) -> Result -where - S: serde::Serializer, -{ - serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string())) -} +use std::sync::Arc; +use crate::storage::Storage; fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec { blocks.iter().map(|b| match b { @@ -62,6 +57,7 @@ pub struct AnthropicProvider { temperature: Option, max_tokens: Option, model_extra: HashMap, + storage: Option>, } impl AnthropicProvider { @@ -85,8 +81,13 @@ impl AnthropicProvider { temperature, max_tokens, model_extra, + storage: None, } } + + pub fn set_storage(&mut self, storage: Arc) { + self.storage = Some(storage); + } } #[derive(Serialize)] @@ -104,7 +105,6 @@ struct AnthropicRequest { #[derive(Serialize)] struct AnthropicMessage { role: String, - #[serde(serialize_with = "serialize_content_blocks")] content: Vec, } @@ -128,14 +128,23 @@ struct AnthropicResponse { #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum AnthropicContent { - Text { text: String }, - Thinking { thinking: String }, + Text { + #[serde(alias = "content")] + text: String, + }, + Thinking { + #[serde(alias = "content")] + thinking: String, + }, #[serde(rename = "tool_use")] ToolUse { id: String, name: String, + #[serde(alias = "arguments")] input: serde_json::Value, }, + #[serde(other)] + Unknown, } #[derive(Deserialize)] @@ -152,6 +161,7 @@ impl LLMProvider for AnthropicProvider { &self, request: ChatCompletionRequest, ) -> Result> { + let start = std::time::Instant::now(); let url = format!("{}/v1/messages", self.base_url); let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024); @@ -190,7 +200,19 @@ impl LLMProvider for AnthropicProvider { "content": output, })] } else { - convert_content_blocks(&m.content) + let mut blocks = convert_content_blocks(&m.content); + // Append tool_use blocks from assistant messages with tool calls + if let Some(ref tool_calls) = m.tool_calls { + for tc in tool_calls { + blocks.push(serde_json::json!({ + "type": "tool_use", + "id": tc.id, + "name": tc.name, + "input": tc.arguments, + })); + } + } + blocks }; AnthropicMessage { role, content } }) @@ -212,10 +234,14 @@ impl LLMProvider for AnthropicProvider { req_builder = req_builder.header(key.as_str(), value.as_str()); } + let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); + tracing::debug!(req_body = %req_body_str, "LLM request"); + let resp = req_builder.json(&body).send().await?; let status = resp.status(); let body_text = resp.text().await?; + tracing::debug!(status = %status, resp_body = %body_text, "LLM response"); if !status.is_success() { let error_msg = serde_json::from_str::(&body_text) @@ -227,11 +253,33 @@ impl LLMProvider for AnthropicProvider { .map(|s| s.to_string()) }) .unwrap_or_else(|| body_text.clone()); + if let Some(ref storage) = self.storage { + let _ = storage.append_llm_call( + &self.name, &self.model_id, &req_body_str, + Some(&body_text), Some(&error_msg), + start.elapsed().as_millis() as u64, + ).await; + } return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into()); } let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text) - .map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?; + .map_err(|e| { + let err_msg = format!("decode error: {} | body: {}", e, &body_text); + if let Some(ref storage) = self.storage { + let name = self.name.clone(); + let model = self.model_id.clone(); + let req = req_body_str.clone(); + let resp_body = body_text.clone(); + let dur = start.elapsed().as_millis() as u64; + let err = err_msg.clone(); + let s = storage.clone(); + tokio::spawn(async move { + let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await; + }); + } + err_msg + })?; let mut content = String::new(); let mut tool_calls = Vec::new(); @@ -247,6 +295,7 @@ impl LLMProvider for AnthropicProvider { } } AnthropicContent::Thinking { .. } => {} + AnthropicContent::Unknown => {} AnthropicContent::ToolUse { id, name, input } => { tool_calls.push(ToolCall { id: id.clone(), @@ -257,7 +306,7 @@ impl LLMProvider for AnthropicProvider { } } - Ok(ChatCompletionResponse { + let response = ChatCompletionResponse { id: anthropic_resp.id.unwrap_or_default(), model: anthropic_resp.model.unwrap_or_default(), content, @@ -267,7 +316,20 @@ impl LLMProvider for AnthropicProvider { completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0), }, - }) + }; + + if let Some(ref storage) = self.storage { + let _ = storage.append_llm_call( + &self.name, + &self.model_id, + &req_body_str, + Some(&body_text), + None, + start.elapsed().as_millis() as u64, + ).await; + } + + Ok(response) } fn ptype(&self) -> &str { diff --git a/src/providers/openai.rs b/src/providers/openai.rs index b55ed8e..2fbdb76 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -7,6 +7,8 @@ use std::collections::HashMap; use crate::bus::message::ContentBlock; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::traits::Usage; +use std::sync::Arc; +use crate::storage::Storage; fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { if blocks.len() == 1 { @@ -32,6 +34,7 @@ pub struct OpenAIProvider { temperature: Option, max_tokens: Option, model_extra: HashMap, + storage: Option>, } impl OpenAIProvider { @@ -55,9 +58,14 @@ impl OpenAIProvider { temperature, max_tokens, model_extra, + storage: None, } } + pub fn set_storage(&mut self, storage: Arc) { + self.storage = Some(storage); + } + fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { let mut body = json!({ "model": self.model_id, @@ -162,6 +170,7 @@ impl LLMProvider for OpenAIProvider { &self, request: ChatCompletionRequest, ) -> Result> { + let start = std::time::Instant::now(); let url = format!("{}/chat/completions", self.base_url); let body = self.build_request_body(&request); @@ -200,24 +209,44 @@ impl LLMProvider for OpenAIProvider { req_builder = req_builder.header(key.as_str(), value.as_str()); } + let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default(); + tracing::debug!(req_body = %req_body_str, "LLM request"); + let resp = req_builder.json(&body).send().await?; let status = resp.status(); let text = resp.text().await?; - - // Debug: Log LLM response (only in debug builds) - #[cfg(debug_assertions)] - { - let resp_preview: String = text.chars().take(100).collect(); - tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)"); - } + tracing::debug!(status = %status, resp_body = %text, "LLM response"); if !status.is_success() { - return Err(format!("API error {}: {}", status, text).into()); + let error = format!("API error {}: {}", status, text); + if let Some(ref storage) = self.storage { + let _ = storage.append_llm_call( + &self.name, &self.model_id, &req_body_str, + Some(&text), Some(&error), + start.elapsed().as_millis() as u64, + ).await; + } + return Err(error.into()); } let openai_resp: OpenAIResponse = serde_json::from_str(&text) - .map_err(|e| format!("decode error: {} | body: {}", e, &text))?; + .map_err(|e| { + let err_msg = format!("decode error: {} | body: {}", e, &text); + if let Some(ref storage) = self.storage { + let name = self.name.clone(); + let model = self.model_id.clone(); + let req = req_body_str.clone(); + let resp = text.clone(); + let dur = start.elapsed().as_millis() as u64; + let err = err_msg.clone(); + let s = storage.clone(); + tokio::spawn(async move { + let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await; + }); + } + err_msg + })?; let content = openai_resp.choices[0] .message @@ -237,7 +266,7 @@ impl LLMProvider for OpenAIProvider { }) .collect(); - Ok(ChatCompletionResponse { + let response = ChatCompletionResponse { id: openai_resp.id, model: openai_resp.model, content, @@ -247,7 +276,17 @@ impl LLMProvider for OpenAIProvider { completion_tokens: openai_resp.usage.completion_tokens, total_tokens: openai_resp.usage.total_tokens, }, - }) + }; + + if let Some(ref storage) = self.storage { + let _ = storage.append_llm_call( + &self.name, &self.model_id, &req_body_str, + Some(&text), None, + start.elapsed().as_millis() as u64, + ).await; + } + + Ok(response) } fn ptype(&self) -> &str { diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 63a46ef..b5f20b2 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -123,4 +123,6 @@ pub trait LLMProvider: Send + Sync { fn name(&self) -> &str; fn model_id(&self) -> &str; + + fn set_storage(&mut self, _storage: std::sync::Arc) {} } diff --git a/src/session/session.rs b/src/session/session.rs index d5ed0f8..6eb5ddc 100644 --- a/src/session/session.rs +++ b/src/session/session.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::sync::{Mutex, mpsc}; +use tokio::sync::Mutex; use uuid::Uuid; use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; @@ -21,7 +21,6 @@ use crate::config::LLMProviderConfig; use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::agent::system_prompt::build_system_prompt; use crate::agent::context_compressor::ContextCompressionConfig; -use crate::protocol::WsOutbound; use crate::providers::{create_provider, LLMProvider}; use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID}; use crate::session::events::DialogInfo; @@ -49,7 +48,6 @@ pub struct Session { messages: Vec, seq_counter: i64, - pub user_tx: mpsc::Sender, provider_config: LLMProviderConfig, provider: Arc, tools: Arc, @@ -63,14 +61,16 @@ impl Session { pub async fn new( id: UnifiedSessionId, provider_config: LLMProviderConfig, - user_tx: mpsc::Sender, tools: Arc, storage: Option>, routing_info: String, title: String, ) -> Result { - let provider_box = create_provider(provider_config.clone()) + let mut provider_box = create_provider(provider_config.clone()) .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; + if let Some(ref s) = storage { + provider_box.set_storage(s.clone()); + } let provider: Arc = Arc::from(provider_box); let compressor_config = ContextCompressionConfig { @@ -89,7 +89,6 @@ impl Session { total_message_count: 0, messages: Vec::new(), seq_counter: 1, - user_tx, provider_config: provider_config.clone(), provider: provider.clone(), tools, @@ -103,7 +102,6 @@ impl Session { pub async fn from_storage( id: UnifiedSessionId, provider_config: LLMProviderConfig, - user_tx: mpsc::Sender, tools: Arc, storage: StdArc, ) -> Result { @@ -113,8 +111,9 @@ impl Session { let messages = storage.load_messages(&id.to_string(), 0).await .map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?; - let provider_box = create_provider(provider_config.clone()) + let mut provider_box = create_provider(provider_config.clone()) .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; + provider_box.set_storage(storage.clone()); let provider: Arc = Arc::from(provider_box); let compressor_config = ContextCompressionConfig { @@ -123,6 +122,7 @@ impl Session { }; // Convert MessageMeta to ChatMessage + // Clear tool_call_id/tool_name — they're not valid across API sessions let chat_messages: Vec = messages.into_iter().map(|m| { ChatMessage { id: m.id, @@ -130,8 +130,8 @@ impl Session { content: m.content, media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(), timestamp: m.created_at, - tool_call_id: m.tool_call_id, - tool_name: m.tool_name, + tool_call_id: None, + tool_name: None, tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()), source: m.source.and_then(|s| serde_json::from_str(&s).ok()), } @@ -149,7 +149,6 @@ impl Session { total_message_count, messages: chat_messages, seq_counter, - user_tx, provider_config: provider_config.clone(), provider: provider.clone(), tools, @@ -252,18 +251,6 @@ impl Session { } } - pub async fn send(&self, msg: WsOutbound) { - let _ = self.user_tx.send(msg).await; - } - - /// 发送系统通知(不记录进 session 历史) - pub async fn send_system_notification(&self, content: &str) { - let msg = WsOutbound::SystemNotification { - content: content.to_string(), - }; - let _ = self.user_tx.send(msg).await; - } - /// 将 session 元数据写回 Storage pub async fn persist_session_meta(&self) -> Result<(), StorageError> { if let Some(ref storage) = self.storage { @@ -364,6 +351,14 @@ impl Session { )) } + /// 创建一个附通知通道的 AgentLoop 实例 + pub fn create_agent_with_notify( + &self, + notify_tx: tokio::sync::mpsc::UnboundedSender, + ) -> Result { + Ok(self.create_agent()?.with_notify(notify_tx)) + } + /// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills) pub fn build_system_prompt(&self, skills_prompt: &str) -> String { let base_prompt = build_system_prompt( @@ -874,11 +869,9 @@ impl SessionManager { self.storage.upsert_session(&meta).await .map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?; - let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( unified_id.clone(), self.provider_config.clone(), - user_tx, self.tools.clone(), Some(self.storage.clone()), routing_info, @@ -909,11 +902,9 @@ impl SessionManager { match self.storage.get_session(&session_id_str).await { Ok(meta) => { tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage"); - let (user_tx, _rx) = mpsc::channel::(100); let session = Session::from_storage( unified_id.clone(), self.provider_config.clone(), - user_tx, self.tools.clone(), self.storage.clone(), ).await?; @@ -932,11 +923,9 @@ impl SessionManager { } // Create new session - let (user_tx, _rx) = mpsc::channel::(100); let session = Session::new( unified_id.clone(), self.provider_config.clone(), - user_tx, self.tools.clone(), Some(self.storage.clone()), String::new(), @@ -1175,6 +1164,30 @@ impl SessionManager { } // Normal message handling through LLM + let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel(); + + // Spawn notification publisher — sends immediately when tools are detected + { + let bus = self.bus.clone(); + let ch = channel.to_string(); + let cid = chat_id.to_string(); + tokio::spawn(async move { + while let Some(notif) = notify_rx.recv().await { + let mut metadata = HashMap::new(); + metadata.insert("_type".to_string(), "notification".to_string()); + let outbound = OutboundMessage { + channel: ch.clone(), + chat_id: cid.clone(), + content: notif, + reply_to: None, + media: vec![], + metadata, + }; + let _ = bus.publish_outbound(outbound).await; + } + }); + } + let response: String = { let mut session_guard = session.lock().await; @@ -1202,7 +1215,7 @@ impl SessionManager { .compress_if_needed(history) .await?; - let agent = session_guard.create_agent()?; + let agent = session_guard.create_agent_with_notify(notify_tx)?; let result = agent.process(history).await?; for msg in result.emitted_messages { @@ -1322,7 +1335,6 @@ impl OutboundMessenger for SessionManager { mod tests { use super::*; use std::collections::HashMap; - use tokio::sync::mpsc; fn test_provider_config() -> LLMProviderConfig { LLMProviderConfig { diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 94f8680..b9a7a8d 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -92,6 +92,70 @@ impl Storage { .await .ok(); + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS llm_calls ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at INTEGER NOT NULL, + provider TEXT NOT NULL, + model TEXT NOT NULL, + request_body TEXT NOT NULL, + response_body TEXT, + error TEXT, + duration_ms INTEGER + ) + "#, + ) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn append_llm_call( + &self, + provider: &str, + model: &str, + request_body: &str, + response_body: Option<&str>, + error: Option<&str>, + duration_ms: u64, + ) -> Result<(), StorageError> { + let now = chrono::Utc::now().timestamp_millis(); + sqlx::query( + r#" + INSERT INTO llm_calls (created_at, provider, model, request_body, response_body, error, duration_ms) + VALUES (?, ?, ?, ?, ?, ?, ?) + "#, + ) + .bind(now) + .bind(provider) + .bind(model) + .bind(request_body) + .bind(response_body) + .bind(error) + .bind(duration_ms as i64) + .execute(self.pool()) + .await?; + + // Prune to keep last 1000 records + self.prune_llm_calls(1000).await?; + + Ok(()) + } + + async fn prune_llm_calls(&self, max_records: i64) -> Result<(), StorageError> { + sqlx::query( + r#" + DELETE FROM llm_calls WHERE id <= ( + SELECT COALESCE(MAX(id), 0) - ? FROM llm_calls + ) + "#, + ) + .bind(max_records) + .execute(self.pool()) + .await?; + Ok(()) }