diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index ebbfc1e..1beba50 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -646,6 +646,8 @@ pub struct AgentLoop { observer: Option>, emitted_message_handler: Option>, max_iterations: usize, + /// 取消信号接收端:Agent 在每次迭代开始时检查是否被取消 + cancel_token: Option>, } #[derive(Debug, Clone)] @@ -742,6 +744,7 @@ impl AgentLoop { tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, + cancel_token: None, max_iterations, }) } @@ -764,6 +767,7 @@ impl AgentLoop { tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, + cancel_token: None, max_iterations, }) } @@ -787,6 +791,7 @@ impl AgentLoop { tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, + cancel_token: None, max_iterations, }) } @@ -814,6 +819,7 @@ impl AgentLoop { tool_context: ToolContext::default(), observer: None, emitted_message_handler: None, + cancel_token: None, max_iterations, }) } @@ -834,6 +840,15 @@ impl AgentLoop { self } + /// 设置取消信号接收端。 + /// + /// Agent 在每次迭代开始时检查 `cancel_token.has_changed()`, + /// 如果已收到取消信号则提前返回。 + pub fn with_cancel_token(mut self, token: tokio::sync::watch::Receiver<()>) -> Self { + self.cancel_token = Some(token); + self + } + pub fn tools(&self) -> &Arc { &self.tools } @@ -988,6 +1003,25 @@ impl AgentLoop { #[cfg(debug_assertions)] tracing::debug!(iteration, "Agent iteration started"); + // 检查取消信号 + if let Some(ref token) = self.cancel_token { + if token.has_changed().unwrap_or(false) { + tracing::info!(iteration, "Agent execution cancelled by user"); + let cancel_message = format!( + "\n\n[用户已取消执行。已迭代 {} 次,取消前共生成了 {} 条消息。]", + iteration, + emitted_messages.len() + ); + let assistant_message = ChatMessage::assistant(cancel_message); + emitted_messages.push(assistant_message.clone()); + self.emit_live_tool_call_message(assistant_message.clone()).await; + return Ok(AgentProcessResult { + final_response: assistant_message, + emitted_messages, + }); + } + } + // Build request let tool_defs = self.tools.get_definitions(); let tools = if tool_defs.is_empty() { diff --git a/src/command/adapters/channel.rs b/src/command/adapters/channel.rs index 292d179..d89e20b 100644 --- a/src/command/adapters/channel.rs +++ b/src/command/adapters/channel.rs @@ -133,6 +133,11 @@ impl InputAdapter for ChannelInputAdapter { return Ok(Some(Command::GetCurrentSession)); } + // 解析 /stop 命令 - 停止当前执行的 Agent + if trimmed == "/stop" { + return Ok(Some(Command::StopExecution)); + } + // 解析 /help 命令 - 显示所有支持的命令 if trimmed == "/help" { return Ok(Some(Command::Help)); diff --git a/src/command/adapters/cli.rs b/src/command/adapters/cli.rs index 5e3dca0..9a4544c 100644 --- a/src/command/adapters/cli.rs +++ b/src/command/adapters/cli.rs @@ -134,6 +134,11 @@ impl InputAdapter for CliInputAdapter { return Ok(Some(Command::GetCurrentSession)); } + // 解析 /stop 命令 - 停止当前执行的 Agent + if trimmed == "/stop" { + return Ok(Some(Command::StopExecution)); + } + // 解析 /help 命令 - 显示所有支持的命令 if trimmed == "/help" { return Ok(Some(Command::Help)); diff --git a/src/command/handlers/mod.rs b/src/command/handlers/mod.rs index 9696970..4986697 100644 --- a/src/command/handlers/mod.rs +++ b/src/command/handlers/mod.rs @@ -11,6 +11,7 @@ pub mod load_topic; pub mod save_session; pub mod save_topic; pub mod session; +pub mod stop_execution; pub mod switch_topic; // 导出公共函数供其他模块复用 diff --git a/src/command/handlers/stop_execution.rs b/src/command/handlers/stop_execution.rs new file mode 100644 index 0000000..9e8b0da --- /dev/null +++ b/src/command/handlers/stop_execution.rs @@ -0,0 +1,52 @@ +use async_trait::async_trait; + +use crate::command::context::CommandContext; +use crate::command::handler::{CommandHandler, CommandMetadata}; +use crate::command::response::{CommandError, CommandResponse, MessageKind}; +use crate::command::Command; +use crate::gateway::cancel_manager::CancelManager; + +/// 处理 StopExecution 命令:取消当前正在执行的 Agent。 +pub struct StopExecutionCommandHandler { + cancel_manager: CancelManager, +} + +impl StopExecutionCommandHandler { + pub fn new(cancel_manager: CancelManager) -> Self { + Self { cancel_manager } + } +} + +#[async_trait] +impl CommandHandler for StopExecutionCommandHandler { + fn can_handle(&self, cmd: &Command) -> bool { + matches!(cmd, Command::StopExecution) + } + + fn metadata(&self) -> Option { + Some(CommandMetadata { + name: "stop", + description: "停止当前正在执行的 Agent", + usage: "/stop", + }) + } + + async fn handle( + &self, + _cmd: Command, + ctx: CommandContext, + ) -> Result { + let channel = &ctx.channel_name; + let chat_id = ctx.chat_id.as_deref().unwrap_or("default"); + + let cancelled = self.cancel_manager.cancel(channel, chat_id).await; + + if cancelled { + Ok(CommandResponse::success(ctx.request_id) + .with_message(MessageKind::Notification, "正在停止当前任务...")) + } else { + Ok(CommandResponse::success(ctx.request_id) + .with_message(MessageKind::Notification, "当前没有正在执行的任务")) + } + } +} diff --git a/src/command/mod.rs b/src/command/mod.rs index 1c2b0b2..25ca552 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -52,6 +52,8 @@ pub enum Command { channel: String, chat_id: String, }, + /// 停止当前正在执行的 Agent + StopExecution, } impl Command { @@ -72,6 +74,7 @@ impl Command { Command::LoadTaskMessages { .. } => "load_task_messages", Command::ListSchedulerJobs => "list_scheduler_jobs", Command::LoadChatMessages { .. } => "load_chat_messages", + Command::StopExecution => "stop_execution", } } } diff --git a/src/gateway/agent_factory.rs b/src/gateway/agent_factory.rs index 7d7c0ba..93bc6be 100644 --- a/src/gateway/agent_factory.rs +++ b/src/gateway/agent_factory.rs @@ -23,6 +23,8 @@ pub(crate) struct AgentBuildRequest<'a> { pub(crate) sender_id: Option<&'a str>, pub(crate) message_id: Option<&'a str>, pub(crate) provider_config: LLMProviderConfig, + /// 取消信号接收端(可选):Agent 在每次迭代时检查是否被取消 + pub(crate) cancel_token: Option>, } impl AgentFactory { @@ -64,7 +66,7 @@ impl AgentFactory { let tool_chat_id = request .notification_chat_id .unwrap_or(request.session_chat_id); - agent.with_tool_context(ToolContext { + let mut agent = agent.with_tool_context(ToolContext { channel_name: Some(request.channel_name.to_string()), sender_id: request.sender_id.map(str::to_string), chat_id: Some(tool_chat_id.to_string()), @@ -73,7 +75,12 @@ impl AgentFactory { message_id: request.message_id.map(str::to_string), message_seq: None, subagent_description: None, - }) + }); + // 如果有取消信号接收端,注入 Agent + if let Some(token) = request.cancel_token { + agent = agent.with_cancel_token(token); + } + agent }) } } diff --git a/src/gateway/cancel_manager.rs b/src/gateway/cancel_manager.rs new file mode 100644 index 0000000..a541a7a --- /dev/null +++ b/src/gateway/cancel_manager.rs @@ -0,0 +1,62 @@ +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{Mutex, watch}; + +/// 共享的 Agent 取消注册表。 +/// +/// 每个正在执行的 Agent 在启动前注册一个 watch::Sender, +/// 外部(如 /stop 命令)通过 cancel() 发送取消信号。 +/// Agent 循环内部通过 watch::Receiver::has_changed() 检测取消。 +#[derive(Clone)] +pub struct CancelManager { + tokens: Arc>>>, +} + +impl CancelManager { + pub fn new() -> Self { + Self { + tokens: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// 注册一个取消通道,返回 receiver 供 Agent 持有。 + /// + /// 如果同 (channel, chat_id) 已有注册,旧 sender 被覆盖并 drop, + /// 旧 receiver 将收到通道关闭信号。 + pub async fn register(&self, channel: &str, chat_id: &str) -> watch::Receiver<()> { + let (tx, rx) = watch::channel(()); + let key = (channel.to_string(), chat_id.to_string()); + self.tokens.lock().await.insert(key, tx); + rx + } + + /// 发送取消信号并移除注册条目。 + /// + /// 返回 `true` 表示找到了对应的任务并发送了取消信号, + /// 返回 `false` 表示没有找到对应的任务(可能已经完成或从未注册)。 + pub async fn cancel(&self, channel: &str, chat_id: &str) -> bool { + let key = (channel.to_string(), chat_id.to_string()); + if let Some(tx) = self.tokens.lock().await.remove(&key) { + // send 可能失败(receiver 已被 drop),这不影响语义 + let _ = tx.send(()); + true + } else { + false + } + } + + /// 正常完成后清理注册条目(幂等)。 + /// + /// 与 cancel() 不同,此方法不发送取消信号,仅移除条目。 + /// 如果条目已被 cancel() 移除,此调用为 no-op。 + pub async fn remove(&self, channel: &str, chat_id: &str) { + let key = (channel.to_string(), chat_id.to_string()); + self.tokens.lock().await.remove(&key); + } +} + +impl Default for CancelManager { + fn default() -> Self { + Self::new() + } +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index c5d560b..2453933 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,6 +1,7 @@ pub mod agent_factory; pub mod agent_prompt_provider; pub mod agent_task_executor; +pub mod cancel_manager; pub mod cli_session; pub mod command; pub mod compaction; @@ -42,6 +43,7 @@ use crate::scheduler::Scheduler; use crate::skills::SkillRuntime; use crate::tools::task::repository::TaskRepository; use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService}; +use cancel_manager::CancelManager; use outbound_dispatcher::OutboundDispatcher; use processor::InboundProcessor; use runtime::build_session_manager_with_sender; @@ -55,6 +57,7 @@ pub struct GatewayState { pub channel_manager: ChannelManager, pub bus: Arc, pub task_repository: Arc, + pub cancel_manager: CancelManager, } impl GatewayState { @@ -96,12 +99,15 @@ impl GatewayState { Some(bus.clone()), )?; + let cancel_manager = CancelManager::new(); + Ok(Self { config, session_manager, channel_manager, bus, task_repository, + cancel_manager, }) } @@ -122,7 +128,7 @@ impl GatewayState { } }; let inbound_processor = - InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore, provider_config); + InboundProcessor::new(self.bus.clone(), self.session_manager.clone(), semaphore, provider_config, self.cancel_manager.clone()); tokio::spawn(inbound_processor.run()); // Spawn outbound dispatcher diff --git a/src/gateway/processor.rs b/src/gateway/processor.rs index 7bf15e1..6cf934d 100644 --- a/src/gateway/processor.rs +++ b/src/gateway/processor.rs @@ -14,9 +14,11 @@ use crate::command::handlers::load_topic::LoadTopicCommandHandler; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::save_topic::SaveTopicCommandHandler; use crate::command::handlers::session::SessionCommandHandler; +use crate::command::handlers::stop_execution::StopExecutionCommandHandler; use crate::command::handlers::switch_topic::SwitchTopicCommandHandler; use crate::config::LLMProviderConfig; use crate::gateway::agent_prompt_provider::AgentPromptProvider; +use crate::gateway::cancel_manager::CancelManager; use crate::providers::{create_provider, ProviderRuntimeConfig}; use crate::skills::SkillPromptProvider; use crate::storage::persistent_session_id; @@ -31,6 +33,7 @@ pub struct InboundProcessor { semaphore: Arc, provider_config: LLMProviderConfig, command_router: Arc, + cancel_manager: CancelManager, } impl InboundProcessor { @@ -39,6 +42,7 @@ impl InboundProcessor { session_manager: SessionManager, semaphore: Arc, provider_config: LLMProviderConfig, + cancel_manager: CancelManager, ) -> Self { // 创建命令路由器并注册处理器 let mut command_router = CommandRouter::new(); @@ -97,12 +101,18 @@ impl InboundProcessor { let metadata = command_router.metadata_arc(); command_router.register(Box::new(HelpCommandHandler::new(metadata))); + // 注册 stop_execution 处理器 + command_router.register(Box::new(StopExecutionCommandHandler::new( + cancel_manager.clone(), + ))); + Self { bus, session_manager, semaphore, provider_config, command_router: Arc::new(command_router), + cancel_manager, } } @@ -236,6 +246,16 @@ impl InboundProcessor { current_topic.clone(), )); + // 保存 channel 和 chat_id 用于后续清理(因 match 中可能 move inbound) + let channel = inbound.channel.clone(); + let chat_id = inbound.chat_id.clone(); + + // 注册取消信号:Agent 构建时通过 Session 消费该 receiver + let cancel_rx = self.cancel_manager.register(&channel, &chat_id).await; + self.session_manager + .set_agent_cancel_token(&channel, &chat_id, cancel_rx) + .await; + match self .session_manager .handle_message( @@ -308,6 +328,9 @@ impl InboundProcessor { } } + // 清理取消信号注册(幂等:如果已被 cancel() 移除则为 no-op) + self.cancel_manager.remove(&channel, &chat_id).await; + Ok(()) } } diff --git a/src/gateway/session.rs b/src/gateway/session.rs index c35ab76..8f961b1 100644 --- a/src/gateway/session.rs +++ b/src/gateway/session.rs @@ -45,6 +45,9 @@ pub struct Session { compressor: ContextCompressor, history: SessionHistory, store: Arc, + /// 等待中的取消信号接收端(按 chat_id 索引)。 + /// 在 Agent 执行前由外部注入,Agent 构建时消费。 + pending_cancel_tokens: HashMap>, } pub struct BusToolCallEmitter { @@ -163,6 +166,7 @@ impl Session { skill_events, ), store, + pending_cancel_tokens: HashMap::new(), }) } @@ -179,6 +183,19 @@ impl Session { } } + /// 存入待使用的取消信号接收端。 + /// + /// 在 Agent 执行前由处理器调用,Agent 构建时(create_agent)自动消费。 + /// 每个 chat_id 同时只允许一个 pending token;新 token 会替换旧 token。 + pub fn set_cancel_receiver( + &mut self, + chat_id: &str, + receiver: tokio::sync::watch::Receiver<()>, + ) { + self.pending_cancel_tokens + .insert(chat_id.to_string(), receiver); + } + /// 获取当前话题 ID(指定 chat) pub fn current_topic(&self, chat_id: &str) -> Option<&str> { self.history.chat_topic(chat_id) @@ -420,7 +437,7 @@ impl Session { /// 创建一个临时的 AgentLoop 实例来处理消息 pub fn create_agent( - &self, + &mut self, chat_id: &str, sender_id: Option<&str>, message_id: Option<&str>, @@ -435,13 +452,15 @@ impl Session { } pub fn create_agent_with_provider_config( - &self, + &mut self, session_chat_id: &str, notification_chat_id: Option<&str>, sender_id: Option<&str>, message_id: Option<&str>, provider_config: LLMProviderConfig, ) -> Result { + // 消费 pending 的取消信号接收端(如果存在) + let cancel_token = self.pending_cancel_tokens.remove(session_chat_id); self.agent_factory.create(AgentBuildRequest { channel_name: &self.channel_name, session_chat_id, @@ -449,6 +468,7 @@ impl Session { sender_id, message_id, provider_config, + cancel_token, }) } } @@ -607,6 +627,20 @@ impl SessionManager { } } + /// 存入 Agent 取消信号接收端,供 Agent 构建时消费。 + /// + /// 在 Agent 执行前由处理器调用。Agent 在 create_agent() 时自动取出。 + pub async fn set_agent_cancel_token( + &self, + channel_name: &str, + chat_id: &str, + token: tokio::sync::watch::Receiver<()>, + ) { + if let Some(session) = self.get(channel_name).await { + session.lock().await.set_cancel_receiver(chat_id, token); + } + } + /// 更新最后活跃时间 pub async fn touch(&self, channel_name: &str) { self.lifecycle.touch(channel_name).await; diff --git a/src/gateway/ws.rs b/src/gateway/ws.rs index 7ca9543..616b0ff 100644 --- a/src/gateway/ws.rs +++ b/src/gateway/ws.rs @@ -17,6 +17,7 @@ use crate::command::handlers::load_task_messages::LoadTaskMessagesCommandHandler use crate::command::handlers::load_topic::LoadTopicCommandHandler; use crate::command::handlers::save_session::SaveSessionCommandHandler; use crate::command::handlers::session::SessionCommandHandler; +use crate::command::handlers::stop_execution::StopExecutionCommandHandler; use crate::command::handlers::switch_topic::SwitchTopicCommandHandler; use crate::gateway::agent_prompt_provider::AgentPromptProvider; use crate::protocol::{WsInbound, WsOutbound, MediaSummary, parse_inbound, serialize_outbound}; @@ -405,6 +406,10 @@ async fn handle_inbound( router.register(Box::new(ListSchedulerJobsCommandHandler::new(store.clone()))); // 注册 load_chat_messages 处理器 router.register(Box::new(LoadChatMessagesCommandHandler::new())); + // 注册 stop_execution 处理器 + router.register(Box::new(StopExecutionCommandHandler::new( + state.cancel_manager.clone(), + ))); // 构建命令上下文 tracing::debug!( diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 54b0e3d..2094a34 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -187,6 +187,8 @@ pub enum WsOutbound { SchedulerJobList { jobs: Vec, }, + #[serde(rename = "execution_cancelled")] + ExecutionCancelled { message: String }, #[serde(rename = "pong")] Pong, } diff --git a/tests/test_request_format.rs b/tests/test_request_format.rs index e86d6c5..3db4278 100644 --- a/tests/test_request_format.rs +++ b/tests/test_request_format.rs @@ -155,6 +155,7 @@ fn test_tool_result_outbound_serialization() { content: "工具结果: calculator\n\n2".to_string(), role: "tool".to_string(), subagent_task_id: None, + duration_ms: None, }; let json = serde_json::to_string(&msg).unwrap(); diff --git a/web/src/App.tsx b/web/src/App.tsx index 0cad1e9..be2c9c2 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -51,6 +51,7 @@ function App() { requestTopicList, enterSubAgentView, exitSubAgentView, + handleStop, } = useChat() const { status, sendMessage } = useWebSocket({ @@ -127,6 +128,9 @@ function App() { case 'save': cmd = { type: 'save_topic', filepath: args[0] || undefined, include_subagents: false } break + case 'stop': + cmd = { type: 'stop_execution' } + break default: alert(`Unknown command: /${command}`) return @@ -147,6 +151,12 @@ function App() { [sendMessage, handleMessage, handleCommand, sessionId, chatId, isReadOnly] ) + const handleStopExecution = useCallback(() => { + const cmd = handleStop() + handleCommand(cmd) + sendMessage({ type: 'command', payload: JSON.stringify(cmd) }) + }, [sendMessage, handleCommand, handleStop]) + const handleCreateTopic = useCallback(() => { if (isReadOnly || !sessionId) { return @@ -415,6 +425,7 @@ function App() { } onSendMessage={subAgentView || schedulerView ? () => {} : handleSendMessage} onNavigateToSubAgent={handleNavigateToSubAgent} + onStop={handleStopExecution} /> diff --git a/web/src/components/Chat/ChatContainer.tsx b/web/src/components/Chat/ChatContainer.tsx index 6e71d4e..c3602dd 100644 --- a/web/src/components/Chat/ChatContainer.tsx +++ b/web/src/components/Chat/ChatContainer.tsx @@ -9,6 +9,7 @@ interface ChatContainerProps { channelName?: string onSendMessage: (content: string, attachments: Attachment[]) => void onNavigateToSubAgent?: (taskId: string, description: string) => void + onStop?: () => void } export function ChatContainer({ @@ -18,6 +19,7 @@ export function ChatContainer({ channelName, onSendMessage, onNavigateToSubAgent, + onStop, }: ChatContainerProps) { return (
@@ -26,6 +28,7 @@ export function ChatContainer({
void + onStop?: () => void disabled?: boolean isLoading?: boolean placeholder?: string @@ -29,6 +30,7 @@ function getMediaType(mimeType: string): string { export function MessageInput({ onSend, + onStop, disabled = false, isLoading = false, placeholder = '输入消息...按 / 查看命令', @@ -366,18 +368,28 @@ export function MessageInput({ - {/* 发送按钮 */} - + {/* 发送/停止按钮 */} + {isLoading && onStop ? ( + + ) : ( + + )} {/* 提示 */} diff --git a/web/src/hooks/useChat.ts b/web/src/hooks/useChat.ts index c450b3a..09edfb4 100644 --- a/web/src/hooks/useChat.ts +++ b/web/src/hooks/useChat.ts @@ -72,6 +72,9 @@ interface UseChatReturn { schedulerView: SchedulerJobView | null enterSchedulerJobView: (lookup: SchedulerJobSessionLookup, jobId: string, description: string) => Command exitSchedulerJobView: () => void + + // 停止当前 Agent 执行 + handleStop: () => Command } interface SubAgentView { @@ -257,8 +260,8 @@ export function useChat(): UseChatReturn { const msgSubagentTaskId = getSubagentTaskId(message) if (msgSubagentTaskId && msgSubagentTaskId === currentSubAgentView.taskId) { appendToSubAgentViewMessage(message) + return } - return } // In main view, skip sub-agent messages (they belong to sub-agent view). @@ -410,6 +413,21 @@ export function useChat(): UseChatReturn { break } + case 'execution_cancelled': { + setMessages((prev) => [ + ...prev, + { + id: generateMessageId(), + role: 'assistant', + content: (message as { type: 'execution_cancelled'; message: string }).message, + timestamp: Date.now(), + type: 'message', + }, + ]) + setIsLoading(false) + break + } + case 'error': { setMessages((prev) => [ ...prev, @@ -568,6 +586,10 @@ export function useChat(): UseChatReturn { setSchedulerView(null) }, []) + const handleStop = useCallback((): Command => { + return { type: 'stop_execution' } + }, []) + // Memoize messages: sub-agent view > scheduler view > main const resolvedMessages = useMemo(() => { if (subAgentView) { @@ -612,5 +634,6 @@ export function useChat(): UseChatReturn { schedulerView, enterSchedulerJobView, exitSchedulerJobView, + handleStop, } } diff --git a/web/src/types/protocol.ts b/web/src/types/protocol.ts index 9543d8a..3b3e7ee 100644 --- a/web/src/types/protocol.ts +++ b/web/src/types/protocol.ts @@ -192,6 +192,11 @@ export interface TaskMessagesLoaded { summary?: string } +export interface ExecutionCancelled { + type: 'execution_cancelled' + message: string +} + export type WsOutbound = | AssistantResponse | ToolCall @@ -207,6 +212,7 @@ export type WsOutbound = | ChannelList | TaskMessagesLoaded | SchedulerJobList + | ExecutionCancelled | Pong // ============================================================================ @@ -284,6 +290,10 @@ export interface LoadChatMessagesCommand { chat_id: string } +export interface StopExecutionCommand { + type: 'stop_execution' +} + export type Command = | CreateSessionCommand | ListSessionsCommand @@ -299,6 +309,7 @@ export type Command = | LoadTaskMessagesCommand | ListSchedulerJobsCommand | LoadChatMessagesCommand + | StopExecutionCommand // ============================================================================ // UI Types