Compare commits
No commits in common. "bee1a39a063fee846fead0b8578dca167fc4e950" and "8edc7ef9050949536cd8d154cf25934b489aef3d" have entirely different histories.
bee1a39a06
...
8edc7ef905
@ -100,63 +100,15 @@ impl Default for SkillsConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolsConfig {
|
||||
#[serde(default)]
|
||||
pub disabled: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub task: TaskConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TaskConfig {
|
||||
#[serde(default = "default_task_enabled")]
|
||||
pub enabled: bool,
|
||||
#[serde(default = "default_task_max_execution_secs")]
|
||||
pub max_execution_secs: u64,
|
||||
#[serde(default = "default_task_ttl_hours")]
|
||||
pub ttl_hours: u64,
|
||||
#[serde(default = "default_task_allowed_tools")]
|
||||
pub allowed_tools: Vec<String>,
|
||||
}
|
||||
|
||||
fn default_task_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_task_max_execution_secs() -> u64 {
|
||||
300
|
||||
}
|
||||
|
||||
fn default_task_ttl_hours() -> u64 {
|
||||
24
|
||||
}
|
||||
|
||||
fn default_task_allowed_tools() -> Vec<String> {
|
||||
vec![
|
||||
"file_read".to_string(),
|
||||
"file_edit".to_string(),
|
||||
"file_write".to_string(),
|
||||
"bash".to_string(),
|
||||
"http_request".to_string(),
|
||||
"web_fetch".to_string(),
|
||||
"memory_search".to_string(),
|
||||
"get_time".to_string(),
|
||||
"calculator".to_string(),
|
||||
"skill_activate".to_string(),
|
||||
"skill_list".to_string(),
|
||||
"send_session_message".to_string(),
|
||||
]
|
||||
}
|
||||
|
||||
impl Default for TaskConfig {
|
||||
impl Default for ToolsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: default_task_enabled(),
|
||||
max_execution_secs: default_task_max_execution_secs(),
|
||||
ttl_hours: default_task_ttl_hours(),
|
||||
allowed_tools: default_task_allowed_tools(),
|
||||
}
|
||||
Self { disabled: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -71,7 +71,6 @@ impl AgentFactory {
|
||||
session_id: Some(session_id),
|
||||
message_id: request.message_id.map(str::to_string),
|
||||
message_seq: None,
|
||||
subagent_description: None,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@ -81,7 +81,6 @@ impl GatewayState {
|
||||
skills,
|
||||
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
||||
std::collections::HashSet::new(),
|
||||
config.tools.task.clone(),
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)?;
|
||||
|
||||
@ -2,17 +2,14 @@ use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::agent::AgentError;
|
||||
use crate::config::{LLMProviderConfig, TaskConfig};
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{
|
||||
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
||||
SessionStore, SkillEventRepository,
|
||||
};
|
||||
use crate::tools::{
|
||||
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
|
||||
SessionMessageSender, SubAgentRuntime, SubAgentRuntimeConfig, ToolRegistry,
|
||||
};
|
||||
use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry};
|
||||
|
||||
use super::agent_factory::AgentFactory;
|
||||
use super::cli_session::CliSessionService;
|
||||
@ -32,7 +29,6 @@ pub(crate) fn build_session_manager(
|
||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<SessionManager, AgentError> {
|
||||
@ -45,7 +41,6 @@ pub(crate) fn build_session_manager(
|
||||
skills,
|
||||
Arc::new(NoopSessionMessageSender),
|
||||
disabled_tools,
|
||||
task_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)
|
||||
@ -60,7 +55,6 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
skills: Arc<SkillRuntime>,
|
||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<SessionManager, AgentError> {
|
||||
@ -80,49 +74,20 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
let memories: Arc<dyn MemoryRepository> = store.clone();
|
||||
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
|
||||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||
|
||||
// 创建 ToolRegistryFactory
|
||||
let factory = ToolRegistryFactory::new(
|
||||
let tools = Arc::new(
|
||||
ToolRegistryFactory::new(
|
||||
skills.clone(),
|
||||
memories,
|
||||
scheduler_jobs,
|
||||
skill_events.clone(),
|
||||
session_message_sender.clone(),
|
||||
conversations.clone(),
|
||||
session_message_sender,
|
||||
known_agents,
|
||||
default_timezone,
|
||||
disabled_tools,
|
||||
task_config.clone(),
|
||||
)
|
||||
.build(),
|
||||
);
|
||||
|
||||
// 创建 SubAgentRuntime(如果 task 工具启用)
|
||||
let factory = if task_config.enabled {
|
||||
let task_repository = Arc::new(InMemoryTaskRepository::new());
|
||||
let subagent_tools = Arc::new(factory.build_subagent_tools());
|
||||
|
||||
let runtime_config = SubAgentRuntimeConfig {
|
||||
allowed_tools: task_config.allowed_tools.iter().cloned().collect(),
|
||||
max_execution_secs: task_config.max_execution_secs,
|
||||
explore_max_tool_calls: 20,
|
||||
ttl_hours: task_config.ttl_hours,
|
||||
};
|
||||
|
||||
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
|
||||
runtime_config,
|
||||
task_repository,
|
||||
conversations.clone(),
|
||||
subagent_tools,
|
||||
provider_config.clone(),
|
||||
));
|
||||
|
||||
factory.with_subagent_runtime(subagent_runtime)
|
||||
} else {
|
||||
factory
|
||||
};
|
||||
|
||||
let tools = Arc::new(factory.build());
|
||||
|
||||
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
|
||||
let agent_factory = AgentFactory::new(
|
||||
tools.clone(),
|
||||
@ -130,6 +95,7 @@ pub(crate) fn build_session_manager_with_sender(
|
||||
agent_prompt_reinject_every as usize,
|
||||
prompt_repository.clone(),
|
||||
);
|
||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||
let session_factory = SessionFactory::new(
|
||||
provider_config.clone(),
|
||||
skills.clone(),
|
||||
|
||||
@ -485,7 +485,6 @@ impl SessionManager {
|
||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||
skills: Arc<SkillRuntime>,
|
||||
disabled_tools: std::collections::HashSet<String>,
|
||||
task_config: crate::config::TaskConfig,
|
||||
chat_history_ttl_hours: Option<u64>,
|
||||
session_ttl_hours: Option<u64>,
|
||||
) -> Result<Self, AgentError> {
|
||||
@ -497,7 +496,6 @@ impl SessionManager {
|
||||
provider_configs,
|
||||
skills,
|
||||
disabled_tools,
|
||||
task_config,
|
||||
chat_history_ttl_hours,
|
||||
session_ttl_hours,
|
||||
)
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::TaskConfig;
|
||||
use crate::skills::SkillRuntime;
|
||||
use crate::storage::{ConversationRepository, MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
||||
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SkillListTool,
|
||||
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
||||
ToolRegistry, WebFetchTool,
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
|
||||
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender,
|
||||
SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TimeTool, ToolRegistry,
|
||||
WebFetchTool,
|
||||
};
|
||||
|
||||
pub(crate) struct ToolRegistryFactory {
|
||||
@ -18,12 +16,9 @@ pub(crate) struct ToolRegistryFactory {
|
||||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
known_agents: HashSet<String>,
|
||||
default_timezone: String,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
||||
}
|
||||
|
||||
impl ToolRegistryFactory {
|
||||
@ -33,11 +28,9 @@ impl ToolRegistryFactory {
|
||||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||||
skill_events: Arc<dyn SkillEventRepository>,
|
||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||
conversations: Arc<dyn ConversationRepository>,
|
||||
known_agents: HashSet<String>,
|
||||
default_timezone: String,
|
||||
disabled_tools: HashSet<String>,
|
||||
task_config: TaskConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
skills,
|
||||
@ -45,23 +38,12 @@ impl ToolRegistryFactory {
|
||||
scheduler_jobs,
|
||||
skill_events,
|
||||
session_message_sender,
|
||||
conversations,
|
||||
known_agents,
|
||||
default_timezone,
|
||||
disabled_tools,
|
||||
task_config,
|
||||
subagent_runtime: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_subagent_runtime(
|
||||
mut self,
|
||||
runtime: Arc<dyn SubAgentRuntime>,
|
||||
) -> Self {
|
||||
self.subagent_runtime = Some(runtime);
|
||||
self
|
||||
}
|
||||
|
||||
fn is_enabled(&self, tool_name: &str) -> bool {
|
||||
!self.disabled_tools.contains(tool_name)
|
||||
}
|
||||
@ -126,74 +108,6 @@ impl ToolRegistryFactory {
|
||||
registry.register(WebFetchTool::new(50_000, 30));
|
||||
}
|
||||
|
||||
// 注册 Task 工具(如果启用且有 subagent_runtime)
|
||||
if self.is_enabled("task") && self.task_config.enabled {
|
||||
if let Some(runtime) = &self.subagent_runtime {
|
||||
registry.register(TaskTool::new(runtime.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
/// 构建子代理专用工具集(不包含 task 工具防止递归)
|
||||
pub(crate) fn build_subagent_tools(&self) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
|
||||
// 基础工具
|
||||
if self.is_enabled("calculator") {
|
||||
registry.register(CalculatorTool::new());
|
||||
}
|
||||
if self.is_enabled("get_time") {
|
||||
registry.register(TimeTool::new(self.default_timezone.clone()));
|
||||
}
|
||||
if self.is_enabled("file_read") {
|
||||
registry.register(FileReadTool::new());
|
||||
}
|
||||
if self.is_enabled("file_write") {
|
||||
registry.register(FileWriteTool::new());
|
||||
}
|
||||
if self.is_enabled("file_edit") {
|
||||
registry.register(FileEditTool::new());
|
||||
}
|
||||
if self.is_enabled("bash") {
|
||||
registry.register(BashTool::new());
|
||||
}
|
||||
if self.is_enabled("http_request") {
|
||||
registry.register(HttpRequestTool::new(
|
||||
vec!["*".to_string()],
|
||||
1_000_000,
|
||||
30,
|
||||
false,
|
||||
));
|
||||
}
|
||||
if self.is_enabled("web_fetch") {
|
||||
registry.register(WebFetchTool::new(50_000, 30));
|
||||
}
|
||||
|
||||
// 记忆工具(只读)
|
||||
if self.is_enabled("memory_search") {
|
||||
registry.register(MemorySearchTool::new(self.memories.clone()));
|
||||
}
|
||||
|
||||
// Skill 工具
|
||||
if self.is_enabled("skill_activate") {
|
||||
registry.register(SkillActivateTool::new(
|
||||
self.skills.clone(),
|
||||
self.skill_events.clone(),
|
||||
));
|
||||
}
|
||||
if self.is_enabled("skill_list") {
|
||||
registry.register(SkillListTool::new(self.skills.clone()));
|
||||
}
|
||||
|
||||
// 进度通知工具
|
||||
if self.is_enabled("session_send") {
|
||||
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
|
||||
}
|
||||
|
||||
// 注意:不注册 task 工具,防止递归创建子代理
|
||||
|
||||
registry
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,7 +12,6 @@ pub mod session_send;
|
||||
pub mod schema;
|
||||
pub mod skill_activate;
|
||||
pub mod skill_manage;
|
||||
pub mod task;
|
||||
pub mod time;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
@ -34,10 +33,6 @@ pub use session_send::{
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use skill_activate::SkillActivateTool;
|
||||
pub use skill_manage::{SkillListTool, SkillManageTool};
|
||||
pub use task::{
|
||||
DefaultSubAgentRuntime, InMemoryTaskRepository, SubAgentRuntime, SubAgentRuntimeConfig,
|
||||
TaskError, TaskRepository, TaskTool,
|
||||
};
|
||||
pub use time::TimeTool;
|
||||
pub use traits::{Tool, ToolContext, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
@ -1,48 +0,0 @@
|
||||
use crate::storage::StorageError;
|
||||
|
||||
/// 任务错误类型
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TaskError {
|
||||
#[error("Task session not found: {0}")]
|
||||
SessionNotFound(String),
|
||||
|
||||
#[error("Invalid parent session")]
|
||||
InvalidParentSession,
|
||||
|
||||
#[error("Failed to create subagent: {0}")]
|
||||
AgentCreationFailed(String),
|
||||
|
||||
#[error("Execution failed: {0}")]
|
||||
ExecutionFailed(String),
|
||||
|
||||
#[error("Task execution timed out")]
|
||||
Timeout,
|
||||
|
||||
#[error("Repository error: {0}")]
|
||||
RepositoryError(#[from] StorageError),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Missing required context: {0}")]
|
||||
MissingContext(String),
|
||||
|
||||
#[error("Invalid arguments: {0}")]
|
||||
InvalidArguments(String),
|
||||
}
|
||||
|
||||
impl TaskError {
|
||||
pub fn as_status(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Timeout => "timeout",
|
||||
Self::SessionNotFound(_) => "failed",
|
||||
Self::InvalidParentSession => "failed",
|
||||
Self::AgentCreationFailed(_) => "failed",
|
||||
Self::ExecutionFailed(_) => "failed",
|
||||
Self::RepositoryError(_) => "failed",
|
||||
Self::SerializationError(_) => "failed",
|
||||
Self::MissingContext(_) => "failed",
|
||||
Self::InvalidArguments(_) => "failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,13 +0,0 @@
|
||||
pub mod error;
|
||||
pub mod prompt;
|
||||
pub mod repository;
|
||||
pub mod runtime;
|
||||
pub mod tool;
|
||||
pub mod types;
|
||||
|
||||
pub use error::TaskError;
|
||||
pub use prompt::SubagentPromptBuilder;
|
||||
pub use repository::{InMemoryTaskRepository, TaskRepository};
|
||||
pub use runtime::{DefaultSubAgentRuntime, SubAgentRuntime, SubAgentRuntimeConfig, StaticSystemPromptProvider};
|
||||
pub use tool::TaskTool;
|
||||
pub use types::{SubagentType, TaskDefinition, TaskHandle, TaskSession, TaskSessionState, TaskToolArgs, TaskToolResult};
|
||||
@ -1,80 +0,0 @@
|
||||
use super::types::SubagentType;
|
||||
|
||||
/// 子代理系统提示词构建器
|
||||
pub struct SubagentPromptBuilder;
|
||||
|
||||
impl SubagentPromptBuilder {
|
||||
/// 构建子代理系统提示词
|
||||
pub fn build(
|
||||
subagent_type: SubagentType,
|
||||
description: &str,
|
||||
_prompt: &str,
|
||||
) -> String {
|
||||
match subagent_type {
|
||||
SubagentType::General => Self::build_general_prompt(description),
|
||||
SubagentType::Explore => Self::build_explore_prompt(description),
|
||||
}
|
||||
}
|
||||
|
||||
/// 构建恢复任务的提示词
|
||||
pub fn build_resume_prompt(session_description: &str, additional_prompt: &str) -> String {
|
||||
format!(
|
||||
"你正在继续执行一个之前创建的子代理任务。\n\n\
|
||||
任务描述: {}\n\n\
|
||||
继续执行指令: {}\n\n\
|
||||
你应该:\n\
|
||||
1. 回顾之前的工作进度(如果已有历史)\n\
|
||||
2. 继续完成任务,不要偏离目标\n\
|
||||
3. 完成后给出简洁的总结\n\
|
||||
4. 不要尝试创建新的子代理任务\n\n\
|
||||
注意: 你在一个独立的执行上下文中,没有访问主对话历史的权限。",
|
||||
session_description, additional_prompt
|
||||
)
|
||||
}
|
||||
|
||||
fn build_general_prompt(description: &str) -> String {
|
||||
format!(
|
||||
"你是一个专注的子代理,正在执行一个独立任务。\n\n\
|
||||
任务描述: {}\n\n\
|
||||
你应该:\n\
|
||||
1. 专注于完成任务,不要偏离目标\n\
|
||||
2. 使用可用的工具进行必要操作\n\
|
||||
3. 完成后给出简洁的总结\n\
|
||||
4. 不要尝试创建新的子代理任务\n\n\
|
||||
注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。",
|
||||
description
|
||||
)
|
||||
}
|
||||
|
||||
fn build_explore_prompt(description: &str) -> String {
|
||||
format!(
|
||||
"你是一个只读探索代理,用于代码库探索和信息收集。\n\n\
|
||||
任务描述: {}\n\n\
|
||||
你应该:\n\
|
||||
1. 只使用只读工具进行探索\n\
|
||||
2. 专注于理解和收集信息\n\
|
||||
3. 不要进行任何写操作\n\
|
||||
4. 给出简洁的发现总结\n\n\
|
||||
注意: 你是一个只读代理,禁止执行任何修改操作。",
|
||||
description
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// 从子代理输出提取简洁摘要
|
||||
pub fn extract_summary(content: &str) -> String {
|
||||
// 取第一段或前 500 字符
|
||||
let first_paragraph = content
|
||||
.lines()
|
||||
.take_while(|line| !line.trim().is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
if first_paragraph.len() > 500 {
|
||||
first_paragraph.chars().take(500).collect()
|
||||
} else if first_paragraph.is_empty() {
|
||||
content.chars().take(200).collect()
|
||||
} else {
|
||||
first_paragraph
|
||||
}
|
||||
}
|
||||
@ -1,98 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::storage::StorageError;
|
||||
|
||||
use super::types::TaskSession;
|
||||
|
||||
/// 任务持久化接口
|
||||
#[async_trait]
|
||||
pub trait TaskRepository: Send + Sync + 'static {
|
||||
/// 保存任务会话
|
||||
async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError>;
|
||||
|
||||
/// 加载任务会话
|
||||
async fn load_task_session(&self, task_id: &str) -> Result<Option<TaskSession>, StorageError>;
|
||||
|
||||
/// 删除任务会话
|
||||
async fn delete_task_session(&self, task_id: &str) -> Result<bool, StorageError>;
|
||||
|
||||
/// 列出父会话的所有任务
|
||||
async fn list_tasks_for_session(
|
||||
&self,
|
||||
parent_session_id: &str,
|
||||
) -> Result<Vec<TaskSession>, StorageError>;
|
||||
|
||||
/// 清理过期任务(超过指定小时)
|
||||
async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result<usize, StorageError>;
|
||||
}
|
||||
|
||||
/// 内存实现(用于测试)
|
||||
pub struct InMemoryTaskRepository {
|
||||
sessions: RwLock<HashMap<String, TaskSession>>,
|
||||
}
|
||||
|
||||
impl InMemoryTaskRepository {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InMemoryTaskRepository {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TaskRepository for InMemoryTaskRepository {
|
||||
async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError> {
|
||||
self.sessions
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(session.id.clone(), session.clone());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn load_task_session(&self, task_id: &str) -> Result<Option<TaskSession>, StorageError> {
|
||||
Ok(self.sessions.read().unwrap().get(task_id).cloned())
|
||||
}
|
||||
|
||||
async fn delete_task_session(&self, task_id: &str) -> Result<bool, StorageError> {
|
||||
Ok(self.sessions.write().unwrap().remove(task_id).is_some())
|
||||
}
|
||||
|
||||
async fn list_tasks_for_session(
|
||||
&self,
|
||||
parent_session_id: &str,
|
||||
) -> Result<Vec<TaskSession>, StorageError> {
|
||||
Ok(self
|
||||
.sessions
|
||||
.read()
|
||||
.unwrap()
|
||||
.values()
|
||||
.filter(|s| s.parent_session_id == parent_session_id)
|
||||
.cloned()
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result<usize, StorageError> {
|
||||
let now = current_timestamp();
|
||||
let ttl_millis = ttl_hours * 3600 * 1000;
|
||||
let mut sessions = self.sessions.write().unwrap();
|
||||
let before = sessions.len();
|
||||
sessions.retain(|_, s| now - s.updated_at < ttl_millis as i64);
|
||||
Ok(before - sessions.len())
|
||||
}
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("system clock before unix epoch")
|
||||
.as_millis() as i64
|
||||
}
|
||||
@ -1,378 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::agent::{AgentLoop, AgentRuntimeConfig, SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::config::LLMProviderConfig;
|
||||
use crate::storage::ConversationRepository;
|
||||
use crate::tools::{ToolContext, ToolRegistry};
|
||||
|
||||
use super::error::TaskError;
|
||||
use super::prompt::{extract_summary, SubagentPromptBuilder};
|
||||
use super::repository::TaskRepository;
|
||||
use super::types::{SubagentType, TaskDefinition, TaskHandle, TaskSession, TaskSessionState, TaskToolResult};
|
||||
|
||||
/// 子代理运行时配置
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubAgentRuntimeConfig {
|
||||
/// 子代理可用的工具列表(白名单)
|
||||
pub allowed_tools: HashSet<String>,
|
||||
/// 最大执行时间(秒)
|
||||
pub max_execution_secs: u64,
|
||||
/// 探索类型的最大工具调用次数
|
||||
pub explore_max_tool_calls: usize,
|
||||
/// 任务 TTL(小时)
|
||||
pub ttl_hours: u64,
|
||||
}
|
||||
|
||||
impl Default for SubAgentRuntimeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed_tools: HashSet::from([
|
||||
"file_read".to_string(),
|
||||
"file_edit".to_string(),
|
||||
"file_write".to_string(),
|
||||
"bash".to_string(),
|
||||
"http_request".to_string(),
|
||||
"web_fetch".to_string(),
|
||||
"memory_search".to_string(),
|
||||
"get_time".to_string(),
|
||||
"calculator".to_string(),
|
||||
"skill_activate".to_string(),
|
||||
"skill_list".to_string(),
|
||||
"send_session_message".to_string(), // 用于进度通知
|
||||
]),
|
||||
max_execution_secs: 300, // 5分钟
|
||||
explore_max_tool_calls: 20,
|
||||
ttl_hours: 24,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 子代理运行时抽象接口
|
||||
#[async_trait]
|
||||
pub trait SubAgentRuntime: Send + Sync + 'static {
|
||||
/// 创建并执行子代理任务
|
||||
async fn spawn(
|
||||
&self,
|
||||
parent_context: &ToolContext,
|
||||
task: TaskDefinition,
|
||||
) -> Result<TaskToolResult, TaskError>;
|
||||
|
||||
/// 恢复现有任务
|
||||
async fn resume(
|
||||
&self,
|
||||
task_id: &str,
|
||||
parent_context: &ToolContext,
|
||||
additional_prompt: String,
|
||||
) -> Result<TaskToolResult, TaskError>;
|
||||
|
||||
/// 发送消息给子代理(支持中断或补充指令)
|
||||
async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>;
|
||||
|
||||
/// 清理过期任务
|
||||
async fn cleanup_expired(&self) -> Result<usize, TaskError>;
|
||||
}
|
||||
|
||||
/// 静态系统提示词提供者(用于子代理)
|
||||
pub struct StaticSystemPromptProvider {
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
impl StaticSystemPromptProvider {
|
||||
pub fn new(prompt: String) -> Self {
|
||||
Self { prompt }
|
||||
}
|
||||
}
|
||||
|
||||
impl SystemPromptProvider for StaticSystemPromptProvider {
|
||||
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
||||
Some(SystemPrompt {
|
||||
content: self.prompt.clone(),
|
||||
context: Some("subagent".to_string()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// 默认子代理运行时实现
|
||||
pub struct DefaultSubAgentRuntime {
|
||||
config: SubAgentRuntimeConfig,
|
||||
task_repository: Arc<dyn TaskRepository>,
|
||||
conversation_repository: Arc<dyn ConversationRepository>,
|
||||
subagent_tools: Arc<ToolRegistry>,
|
||||
provider_config: LLMProviderConfig,
|
||||
}
|
||||
|
||||
impl DefaultSubAgentRuntime {
|
||||
pub fn new(
|
||||
config: SubAgentRuntimeConfig,
|
||||
task_repository: Arc<dyn TaskRepository>,
|
||||
conversation_repository: Arc<dyn ConversationRepository>,
|
||||
subagent_tools: Arc<ToolRegistry>,
|
||||
provider_config: LLMProviderConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
task_repository,
|
||||
conversation_repository,
|
||||
subagent_tools,
|
||||
provider_config,
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建子代理实例
|
||||
fn create_subagent(
|
||||
&self,
|
||||
session: &TaskSession,
|
||||
system_prompt: String,
|
||||
) -> Result<AgentLoop, TaskError> {
|
||||
let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt));
|
||||
|
||||
AgentLoop::with_tools_and_system_prompt_provider(
|
||||
AgentRuntimeConfig::from(self.provider_config.clone()),
|
||||
self.subagent_tools.clone(),
|
||||
prompt_provider,
|
||||
None, // 子代理不需要 skill provider
|
||||
)
|
||||
.map(|agent| {
|
||||
agent.with_tool_context(ToolContext {
|
||||
channel_name: Some(session.parent_channel_name.clone()),
|
||||
sender_id: None,
|
||||
chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id
|
||||
session_id: Some(session.session_id.clone()), // 子代理自己的 session_id
|
||||
message_id: None,
|
||||
message_seq: None,
|
||||
subagent_description: Some(session.description.clone()),
|
||||
})
|
||||
})
|
||||
.map_err(|e| TaskError::AgentCreationFailed(e.to_string()))
|
||||
}
|
||||
|
||||
/// 执行任务(带超时控制)
|
||||
async fn execute_task(
|
||||
&self,
|
||||
agent: AgentLoop,
|
||||
session: &TaskSession,
|
||||
prompt: String,
|
||||
) -> Result<TaskToolResult, TaskError> {
|
||||
// 构建初始消息
|
||||
let history = vec![ChatMessage::user(prompt)];
|
||||
let system_prompt_context = SystemPromptContext {
|
||||
session_id: Some(session.session_id.clone()),
|
||||
chat_id: session.session_id.clone(),
|
||||
user_message_count: 1,
|
||||
};
|
||||
|
||||
// 设置超时
|
||||
let max_secs = if session.subagent_type == SubagentType::Explore {
|
||||
self.config.max_execution_secs / 2 // Explore 类型时间更短
|
||||
} else {
|
||||
self.config.max_execution_secs
|
||||
};
|
||||
let timeout_duration = Duration::from_secs(max_secs);
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
timeout_duration,
|
||||
agent.process(history, Some(&system_prompt_context)),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(process_result)) => {
|
||||
let final_message = process_result.final_response;
|
||||
Ok(TaskToolResult {
|
||||
status: "success".to_string(),
|
||||
summary: extract_summary(&final_message.content),
|
||||
output: final_message.content,
|
||||
task_id: session.id.clone(),
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
||||
Err(_) => Err(TaskError::Timeout),
|
||||
}
|
||||
}
|
||||
|
||||
/// 使用历史继续执行
|
||||
async fn execute_task_with_history(
|
||||
&self,
|
||||
agent: AgentLoop,
|
||||
session: &TaskSession,
|
||||
additional_prompt: String,
|
||||
) -> Result<TaskToolResult, TaskError> {
|
||||
// 加载历史 + 新消息
|
||||
let mut history = self
|
||||
.conversation_repository
|
||||
.load_messages(&session.session_id)
|
||||
.map_err(TaskError::RepositoryError)?;
|
||||
history.push(ChatMessage::user(additional_prompt));
|
||||
|
||||
let user_message_count = history.iter().filter(|m| m.role == "user").count();
|
||||
let system_prompt_context = SystemPromptContext {
|
||||
session_id: Some(session.session_id.clone()),
|
||||
chat_id: session.session_id.clone(),
|
||||
user_message_count,
|
||||
};
|
||||
|
||||
let timeout_duration = Duration::from_secs(self.config.max_execution_secs);
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
timeout_duration,
|
||||
agent.process(history, Some(&system_prompt_context)),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(process_result)) => {
|
||||
let final_message = process_result.final_response;
|
||||
Ok(TaskToolResult {
|
||||
status: "success".to_string(),
|
||||
summary: extract_summary(&final_message.content),
|
||||
output: final_message.content,
|
||||
task_id: session.id.clone(),
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
||||
Err(_) => Err(TaskError::Timeout),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SubAgentRuntime for DefaultSubAgentRuntime {
|
||||
async fn spawn(
|
||||
&self,
|
||||
parent_context: &ToolContext,
|
||||
task: TaskDefinition,
|
||||
) -> Result<TaskToolResult, TaskError> {
|
||||
// 1. 验证上下文
|
||||
let session_id = parent_context
|
||||
.session_id
|
||||
.clone()
|
||||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
||||
let chat_id = parent_context
|
||||
.chat_id
|
||||
.clone()
|
||||
.ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?;
|
||||
let channel_name = parent_context
|
||||
.channel_name
|
||||
.clone()
|
||||
.ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?;
|
||||
|
||||
// 2. 创建任务会话
|
||||
let session = TaskSession::new(
|
||||
session_id,
|
||||
chat_id,
|
||||
channel_name,
|
||||
task.description.clone(),
|
||||
task.subagent_type,
|
||||
);
|
||||
|
||||
// 3. 保存会话
|
||||
self.task_repository.save_task_session(&session).await?;
|
||||
|
||||
// 4. 构建子代理系统提示词
|
||||
let system_prompt = SubagentPromptBuilder::build(
|
||||
task.subagent_type,
|
||||
&task.description,
|
||||
&task.prompt,
|
||||
);
|
||||
|
||||
// 5. 创建子代理
|
||||
let agent = self.create_subagent(&session, system_prompt)?;
|
||||
|
||||
// 6. 执行任务
|
||||
let result = self
|
||||
.execute_task(agent, &session, task.prompt.clone())
|
||||
.await;
|
||||
|
||||
// 7. 更新会话状态并保存
|
||||
match result {
|
||||
Ok(tool_result) => {
|
||||
let mut session = session;
|
||||
session.mark_completed(tool_result.summary.clone());
|
||||
self.task_repository.save_task_session(&session).await?;
|
||||
Ok(tool_result)
|
||||
}
|
||||
Err(e) => {
|
||||
let mut session = session;
|
||||
let status = e.as_status();
|
||||
if status == "timeout" {
|
||||
session.mark_timeout();
|
||||
} else {
|
||||
session.mark_failed(e.to_string());
|
||||
}
|
||||
self.task_repository.save_task_session(&session).await?;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn resume(
|
||||
&self,
|
||||
task_id: &str,
|
||||
parent_context: &ToolContext,
|
||||
additional_prompt: String,
|
||||
) -> Result<TaskToolResult, TaskError> {
|
||||
// 1. 加载现有会话
|
||||
let session = self
|
||||
.task_repository
|
||||
.load_task_session(task_id)
|
||||
.await?
|
||||
.ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?;
|
||||
|
||||
// 2. 验证父会话匹配
|
||||
let parent_session_id = parent_context
|
||||
.session_id
|
||||
.clone()
|
||||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
||||
if session.parent_session_id != parent_session_id {
|
||||
return Err(TaskError::InvalidParentSession);
|
||||
}
|
||||
|
||||
// 3. 构建恢复提示词
|
||||
let system_prompt = SubagentPromptBuilder::build_resume_prompt(
|
||||
&session.description,
|
||||
&additional_prompt,
|
||||
);
|
||||
|
||||
// 4. 创建子代理
|
||||
let agent = self.create_subagent(&session, system_prompt)?;
|
||||
|
||||
// 5. 使用历史继续执行
|
||||
let result = self
|
||||
.execute_task_with_history(agent, &session, additional_prompt)
|
||||
.await;
|
||||
|
||||
// 6. 更新会话状态
|
||||
match result {
|
||||
Ok(tool_result) => {
|
||||
let mut session = session;
|
||||
session.mark_completed(tool_result.summary.clone());
|
||||
self.task_repository.save_task_session(&session).await?;
|
||||
Ok(tool_result)
|
||||
}
|
||||
Err(e) => {
|
||||
let mut session = session;
|
||||
session.mark_failed(e.to_string());
|
||||
self.task_repository.save_task_session(&session).await?;
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> {
|
||||
// TODO: 实现双向通信
|
||||
// 需要在 TaskSession 中添加 pending_messages 队列
|
||||
Err(TaskError::InvalidArguments("send_message not implemented yet".to_string()))
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&self) -> Result<usize, TaskError> {
|
||||
self.task_repository
|
||||
.cleanup_expired_tasks(self.config.ttl_hours)
|
||||
.await
|
||||
.map_err(TaskError::from)
|
||||
}
|
||||
}
|
||||
@ -1,153 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::tools::{Tool, ToolContext, ToolResult};
|
||||
|
||||
use super::runtime::SubAgentRuntime;
|
||||
use super::types::{TaskDefinition, TaskToolArgs};
|
||||
|
||||
/// Task 工具 - 创建和管理子代理
|
||||
pub struct TaskTool {
|
||||
runtime: Arc<dyn SubAgentRuntime>,
|
||||
}
|
||||
|
||||
impl TaskTool {
|
||||
pub fn new(runtime: Arc<dyn SubAgentRuntime>) -> Self {
|
||||
Self { runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for TaskTool {
|
||||
fn name(&self) -> &str {
|
||||
"task"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Launch a specialized subagent to handle complex, multi-step tasks. \
|
||||
Subagents run in isolated contexts and can work in parallel. \
|
||||
Use 'general' type for complex tasks, 'explore' type for read-only exploration. \
|
||||
You can resume a previous task by providing its task_id."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Short description (3-5 words) of what this task does",
|
||||
"maxLength": 50
|
||||
},
|
||||
"prompt": {
|
||||
"type": "string",
|
||||
"description": "Detailed instructions for the subagent to execute"
|
||||
},
|
||||
"subagent_type": {
|
||||
"type": "string",
|
||||
"enum": ["general", "explore"],
|
||||
"default": "general",
|
||||
"description": "Type of subagent: 'general' for complex multi-step tasks, 'explore' for read-only search/exploration"
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "Optional: Resume an existing task session by providing its task_id"
|
||||
}
|
||||
},
|
||||
"required": ["description", "prompt"]
|
||||
})
|
||||
}
|
||||
|
||||
fn read_only(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn exclusive(&self) -> bool {
|
||||
// Task 工具创建子代理,不应与其他工具并发执行
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
// Task 工具必须通过 execute_with_context 获取父会话信息
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"task tool requires tool context with session_id, chat_id, and channel_name"
|
||||
.to_string(),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute_with_context(
|
||||
&self,
|
||||
context: &ToolContext,
|
||||
args: serde_json::Value,
|
||||
) -> anyhow::Result<ToolResult> {
|
||||
// 1. 解析参数
|
||||
let task_args: TaskToolArgs = serde_json::from_value(args.clone())
|
||||
.map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?;
|
||||
|
||||
// 2. 验证描述长度
|
||||
let word_count = task_args.description.split_whitespace().count();
|
||||
if task_args.description.len() > 50 || word_count > 7 || word_count < 1 {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(
|
||||
"description should be 1-5 words, max 50 characters".to_string(),
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
// 3. 验证上下文
|
||||
if context.session_id.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("task tool requires session_id in context".to_string()),
|
||||
});
|
||||
}
|
||||
if context.chat_id.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("task tool requires chat_id in context".to_string()),
|
||||
});
|
||||
}
|
||||
if context.channel_name.is_none() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("task tool requires channel_name in context".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
// 4. 执行任务
|
||||
let result = if let Some(task_id) = task_args.task_id {
|
||||
// 恢复现有任务
|
||||
self.runtime
|
||||
.resume(&task_id, context, task_args.prompt)
|
||||
.await
|
||||
} else {
|
||||
// 创建新任务
|
||||
let task_def = TaskDefinition::from(task_args);
|
||||
self.runtime.spawn(context, task_def).await
|
||||
};
|
||||
|
||||
// 5. 构建返回结果
|
||||
match result {
|
||||
Ok(task_result) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: serde_json::to_string(&task_result)?,
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,190 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 子代理会话状态
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TaskSessionState {
|
||||
/// 正在执行
|
||||
Running,
|
||||
/// 已完成
|
||||
Completed,
|
||||
/// 已失败
|
||||
Failed,
|
||||
/// 已超时
|
||||
Timeout,
|
||||
}
|
||||
|
||||
impl Default for TaskSessionState {
|
||||
fn default() -> Self {
|
||||
Self::Running
|
||||
}
|
||||
}
|
||||
|
||||
/// 子代理类型
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SubagentType {
|
||||
/// 通用型 - 处理复杂多步骤任务
|
||||
#[default]
|
||||
General,
|
||||
/// 探索型 - 只读搜索代理
|
||||
Explore,
|
||||
}
|
||||
|
||||
impl SubagentType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::General => "general",
|
||||
Self::Explore => "explore",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"general" => Some(Self::General),
|
||||
"explore" => Some(Self::Explore),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 任务会话记录
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskSession {
|
||||
/// 任务唯一 ID (UUID)
|
||||
pub id: String,
|
||||
/// 子代理独立的 session_id(存储在 message 表)
|
||||
pub session_id: String,
|
||||
/// 父会话 ID (用于关联)
|
||||
pub parent_session_id: String,
|
||||
/// 父 chat_id
|
||||
pub parent_chat_id: String,
|
||||
/// 父 channel_name
|
||||
pub parent_channel_name: String,
|
||||
/// 任务描述
|
||||
pub description: String,
|
||||
/// 子代理类型
|
||||
pub subagent_type: SubagentType,
|
||||
/// 当前状态
|
||||
pub state: TaskSessionState,
|
||||
/// 创建时间
|
||||
pub created_at: i64,
|
||||
/// 最后更新时间
|
||||
pub updated_at: i64,
|
||||
/// 执行摘要
|
||||
pub summary: Option<String>,
|
||||
/// 错误信息
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl TaskSession {
|
||||
pub fn new(
|
||||
parent_session_id: String,
|
||||
parent_chat_id: String,
|
||||
parent_channel_name: String,
|
||||
description: String,
|
||||
subagent_type: SubagentType,
|
||||
) -> Self {
|
||||
let id = format!("task:{}", uuid::Uuid::new_v4());
|
||||
let session_id = format!("sub:{}:{}", parent_session_id, id);
|
||||
let now = current_timestamp();
|
||||
Self {
|
||||
id,
|
||||
session_id,
|
||||
parent_session_id,
|
||||
parent_chat_id,
|
||||
parent_channel_name,
|
||||
description,
|
||||
subagent_type,
|
||||
state: TaskSessionState::Running,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
summary: None,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 标记完成
|
||||
pub fn mark_completed(&mut self, summary: String) {
|
||||
self.state = TaskSessionState::Completed;
|
||||
self.summary = Some(summary);
|
||||
self.updated_at = current_timestamp();
|
||||
}
|
||||
|
||||
/// 标记失败
|
||||
pub fn mark_failed(&mut self, error: String) {
|
||||
self.state = TaskSessionState::Failed;
|
||||
self.error = Some(error);
|
||||
self.updated_at = current_timestamp();
|
||||
}
|
||||
|
||||
/// 标记超时
|
||||
pub fn mark_timeout(&mut self) {
|
||||
self.state = TaskSessionState::Timeout;
|
||||
self.error = Some("Task execution timed out".to_string());
|
||||
self.updated_at = current_timestamp();
|
||||
}
|
||||
}
|
||||
|
||||
/// 任务工具参数
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TaskToolArgs {
|
||||
/// 简短描述(3-5词)
|
||||
pub description: String,
|
||||
/// 详细指令
|
||||
pub prompt: String,
|
||||
/// 子代理类型
|
||||
#[serde(default)]
|
||||
pub subagent_type: SubagentType,
|
||||
/// 恢复现有会话的 task_id
|
||||
#[serde(default)]
|
||||
pub task_id: Option<String>,
|
||||
}
|
||||
|
||||
/// 任务定义(用于 SubAgentRuntime::spawn)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskDefinition {
|
||||
pub description: String,
|
||||
pub prompt: String,
|
||||
pub subagent_type: SubagentType,
|
||||
pub max_execution_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl From<TaskToolArgs> for TaskDefinition {
|
||||
fn from(args: TaskToolArgs) -> Self {
|
||||
Self {
|
||||
description: args.description,
|
||||
prompt: args.prompt,
|
||||
subagent_type: args.subagent_type,
|
||||
max_execution_secs: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 任务句柄(运行中任务)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TaskHandle {
|
||||
pub task_id: String,
|
||||
pub session_id: String,
|
||||
pub status: TaskSessionState,
|
||||
}
|
||||
|
||||
/// 任务执行结果
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct TaskToolResult {
|
||||
/// 状态: success/failed/timeout
|
||||
pub status: String,
|
||||
/// 任务完成总结
|
||||
pub summary: String,
|
||||
/// 详细输出
|
||||
pub output: String,
|
||||
/// 会话 ID(用于恢复)
|
||||
pub task_id: String,
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("system clock before unix epoch")
|
||||
.as_millis() as i64
|
||||
}
|
||||
@ -15,8 +15,6 @@ pub struct ToolContext {
|
||||
pub session_id: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub message_seq: Option<i64>,
|
||||
/// 子代理标识,用于标注消息来源
|
||||
pub subagent_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user