Compare commits

...

3 Commits

7 changed files with 258 additions and 56 deletions

View File

@ -1,3 +1,4 @@
use async_trait::async_trait;
use crate::bus::message::ContentBlock;
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
@ -227,6 +228,7 @@ pub struct AgentLoop {
skill_event_session_id: Option<String>,
tool_context: ToolContext,
observer: Option<Arc<dyn Observer>>,
emitted_message_handler: Option<Arc<dyn EmittedMessageHandler>>,
max_iterations: usize,
}
@ -236,6 +238,11 @@ pub struct AgentProcessResult {
pub emitted_messages: Vec<ChatMessage>,
}
#[async_trait]
pub trait EmittedMessageHandler: Send + Sync + 'static {
async fn handle(&self, message: ChatMessage);
}
impl AgentLoop {
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
let max_iterations = provider_config.max_tool_iterations;
@ -250,6 +257,7 @@ impl AgentLoop {
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
@ -267,6 +275,7 @@ impl AgentLoop {
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
@ -288,6 +297,7 @@ impl AgentLoop {
skill_event_session_id: None,
tool_context: ToolContext::default(),
observer: None,
emitted_message_handler: None,
max_iterations,
})
}
@ -309,6 +319,11 @@ impl AgentLoop {
self
}
pub fn with_emitted_message_handler(mut self, handler: Arc<dyn EmittedMessageHandler>) -> Self {
self.emitted_message_handler = Some(handler);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
@ -388,6 +403,7 @@ impl AgentLoop {
);
messages.push(assistant_message.clone());
emitted_messages.push(assistant_message);
self.emit_live_tool_call_message(emitted_messages.last().expect("assistant message just pushed").clone()).await;
// Execute tools and add results to messages
let tool_results = self.execute_tools(&response.tool_calls).await;
@ -487,6 +503,16 @@ impl AgentLoop {
}
}
async fn emit_live_tool_call_message(&self, message: ChatMessage) {
if !message.is_assistant_tool_call_message() {
return;
}
if let Some(handler) = &self.emitted_message_handler {
handler.handle(message).await;
}
}
/// Determine whether to execute tools in parallel or sequentially.
///
/// Returns true if:

View File

@ -1,5 +1,5 @@
pub mod agent_loop;
pub mod context_compressor;
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult, EmittedMessageHandler};
pub use context_compressor::ContextCompressor;

View File

@ -153,6 +153,15 @@ impl ChatMessage {
tool_calls: None,
}
}
pub fn is_assistant_tool_call_message(&self) -> bool {
self.role == "assistant"
&& self
.tool_calls
.as_ref()
.map(|calls| !calls.is_empty())
.unwrap_or(false)
}
}
// ============================================================================
@ -319,36 +328,39 @@ impl OutboundMessage {
)]
}
}
"tool" => vec![Self::tool_result(
channel.to_string(),
chat_id.to_string(),
message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
message.content.clone(),
reply_to,
metadata.clone(),
)],
"tool" => Vec::new(),
_ => Vec::new(),
}
}
}
fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
format!(
"调用工具: {}\n\n输入参数:\n{}",
tool_name,
format_json_value(tool_arguments),
)
pub(crate) fn format_tool_call_content(tool_name: &str, tool_arguments: &serde_json::Value) -> String {
let mut lines = vec![format!("### {}", tool_name)];
match tool_arguments {
serde_json::Value::Object(map) if !map.is_empty() => {
let mut entries: Vec<_> = map.iter().collect();
entries.sort_by(|(left, _), (right, _)| left.cmp(right));
for (key, value) in entries {
lines.push(format!("- {}: {}", key, format_tool_argument_value(value)));
}
}
serde_json::Value::Object(_) => {}
other => lines.push(format!("- args: {}", format_tool_argument_value(other))),
}
lines.join("\n")
}
fn format_tool_result_content(tool_name: &str, content: &str) -> String {
format!("工具结果: {}\n\n{}", tool_name, content)
}
fn format_json_value(value: &serde_json::Value) -> String {
fn format_tool_argument_value(value: &serde_json::Value) -> String {
match value {
serde_json::Value::Object(map) if map.is_empty() => "{}".to_string(),
other => serde_json::to_string_pretty(other).unwrap_or_else(|_| other.to_string()),
serde_json::Value::String(text) => text.clone(),
serde_json::Value::Null => "null".to_string(),
other => serde_json::to_string(other).unwrap_or_else(|_| other.to_string()),
}
}
@ -400,11 +412,13 @@ mod tests {
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolCall);
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert_eq!(outbound[0].tool_arguments.as_ref().unwrap()["expression"], "1 + 1");
assert_eq!(outbound[0].content, "### calculator\n- expression: 1 + 1");
assert_eq!(outbound[1].tool_name.as_deref(), Some("file_read"));
assert_eq!(outbound[1].content, "### file_read\n- path: README.md");
}
#[test]
fn test_from_chat_message_maps_tool_result() {
fn test_from_chat_message_omits_tool_result() {
let message = ChatMessage::tool("call-9", "calculator", "2");
let outbound = OutboundMessage::from_chat_message(
@ -415,10 +429,6 @@ mod tests {
&message,
);
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].event_kind, OutboundEventKind::ToolResult);
assert_eq!(outbound[0].tool_call_id.as_deref(), Some("call-9"));
assert_eq!(outbound[0].tool_name.as_deref(), Some("calculator"));
assert!(outbound[0].content.contains("工具结果: calculator"));
assert!(outbound.is_empty());
}
}

View File

@ -11,7 +11,7 @@ use crate::channels::ChannelManager;
use crate::config::Config;
use crate::logging;
use crate::skills::SkillRuntime;
use session::SessionManager;
use session::{BusToolCallEmitter, SessionManager};
pub struct GatewayState {
pub config: Config,
@ -74,12 +74,19 @@ impl GatewayState {
}
// Process via session manager
let live_emitter = Arc::new(BusToolCallEmitter::new(
bus_for_inbound.clone(),
inbound.channel.clone(),
inbound.chat_id.clone(),
inbound.forwarded_metadata.clone(),
));
match session_manager.handle_message(
&inbound.channel,
&inbound.sender_id,
&inbound.chat_id,
&inbound.content,
inbound.media,
Some(live_emitter),
).await {
Ok(outbound_messages) => {
// Forward channel-specific metadata from inbound to outbound.

View File

@ -1,11 +1,12 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::{ChatMessage, OutboundMessage};
use crate::bus::{ChatMessage, MessageBus, OutboundMessage};
use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::agent::{AgentLoop, AgentError, ContextCompressor, EmittedMessageHandler};
use crate::protocol::WsOutbound;
use crate::skills::SkillRuntime;
use crate::storage::{SessionRecord, SessionStore, persistent_session_id};
@ -30,6 +31,46 @@ pub struct Session {
store: Arc<SessionStore>,
}
pub struct BusToolCallEmitter {
bus: Arc<MessageBus>,
channel_name: String,
chat_id: String,
metadata: HashMap<String, String>,
}
impl BusToolCallEmitter {
pub fn new(
bus: Arc<MessageBus>,
channel_name: impl Into<String>,
chat_id: impl Into<String>,
metadata: HashMap<String, String>,
) -> Self {
Self {
bus,
channel_name: channel_name.into(),
chat_id: chat_id.into(),
metadata,
}
}
}
#[async_trait]
impl EmittedMessageHandler for BusToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in OutboundMessage::from_chat_message(
&self.channel_name,
&self.chat_id,
None,
&self.metadata,
&message,
) {
if let Err(error) = self.bus.publish_outbound(outbound).await {
tracing::error!(error = %error, channel = %self.channel_name, chat_id = %self.chat_id, "Failed to publish live outbound tool call");
}
}
}
}
impl Session {
pub async fn new(
channel_name: String,
@ -437,6 +478,7 @@ impl SessionManager {
chat_id: &str,
content: &str,
media: Vec<crate::bus::MediaItem>,
live_emitter: Option<Arc<dyn EmittedMessageHandler>>,
) -> Result<Vec<OutboundMessage>, AgentError> {
#[cfg(debug_assertions)]
{
@ -502,7 +544,10 @@ impl SessionManager {
session_guard.record_skill_offer(chat_id)?;
// 创建 agent 并处理
let agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
let mut agent = session_guard.create_agent(chat_id, Some(sender_id), Some(&user_message_id))?;
if let Some(handler) = live_emitter.clone() {
agent = agent.with_emitted_message_handler(handler);
}
let result = agent.process(history).await?;
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
@ -511,6 +556,7 @@ impl SessionManager {
result
.emitted_messages
.iter()
.filter(|message| !message.is_assistant_tool_call_message() || live_emitter.is_none())
.flat_map(|message| {
OutboundMessage::from_chat_message(
channel_name,

View File

@ -1,13 +1,29 @@
use std::sync::Arc;
use async_trait::async_trait;
use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message as WsMessage};
use axum::extract::State;
use axum::response::Response;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::{mpsc, Mutex};
use crate::agent::EmittedMessageHandler;
use crate::bus::message::format_tool_call_content;
use crate::bus::ChatMessage;
use crate::protocol::{parse_inbound, serialize_outbound, SessionSummary, WsInbound, WsOutbound};
use super::{GatewayState, session::{Session, handle_in_chat_command}};
struct WsToolCallEmitter {
sender: mpsc::Sender<WsOutbound>,
}
#[async_trait]
impl EmittedMessageHandler for WsToolCallEmitter {
async fn handle(&self, message: ChatMessage) {
for outbound in ws_outbound_from_chat_message(&message) {
let _ = self.sender.send(outbound).await;
}
}
}
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<GatewayState>>) -> Response {
ws.on_upgrade(|socket| async {
handle_socket(socket, state).await;
@ -141,10 +157,6 @@ fn to_session_summary(record: crate::storage::SessionRecord) -> SessionSummary {
}
}
fn format_tool_arguments(arguments: &serde_json::Value) -> String {
serde_json::to_string_pretty(arguments).unwrap_or_else(|_| arguments.to_string())
}
fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
match message.role.as_str() {
"assistant" => {
@ -156,11 +168,7 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
tool_call_id: tool_call.id.clone(),
tool_name: tool_call.name.clone(),
arguments: tool_call.arguments.clone(),
content: format!(
"调用工具: {}\n\n输入参数:\n{}",
tool_call.name,
format_tool_arguments(&tool_call.arguments),
),
content: format_tool_call_content(&tool_call.name, &tool_call.arguments),
role: message.role.clone(),
})
.collect()
@ -172,17 +180,7 @@ fn ws_outbound_from_chat_message(message: &ChatMessage) -> Vec<WsOutbound> {
}]
}
}
"tool" => vec![WsOutbound::ToolResult {
id: message.id.clone(),
tool_call_id: message.tool_call_id.clone().unwrap_or_else(|| message.id.clone()),
tool_name: message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
content: format!(
"工具结果: {}\n\n{}",
message.tool_name.clone().unwrap_or_else(|| "tool".to_string()),
message.content,
),
role: message.role.clone(),
}],
"tool" => Vec::new(),
_ => Vec::new(),
}
}
@ -231,13 +229,19 @@ async fn handle_inbound(
session_guard.record_skill_offer(&chat_id)?;
let agent = session_guard.create_agent(&chat_id, None, Some(&user_message_id))?;
let live_emitter = Arc::new(WsToolCallEmitter {
sender: session_guard.user_tx.clone(),
});
let agent = session_guard
.create_agent(&chat_id, None, Some(&user_message_id))?
.with_emitted_message_handler(live_emitter);
match agent.process(history).await {
Ok(result) => {
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
for outbound in result
.emitted_messages
.iter()
.filter(|message| !message.is_assistant_tool_call_message())
.flat_map(ws_outbound_from_chat_message)
{
let _ = session_guard.send(outbound).await;
@ -379,3 +383,46 @@ async fn handle_inbound(
}
}
}
#[cfg(test)]
mod tests {
use super::ws_outbound_from_chat_message;
use crate::bus::ChatMessage;
use crate::providers::ToolCall;
use crate::protocol::WsOutbound;
use serde_json::json;
#[test]
fn test_ws_outbound_from_chat_message_expands_tool_calls() {
let message = ChatMessage::assistant_with_tool_calls(
"",
vec![ToolCall {
id: "call-1".to_string(),
name: "calculator".to_string(),
arguments: json!({"expression": "1 + 1"}),
}],
);
let outbound = ws_outbound_from_chat_message(&message);
assert_eq!(outbound.len(), 1);
match &outbound[0] {
WsOutbound::ToolCall { tool_call_id, tool_name, arguments, content, .. } => {
assert_eq!(tool_call_id, "call-1");
assert_eq!(tool_name, "calculator");
assert_eq!(arguments["expression"], "1 + 1");
assert_eq!(content, "### calculator\n- expression: 1 + 1");
}
other => panic!("unexpected outbound variant: {:?}", other),
}
}
#[test]
fn test_ws_outbound_from_chat_message_omits_tool_results() {
let message = ChatMessage::tool("call-1", "calculator", "2");
let outbound = ws_outbound_from_chat_message(&message);
assert!(outbound.is_empty());
}
}

View File

@ -23,7 +23,7 @@ impl Tool for MemoryManageTool {
}
fn description(&self) -> &str {
"Manage user memories stored in SQLite. Supports actions: list, get, put, update, delete. Memories are scoped to the current channel and sender, and record the originating session/message when available."
"Manage user memories stored in SQLite. Supports actions: list, search, get, put, update, delete. Use search first when looking for user preferences, historical facts, prior decisions, or previously stored information by keyword. Memories are scoped to the current channel and sender, and record the originating session/message when available."
}
fn parameters_schema(&self) -> serde_json::Value {
@ -32,16 +32,20 @@ impl Tool for MemoryManageTool {
"properties": {
"action": {
"type": "string",
"enum": ["list", "get", "put", "update", "delete"],
"description": "Management action to perform"
"enum": ["list", "search", "get", "put", "update", "delete"],
"description": "Management action to perform. Prefer 'search' for keyword lookup across stored memories, 'get' for an exact namespace/key lookup, and 'list' for browsing recent memories."
},
"namespace": {
"type": "string",
"description": "Memory namespace, such as profile, preferences, or tasks"
"description": "Optional memory namespace filter, such as profile, preferences, or tasks"
},
"query": {
"type": "string",
"description": "Keyword query for full-text memory search, such as a preference, fact, name, topic, or prior decision"
},
"key": {
"type": "string",
"description": "Memory key within the namespace"
"description": "Exact memory key within the namespace"
},
"content": {
"type": "string",
@ -49,7 +53,7 @@ impl Tool for MemoryManageTool {
},
"limit": {
"type": "integer",
"description": "Maximum number of memories to list",
"description": "Maximum number of memories to return",
"minimum": 1,
"default": 20
}
@ -78,6 +82,7 @@ impl Tool for MemoryManageTool {
};
let namespace = args.get("namespace").and_then(|value| value.as_str());
let query = args.get("query").and_then(|value| value.as_str());
let key = args.get("key").and_then(|value| value.as_str());
let payload = match action {
@ -94,6 +99,24 @@ impl Tool for MemoryManageTool {
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
})
}
"search" => {
let query = match query {
Some(query) if !query.trim().is_empty() => query,
_ => return Ok(error_result("Missing required parameter: query")),
};
let limit = args
.get("limit")
.and_then(|value| value.as_u64())
.unwrap_or(20) as usize;
let memories = self
.store
.search_memories("user", &scope_key, query, namespace, limit)?;
json!({
"query": query,
"count": memories.len(),
"memories": memories.into_iter().map(memory_to_json).collect::<Vec<_>>()
})
}
"get" => {
let namespace = match namespace {
Some(namespace) => namespace,
@ -292,6 +315,49 @@ mod tests {
assert!(get.output.contains("msg-1"));
}
#[tokio::test]
async fn test_memory_manage_search() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let tool = MemoryManageTool::new(store);
let context = ToolContext {
channel_name: Some("feishu".to_string()),
sender_id: Some("user-1".to_string()),
chat_id: Some("chat-1".to_string()),
session_id: Some("feishu:chat-1".to_string()),
message_id: Some("msg-1".to_string()),
message_seq: Some(1),
};
let put = tool
.execute_with_context(
&context,
json!({
"action": "put",
"namespace": "profile",
"key": "editor",
"content": "Prefers rust-analyzer over clippy hints"
}),
)
.await
.unwrap();
assert!(put.success);
let search = tool
.execute_with_context(
&context,
json!({
"action": "search",
"query": "rust-analyzer",
"limit": 5
}),
)
.await
.unwrap();
assert!(search.success);
assert!(search.output.contains("rust-analyzer"));
assert!(search.output.contains("editor"));
}
#[tokio::test]
async fn test_memory_manage_requires_context() {
let store = Arc::new(SessionStore::in_memory().unwrap());