diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 70aae0a..1c443ea 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,15 +1,15 @@ use async_trait::async_trait; use crate::bus::message::ContentBlock; use crate::bus::ChatMessage; +use crate::bus::message::ToolMessageState; use crate::config::LLMProviderConfig; use crate::observability::{ - truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, + truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState, }; use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; use crate::skills::SkillRuntime; use crate::storage::SessionStore; use crate::tools::{ToolContext, ToolRegistry}; -use serde::Deserialize; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; @@ -22,11 +22,8 @@ const MAX_TOOL_RESULT_CHARS: usize = 16_000; /// Minimum characters to keep when truncating const TRUNCATION_SUFFIX_LEN: usize = 200; const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用 memory_manage 工具维护长期记忆。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespace:preferences、profile、tasks、decisions。若需要写入,优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。搜索记忆时,优先使用 memory_manage(action='search'),并尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。"; -const MEMORY_EXTRACTION_SYSTEM_PROMPT: &str = "你负责从一小段对话中提取值得长期记忆的信息。只返回 JSON 数组,不要输出解释或 Markdown。数组元素格式为 {\"namespace\": string, \"key\": string, \"content\": string}。只有在内容属于长期偏好、稳定事实、用户纠正、持续任务上下文或明确决策时才输出;否则返回 []。namespace 只能使用 preferences、profile、tasks、decisions。key 用简短稳定的 snake_case 英文标识。不要保存一次性工具结果、临时列表、敏感信息或猜测。最多输出 2 条。"; -const MEMORY_EXTRACTION_REASONING_EFFORT: &str = "none"; -const MEMORY_EXTRACTION_MAX_TOKENS: u32 = 192; -const MEMORY_EXTRACTION_CONTEXT_MESSAGES: usize = 4; -const MEMORY_EXTRACTION_MAX_CANDIDATES: usize = 2; +const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; +const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。"; /// Build content blocks from text and media paths fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec { @@ -98,6 +95,10 @@ fn truncate_tool_result(output: &str) -> String { } } +fn parse_pending_tool_output(output: &str) -> Option { + output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string()) +} + /// Loop detection result. #[derive(Debug, Clone, PartialEq, Eq)] enum LoopDetectionResult { @@ -250,19 +251,6 @@ pub struct AgentProcessResult { pub emitted_messages: Vec, } -#[derive(Debug, Clone, Deserialize, PartialEq, Eq)] -struct MemoryCandidate { - namespace: String, - key: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct MemoryCandidateEnvelope { - #[serde(default)] - memories: Vec, -} - #[async_trait] pub trait EmittedMessageHandler: Send + Sync + 'static { async fn handle(&self, message: ChatMessage); @@ -370,7 +358,6 @@ impl AgentLoop { // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut emitted_messages = Vec::new(); - let mut memory_write_occurred = false; for iteration in 0..self.max_iterations { #[cfg(debug_assertions)] @@ -416,9 +403,6 @@ impl AgentLoop { // If no tool calls, this is the final response if response.tool_calls.is_empty() { let assistant_message = ChatMessage::assistant(response.content); - if !memory_write_occurred { - self.maybe_extract_and_store_memories(&messages, &assistant_message).await?; - } emitted_messages.push(assistant_message.clone()); return Ok(AgentProcessResult { final_response: assistant_message, @@ -440,14 +424,6 @@ impl AgentLoop { // Execute tools and add results to messages let tool_results = self.execute_tools(&response.tool_calls).await; - if response - .tool_calls - .iter() - .zip(tool_results.iter()) - .any(|(tool_call, result)| did_successfully_write_memory(tool_call, result)) - { - memory_write_occurred = true; - } for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { // Log function call with name and arguments @@ -471,19 +447,29 @@ impl AgentLoop { "Loop warning: {}", msg ); - let tool_message = ChatMessage::tool( + let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), + if result.state == ToolExecutionState::PendingUserAction { + ToolMessageState::PendingUserAction + } else { + ToolMessageState::Completed + }, ); messages.push(tool_message.clone()); emitted_messages.push(tool_message); } LoopDetectionResult::Ok => { - let tool_message = ChatMessage::tool( + let tool_message = ChatMessage::tool_with_state( tool_call.id.clone(), tool_call.name.clone(), truncated_output, + if result.state == ToolExecutionState::PendingUserAction { + ToolMessageState::PendingUserAction + } else { + ToolMessageState::Completed + }, ); messages.push(tool_message.clone()); emitted_messages.push(tool_message); @@ -491,6 +477,29 @@ impl AgentLoop { } } + if let Some((tool_call, pending_result)) = response + .tool_calls + .iter() + .zip(tool_results.iter()) + .find(|(_, result)| result.state == ToolExecutionState::PendingUserAction) + { + let assistant_message = ChatMessage::assistant(format!( + "{}\n\n当前等待中的工具: {}", + pending_result + .output + .lines() + .next() + .filter(|line| !line.trim().is_empty()) + .unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE), + tool_call.name, + )); + emitted_messages.push(assistant_message.clone()); + return Ok(AgentProcessResult { + final_response: assistant_message, + emitted_messages, + }); + } + // Loop continues to next iteration with updated messages #[cfg(debug_assertions)] tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); @@ -524,9 +533,6 @@ impl AgentLoop { match (*self.provider).chat(request).await { Ok(response) => { let assistant_message = ChatMessage::assistant(response.content); - if !memory_write_occurred { - self.maybe_extract_and_store_memories(&messages, &assistant_message).await?; - } emitted_messages.push(assistant_message.clone()); Ok(AgentProcessResult { final_response: assistant_message, @@ -705,7 +711,11 @@ impl AgentLoop { match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await { Ok(result) => { if result.success { - ToolExecutionOutcome::success(result.output) + if let Some(pending_output) = parse_pending_tool_output(&result.output) { + ToolExecutionOutcome::pending(pending_output) + } else { + ToolExecutionOutcome::success(result.output) + } } else { let error = result.error.unwrap_or_default(); ToolExecutionOutcome::failure( @@ -743,227 +753,6 @@ impl AgentLoop { } } -fn did_successfully_write_memory(tool_call: &ToolCall, result: &ToolExecutionOutcome) -> bool { - if !result.success || tool_call.name != "memory_manage" { - return false; - } - - matches!( - tool_call.arguments.get("action").and_then(|value| value.as_str()), - Some("put") | Some("update") - ) -} - -fn lightweight_provider_config(provider_config: &LLMProviderConfig) -> LLMProviderConfig { - let mut config = provider_config.clone(); - if config.provider_type == "openai" { - config.model_extra.insert( - "reasoning_effort".to_string(), - serde_json::Value::String(MEMORY_EXTRACTION_REASONING_EFFORT.to_string()), - ); - } - config -} - -fn build_memory_extraction_context(messages: &[ChatMessage], final_response: &ChatMessage) -> String { - let recent_messages: Vec<&ChatMessage> = messages - .iter() - .rev() - .filter(|message| matches!(message.role.as_str(), "user" | "assistant")) - .take(MEMORY_EXTRACTION_CONTEXT_MESSAGES) - .collect::>() - .into_iter() - .rev() - .collect(); - - let mut lines = vec!["请判断这轮对话中是否有值得长期记忆的信息:".to_string()]; - for message in recent_messages { - let content = message.content.trim(); - if content.is_empty() { - continue; - } - - lines.push(format!( - "{}: {}", - message.role, - truncate_text_for_memory_extraction(content) - )); - } - - let final_content = final_response.content.trim(); - if !final_content.is_empty() { - lines.push(format!( - "assistant_final: {}", - truncate_text_for_memory_extraction(final_content) - )); - } - - lines.join("\n") -} - -fn truncate_text_for_memory_extraction(text: &str) -> String { - if text.chars().count() <= 240 { - return text.to_string(); - } - - let prefix: String = text.chars().take(240).collect(); - format!("{}...", prefix) -} - -fn parse_memory_candidates(raw: &str) -> Vec { - let trimmed = raw.trim(); - let json_payload = trimmed - .strip_prefix("```json") - .and_then(|value| value.strip_suffix("```")) - .map(str::trim) - .or_else(|| { - trimmed - .strip_prefix("```") - .and_then(|value| value.strip_suffix("```")) - .map(str::trim) - }) - .unwrap_or(trimmed); - - if let Ok(candidates) = serde_json::from_str::>(json_payload) { - return normalize_memory_candidates(candidates); - } - - if let Ok(envelope) = serde_json::from_str::(json_payload) { - return normalize_memory_candidates(envelope.memories); - } - - Vec::new() -} - -fn normalize_memory_candidates(candidates: Vec) -> Vec { - let mut normalized = Vec::new(); - - for candidate in candidates { - let Some(namespace) = canonical_memory_namespace(&candidate.namespace) else { - continue; - }; - let key = canonical_memory_key(&candidate.key); - let content = candidate.content.trim().to_string(); - if key.is_empty() || content.len() < 8 { - continue; - } - - normalized.push(MemoryCandidate { - namespace, - key, - content, - }); - - if normalized.len() >= MEMORY_EXTRACTION_MAX_CANDIDATES { - break; - } - } - - normalized -} - -fn canonical_memory_namespace(namespace: &str) -> Option { - match namespace.trim().to_lowercase().as_str() { - "preferences" | "preference" => Some("preferences".to_string()), - "profile" | "profiles" => Some("profile".to_string()), - "tasks" | "task" => Some("tasks".to_string()), - "decisions" | "decision" => Some("decisions".to_string()), - _ => None, - } -} - -fn canonical_memory_key(key: &str) -> String { - let mut normalized = String::new(); - let mut prev_is_sep = false; - - for ch in key.trim().chars() { - let mapped = if ch.is_ascii_alphanumeric() { - prev_is_sep = false; - Some(ch.to_ascii_lowercase()) - } else if matches!(ch, '_' | '-' | '/' | ' ') { - if prev_is_sep { - None - } else { - prev_is_sep = true; - Some('_') - } - } else { - None - }; - - if let Some(ch) = mapped { - normalized.push(ch); - } - } - - normalized.trim_matches('_').to_string() -} - -impl AgentLoop { - async fn maybe_extract_and_store_memories( - &self, - messages: &[ChatMessage], - final_response: &ChatMessage, - ) -> Result { - let Some(tool) = self.tools.get("memory_manage") else { - return Ok(0); - }; - if self.tool_context.channel_name.is_none() || self.tool_context.sender_id.is_none() { - return Ok(0); - } - - let context = build_memory_extraction_context(messages, final_response); - let provider = create_provider(lightweight_provider_config(&self.provider_config)) - .map_err(|err| AgentError::ProviderCreation(err.to_string()))?; - let request = ChatCompletionRequest { - messages: vec![ - Message::system(MEMORY_EXTRACTION_SYSTEM_PROMPT), - Message::user(context), - ], - temperature: Some(0.0), - max_tokens: Some(MEMORY_EXTRACTION_MAX_TOKENS), - tools: None, - }; - - let response = provider - .chat(request) - .await - .map_err(|err| AgentError::LlmError(err.to_string()))?; - let candidates = parse_memory_candidates(&response.content); - if candidates.is_empty() { - tracing::debug!("No auto-save memory candidates extracted from final response"); - return Ok(0); - } - - tracing::info!(candidate_count = candidates.len(), candidates = ?candidates, "Auto-save memory candidates extracted"); - - let mut saved = 0; - for candidate in candidates { - let result = tool - .execute_with_context( - &self.tool_context, - serde_json::json!({ - "action": "put", - "namespace": candidate.namespace, - "key": candidate.key, - "content": candidate.content, - }), - ) - .await - .map_err(|err| AgentError::Other(format!("auto-save memory tool error: {}", err)))?; - - if result.success { - saved += 1; - } else { - tracing::warn!(error = result.error.as_deref().unwrap_or("unknown"), "Auto-save memory write failed"); - } - } - - tracing::info!(saved_count = saved, "Auto-save memory extraction completed"); - Ok(saved) - } -} - #[cfg(test)] mod tests { use super::*; @@ -1051,70 +840,10 @@ mod tests { } #[test] - fn test_did_successfully_write_memory_only_accepts_successful_put_or_update() { - let tool_call = ToolCall { - id: "call_1".to_string(), - name: "memory_manage".to_string(), - arguments: serde_json::json!({ "action": "put" }), - }; - - assert!(did_successfully_write_memory( - &tool_call, - &ToolExecutionOutcome::success("ok".to_string()) - )); - - let failed = ToolExecutionOutcome::failure("err".to_string(), Some("boom".to_string())); - assert!(!did_successfully_write_memory(&tool_call, &failed)); - - let search_call = ToolCall { - id: "call_2".to_string(), - name: "memory_manage".to_string(), - arguments: serde_json::json!({ "action": "search" }), - }; - assert!(!did_successfully_write_memory( - &search_call, - &ToolExecutionOutcome::success("ok".to_string()) - )); - } - - #[test] - fn test_parse_memory_candidates_normalizes_and_limits_results() { - let raw = r#"```json - {"memories": [ - {"namespace": "preference", "key": "Email Folder Preference", "content": "用户提到邮件时默认查看代收邮箱而不是收件箱。"}, - {"namespace": "decision", "key": "mailbox strategy", "content": "后续默认先查看代收邮箱。"}, - {"namespace": "tasks", "key": "short", "content": "太短"} - ]} - ```"#; - - let candidates = parse_memory_candidates(raw); - assert_eq!(candidates.len(), 2); - assert_eq!(candidates[0].namespace, "preferences"); - assert_eq!(candidates[0].key, "email_folder_preference"); - assert_eq!(candidates[1].namespace, "decisions"); - assert_eq!(candidates[1].key, "mailbox_strategy"); - } - - #[test] - fn test_build_memory_extraction_context_uses_recent_user_and_assistant_messages() { - let messages = vec![ - ChatMessage::system("system"), - ChatMessage::tool("call-1", "calculator", "2"), - ChatMessage::user("first user"), - ChatMessage::assistant("first assistant"), - ChatMessage::user("second user"), - ChatMessage::assistant("second assistant"), - ChatMessage::user("third user"), - ]; - let final_response = ChatMessage::assistant("final assistant"); - - let context = build_memory_extraction_context(&messages, &final_response); - assert!(context.contains("user: first user") == false); - assert!(context.contains("user: second user")); - assert!(context.contains("assistant: second assistant")); - assert!(context.contains("user: third user")); - assert!(context.contains("assistant_final: final assistant")); - assert!(!context.contains("tool:")); + fn test_parse_pending_tool_output() { + let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权"); + assert_eq!(output.as_deref(), Some("请完成授权")); + assert!(parse_pending_tool_output("normal output").is_none()); } } diff --git a/src/bus/message.rs b/src/bus/message.rs index bdb82f4..31be986 100644 --- a/src/bus/message.rs +++ b/src/bus/message.rs @@ -3,6 +3,13 @@ use serde::{Deserialize, Serialize}; use crate::providers::ToolCall; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum ToolMessageState { + Completed, + PendingUserAction, +} + // ============================================================================ // ContentBlock - Multimodal content representation (OpenAI-style) // ============================================================================ @@ -72,6 +79,8 @@ pub struct ChatMessage { #[serde(skip_serializing_if = "Option::is_none")] pub tool_name: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub tool_state: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, } @@ -85,6 +94,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_state: None, tool_calls: None, } } @@ -98,6 +108,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_state: None, tool_calls: None, } } @@ -111,6 +122,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_state: None, tool_calls: None, } } @@ -124,6 +136,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_state: None, tool_calls: Some(tool_calls), } } @@ -137,11 +150,21 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: None, tool_name: None, + tool_state: None, tool_calls: None, } } pub fn tool(tool_call_id: impl Into, tool_name: impl Into, content: impl Into) -> Self { + Self::tool_with_state(tool_call_id, tool_name, content, ToolMessageState::Completed) + } + + pub fn tool_with_state( + tool_call_id: impl Into, + tool_name: impl Into, + content: impl Into, + tool_state: ToolMessageState, + ) -> Self { Self { id: uuid::Uuid::new_v4().to_string(), role: "tool".to_string(), @@ -150,6 +173,7 @@ impl ChatMessage { timestamp: current_timestamp(), tool_call_id: Some(tool_call_id.into()), tool_name: Some(tool_name.into()), + tool_state: Some(tool_state), tool_calls: None, } } @@ -212,6 +236,7 @@ pub enum OutboundEventKind { AssistantResponse, ToolCall, ToolResult, + ToolPending, } impl OutboundMessage { @@ -294,6 +319,33 @@ impl OutboundMessage { } } + pub fn tool_pending( + channel: impl Into, + chat_id: impl Into, + tool_call_id: impl Into, + tool_name: impl Into, + content: impl Into, + reply_to: Option, + metadata: HashMap, + ) -> Self { + let tool_name = tool_name.into(); + let raw_content = content.into(); + let content = format_tool_result_content(&tool_name, &raw_content); + Self { + channel: channel.into(), + chat_id: chat_id.into(), + content, + reply_to, + media: Vec::new(), + metadata, + event_kind: OutboundEventKind::ToolPending, + role: "tool".to_string(), + tool_call_id: Some(tool_call_id.into()), + tool_name: Some(tool_name), + tool_arguments: None, + } + } + pub fn from_chat_message( channel: &str, chat_id: &str, @@ -328,7 +380,26 @@ impl OutboundMessage { )] } } - "tool" => Vec::new(), + "tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) { + ToolMessageState::Completed => vec![Self::tool_result( + channel.to_string(), + chat_id.to_string(), + message.tool_call_id.clone().unwrap_or_default(), + message.tool_name.clone().unwrap_or_default(), + message.content.clone(), + reply_to, + metadata.clone(), + )], + ToolMessageState::PendingUserAction => vec![Self::tool_pending( + channel.to_string(), + chat_id.to_string(), + message.tool_call_id.clone().unwrap_or_default(), + message.tool_name.clone().unwrap_or_default(), + message.content.clone(), + reply_to, + metadata.clone(), + )], + }, _ => Vec::new(), } } @@ -377,7 +448,7 @@ fn current_timestamp() -> i64 { #[cfg(test)] mod tests { - use super::{ChatMessage, OutboundEventKind, OutboundMessage}; + use super::{ChatMessage, OutboundEventKind, OutboundMessage, ToolMessageState}; use crate::providers::ToolCall; use serde_json::json; use std::collections::HashMap; @@ -418,7 +489,7 @@ mod tests { } #[test] - fn test_from_chat_message_omits_tool_result() { + fn test_from_chat_message_includes_tool_result() { let message = ChatMessage::tool("call-9", "calculator", "2"); let outbound = OutboundMessage::from_chat_message( @@ -429,6 +500,28 @@ mod tests { &message, ); - assert!(outbound.is_empty()); + assert_eq!(outbound.len(), 1); + assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult); + } + + #[test] + fn test_from_chat_message_includes_tool_pending() { + let message = ChatMessage::tool_with_state( + "call-9", + "bash", + "等待你完成浏览器授权后再继续。", + ToolMessageState::PendingUserAction, + ); + + let outbound = OutboundMessage::from_chat_message( + "feishu", + "chat-1", + None, + &HashMap::new(), + &message, + ); + + assert_eq!(outbound.len(), 1); + assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending); } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 78c9ea8..4bd50be 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -73,6 +73,9 @@ pub async fn run(gateway_url: &str) -> Result<(), Box> { WsOutbound::ToolResult { tool_name, content, .. } => { input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?; } + WsOutbound::ToolPending { tool_name, content, resume_hint, .. } => { + input.write_output(&format!("Tool pending: {}\n{}\n{}\n", tool_name, content, resume_hint)).await?; + } WsOutbound::Error { message, .. } => { input.write_output(&format!("Error: {}", message)).await?; } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 756dd2c..bc25218 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -19,7 +19,7 @@ use crate::tools::{ }; const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot,一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n"; -const MEMORY_KEYWORD_SYSTEM_PROMPT: &str = "你负责为长期记忆检索生成关键词。根据给定的最近会话,仅输出 JSON 数组字符串。关键词必须同时覆盖中文和英文:优先为同一主题同时给出中文关键词和对应英文关键词。关键词必须是短词语,优先使用最容易命中记忆的核心检索词,不要输出完整句子、解释或长描述。必要时优先保留实体名、产品名、偏好名、snake_case key 风格短词。数组元素总数控制在 2 到 6 个简短关键词或短语。不要输出解释,不要输出 Markdown。"; +const MEMORY_KEYWORD_SYSTEM_PROMPT: &str = "你负责为长期记忆检索生成关键词。根据给定的最近会话,仅输出 JSON 数组字符串。关键词必须同时覆盖中文和英文:优先为同一主题同时给出中文关键词和对应英文关键词。关键词必须是短词语,优先使用最容易命中记忆的核心检索词,不要输出完整句子、解释或长描述。必要时优先保留实体名、产品名、偏好名、snake_case key 风格短词。数组元素总数控制在 2 到 6 个简短关键词或短语。不要输出解释,不要输出 Markdown。如果最新的会话跟前面有明显主题变化,且新主题可能与旧主题的关键词不同,优先输出新主题的关键词;如果最新会话是旧主题的延续,优先输出旧主题的关键词;如果最新会话同时包含新旧主题,优先输出同时覆盖新旧主题的关键词。"; const RELATED_MEMORY_SYSTEM_PROMPT_PREFIX: &str = "找到相关的记忆。你必须优先参考这些记忆,并在后续推理中把它们当作当前会话的补充上下文;若与用户本轮明确要求冲突,以用户本轮要求为准。"; const MEMORY_KEYWORD_REASONING_EFFORT: &str = "none"; const MEMORY_KEYWORD_MAX_CHARS: usize = 32; diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 88dd308..63b4e19 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -6,7 +6,7 @@ use axum::response::Response; use futures_util::{SinkExt, StreamExt}; use tokio::sync::{mpsc, Mutex}; use crate::agent::EmittedMessageHandler; -use crate::bus::message::format_tool_call_content; +use crate::bus::message::{format_tool_call_content, ToolMessageState}; use crate::bus::ChatMessage; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; use super::{GatewayState, session::{Session, handle_in_chat_command}}; @@ -181,7 +181,23 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec { }] } } - "tool" => Vec::new(), + "tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) { + ToolMessageState::Completed => vec![WsOutbound::ToolResult { + id: message.id.clone(), + tool_call_id: message.tool_call_id.clone().unwrap_or_default(), + tool_name: message.tool_name.clone().unwrap_or_default(), + content: message.content.clone(), + role: message.role.clone(), + }], + ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending { + id: message.id.clone(), + tool_call_id: message.tool_call_id.clone().unwrap_or_default(), + tool_name: message.tool_name.clone().unwrap_or_default(), + content: message.content.clone(), + role: message.role.clone(), + resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(), + }], + }, _ => Vec::new(), } } @@ -391,6 +407,7 @@ async fn handle_inbound( mod tests { use super::ws_outbound_from_chat_message; use crate::bus::ChatMessage; + use crate::bus::message::ToolMessageState; use crate::providers::ToolCall; use crate::protocol::WsOutbound; use serde_json::json; @@ -421,11 +438,27 @@ mod tests { } #[test] - fn test_ws_outbound_from_chat_message_omits_tool_results() { + fn test_ws_outbound_from_chat_message_includes_tool_results() { let message = ChatMessage::tool("call-1", "calculator", "2"); let outbound = ws_outbound_from_chat_message(&message); - assert!(outbound.is_empty()); + assert_eq!(outbound.len(), 1); + assert!(matches!(outbound[0], WsOutbound::ToolResult { .. })); + } + + #[test] + fn test_ws_outbound_from_chat_message_includes_tool_pending() { + let message = ChatMessage::tool_with_state( + "call-1", + "bash", + "等待你完成授权后再继续。", + ToolMessageState::PendingUserAction, + ); + + let outbound = ws_outbound_from_chat_message(&message); + + assert_eq!(outbound.len(), 1); + assert!(matches!(outbound[0], WsOutbound::ToolPending { .. })); } } diff --git a/src/observability/mod.rs b/src/observability/mod.rs index d050567..e66e86c 100644 --- a/src/observability/mod.rs +++ b/src/observability/mod.rs @@ -5,6 +5,12 @@ use std::time::Duration; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolExecutionState { + Completed, + PendingUserAction, +} + /// Events emitted during agent and tool execution. #[derive(Debug, Clone)] pub enum ObserverEvent { @@ -60,6 +66,8 @@ pub struct ToolExecutionOutcome { pub error_reason: Option, /// How long the tool took to execute. pub duration: Duration, + /// Whether the tool completed or is waiting for external user action. + pub state: ToolExecutionState, } impl ToolExecutionOutcome { @@ -70,6 +78,7 @@ impl ToolExecutionOutcome { success: true, error_reason: None, duration: Duration::ZERO, + state: ToolExecutionState::Completed, } } @@ -80,6 +89,18 @@ impl ToolExecutionOutcome { success: true, error_reason: None, duration, + state: ToolExecutionState::Completed, + } + } + + /// Create a pending outcome with zero duration. + pub fn pending(output: String) -> Self { + Self { + output, + success: true, + error_reason: None, + duration: Duration::ZERO, + state: ToolExecutionState::PendingUserAction, } } @@ -90,6 +111,7 @@ impl ToolExecutionOutcome { success: false, error_reason, duration: Duration::ZERO, + state: ToolExecutionState::Completed, } } @@ -100,6 +122,7 @@ impl ToolExecutionOutcome { success: false, error_reason, duration, + state: ToolExecutionState::Completed, } } } @@ -201,6 +224,7 @@ mod tests { assert_eq!(outcome.output, "output content"); assert!(outcome.error_reason.is_none()); assert_eq!(outcome.duration, Duration::ZERO); + assert_eq!(outcome.state, ToolExecutionState::Completed); } #[test] @@ -211,6 +235,7 @@ mod tests { ); assert!(outcome.success); assert_eq!(outcome.duration, Duration::from_millis(100)); + assert_eq!(outcome.state, ToolExecutionState::Completed); } #[test] @@ -223,6 +248,14 @@ mod tests { assert_eq!(outcome.output, "error output"); assert_eq!(outcome.error_reason, Some("error reason".to_string())); assert_eq!(outcome.duration, Duration::ZERO); + assert_eq!(outcome.state, ToolExecutionState::Completed); + } + + #[test] + fn test_tool_execution_outcome_pending() { + let outcome = ToolExecutionOutcome::pending("waiting".to_string()); + assert!(outcome.success); + assert_eq!(outcome.state, ToolExecutionState::PendingUserAction); } #[test] diff --git a/src/protocol.rs b/src/protocol.rs index 17e2bf9..c895d00 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -88,6 +88,15 @@ pub enum WsOutbound { content: String, role: String, }, + #[serde(rename = "tool_pending")] + ToolPending { + id: String, + tool_call_id: String, + tool_name: String, + content: String, + role: String, + resume_hint: String, + }, #[serde(rename = "error")] Error { code: String, message: String }, #[serde(rename = "session_established")] diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 459148e..9db2395 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1030,6 +1030,7 @@ fn load_messages_after( timestamp: row.get(4)?, tool_call_id: row.get(5)?, tool_name: row.get(6)?, + tool_state: None, tool_calls, }) })?; diff --git a/src/tools/bash.rs b/src/tools/bash.rs index 9b9cd92..3354e5f 100644 --- a/src/tools/bash.rs +++ b/src/tools/bash.rs @@ -1,17 +1,21 @@ use std::path::Path; use std::process::Stdio; +use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use serde_json::json; -use tokio::io::AsyncReadExt; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader}; use tokio::process::Command; -use tokio::time::timeout; +use tokio::sync::{Mutex, mpsc}; +use tokio::time::{Instant, sleep_until}; use crate::tools::traits::{Tool, ToolResult}; const MAX_TIMEOUT_SECS: u64 = 600; const MAX_OUTPUT_CHARS: usize = 50_000; +const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__"; +const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。"; pub struct BashTool { timeout_secs: u64, @@ -79,6 +83,49 @@ impl BashTool { tail ) } + + fn pending_output(&self, output: &str) -> String { + format!( + "{}\n{}\n\n{}", + PENDING_USER_ACTION_MARKER, + USER_ACTION_HINT, + self.truncate_output(output.trim()) + ) + } + + fn should_return_pending(&self, interactive: bool, output: &str) -> bool { + let normalized = output.to_lowercase(); + let has_auth_phrase = [ + "等待用户授权", + "等待授权", + "等待你授权", + "在浏览器中打开以下链接进行认证", + "open the following link", + "waiting for authorization", + "waiting for user authorization", + "waiting for approval", + "device/verify", + "user_code=", + ] + .iter() + .any(|pattern| normalized.contains(pattern)); + + has_auth_phrase || (interactive && !output.trim().is_empty()) + } +} + +async fn drain_available_chunks( + rx: &mut mpsc::UnboundedReceiver<(bool, String)>, + stdout_buf: &Arc>, + stderr_buf: &Arc>, +) { + while let Ok((is_stderr, chunk)) = rx.try_recv() { + if is_stderr { + stderr_buf.lock().await.push_str(&chunk); + } else { + stdout_buf.lock().await.push_str(&chunk); + } + } } impl Default for BashTool { @@ -110,6 +157,10 @@ impl Tool for BashTool { "description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS), "minimum": 1, "maximum": MAX_TIMEOUT_SECS + }, + "interactive": { + "type": "boolean", + "description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication" } }, "required": ["command"] @@ -146,6 +197,10 @@ impl Tool for BashTool { .and_then(|v| v.as_u64()) .unwrap_or(self.timeout_secs) .min(MAX_TIMEOUT_SECS); + let interactive = args + .get("interactive") + .and_then(|v| v.as_bool()) + .unwrap_or(false); let cwd = self .working_dir @@ -153,37 +208,29 @@ impl Tool for BashTool { .map(|d| Path::new(d)) .unwrap_or_else(|| Path::new(".")); - let result = timeout( - Duration::from_secs(timeout_secs), - self.run_command(command, cwd), - ) - .await; - - match result { - Ok(Ok(output)) => Ok(ToolResult { + match self.run_command(command, cwd, timeout_secs, interactive).await { + Ok(output) => Ok(ToolResult { success: true, output, error: None, }), - Ok(Err(e)) => Ok(ToolResult { + Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(e), }), - Err(_) => Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!( - "Command timed out after {} seconds", - timeout_secs - )), - }), } } } impl BashTool { - async fn run_command(&self, command: &str, cwd: &Path) -> Result { + async fn run_command( + &self, + command: &str, + cwd: &Path, + timeout_secs: u64, + interactive: bool, + ) -> Result { let mut cmd = Command::new("bash"); cmd.args(["-c", command]) .stdout(Stdio::piped()) @@ -192,50 +239,114 @@ impl BashTool { let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; - let mut stdout = Vec::new(); - let mut stderr = Vec::new(); + let stdout = child.stdout.take(); + let stderr = child.stderr.take(); + let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>(); - if let Some(ref mut out) = child.stdout { - out.read_to_end(&mut stdout) - .await - .map_err(|e| format!("Failed to read stdout: {}", e))?; + if let Some(stdout) = stdout { + tokio::spawn(read_stream(stdout, false, tx.clone())); } - - if let Some(ref mut err) = child.stderr { - err.read_to_end(&mut stderr) - .await - .map_err(|e| format!("Failed to read stderr: {}", e))?; + if let Some(stderr) = stderr { + tokio::spawn(read_stream(stderr, true, tx.clone())); } + drop(tx); - let status = child - .wait() - .await - .map_err(|e| format!("Failed to wait: {}", e))?; + let stdout_buf = Arc::new(Mutex::new(String::new())); + let stderr_buf = Arc::new(Mutex::new(String::new())); + let deadline = Instant::now() + Duration::from_secs(timeout_secs); - let mut output = String::new(); - - if !stdout.is_empty() { - let stdout_str = String::from_utf8_lossy(&stdout); - output.push_str(&stdout_str); - } - - if !stderr.is_empty() { - let stderr_str = String::from_utf8_lossy(&stderr); - if !stderr_str.trim().is_empty() { - if !output.is_empty() { - output.push_str("\n"); + loop { + tokio::select! { + status = child.wait() => { + let status = status.map_err(|e| format!("Failed to wait: {}", e))?; + drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; + while let Some((is_stderr, chunk)) = rx.recv().await { + if is_stderr { + stderr_buf.lock().await.push_str(&chunk); + } else { + stdout_buf.lock().await.push_str(&chunk); + } + } + let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1))); + return Ok(self.truncate_output(&output)); + } + Some((is_stderr, chunk)) = rx.recv() => { + if is_stderr { + stderr_buf.lock().await.push_str(&chunk); + } else { + stdout_buf.lock().await.push_str(&chunk); + } + + let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); + if self.should_return_pending(interactive, &combined) { + drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; + let _ = child.start_kill(); + let _ = child.wait().await; + let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); + return Ok(self.pending_output(&combined)); + } + } + _ = sleep_until(deadline) => { + drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await; + let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None); + let _ = child.start_kill(); + let _ = child.wait().await; + if self.should_return_pending(interactive, &combined) { + return Ok(self.pending_output(&combined)); + } + return Err(format!("Command timed out after {} seconds", timeout_secs)); } - output.push_str("STDERR:\n"); - output.push_str(&stderr_str); } } - - output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1))); - - Ok(self.truncate_output(&output)) } } +async fn read_stream(stream: R, is_stderr: bool, tx: mpsc::UnboundedSender<(bool, String)>) +where + R: AsyncRead + Unpin + Send + 'static, +{ + let mut reader = BufReader::new(stream); + let mut line = String::new(); + + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let _ = tx.send((is_stderr, line.clone())); + } + Err(_) => break, + } + } + + let mut remainder = String::new(); + if reader.read_to_string(&mut remainder).await.is_ok() && !remainder.is_empty() { + let _ = tx.send((is_stderr, remainder)); + } +} + +fn format_command_output(stdout: &str, stderr: &str, exit_code: Option) -> String { + let mut output = String::new(); + + if !stdout.is_empty() { + output.push_str(stdout); + } + + if !stderr.trim().is_empty() { + if !output.is_empty() { + output.push_str("\n"); + } + output.push_str("STDERR:\n"); + output.push_str(stderr); + } + + if let Some(code) = exit_code { + output.push_str(&format!("\nExit code: {}", code)); + } + + output +} + #[cfg(test)] mod tests { use super::*; @@ -319,6 +430,22 @@ mod tests { assert!(result.error.unwrap().contains("timed out")); } + #[tokio::test] + async fn test_pending_user_action_detection() { + let tool = BashTool::new(); + let result = tool + .execute(json!({ + "command": "printf '在浏览器中打开以下链接进行认证:\n\nhttps://example.com/device/verify\n\n等待用户授权...\n'; sleep 10", + "timeout": 1 + })) + .await + .unwrap(); + + assert!(result.success); + assert!(result.output.contains(PENDING_USER_ACTION_MARKER)); + assert!(result.output.contains("等待用户授权")); + } + #[test] fn test_truncate_output_handles_utf8_char_boundaries() { let tool = BashTool::new();