feat: Enhance tool execution handling with pending user action state

- Introduced ToolMessageState enum to represent tool execution states (Completed, PendingUserAction).
- Updated ChatMessage struct to include tool_state for tracking tool execution status.
- Modified AgentLoop to handle tool results and pending actions, providing appropriate responses to users.
- Enhanced BashTool to detect when commands require user interaction, returning a pending state with hints.
- Updated WebSocket protocol to support tool pending messages, allowing clients to handle pending actions effectively.
- Refactored related tests to ensure proper functionality of new pending state handling.
This commit is contained in:
ooodc 2026-04-22 14:49:50 +08:00
parent 30d033e1d1
commit 038b5eccc6
9 changed files with 413 additions and 385 deletions

View File

@ -1,15 +1,15 @@
use async_trait::async_trait; use async_trait::async_trait;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::observability::{ use crate::observability::{
truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, truncate_args, Observer, ObserverEvent, ToolExecutionOutcome, ToolExecutionState,
}; };
use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall}; use crate::providers::{create_provider, LLMProvider, ChatCompletionRequest, Message, ToolCall};
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::SessionStore; use crate::storage::SessionStore;
use crate::tools::{ToolContext, ToolRegistry}; use crate::tools::{ToolContext, ToolRegistry};
use serde::Deserialize;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::io::Read; use std::io::Read;
@ -22,11 +22,8 @@ const MAX_TOOL_RESULT_CHARS: usize = 16_000;
/// Minimum characters to keep when truncating /// Minimum characters to keep when truncating
const TRUNCATION_SUFFIX_LEN: usize = 200; const TRUNCATION_SUFFIX_LEN: usize = 200;
const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用 memory_manage 工具维护长期记忆。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespacepreferences、profile、tasks、decisions。若需要写入优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。搜索记忆时,优先使用 memory_manage(action='search'),并尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。"; const MEMORY_AUTOSAVE_SYSTEM_PROMPT: &str = "你可以在处理任务过程中使用 memory_manage 工具维护长期记忆。仅在遇到高价值且未来仍有用的信息时写入记忆:用户长期偏好、稳定事实、用户对你的纠正、持续任务/项目上下文、明确决策。不要保存一次性工具结果、临时列表、敏感凭证或不确定推测。写入时优先使用规范 namespacepreferences、profile、tasks、decisions。若需要写入优先调用 memory_manage(action='put');同一 namespace/key 可直接覆盖更新。搜索记忆时,优先使用 memory_manage(action='search'),并尽量同时提供中文关键词、英文别名,以及可能的 snake_case memory_key 词,例如 email / 邮件 / email_folder_preference。";
const MEMORY_EXTRACTION_SYSTEM_PROMPT: &str = "你负责从一小段对话中提取值得长期记忆的信息。只返回 JSON 数组,不要输出解释或 Markdown。数组元素格式为 {\"namespace\": string, \"key\": string, \"content\": string}。只有在内容属于长期偏好、稳定事实、用户纠正、持续任务上下文或明确决策时才输出;否则返回 []。namespace 只能使用 preferences、profile、tasks、decisions。key 用简短稳定的 snake_case 英文标识。不要保存一次性工具结果、临时列表、敏感信息或猜测。最多输出 2 条。"; const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const MEMORY_EXTRACTION_REASONING_EFFORT: &str = "none"; const DEFAULT_PENDING_ASSISTANT_MESSAGE: &str = "工具已经启动并进入等待用户操作的状态。请先完成外部操作,完成后直接告诉我继续。";
const MEMORY_EXTRACTION_MAX_TOKENS: u32 = 192;
const MEMORY_EXTRACTION_CONTEXT_MESSAGES: usize = 4;
const MEMORY_EXTRACTION_MAX_CANDIDATES: usize = 2;
/// Build content blocks from text and media paths /// Build content blocks from text and media paths
fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> { fn build_content_blocks(text: &str, media_paths: &[String]) -> Vec<ContentBlock> {
@ -98,6 +95,10 @@ fn truncate_tool_result(output: &str) -> String {
} }
} }
fn parse_pending_tool_output(output: &str) -> Option<String> {
output.strip_prefix(PENDING_USER_ACTION_MARKER).map(|rest| rest.trim().to_string())
}
/// Loop detection result. /// Loop detection result.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
enum LoopDetectionResult { enum LoopDetectionResult {
@ -250,19 +251,6 @@ pub struct AgentProcessResult {
pub emitted_messages: Vec<ChatMessage>, pub emitted_messages: Vec<ChatMessage>,
} }
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
struct MemoryCandidate {
namespace: String,
key: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct MemoryCandidateEnvelope {
#[serde(default)]
memories: Vec<MemoryCandidate>,
}
#[async_trait] #[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static { pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage); async fn handle(&self, message: ChatMessage);
@ -370,7 +358,6 @@ impl AgentLoop {
// Track tool calls for loop detection // Track tool calls for loop detection
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
let mut emitted_messages = Vec::new(); let mut emitted_messages = Vec::new();
let mut memory_write_occurred = false;
for iteration in 0..self.max_iterations { for iteration in 0..self.max_iterations {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@ -416,9 +403,6 @@ impl AgentLoop {
// If no tool calls, this is the final response // If no tool calls, this is the final response
if response.tool_calls.is_empty() { if response.tool_calls.is_empty() {
let assistant_message = ChatMessage::assistant(response.content); let assistant_message = ChatMessage::assistant(response.content);
if !memory_write_occurred {
self.maybe_extract_and_store_memories(&messages, &assistant_message).await?;
}
emitted_messages.push(assistant_message.clone()); emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult { return Ok(AgentProcessResult {
final_response: assistant_message, final_response: assistant_message,
@ -440,14 +424,6 @@ impl AgentLoop {
// Execute tools and add results to messages // Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await; let tool_results = self.execute_tools(&response.tool_calls).await;
if response
.tool_calls
.iter()
.zip(tool_results.iter())
.any(|(tool_call, result)| did_successfully_write_memory(tool_call, result))
{
memory_write_occurred = true;
}
for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) { for (tool_call, result) in response.tool_calls.iter().zip(tool_results.iter()) {
// Log function call with name and arguments // Log function call with name and arguments
@ -471,19 +447,29 @@ impl AgentLoop {
"Loop warning: {}", "Loop warning: {}",
msg msg
); );
let tool_message = ChatMessage::tool( let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(), tool_call.id.clone(),
tool_call.name.clone(), tool_call.name.clone(),
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output), format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
); );
messages.push(tool_message.clone()); messages.push(tool_message.clone());
emitted_messages.push(tool_message); emitted_messages.push(tool_message);
} }
LoopDetectionResult::Ok => { LoopDetectionResult::Ok => {
let tool_message = ChatMessage::tool( let tool_message = ChatMessage::tool_with_state(
tool_call.id.clone(), tool_call.id.clone(),
tool_call.name.clone(), tool_call.name.clone(),
truncated_output, truncated_output,
if result.state == ToolExecutionState::PendingUserAction {
ToolMessageState::PendingUserAction
} else {
ToolMessageState::Completed
},
); );
messages.push(tool_message.clone()); messages.push(tool_message.clone());
emitted_messages.push(tool_message); emitted_messages.push(tool_message);
@ -491,6 +477,29 @@ impl AgentLoop {
} }
} }
if let Some((tool_call, pending_result)) = response
.tool_calls
.iter()
.zip(tool_results.iter())
.find(|(_, result)| result.state == ToolExecutionState::PendingUserAction)
{
let assistant_message = ChatMessage::assistant(format!(
"{}\n\n当前等待中的工具: {}",
pending_result
.output
.lines()
.next()
.filter(|line| !line.trim().is_empty())
.unwrap_or(DEFAULT_PENDING_ASSISTANT_MESSAGE),
tool_call.name,
));
emitted_messages.push(assistant_message.clone());
return Ok(AgentProcessResult {
final_response: assistant_message,
emitted_messages,
});
}
// Loop continues to next iteration with updated messages // Loop continues to next iteration with updated messages
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration"); tracing::debug!(iteration, message_count = messages.len(), "Tool execution complete, continuing to next iteration");
@ -524,9 +533,6 @@ impl AgentLoop {
match (*self.provider).chat(request).await { match (*self.provider).chat(request).await {
Ok(response) => { Ok(response) => {
let assistant_message = ChatMessage::assistant(response.content); let assistant_message = ChatMessage::assistant(response.content);
if !memory_write_occurred {
self.maybe_extract_and_store_memories(&messages, &assistant_message).await?;
}
emitted_messages.push(assistant_message.clone()); emitted_messages.push(assistant_message.clone());
Ok(AgentProcessResult { Ok(AgentProcessResult {
final_response: assistant_message, final_response: assistant_message,
@ -705,7 +711,11 @@ impl AgentLoop {
match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await { match tool.execute_with_context(&self.tool_context, tool_call.arguments.clone()).await {
Ok(result) => { Ok(result) => {
if result.success { if result.success {
ToolExecutionOutcome::success(result.output) if let Some(pending_output) = parse_pending_tool_output(&result.output) {
ToolExecutionOutcome::pending(pending_output)
} else {
ToolExecutionOutcome::success(result.output)
}
} else { } else {
let error = result.error.unwrap_or_default(); let error = result.error.unwrap_or_default();
ToolExecutionOutcome::failure( ToolExecutionOutcome::failure(
@ -743,227 +753,6 @@ impl AgentLoop {
} }
} }
fn did_successfully_write_memory(tool_call: &ToolCall, result: &ToolExecutionOutcome) -> bool {
if !result.success || tool_call.name != "memory_manage" {
return false;
}
matches!(
tool_call.arguments.get("action").and_then(|value| value.as_str()),
Some("put") | Some("update")
)
}
fn lightweight_provider_config(provider_config: &LLMProviderConfig) -> LLMProviderConfig {
let mut config = provider_config.clone();
if config.provider_type == "openai" {
config.model_extra.insert(
"reasoning_effort".to_string(),
serde_json::Value::String(MEMORY_EXTRACTION_REASONING_EFFORT.to_string()),
);
}
config
}
fn build_memory_extraction_context(messages: &[ChatMessage], final_response: &ChatMessage) -> String {
let recent_messages: Vec<&ChatMessage> = messages
.iter()
.rev()
.filter(|message| matches!(message.role.as_str(), "user" | "assistant"))
.take(MEMORY_EXTRACTION_CONTEXT_MESSAGES)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let mut lines = vec!["请判断这轮对话中是否有值得长期记忆的信息:".to_string()];
for message in recent_messages {
let content = message.content.trim();
if content.is_empty() {
continue;
}
lines.push(format!(
"{}: {}",
message.role,
truncate_text_for_memory_extraction(content)
));
}
let final_content = final_response.content.trim();
if !final_content.is_empty() {
lines.push(format!(
"assistant_final: {}",
truncate_text_for_memory_extraction(final_content)
));
}
lines.join("\n")
}
fn truncate_text_for_memory_extraction(text: &str) -> String {
if text.chars().count() <= 240 {
return text.to_string();
}
let prefix: String = text.chars().take(240).collect();
format!("{}...", prefix)
}
fn parse_memory_candidates(raw: &str) -> Vec<MemoryCandidate> {
let trimmed = raw.trim();
let json_payload = trimmed
.strip_prefix("```json")
.and_then(|value| value.strip_suffix("```"))
.map(str::trim)
.or_else(|| {
trimmed
.strip_prefix("```")
.and_then(|value| value.strip_suffix("```"))
.map(str::trim)
})
.unwrap_or(trimmed);
if let Ok(candidates) = serde_json::from_str::<Vec<MemoryCandidate>>(json_payload) {
return normalize_memory_candidates(candidates);
}
if let Ok(envelope) = serde_json::from_str::<MemoryCandidateEnvelope>(json_payload) {
return normalize_memory_candidates(envelope.memories);
}
Vec::new()
}
fn normalize_memory_candidates(candidates: Vec<MemoryCandidate>) -> Vec<MemoryCandidate> {
let mut normalized = Vec::new();
for candidate in candidates {
let Some(namespace) = canonical_memory_namespace(&candidate.namespace) else {
continue;
};
let key = canonical_memory_key(&candidate.key);
let content = candidate.content.trim().to_string();
if key.is_empty() || content.len() < 8 {
continue;
}
normalized.push(MemoryCandidate {
namespace,
key,
content,
});
if normalized.len() >= MEMORY_EXTRACTION_MAX_CANDIDATES {
break;
}
}
normalized
}
fn canonical_memory_namespace(namespace: &str) -> Option<String> {
match namespace.trim().to_lowercase().as_str() {
"preferences" | "preference" => Some("preferences".to_string()),
"profile" | "profiles" => Some("profile".to_string()),
"tasks" | "task" => Some("tasks".to_string()),
"decisions" | "decision" => Some("decisions".to_string()),
_ => None,
}
}
fn canonical_memory_key(key: &str) -> String {
let mut normalized = String::new();
let mut prev_is_sep = false;
for ch in key.trim().chars() {
let mapped = if ch.is_ascii_alphanumeric() {
prev_is_sep = false;
Some(ch.to_ascii_lowercase())
} else if matches!(ch, '_' | '-' | '/' | ' ') {
if prev_is_sep {
None
} else {
prev_is_sep = true;
Some('_')
}
} else {
None
};
if let Some(ch) = mapped {
normalized.push(ch);
}
}
normalized.trim_matches('_').to_string()
}
impl AgentLoop {
async fn maybe_extract_and_store_memories(
&self,
messages: &[ChatMessage],
final_response: &ChatMessage,
) -> Result<usize, AgentError> {
let Some(tool) = self.tools.get("memory_manage") else {
return Ok(0);
};
if self.tool_context.channel_name.is_none() || self.tool_context.sender_id.is_none() {
return Ok(0);
}
let context = build_memory_extraction_context(messages, final_response);
let provider = create_provider(lightweight_provider_config(&self.provider_config))
.map_err(|err| AgentError::ProviderCreation(err.to_string()))?;
let request = ChatCompletionRequest {
messages: vec![
Message::system(MEMORY_EXTRACTION_SYSTEM_PROMPT),
Message::user(context),
],
temperature: Some(0.0),
max_tokens: Some(MEMORY_EXTRACTION_MAX_TOKENS),
tools: None,
};
let response = provider
.chat(request)
.await
.map_err(|err| AgentError::LlmError(err.to_string()))?;
let candidates = parse_memory_candidates(&response.content);
if candidates.is_empty() {
tracing::debug!("No auto-save memory candidates extracted from final response");
return Ok(0);
}
tracing::info!(candidate_count = candidates.len(), candidates = ?candidates, "Auto-save memory candidates extracted");
let mut saved = 0;
for candidate in candidates {
let result = tool
.execute_with_context(
&self.tool_context,
serde_json::json!({
"action": "put",
"namespace": candidate.namespace,
"key": candidate.key,
"content": candidate.content,
}),
)
.await
.map_err(|err| AgentError::Other(format!("auto-save memory tool error: {}", err)))?;
if result.success {
saved += 1;
} else {
tracing::warn!(error = result.error.as_deref().unwrap_or("unknown"), "Auto-save memory write failed");
}
}
tracing::info!(saved_count = saved, "Auto-save memory extraction completed");
Ok(saved)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -1051,70 +840,10 @@ mod tests {
} }
#[test] #[test]
fn test_did_successfully_write_memory_only_accepts_successful_put_or_update() { fn test_parse_pending_tool_output() {
let tool_call = ToolCall { let output = parse_pending_tool_output("__PICOBOT_PENDING_USER_ACTION__\n请完成授权");
id: "call_1".to_string(), assert_eq!(output.as_deref(), Some("请完成授权"));
name: "memory_manage".to_string(), assert!(parse_pending_tool_output("normal output").is_none());
arguments: serde_json::json!({ "action": "put" }),
};
assert!(did_successfully_write_memory(
&tool_call,
&ToolExecutionOutcome::success("ok".to_string())
));
let failed = ToolExecutionOutcome::failure("err".to_string(), Some("boom".to_string()));
assert!(!did_successfully_write_memory(&tool_call, &failed));
let search_call = ToolCall {
id: "call_2".to_string(),
name: "memory_manage".to_string(),
arguments: serde_json::json!({ "action": "search" }),
};
assert!(!did_successfully_write_memory(
&search_call,
&ToolExecutionOutcome::success("ok".to_string())
));
}
#[test]
fn test_parse_memory_candidates_normalizes_and_limits_results() {
let raw = r#"```json
{"memories": [
{"namespace": "preference", "key": "Email Folder Preference", "content": "用户提到邮件时默认查看代收邮箱而不是收件箱。"},
{"namespace": "decision", "key": "mailbox strategy", "content": "后续默认先查看代收邮箱。"},
{"namespace": "tasks", "key": "short", "content": "太短"}
]}
```"#;
let candidates = parse_memory_candidates(raw);
assert_eq!(candidates.len(), 2);
assert_eq!(candidates[0].namespace, "preferences");
assert_eq!(candidates[0].key, "email_folder_preference");
assert_eq!(candidates[1].namespace, "decisions");
assert_eq!(candidates[1].key, "mailbox_strategy");
}
#[test]
fn test_build_memory_extraction_context_uses_recent_user_and_assistant_messages() {
let messages = vec![
ChatMessage::system("system"),
ChatMessage::tool("call-1", "calculator", "2"),
ChatMessage::user("first user"),
ChatMessage::assistant("first assistant"),
ChatMessage::user("second user"),
ChatMessage::assistant("second assistant"),
ChatMessage::user("third user"),
];
let final_response = ChatMessage::assistant("final assistant");
let context = build_memory_extraction_context(&messages, &final_response);
assert!(context.contains("user: first user") == false);
assert!(context.contains("user: second user"));
assert!(context.contains("assistant: second assistant"));
assert!(context.contains("user: third user"));
assert!(context.contains("assistant_final: final assistant"));
assert!(!context.contains("tool:"));
} }
} }

View File

@ -3,6 +3,13 @@ use serde::{Deserialize, Serialize};
use crate::providers::ToolCall; use crate::providers::ToolCall;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ToolMessageState {
Completed,
PendingUserAction,
}
// ============================================================================ // ============================================================================
// ContentBlock - Multimodal content representation (OpenAI-style) // ContentBlock - Multimodal content representation (OpenAI-style)
// ============================================================================ // ============================================================================
@ -72,6 +79,8 @@ pub struct ChatMessage {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>, pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_state: Option<ToolMessageState>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>, pub tool_calls: Option<Vec<ToolCall>>,
} }
@ -85,6 +94,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None,
tool_calls: None, tool_calls: None,
} }
} }
@ -98,6 +108,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None,
tool_calls: None, tool_calls: None,
} }
} }
@ -111,6 +122,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None,
tool_calls: None, tool_calls: None,
} }
} }
@ -124,6 +136,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None,
tool_calls: Some(tool_calls), tool_calls: Some(tool_calls),
} }
} }
@ -137,11 +150,21 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: None, tool_call_id: None,
tool_name: None, tool_name: None,
tool_state: None,
tool_calls: None, tool_calls: None,
} }
} }
pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self { pub fn tool(tool_call_id: impl Into<String>, tool_name: impl Into<String>, content: impl Into<String>) -> Self {
Self::tool_with_state(tool_call_id, tool_name, content, ToolMessageState::Completed)
}
pub fn tool_with_state(
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
tool_state: ToolMessageState,
) -> Self {
Self { Self {
id: uuid::Uuid::new_v4().to_string(), id: uuid::Uuid::new_v4().to_string(),
role: "tool".to_string(), role: "tool".to_string(),
@ -150,6 +173,7 @@ impl ChatMessage {
timestamp: current_timestamp(), timestamp: current_timestamp(),
tool_call_id: Some(tool_call_id.into()), tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name.into()), tool_name: Some(tool_name.into()),
tool_state: Some(tool_state),
tool_calls: None, tool_calls: None,
} }
} }
@ -212,6 +236,7 @@ pub enum OutboundEventKind {
AssistantResponse, AssistantResponse,
ToolCall, ToolCall,
ToolResult, ToolResult,
ToolPending,
} }
impl OutboundMessage { impl OutboundMessage {
@ -294,6 +319,33 @@ impl OutboundMessage {
} }
} }
pub fn tool_pending(
channel: impl Into<String>,
chat_id: impl Into<String>,
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
content: impl Into<String>,
reply_to: Option<String>,
metadata: HashMap<String, String>,
) -> Self {
let tool_name = tool_name.into();
let raw_content = content.into();
let content = format_tool_result_content(&tool_name, &raw_content);
Self {
channel: channel.into(),
chat_id: chat_id.into(),
content,
reply_to,
media: Vec::new(),
metadata,
event_kind: OutboundEventKind::ToolPending,
role: "tool".to_string(),
tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name),
tool_arguments: None,
}
}
pub fn from_chat_message( pub fn from_chat_message(
channel: &str, channel: &str,
chat_id: &str, chat_id: &str,
@ -328,7 +380,26 @@ impl OutboundMessage {
)] )]
} }
} }
"tool" => Vec::new(), "tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
ToolMessageState::Completed => vec![Self::tool_result(
channel.to_string(),
chat_id.to_string(),
message.tool_call_id.clone().unwrap_or_default(),
message.tool_name.clone().unwrap_or_default(),
message.content.clone(),
reply_to,
metadata.clone(),
)],
ToolMessageState::PendingUserAction => vec![Self::tool_pending(
channel.to_string(),
chat_id.to_string(),
message.tool_call_id.clone().unwrap_or_default(),
message.tool_name.clone().unwrap_or_default(),
message.content.clone(),
reply_to,
metadata.clone(),
)],
},
_ => Vec::new(), _ => Vec::new(),
} }
} }
@ -377,7 +448,7 @@ fn current_timestamp() -> i64 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{ChatMessage, OutboundEventKind, OutboundMessage}; use super::{ChatMessage, OutboundEventKind, OutboundMessage, ToolMessageState};
use crate::providers::ToolCall; use crate::providers::ToolCall;
use serde_json::json; use serde_json::json;
use std::collections::HashMap; use std::collections::HashMap;
@ -418,7 +489,7 @@ mod tests {
} }
#[test] #[test]
fn test_from_chat_message_omits_tool_result() { fn test_from_chat_message_includes_tool_result() {
let message = ChatMessage::tool("call-9", "calculator", "2"); let message = ChatMessage::tool("call-9", "calculator", "2");
let outbound = OutboundMessage::from_chat_message( let outbound = OutboundMessage::from_chat_message(
@ -429,6 +500,28 @@ mod tests {
&message, &message,
); );
assert!(outbound.is_empty()); assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
}
#[test]
fn test_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-9",
"bash",
"等待你完成浏览器授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = OutboundMessage::from_chat_message(
"feishu",
"chat-1",
None,
&HashMap::new(),
&message,
);
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolPending);
} }
} }

View File

@ -73,6 +73,9 @@ pub async fn run(gateway_url: &str) -> Result<(), Box<dyn std::error::Error>> {
WsOutbound::ToolResult { tool_name, content, .. } => { WsOutbound::ToolResult { tool_name, content, .. } => {
input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?; input.write_output(&format!("Tool result: {}\n{}\n", tool_name, content)).await?;
} }
WsOutbound::ToolPending { tool_name, content, resume_hint, .. } => {
input.write_output(&format!("Tool pending: {}\n{}\n{}\n", tool_name, content, resume_hint)).await?;
}
WsOutbound::Error { message, .. } => { WsOutbound::Error { message, .. } => {
input.write_output(&format!("Error: {}", message)).await?; input.write_output(&format!("Error: {}", message)).await?;
} }

View File

@ -19,7 +19,7 @@ use crate::tools::{
}; };
const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n"; const DEFAULT_AGENT_PROMPT: &str = "# PicoBot 代理配置\n\n## 身份\n- 你是 PicoBot一名务实、可靠的通用助理。\n- 你的目标是理解用户当下的真实需求,并给出清晰、可执行的帮助。\n\n## 工作方式\n- 优先理解意图,再给出回应或行动。\n- 保持简洁、准确、自然,不故作热情,也不空泛铺陈。\n- 能直接验证的内容尽量先验证,避免凭空猜测。\n- 当现有工具是完成任务的最直接方式时,优先使用工具。\n- 除非用户明确要求改变方向,否则保持用户原本目标不变。\n\n## 助理原则\n- 优先解决问题,而不是展示过程。\n- 输出要方便用户立即使用,结论尽量明确。\n- 对不确定的地方要直说,不把猜测包装成事实。\n- 复杂任务先收敛重点,简单任务直接给结果。\n- 避免不必要的重复、客套和冗长说明。\n\n## 回复规则\n- 除非用户另有要求,否则使用中文回复。\n- 默认短而清楚,按信息密度组织内容。\n- 如果任务涉及文件、命令、配置或下一步操作,优先给出最关键的那部分。\n- 如果存在限制、风险或前提条件,要直接说明。\n\n## 补充要求\n- 你是 PicoBot。\n- 回答应以帮助用户完成当前目标为中心。\n- 在信息不足时先补关键前提,在信息充分时直接执行。\n";
const MEMORY_KEYWORD_SYSTEM_PROMPT: &str = "你负责为长期记忆检索生成关键词。根据给定的最近会话,仅输出 JSON 数组字符串。关键词必须同时覆盖中文和英文优先为同一主题同时给出中文关键词和对应英文关键词。关键词必须是短词语优先使用最容易命中记忆的核心检索词不要输出完整句子、解释或长描述。必要时优先保留实体名、产品名、偏好名、snake_case key 风格短词。数组元素总数控制在 2 到 6 个简短关键词或短语。不要输出解释,不要输出 Markdown。"; const MEMORY_KEYWORD_SYSTEM_PROMPT: &str = "你负责为长期记忆检索生成关键词。根据给定的最近会话,仅输出 JSON 数组字符串。关键词必须同时覆盖中文和英文优先为同一主题同时给出中文关键词和对应英文关键词。关键词必须是短词语优先使用最容易命中记忆的核心检索词不要输出完整句子、解释或长描述。必要时优先保留实体名、产品名、偏好名、snake_case key 风格短词。数组元素总数控制在 2 到 6 个简短关键词或短语。不要输出解释,不要输出 Markdown。如果最新的会话跟前面有明显主题变化,且新主题可能与旧主题的关键词不同,优先输出新主题的关键词;如果最新会话是旧主题的延续,优先输出旧主题的关键词;如果最新会话同时包含新旧主题,优先输出同时覆盖新旧主题的关键词。";
const RELATED_MEMORY_SYSTEM_PROMPT_PREFIX: &str = "找到相关的记忆。你必须优先参考这些记忆,并在后续推理中把它们当作当前会话的补充上下文;若与用户本轮明确要求冲突,以用户本轮要求为准。"; const RELATED_MEMORY_SYSTEM_PROMPT_PREFIX: &str = "找到相关的记忆。你必须优先参考这些记忆,并在后续推理中把它们当作当前会话的补充上下文;若与用户本轮明确要求冲突,以用户本轮要求为准。";
const MEMORY_KEYWORD_REASONING_EFFORT: &str = "none"; const MEMORY_KEYWORD_REASONING_EFFORT: &str = "none";
const MEMORY_KEYWORD_MAX_CHARS: usize = 32; const MEMORY_KEYWORD_MAX_CHARS: usize = 32;

View File

@ -6,7 +6,7 @@ use axum::response::Response;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use crate::agent::EmittedMessageHandler; use crate::agent::EmittedMessageHandler;
use crate::bus::message::format_tool_call_content; use crate::bus::message::{format_tool_call_content, ToolMessageState};
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::{Session, handle_in_chat_command}}; use super::{GatewayState, session::{Session, handle_in_chat_command}};
@ -181,7 +181,23 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
}] }]
} }
} }
"tool" => Vec::new(), "tool" => match message.tool_state.as_ref().unwrap_or(&ToolMessageState::Completed) {
ToolMessageState::Completed => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
}],
ToolMessageState::PendingUserAction => vec![WsOutbound::ToolPending {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_default(),
tool_name: message.tool_name.clone().unwrap_or_default(),
content: message.content.clone(),
role: message.role.clone(),
resume_hint: "完成外部操作后,直接发一条继续消息即可。".to_string(),
}],
},
_ => Vec::new(), _ => Vec::new(),
} }
} }
@ -391,6 +407,7 @@ async fn handle_inbound(
mod tests { mod tests {
use super::ws_outbound_from_chat_message; use super::ws_outbound_from_chat_message;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::bus::message::ToolMessageState;
use crate::providers::ToolCall; use crate::providers::ToolCall;
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use serde_json::json; use serde_json::json;
@ -421,11 +438,27 @@ mod tests {
} }
#[test] #[test]
fn test_ws_outbound_from_chat_message_omits_tool_results() { fn test_ws_outbound_from_chat_message_includes_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2"); let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message); let outbound = ws_outbound_from_chat_message(&message);
assert!(outbound.is_empty()); assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolResult { .. }));
}
#[test]
fn test_ws_outbound_from_chat_message_includes_tool_pending() {
let message = ChatMessage::tool_with_state(
"call-1",
"bash",
"等待你完成授权后再继续。",
ToolMessageState::PendingUserAction,
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
assert!(matches!(outbound[0], WsOutbound::ToolPending { .. }));
} }
} }

View File

@ -5,6 +5,12 @@
use std::time::Duration; use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolExecutionState {
Completed,
PendingUserAction,
}
/// Events emitted during agent and tool execution. /// Events emitted during agent and tool execution.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ObserverEvent { pub enum ObserverEvent {
@ -60,6 +66,8 @@ pub struct ToolExecutionOutcome {
pub error_reason: Option<String>, pub error_reason: Option<String>,
/// How long the tool took to execute. /// How long the tool took to execute.
pub duration: Duration, pub duration: Duration,
/// Whether the tool completed or is waiting for external user action.
pub state: ToolExecutionState,
} }
impl ToolExecutionOutcome { impl ToolExecutionOutcome {
@ -70,6 +78,7 @@ impl ToolExecutionOutcome {
success: true, success: true,
error_reason: None, error_reason: None,
duration: Duration::ZERO, duration: Duration::ZERO,
state: ToolExecutionState::Completed,
} }
} }
@ -80,6 +89,18 @@ impl ToolExecutionOutcome {
success: true, success: true,
error_reason: None, error_reason: None,
duration, duration,
state: ToolExecutionState::Completed,
}
}
/// Create a pending outcome with zero duration.
pub fn pending(output: String) -> Self {
Self {
output,
success: true,
error_reason: None,
duration: Duration::ZERO,
state: ToolExecutionState::PendingUserAction,
} }
} }
@ -90,6 +111,7 @@ impl ToolExecutionOutcome {
success: false, success: false,
error_reason, error_reason,
duration: Duration::ZERO, duration: Duration::ZERO,
state: ToolExecutionState::Completed,
} }
} }
@ -100,6 +122,7 @@ impl ToolExecutionOutcome {
success: false, success: false,
error_reason, error_reason,
duration, duration,
state: ToolExecutionState::Completed,
} }
} }
} }
@ -201,6 +224,7 @@ mod tests {
assert_eq!(outcome.output, "output content"); assert_eq!(outcome.output, "output content");
assert!(outcome.error_reason.is_none()); assert!(outcome.error_reason.is_none());
assert_eq!(outcome.duration, Duration::ZERO); assert_eq!(outcome.duration, Duration::ZERO);
assert_eq!(outcome.state, ToolExecutionState::Completed);
} }
#[test] #[test]
@ -211,6 +235,7 @@ mod tests {
); );
assert!(outcome.success); assert!(outcome.success);
assert_eq!(outcome.duration, Duration::from_millis(100)); assert_eq!(outcome.duration, Duration::from_millis(100));
assert_eq!(outcome.state, ToolExecutionState::Completed);
} }
#[test] #[test]
@ -223,6 +248,14 @@ mod tests {
assert_eq!(outcome.output, "error output"); assert_eq!(outcome.output, "error output");
assert_eq!(outcome.error_reason, Some("error reason".to_string())); assert_eq!(outcome.error_reason, Some("error reason".to_string()));
assert_eq!(outcome.duration, Duration::ZERO); assert_eq!(outcome.duration, Duration::ZERO);
assert_eq!(outcome.state, ToolExecutionState::Completed);
}
#[test]
fn test_tool_execution_outcome_pending() {
let outcome = ToolExecutionOutcome::pending("waiting".to_string());
assert!(outcome.success);
assert_eq!(outcome.state, ToolExecutionState::PendingUserAction);
} }
#[test] #[test]

View File

@ -88,6 +88,15 @@ pub enum WsOutbound {
content: String, content: String,
role: String, role: String,
}, },
#[serde(rename = "tool_pending")]
ToolPending {
id: String,
tool_call_id: String,
tool_name: String,
content: String,
role: String,
resume_hint: String,
},
#[serde(rename = "error")] #[serde(rename = "error")]
Error { code: String, message: String }, Error { code: String, message: String },
#[serde(rename = "session_established")] #[serde(rename = "session_established")]

View File

@ -1030,6 +1030,7 @@ fn load_messages_after(
timestamp: row.get(4)?, timestamp: row.get(4)?,
tool_call_id: row.get(5)?, tool_call_id: row.get(5)?,
tool_name: row.get(6)?, tool_name: row.get(6)?,
tool_state: None,
tool_calls, tool_calls,
}) })
})?; })?;

View File

@ -1,17 +1,21 @@
use std::path::Path; use std::path::Path;
use std::process::Stdio; use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::json; use serde_json::json;
use tokio::io::AsyncReadExt; use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, BufReader};
use tokio::process::Command; use tokio::process::Command;
use tokio::time::timeout; use tokio::sync::{Mutex, mpsc};
use tokio::time::{Instant, sleep_until};
use crate::tools::traits::{Tool, ToolResult}; use crate::tools::traits::{Tool, ToolResult};
const MAX_TIMEOUT_SECS: u64 = 600; const MAX_TIMEOUT_SECS: u64 = 600;
const MAX_OUTPUT_CHARS: usize = 50_000; const MAX_OUTPUT_CHARS: usize = 50_000;
const PENDING_USER_ACTION_MARKER: &str = "__PICOBOT_PENDING_USER_ACTION__";
const USER_ACTION_HINT: &str = "该命令正在等待你完成外部操作。完成后请告诉我继续,或重新运行后续检查命令。";
pub struct BashTool { pub struct BashTool {
timeout_secs: u64, timeout_secs: u64,
@ -79,6 +83,49 @@ impl BashTool {
tail tail
) )
} }
fn pending_output(&self, output: &str) -> String {
format!(
"{}\n{}\n\n{}",
PENDING_USER_ACTION_MARKER,
USER_ACTION_HINT,
self.truncate_output(output.trim())
)
}
fn should_return_pending(&self, interactive: bool, output: &str) -> bool {
let normalized = output.to_lowercase();
let has_auth_phrase = [
"等待用户授权",
"等待授权",
"等待你授权",
"在浏览器中打开以下链接进行认证",
"open the following link",
"waiting for authorization",
"waiting for user authorization",
"waiting for approval",
"device/verify",
"user_code=",
]
.iter()
.any(|pattern| normalized.contains(pattern));
has_auth_phrase || (interactive && !output.trim().is_empty())
}
}
async fn drain_available_chunks(
rx: &mut mpsc::UnboundedReceiver<(bool, String)>,
stdout_buf: &Arc<Mutex<String>>,
stderr_buf: &Arc<Mutex<String>>,
) {
while let Ok((is_stderr, chunk)) = rx.try_recv() {
if is_stderr {
stderr_buf.lock().await.push_str(&chunk);
} else {
stdout_buf.lock().await.push_str(&chunk);
}
}
} }
impl Default for BashTool { impl Default for BashTool {
@ -110,6 +157,10 @@ impl Tool for BashTool {
"description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS), "description": format!("Timeout in seconds (default {}, max {})", self.timeout_secs, MAX_TIMEOUT_SECS),
"minimum": 1, "minimum": 1,
"maximum": MAX_TIMEOUT_SECS "maximum": MAX_TIMEOUT_SECS
},
"interactive": {
"type": "boolean",
"description": "Whether this command may enter a wait-for-user-action flow such as browser/device authentication"
} }
}, },
"required": ["command"] "required": ["command"]
@ -146,6 +197,10 @@ impl Tool for BashTool {
.and_then(|v| v.as_u64()) .and_then(|v| v.as_u64())
.unwrap_or(self.timeout_secs) .unwrap_or(self.timeout_secs)
.min(MAX_TIMEOUT_SECS); .min(MAX_TIMEOUT_SECS);
let interactive = args
.get("interactive")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let cwd = self let cwd = self
.working_dir .working_dir
@ -153,37 +208,29 @@ impl Tool for BashTool {
.map(|d| Path::new(d)) .map(|d| Path::new(d))
.unwrap_or_else(|| Path::new(".")); .unwrap_or_else(|| Path::new("."));
let result = timeout( match self.run_command(command, cwd, timeout_secs, interactive).await {
Duration::from_secs(timeout_secs), Ok(output) => Ok(ToolResult {
self.run_command(command, cwd),
)
.await;
match result {
Ok(Ok(output)) => Ok(ToolResult {
success: true, success: true,
output, output,
error: None, error: None,
}), }),
Ok(Err(e)) => Ok(ToolResult { Err(e) => Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(e), error: Some(e),
}), }),
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Command timed out after {} seconds",
timeout_secs
)),
}),
} }
} }
} }
impl BashTool { impl BashTool {
async fn run_command(&self, command: &str, cwd: &Path) -> Result<String, String> { async fn run_command(
&self,
command: &str,
cwd: &Path,
timeout_secs: u64,
interactive: bool,
) -> Result<String, String> {
let mut cmd = Command::new("bash"); let mut cmd = Command::new("bash");
cmd.args(["-c", command]) cmd.args(["-c", command])
.stdout(Stdio::piped()) .stdout(Stdio::piped())
@ -192,50 +239,114 @@ impl BashTool {
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?; let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn: {}", e))?;
let mut stdout = Vec::new(); let stdout = child.stdout.take();
let mut stderr = Vec::new(); let stderr = child.stderr.take();
let (tx, mut rx) = mpsc::unbounded_channel::<(bool, String)>();
if let Some(ref mut out) = child.stdout { if let Some(stdout) = stdout {
out.read_to_end(&mut stdout) tokio::spawn(read_stream(stdout, false, tx.clone()));
.await
.map_err(|e| format!("Failed to read stdout: {}", e))?;
} }
if let Some(stderr) = stderr {
if let Some(ref mut err) = child.stderr { tokio::spawn(read_stream(stderr, true, tx.clone()));
err.read_to_end(&mut stderr)
.await
.map_err(|e| format!("Failed to read stderr: {}", e))?;
} }
drop(tx);
let status = child let stdout_buf = Arc::new(Mutex::new(String::new()));
.wait() let stderr_buf = Arc::new(Mutex::new(String::new()));
.await let deadline = Instant::now() + Duration::from_secs(timeout_secs);
.map_err(|e| format!("Failed to wait: {}", e))?;
let mut output = String::new(); loop {
tokio::select! {
if !stdout.is_empty() { status = child.wait() => {
let stdout_str = String::from_utf8_lossy(&stdout); let status = status.map_err(|e| format!("Failed to wait: {}", e))?;
output.push_str(&stdout_str); drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
} while let Some((is_stderr, chunk)) = rx.recv().await {
if is_stderr {
if !stderr.is_empty() { stderr_buf.lock().await.push_str(&chunk);
let stderr_str = String::from_utf8_lossy(&stderr); } else {
if !stderr_str.trim().is_empty() { stdout_buf.lock().await.push_str(&chunk);
if !output.is_empty() { }
output.push_str("\n"); }
let output = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, Some(status.code().unwrap_or(-1)));
return Ok(self.truncate_output(&output));
}
Some((is_stderr, chunk)) = rx.recv() => {
if is_stderr {
stderr_buf.lock().await.push_str(&chunk);
} else {
stdout_buf.lock().await.push_str(&chunk);
}
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
if self.should_return_pending(interactive, &combined) {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let _ = child.start_kill();
let _ = child.wait().await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
return Ok(self.pending_output(&combined));
}
}
_ = sleep_until(deadline) => {
drain_available_chunks(&mut rx, &stdout_buf, &stderr_buf).await;
let combined = format_command_output(&stdout_buf.lock().await, &stderr_buf.lock().await, None);
let _ = child.start_kill();
let _ = child.wait().await;
if self.should_return_pending(interactive, &combined) {
return Ok(self.pending_output(&combined));
}
return Err(format!("Command timed out after {} seconds", timeout_secs));
} }
output.push_str("STDERR:\n");
output.push_str(&stderr_str);
} }
} }
output.push_str(&format!("\nExit code: {}", status.code().unwrap_or(-1)));
Ok(self.truncate_output(&output))
} }
} }
async fn read_stream<R>(stream: R, is_stderr: bool, tx: mpsc::UnboundedSender<(bool, String)>)
where
R: AsyncRead + Unpin + Send + 'static,
{
let mut reader = BufReader::new(stream);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break,
Ok(_) => {
let _ = tx.send((is_stderr, line.clone()));
}
Err(_) => break,
}
}
let mut remainder = String::new();
if reader.read_to_string(&mut remainder).await.is_ok() && !remainder.is_empty() {
let _ = tx.send((is_stderr, remainder));
}
}
fn format_command_output(stdout: &str, stderr: &str, exit_code: Option<i32>) -> String {
let mut output = String::new();
if !stdout.is_empty() {
output.push_str(stdout);
}
if !stderr.trim().is_empty() {
if !output.is_empty() {
output.push_str("\n");
}
output.push_str("STDERR:\n");
output.push_str(stderr);
}
if let Some(code) = exit_code {
output.push_str(&format!("\nExit code: {}", code));
}
output
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -319,6 +430,22 @@ mod tests {
assert!(result.error.unwrap().contains("timed out")); assert!(result.error.unwrap().contains("timed out"));
} }
#[tokio::test]
async fn test_pending_user_action_detection() {
let tool = BashTool::new();
let result = tool
.execute(json!({
"command": "printf '在浏览器中打开以下链接进行认证:\n\nhttps://example.com/device/verify\n\n等待用户授权...\n'; sleep 10",
"timeout": 1
}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains(PENDING_USER_ACTION_MARKER));
assert!(result.output.contains("等待用户授权"));
}
#[test] #[test]
fn test_truncate_output_handles_utf8_char_boundaries() { fn test_truncate_output_handles_utf8_char_boundaries() {
let tool = BashTool::new(); let tool = BashTool::new();