diff --git a/src/gateway/execution.rs b/src/gateway/execution.rs index ab95a69..3221012 100644 --- a/src/gateway/execution.rs +++ b/src/gateway/execution.rs @@ -1,11 +1,13 @@ use std::collections::HashMap; +use std::sync::Arc; use crate::agent::{AgentError, AgentProcessResult}; use crate::bus::message::ToolMessageState; use crate::bus::{ChatMessage, OutboundMessage}; use crate::config::LLMProviderConfig; +use tokio::sync::Mutex; -use super::session::Session; +use super::session::{Session, schedule_background_history_compaction}; const SCHEDULED_TASK_EXECUTION_SYSTEM_PROMPT: &str = "系统说明:当前输入来自一次已经触发的定时任务执行。你现在需要执行任务内容本身,而不是创建、修改、恢复、暂停或查询新的定时任务。除非当前任务内容明确要求管理调度器,否则不要调用任何定时任务管理工具;像“每小时”、“每天”、“cron”、“定时”等词,只应视为任务背景,不应再解释为新的建任务请求。"; @@ -112,6 +114,37 @@ impl AgentExecutionService { should_schedule_compaction: true, }) } + + pub(crate) async fn finalize_result_and_schedule_compaction( + &self, + session: Arc>, + request: FinalizeAgentResultRequest<'_>, + ) -> Result, AgentError> { + let channel_name = request.channel_name.to_string(); + let chat_id = request.chat_id.to_string(); + let execution_kind = request.execution_kind.to_string(); + + let finalized_result = { + let mut session_guard = session.lock().await; + self.finalize_result(&mut session_guard, request)? + }; + + if finalized_result.should_schedule_compaction { + if let Err(error) = + schedule_background_history_compaction(session.clone(), chat_id.clone()).await + { + tracing::warn!( + channel = %channel_name, + chat_id = %chat_id, + execution_kind = %execution_kind, + error = %error, + "Failed to schedule background history compaction" + ); + } + } + + Ok(finalized_result.outbound_messages) + } } pub(crate) fn should_display_message_to_user( diff --git a/src/gateway/session.rs b/src/gateway/session.rs index 1e1fcec..6a0bbe8 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -1430,11 +1430,10 @@ impl SessionManager { let result = agent.process(history).await?; - let finalized_result = { - let mut session_guard = session.lock().await; - let metadata = HashMap::new(); - AgentExecutionService::new(self.show_tool_results).finalize_result( - &mut session_guard, + let metadata = HashMap::new(); + let outbound_messages = AgentExecutionService::new(self.show_tool_results) + .finalize_result_and_schedule_compaction( + session.clone(), FinalizeAgentResultRequest { channel_name, chat_id, @@ -1444,26 +1443,18 @@ impl SessionManager { suppress_live_tool_calls: live_emitter.is_some(), execution_kind: "message", }, - )? - }; - - if finalized_result.should_schedule_compaction { - if let Err(error) = - schedule_background_history_compaction(session.clone(), chat_id.to_string()).await - { - tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction"); - } - } + ) + .await?; #[cfg(debug_assertions)] tracing::debug!( channel = %channel_name, chat_id = %chat_id, - outbound_count = finalized_result.outbound_messages.len(), + outbound_count = outbound_messages.len(), "Agent response sequence received" ); - Ok(finalized_result.outbound_messages) + Ok(outbound_messages) } pub async fn run_scheduled_agent_task( @@ -1528,10 +1519,9 @@ impl SessionManager { let result = agent.process(history).await?; - let finalized_result = { - let mut session_guard = session.lock().await; - AgentExecutionService::new(self.show_tool_results).finalize_result( - &mut session_guard, + AgentExecutionService::new(self.show_tool_results) + .finalize_result_and_schedule_compaction( + session.clone(), FinalizeAgentResultRequest { channel_name, chat_id, @@ -1541,18 +1531,8 @@ impl SessionManager { suppress_live_tool_calls: false, execution_kind: "scheduled_task", }, - )? - }; - - if finalized_result.should_schedule_compaction { - if let Err(error) = - schedule_background_history_compaction(session.clone(), chat_id.to_string()).await - { - tracing::warn!(channel = %channel_name, chat_id = %chat_id, error = %error, "Failed to schedule background history compaction for scheduled task"); - } - } - - Ok(finalized_result.outbound_messages) + ) + .await } /// 清除指定 session 的所有历史 @@ -1725,16 +1705,21 @@ mod tests { } async fn start_mock_openai_server() -> String { - async fn handle(Json(body): Json) -> Json { + start_mock_openai_server_with_content(None).await + } + + async fn start_mock_openai_server_with_content( + mock_response_content: Option, + ) -> String { + async fn handle( + axum::extract::State(mock_response_content): axum::extract::State>, + Json(body): Json, + ) -> Json { let model = body .get("model") .and_then(|value| value.as_str()) .unwrap_or("unknown-model"); - let content = body - .get("mock_response_content") - .and_then(|value| value.as_str()) - .map(ToString::to_string) - .unwrap_or_else(|| format!("reply from {}", model)); + let content = mock_response_content.unwrap_or_else(|| format!("reply from {}", model)); Json(json!({ "id": "mock-response", @@ -1755,7 +1740,9 @@ mod tests { })) } - let app = Router::new().route("/chat/completions", post(handle)); + let app = Router::new() + .route("/chat/completions", post(handle)) + .with_state(mock_response_content); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { @@ -1778,11 +1765,15 @@ mod tests { format!("http://{}", address) } - async fn start_mock_openai_flaky_server() -> String { + async fn start_mock_openai_flaky_server(mock_response_content: String) -> String { let attempts = StdArc::new(AtomicUsize::new(0)); + let state = (attempts, mock_response_content); async fn handle( - axum::extract::State(attempts): axum::extract::State>, + axum::extract::State((attempts, mock_response_content)): axum::extract::State<( + StdArc, + String, + )>, Json(body): Json, ) -> (StatusCode, Json) { let attempt = attempts.fetch_add(1, Ordering::SeqCst); @@ -1797,11 +1788,6 @@ mod tests { .get("model") .and_then(|value| value.as_str()) .unwrap_or("unknown-model"); - let content = body - .get("mock_response_content") - .and_then(|value| value.as_str()) - .unwrap_or("{\"user_facts\":[],\"preferences\":[],\"behavior_patterns\":[],\"merges\":[],\"conflicts\":[],\"low_value_ids\":[],\"managed_markdown\":\"\"}"); - ( StatusCode::OK, Json(json!({ @@ -1810,7 +1796,7 @@ mod tests { "choices": [ { "message": { - "content": content, + "content": mock_response_content, "tool_calls": [] } } @@ -1826,7 +1812,7 @@ mod tests { let app = Router::new() .route("/chat/completions", post(handle)) - .with_state(attempts); + .with_state(state); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let address = listener.local_addr().unwrap(); tokio::spawn(async move { @@ -2012,7 +1998,6 @@ mod tests { #[tokio::test] async fn test_summarize_memory_maintenance_for_scope_uses_model_output() { - let base_url = start_mock_openai_server().await; let mock_response_content = serde_json::to_string(&json!({ "user_facts": ["用户在做AI产品"], "preferences": ["偏好简洁表达"], @@ -2023,6 +2008,8 @@ mod tests { "managed_markdown": "### 用户事实\n- 用户在做AI产品\n\n### 用户偏好\n- 偏好简洁表达\n\n### 行为模式\n- 习惯先问方案再要代码" })) .unwrap(); + let base_url = + start_mock_openai_server_with_content(Some(mock_response_content.clone())).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), @@ -2112,7 +2099,6 @@ mod tests { #[tokio::test] async fn test_summarize_memory_maintenance_retries_recoverable_provider_errors() { - let base_url = start_mock_openai_flaky_server().await; let mock_response_content = serde_json::to_string(&json!({ "user_facts": ["用户在做AI产品"], "preferences": [], @@ -2123,6 +2109,7 @@ mod tests { "managed_markdown": "### 用户事实\n- 用户在做AI产品" })) .unwrap(); + let base_url = start_mock_openai_flaky_server(mock_response_content.clone()).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), @@ -2182,8 +2169,9 @@ mod tests { #[tokio::test] async fn test_summarize_memory_maintenance_for_scope_extracts_wrapped_json_object() { - let base_url = start_mock_openai_server().await; let mock_response_content = "结果如下:\n```json\n{\n \"user_facts\": [\"用户在做AI产品\"],\n \"preferences\": [],\n \"behavior_patterns\": [],\n \"merges\": [],\n \"conflicts\": [],\n \"low_value_ids\": [],\n \"managed_markdown\": \"### 用户事实\\n- 用户在做AI产品\"\n}\n```\n"; + let base_url = + start_mock_openai_server_with_content(Some(mock_response_content.to_string())).await; let provider_config = LLMProviderConfig { provider_type: "openai".to_string(), diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index a82f03f..7de9b50 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -515,7 +515,7 @@ mod tests { assert_eq!(tool_call_id, "call-1"); assert_eq!(tool_name, "calculator"); assert_eq!(arguments["expression"], "1 + 1"); - assert_eq!(content, "### calculator\n- expression: 1 + 1"); + assert_eq!(content, "calculator\nargs: {\"expression\":\"1 + 1\"}"); } other => panic!("unexpected outbound variant: {:?}", other), } diff --git a/tests/test_integration.rs b/tests/test_integration.rs index f428a1c..620ace5 100644 --- a/tests/test_integration.rs +++ b/tests/test_integration.rs @@ -25,9 +25,7 @@ fn load_config() -> Option { max_tokens: Some(100), model_extra: HashMap::new(), max_tool_iterations: 20, - token_limit: 128_000, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }) } diff --git a/tests/test_tool_calling.rs b/tests/test_tool_calling.rs index 961832e..c79e3bc 100644 --- a/tests/test_tool_calling.rs +++ b/tests/test_tool_calling.rs @@ -25,9 +25,7 @@ fn load_openai_config() -> Option { max_tokens: Some(100), model_extra: HashMap::new(), max_tool_iterations: 20, - token_limit: 128_000, tool_result_max_chars: 20_000, - context_summary_max_chars: 20_000, context_tool_result_trim_chars: 20_000, }) }