Compare commits

..

No commits in common. "4725b5406e7af5a32739e85935a5f6e0ba70eecb" and "d35e89a44c17c1ae21488ad30733ddb6df101569" have entirely different histories.

7 changed files with 56 additions and 258 deletions

View File

@ -1,4 +1,3 @@
use async_trait::async_trait;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
@ -228,7 +227,6 @@ pub struct AgentLoop {
skill_event_session_id: Option<String>, skill_event_session_id: Option<String>,
tool_context: ToolContext, tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>, observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
max_iterations: usize, max_iterations: usize,
} }
@ -238,11 +236,6 @@ pub struct AgentProcessResult {
pub emitted_messages: Vec<ChatMessage>, pub emitted_messages: Vec<ChatMessage>,
} }
#[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
}
impl AgentLoop { impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> { pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations; let max_iterations = provider_config.max_tool_iterations;
@ -257,7 +250,6 @@ impl AgentLoop {
skill_event_session_id: None, skill_event_session_id: None,
tool_context: ToolContext::default(), tool_context: ToolContext::default(),
observer: None, observer: None,
emitted_message_handler: None,
max_iterations, max_iterations,
}) })
} }
@ -275,7 +267,6 @@ impl AgentLoop {
skill_event_session_id: None, skill_event_session_id: None,
tool_context: ToolContext::default(), tool_context: ToolContext::default(),
observer: None, observer: None,
emitted_message_handler: None,
max_iterations, max_iterations,
}) })
} }
@ -297,7 +288,6 @@ impl AgentLoop {
skill_event_session_id: None, skill_event_session_id: None,
tool_context: ToolContext::default(), tool_context: ToolContext::default(),
observer: None, observer: None,
emitted_message_handler: None,
max_iterations, max_iterations,
}) })
} }
@ -319,11 +309,6 @@ impl AgentLoop {
self self
} }
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
self.emitted_message_handler = Some(handler);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> { pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools &self.tools
} }
@ -403,7 +388,6 @@ impl AgentLoop {
); );
messages.push(assistant_message.clone()); messages.push(assistant_message.clone());
emitted_messages.push(assistant_message); emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
// Execute tools and add results to messages // Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await; let tool_results = self.execute_tools(&response.tool_calls).await;
@ -503,16 +487,6 @@ impl AgentLoop {
} }
} }
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
if !message.is_assistant_tool_call_message() {
return;
}
if let Some(handler) = &self.emitted_message_handler {
handler.handle(message).await;
}
}
/// Determine whether to execute tools in parallel or sequentially. /// Determine whether to execute tools in parallel or sequentially.
/// ///
/// Returns true if: /// Returns true if:

View File

@ -1,5 +1,5 @@
pub mod agent_loop; pub mod agent_loop;
pub mod context_compressor; pub mod context_compressor;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler}; pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
pub use context_compressor::ContextCompressor; pub use context_compressor::ContextCompressor;

View File

@ -153,15 +153,6 @@ impl ChatMessage {
tool_calls: None, tool_calls: None,
} }
} }
pub fn is_assistant_tool_call_message(&self) -> bool {
self.role == "assistant"
&& self
.tool_calls
.as_ref()
.map(|calls| !calls.is_empty())
.unwrap_or(false)
}
} }
// ============================================================================ // ============================================================================
@ -328,39 +319,36 @@ impl OutboundMessage {
)] )]
} }
} }
"tool" => Vec::new(), "tool" => vec![Self::tool_result(
channel.to_string(),
chat_id.to_string(),
message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
message.content.clone(),
reply_to,
metadata.clone(),
)],
_ => Vec::new(), _ => Vec::new(),
} }
} }
} }
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String { fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
let mut lines = vec![format!("### {}", tool_name)]; format!(
"调用工具: {}\n\n输入参数:\n{}",
match tool_arguments { tool_name,
serde_json::Value::Object(map) if !map.is_empty() => { format_json_value(tool_arguments),
let mut entries: Vec<_> = map.iter().collect(); )
entries.sort_by(|(left, _), (right, _)| left.cmp(right));
for (key, value) in entries {
lines.push(format!("- {}: {}", key, format_tool_argument_value(value)));
}
}
serde_json::Value::Object(_) => {}
other => lines.push(format!("- args: {}", format_tool_argument_value(other))),
}
lines.join("\n")
} }
fn format_tool_result_content(tool_name: &str, content: &str) -> String { fn format_tool_result_content(tool_name: &str, content: &str) -> String {
format!("工具结果: {}\n\n{}", tool_name, content) format!("工具结果: {}\n\n{}", tool_name, content)
} }
fn format_tool_argument_value(value: &serde_json::Value) -> String { fn format_json_value(value: &serde_json::Value) -> String {
match value { match value {
serde_json::Value::String(text) => text.clone(), serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(),
serde_json::Value::Null => "null".to_string(), other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()),
} }
} }
@ -412,13 +400,11 @@ mod tests {
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall); assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator")); assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1"); assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
assert_eq!(outbound[0].content, "### calculator\n- expression: 1 + 1");
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read")); assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
assert_eq!(outbound[1].content, "### file_read\n- path: README.md");
} }
#[test] #[test]
fn test_from_chat_message_omits_tool_result() { fn test_from_chat_message_maps_tool_result() {
let message = ChatMessage::tool("call-9", "calculator", "2"); let message = ChatMessage::tool("call-9", "calculator", "2");
let outbound = OutboundMessage::from_chat_message( let outbound = OutboundMessage::from_chat_message(
@ -429,6 +415,10 @@ mod tests {
&message, &message,
); );
assert!(outbound.is_empty()); assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
assert_eq!(outbound[0].tool_call_id.as_deref(), Some("call-9"));
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert!(outbound[0].content.contains("工具结果: calculator"));
} }
} }

View File

@ -11,7 +11,7 @@ use crate::channels::ChannelManager;
use crate::config::Config; use crate::config::Config;
use crate::logging; use crate::logging;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use session::{BusToolCallEmitter, SessionManager}; use session::SessionManager;
pub struct GatewayState { pub struct GatewayState {
pub config: Config, pub config: Config,
@ -74,19 +74,12 @@ impl GatewayState {
} }
// Process via session manager // Process via session manager
let live_emitter = Arc::new(BusToolCallEmitter::new(
bus_for_inbound.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
));
match session_manager.handle_message( match session_manager.handle_message(
&inbound.channel, &inbound.channel,
&inbound.sender_id, &inbound.sender_id,
&inbound.chat_id, &inbound.chat_id,
&inbound.content, &inbound.content,
inbound.media, inbound.media,
Some(live_emitter),
).await { ).await {
Ok(outbound_messages) => { Ok(outbound_messages) => {
// Forward channel-specific metadata from inbound to outbound. // Forward channel-specific metadata from inbound to outbound.

View File

@ -1,12 +1,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::{Mutex, mpsc}; use tokio::sync::{Mutex, mpsc};
use uuid::Uuid; use uuid::Uuid;
use crate::bus::{ChatMessage, MessageBus, OutboundMessage}; use crate::bus::{ChatMessage, OutboundMessage};
use crate::config::LLMProviderConfig; use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler}; use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::protocol::WsOutbound; use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime; use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id}; use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
@ -31,46 +30,6 @@ pub struct Session {
store: Arc<SessionStore>, store: Arc<SessionStore>,
} }
pub struct BusToolCallEmitter {
bus: Arc<MessageBus>,
channel_name: String,
chat_id: String,
metadata: HashMap<String, String>,
}
impl BusToolCallEmitter {
pub fn new(
bus: Arc<MessageBus>,
channel_name: impl Into<String>,
chat_id: impl Into<String>,
metadata: HashMap<String, String>,
) -> Self {
Self {
bus,
channel_name: channel_name.into(),
chat_id: chat_id.into(),
metadata,
}
}
}
#[async_trait]
impl EmittedMessageHandler for BusToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in OutboundMessage::from_chat_message(
&self.channel_name,
&self.chat_id,
None,
&self.metadata,
&message,
) {
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
}
}
}
}
impl Session { impl Session {
pub async fn new( pub async fn new(
channel_name: String, channel_name: String,
@ -478,7 +437,6 @@ impl SessionManager {
chat_id: &str, chat_id: &str,
content: &str, content: &str,
media: Vec<crate::bus::MediaItem>, media: Vec<crate::bus::MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> { ) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@ -544,10 +502,7 @@ impl SessionManager {
session_guard.record_skill_offer(chat_id)?; session_guard.record_skill_offer(chat_id)?;
// 创建 agent 并处理 // 创建 agent 并处理
let mut agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?; let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
if let Some(handler) = live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
let result = agent.process(history).await?; let result = agent.process(history).await?;
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复 // 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
@ -556,7 +511,6 @@ impl SessionManager {
result result
.emitted_messages .emitted_messages
.iter() .iter()
.filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none())
.flat_map(|message| { .flat_map(|message| {
OutboundMessage::from_chat_message( OutboundMessage::from_chat_message(
channel_name, channel_name,

View File

@ -1,29 +1,13 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage}; use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State; use axum::extract::State;
use axum::response::Response; use axum::response::Response;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{mpsc, Mutex};
use crate::agent::EmittedMessageHandler;
use crate::bus::message::format_tool_call_content;
use crate::bus::ChatMessage; use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound}; use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::{Session, handle_in_chat_command}}; use super::{GatewayState, session::{Session, handle_in_chat_command}};
struct WsToolCallEmitter {
sender: mpsc::Sender<WsOutbound>,
}
#[async_trait]
impl EmittedMessageHandler for WsToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in ws_outbound_from_chat_message(&message) {
let _ = self.sender.send(outbound).await;
}
}
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response { pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async { ws.on_upgrade(|socket| async {
handle_socket(socket, state).await; handle_socket(socket, state).await;
@ -157,6 +141,10 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
} }
} }
fn format_tool_arguments(arguments: &serde_json::Value) -> String {
serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string())
}
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> { fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() { match message.role.as_str() {
"assistant" => { "assistant" => {
@ -168,7 +156,11 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
tool_call_id: tool_call.id.clone(), tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(), tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(), arguments: tool_call.arguments.clone(),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments), content: format!(
"调用工具: {}\n\n输入参数:\n{}",
tool_call.name,
format_tool_arguments(&tool_call.arguments),
),
role: message.role.clone(), role: message.role.clone(),
}) })
.collect() .collect()
@ -180,7 +172,17 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
}] }]
} }
} }
"tool" => Vec::new(), "tool" => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
tool_name: message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
content: format!(
"工具结果: {}\n\n{}",
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
message.content,
),
role: message.role.clone(),
}],
_ => Vec::new(), _ => Vec::new(),
} }
} }
@ -229,19 +231,13 @@ async fn handle_inbound(
session_guard.record_skill_offer(&chat_id)?; session_guard.record_skill_offer(&chat_id)?;
let live_emitter = Arc::new(WsToolCallEmitter { let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?;
sender: session_guard.user_tx.clone(),
});
let agent = session_guard
.create_agent(&chat_id, None, Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
match agent.process(history).await { match agent.process(history).await {
Ok(result) => { Ok(result) => {
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?; session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result for outbound in result
.emitted_messages .emitted_messages
.iter() .iter()
.filter(|message| !message.is_assistant_tool_call_message())
.flat_map(ws_outbound_from_chat_message) .flat_map(ws_outbound_from_chat_message)
{ {
let _ = session_guard.send(outbound).await; let _ = session_guard.send(outbound).await;
@ -383,46 +379,3 @@ async fn handle_inbound(
} }
} }
} }
#[cfg(test)]
mod tests {
use super::ws_outbound_from_chat_message;
use crate::bus::ChatMessage;
use crate::providers::ToolCall;
use crate::protocol::WsOutbound;
use serde_json::json;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "### calculator\n- expression: 1 + 1");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_from_chat_message_omits_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert!(outbound.is_empty());
}
}

View File

@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
} }
fn description(&self) -> &str { fn description(&self) -> &str {
"Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information by keyword. Memories are scoped to the current channel and sender, and record the originating session/message when available." "Manage user memories stored in SQLite. Supports actions: list, get, put, update, delete. Memories are scoped to the current channel and sender, and record the originating session/message when available."
} }
fn parameters_schema(&self) -> serde_json::Value { fn parameters_schema(&self) -> serde_json::Value {
@ -32,20 +32,16 @@ impl Tool for MemoryManageTool {
"properties": { "properties": {
"action": { "action": {
"type": "string", "type": "string",
"enum": ["list", "search", "get", "put", "update", "delete"], "enum": ["list", "get", "put", "update", "delete"],
"description": "Management action to perform. Prefer 'search' for keyword lookup across stored memories, 'get' for an exact namespace/key lookup, and 'list' for browsing recent memories." "description": "Management action to perform"
}, },
"namespace": { "namespace": {
"type": "string", "type": "string",
"description": "Optional memory namespace filter, such as profile, preferences, or tasks" "description": "Memory namespace, such as profile, preferences, or tasks"
},
"query": {
"type": "string",
"description": "Keyword query for full-text memory search, such as a preference, fact, name, topic, or prior decision"
}, },
"key": { "key": {
"type": "string", "type": "string",
"description": "Exact memory key within the namespace" "description": "Memory key within the namespace"
}, },
"content": { "content": {
"type": "string", "type": "string",
@ -53,7 +49,7 @@ impl Tool for MemoryManageTool {
}, },
"limit": { "limit": {
"type": "integer", "type": "integer",
"description": "Maximum number of memories to return", "description": "Maximum number of memories to list",
"minimum": 1, "minimum": 1,
"default": 20 "default": 20
} }
@ -82,7 +78,6 @@ impl Tool for MemoryManageTool {
}; };
let namespace = args.get("namespace").and_then(|value| value.as_str()); let namespace = args.get("namespace").and_then(|value| value.as_str());
let query = args.get("query").and_then(|value| value.as_str());
let key = args.get("key").and_then(|value| value.as_str()); let key = args.get("key").and_then(|value| value.as_str());
let payload = match action { let payload = match action {
@ -99,24 +94,6 @@ impl Tool for MemoryManageTool {
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>() "memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
}) })
} }
"search" => {
let query = match query {
Some(query) if !query.trim().is_empty() => query,
_ => return Ok(error_result("Missing required parameter: query")),
};
let limit = args
.get("limit")
.and_then(|value| value.as_u64())
.unwrap_or(20) as usize;
let memories = self
.store
.search_memories("user", &scope_key, query, namespace, limit)?;
json!({
"query": query,
"count": memories.len(),
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
})
}
"get" => { "get" => {
let namespace = match namespace { let namespace = match namespace {
Some(namespace) => namespace, Some(namespace) => namespace,
@ -315,49 +292,6 @@ mod tests {
assert!(get.output.contains("msg-1")); assert!(get.output.contains("msg-1"));
} }
#[tokio::test]
async fn test_memory_manage_search() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = MemoryManageTool::new(store);
let context = ToolContext {
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
chat_id: Some("chat-1".to_string()),
session_id: Some("feishu:chat-1".to_string()),
message_id: Some("msg-1".to_string()),
message_seq: Some(1),
};
let put = tool
.execute_with_context(
&context,
json!({
"action": "put",
"namespace": "profile",
"key": "editor",
"content": "Prefers rust-analyzer over clippy hints"
}),
)
.await
.unwrap();
assert!(put.success);
let search = tool
.execute_with_context(
&context,
json!({
"action": "search",
"query": "rust-analyzer",
"limit": 5
}),
)
.await
.unwrap();
assert!(search.success);
assert!(search.output.contains("rust-analyzer"));
assert!(search.output.contains("editor"));
}
#[tokio::test] #[tokio::test]
async fn test_memory_manage_requires_context() { async fn test_memory_manage_requires_context() {
let store = Arc::new(SessionStore::in_memory().unwrap()); let store = Arc::new(SessionStore::in_memory().unwrap());