Compare commits
No commits in common. "bee1a39a063fee846fead0b8578dca167fc4e950" and "8edc7ef9050949536cd8d154cf25934b489aef3d" have entirely different histories.
bee1a39a06
...
8edc7ef905
@ -100,63 +100,15 @@ impl Default for SkillsConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ToolsConfig {
|
pub struct ToolsConfig {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub disabled: Vec<String>,
|
pub disabled: Vec<String>,
|
||||||
#[serde(default)]
|
|
||||||
pub task: TaskConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
impl Default for ToolsConfig {
|
||||||
pub struct TaskConfig {
|
|
||||||
#[serde(default = "default_task_enabled")]
|
|
||||||
pub enabled: bool,
|
|
||||||
#[serde(default = "default_task_max_execution_secs")]
|
|
||||||
pub max_execution_secs: u64,
|
|
||||||
#[serde(default = "default_task_ttl_hours")]
|
|
||||||
pub ttl_hours: u64,
|
|
||||||
#[serde(default = "default_task_allowed_tools")]
|
|
||||||
pub allowed_tools: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_task_enabled() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_task_max_execution_secs() -> u64 {
|
|
||||||
300
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_task_ttl_hours() -> u64 {
|
|
||||||
24
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_task_allowed_tools() -> Vec<String> {
|
|
||||||
vec![
|
|
||||||
"file_read".to_string(),
|
|
||||||
"file_edit".to_string(),
|
|
||||||
"file_write".to_string(),
|
|
||||||
"bash".to_string(),
|
|
||||||
"http_request".to_string(),
|
|
||||||
"web_fetch".to_string(),
|
|
||||||
"memory_search".to_string(),
|
|
||||||
"get_time".to_string(),
|
|
||||||
"calculator".to_string(),
|
|
||||||
"skill_activate".to_string(),
|
|
||||||
"skill_list".to_string(),
|
|
||||||
"send_session_message".to_string(),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for TaskConfig {
|
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self { disabled: Vec::new() }
|
||||||
enabled: default_task_enabled(),
|
|
||||||
max_execution_secs: default_task_max_execution_secs(),
|
|
||||||
ttl_hours: default_task_ttl_hours(),
|
|
||||||
allowed_tools: default_task_allowed_tools(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -71,7 +71,6 @@ impl AgentFactory {
|
|||||||
session_id: Some(session_id),
|
session_id: Some(session_id),
|
||||||
message_id: request.message_id.map(str::to_string),
|
message_id: request.message_id.map(str::to_string),
|
||||||
message_seq: None,
|
message_seq: None,
|
||||||
subagent_description: None,
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -81,7 +81,6 @@ impl GatewayState {
|
|||||||
skills,
|
skills,
|
||||||
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
Arc::new(BusSessionMessageSender::new(bus.clone())),
|
||||||
std::collections::HashSet::new(),
|
std::collections::HashSet::new(),
|
||||||
config.tools.task.clone(),
|
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)?;
|
)?;
|
||||||
|
|||||||
@ -2,17 +2,14 @@ use std::collections::{HashMap, HashSet};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::agent::AgentError;
|
use crate::agent::AgentError;
|
||||||
use crate::config::{LLMProviderConfig, TaskConfig};
|
use crate::config::LLMProviderConfig;
|
||||||
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
use crate::gateway::tool_registry_factory::ToolRegistryFactory;
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{
|
use crate::storage::{
|
||||||
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
|
||||||
SessionStore, SkillEventRepository,
|
SessionStore, SkillEventRepository,
|
||||||
};
|
};
|
||||||
use crate::tools::{
|
use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry};
|
||||||
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
|
|
||||||
SessionMessageSender, SubAgentRuntime, SubAgentRuntimeConfig, ToolRegistry,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::agent_factory::AgentFactory;
|
use super::agent_factory::AgentFactory;
|
||||||
use super::cli_session::CliSessionService;
|
use super::cli_session::CliSessionService;
|
||||||
@ -32,7 +29,6 @@ pub(crate) fn build_session_manager(
|
|||||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<SessionManager, AgentError> {
|
) -> Result<SessionManager, AgentError> {
|
||||||
@ -45,7 +41,6 @@ pub(crate) fn build_session_manager(
|
|||||||
skills,
|
skills,
|
||||||
Arc::new(NoopSessionMessageSender),
|
Arc::new(NoopSessionMessageSender),
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)
|
)
|
||||||
@ -60,7 +55,6 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<SessionManager, AgentError> {
|
) -> Result<SessionManager, AgentError> {
|
||||||
@ -80,49 +74,20 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
let memories: Arc<dyn MemoryRepository> = store.clone();
|
let memories: Arc<dyn MemoryRepository> = store.clone();
|
||||||
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
|
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
|
||||||
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
|
||||||
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
let tools = Arc::new(
|
||||||
|
ToolRegistryFactory::new(
|
||||||
// 创建 ToolRegistryFactory
|
|
||||||
let factory = ToolRegistryFactory::new(
|
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
memories,
|
memories,
|
||||||
scheduler_jobs,
|
scheduler_jobs,
|
||||||
skill_events.clone(),
|
skill_events.clone(),
|
||||||
session_message_sender.clone(),
|
session_message_sender,
|
||||||
conversations.clone(),
|
|
||||||
known_agents,
|
known_agents,
|
||||||
default_timezone,
|
default_timezone,
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config.clone(),
|
)
|
||||||
|
.build(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// 创建 SubAgentRuntime(如果 task 工具启用)
|
|
||||||
let factory = if task_config.enabled {
|
|
||||||
let task_repository = Arc::new(InMemoryTaskRepository::new());
|
|
||||||
let subagent_tools = Arc::new(factory.build_subagent_tools());
|
|
||||||
|
|
||||||
let runtime_config = SubAgentRuntimeConfig {
|
|
||||||
allowed_tools: task_config.allowed_tools.iter().cloned().collect(),
|
|
||||||
max_execution_secs: task_config.max_execution_secs,
|
|
||||||
explore_max_tool_calls: 20,
|
|
||||||
ttl_hours: task_config.ttl_hours,
|
|
||||||
};
|
|
||||||
|
|
||||||
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
|
|
||||||
runtime_config,
|
|
||||||
task_repository,
|
|
||||||
conversations.clone(),
|
|
||||||
subagent_tools,
|
|
||||||
provider_config.clone(),
|
|
||||||
));
|
|
||||||
|
|
||||||
factory.with_subagent_runtime(subagent_runtime)
|
|
||||||
} else {
|
|
||||||
factory
|
|
||||||
};
|
|
||||||
|
|
||||||
let tools = Arc::new(factory.build());
|
|
||||||
|
|
||||||
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
|
let prompt_repository: Arc<dyn PromptInjectionRepository> = store.clone();
|
||||||
let agent_factory = AgentFactory::new(
|
let agent_factory = AgentFactory::new(
|
||||||
tools.clone(),
|
tools.clone(),
|
||||||
@ -130,6 +95,7 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
agent_prompt_reinject_every as usize,
|
agent_prompt_reinject_every as usize,
|
||||||
prompt_repository.clone(),
|
prompt_repository.clone(),
|
||||||
);
|
);
|
||||||
|
let conversations: Arc<dyn ConversationRepository> = store.clone();
|
||||||
let session_factory = SessionFactory::new(
|
let session_factory = SessionFactory::new(
|
||||||
provider_config.clone(),
|
provider_config.clone(),
|
||||||
skills.clone(),
|
skills.clone(),
|
||||||
|
|||||||
@ -485,7 +485,6 @@ impl SessionManager {
|
|||||||
provider_configs: HashMap<String, LLMProviderConfig>,
|
provider_configs: HashMap<String, LLMProviderConfig>,
|
||||||
skills: Arc<SkillRuntime>,
|
skills: Arc<SkillRuntime>,
|
||||||
disabled_tools: std::collections::HashSet<String>,
|
disabled_tools: std::collections::HashSet<String>,
|
||||||
task_config: crate::config::TaskConfig,
|
|
||||||
chat_history_ttl_hours: Option<u64>,
|
chat_history_ttl_hours: Option<u64>,
|
||||||
session_ttl_hours: Option<u64>,
|
session_ttl_hours: Option<u64>,
|
||||||
) -> Result<Self, AgentError> {
|
) -> Result<Self, AgentError> {
|
||||||
@ -497,7 +496,6 @@ impl SessionManager {
|
|||||||
provider_configs,
|
provider_configs,
|
||||||
skills,
|
skills,
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
|
||||||
chat_history_ttl_hours,
|
chat_history_ttl_hours,
|
||||||
session_ttl_hours,
|
session_ttl_hours,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::config::TaskConfig;
|
|
||||||
use crate::skills::SkillRuntime;
|
use crate::skills::SkillRuntime;
|
||||||
use crate::storage::{ConversationRepository, MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
|
||||||
use crate::tools::{
|
use crate::tools::{
|
||||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
|
||||||
HttpRequestTool, MemoryManageTool, MemorySearchTool,
|
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender,
|
||||||
SchedulerManageTool, SessionMessageSender, SessionSendTool, SkillActivateTool, SkillListTool,
|
SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TimeTool, ToolRegistry,
|
||||||
SkillManageTool, SubAgentRuntime, TaskTool, TimeTool,
|
WebFetchTool,
|
||||||
ToolRegistry, WebFetchTool,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) struct ToolRegistryFactory {
|
pub(crate) struct ToolRegistryFactory {
|
||||||
@ -18,12 +16,9 @@ pub(crate) struct ToolRegistryFactory {
|
|||||||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||||
conversations: Arc<dyn ConversationRepository>,
|
|
||||||
known_agents: HashSet<String>,
|
known_agents: HashSet<String>,
|
||||||
default_timezone: String,
|
default_timezone: String,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
|
||||||
subagent_runtime: Option<Arc<dyn SubAgentRuntime>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolRegistryFactory {
|
impl ToolRegistryFactory {
|
||||||
@ -33,11 +28,9 @@ impl ToolRegistryFactory {
|
|||||||
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
scheduler_jobs: Arc<dyn SchedulerJobRepository>,
|
||||||
skill_events: Arc<dyn SkillEventRepository>,
|
skill_events: Arc<dyn SkillEventRepository>,
|
||||||
session_message_sender: Arc<dyn SessionMessageSender>,
|
session_message_sender: Arc<dyn SessionMessageSender>,
|
||||||
conversations: Arc<dyn ConversationRepository>,
|
|
||||||
known_agents: HashSet<String>,
|
known_agents: HashSet<String>,
|
||||||
default_timezone: String,
|
default_timezone: String,
|
||||||
disabled_tools: HashSet<String>,
|
disabled_tools: HashSet<String>,
|
||||||
task_config: TaskConfig,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
skills,
|
skills,
|
||||||
@ -45,23 +38,12 @@ impl ToolRegistryFactory {
|
|||||||
scheduler_jobs,
|
scheduler_jobs,
|
||||||
skill_events,
|
skill_events,
|
||||||
session_message_sender,
|
session_message_sender,
|
||||||
conversations,
|
|
||||||
known_agents,
|
known_agents,
|
||||||
default_timezone,
|
default_timezone,
|
||||||
disabled_tools,
|
disabled_tools,
|
||||||
task_config,
|
|
||||||
subagent_runtime: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn with_subagent_runtime(
|
|
||||||
mut self,
|
|
||||||
runtime: Arc<dyn SubAgentRuntime>,
|
|
||||||
) -> Self {
|
|
||||||
self.subagent_runtime = Some(runtime);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_enabled(&self, tool_name: &str) -> bool {
|
fn is_enabled(&self, tool_name: &str) -> bool {
|
||||||
!self.disabled_tools.contains(tool_name)
|
!self.disabled_tools.contains(tool_name)
|
||||||
}
|
}
|
||||||
@ -126,74 +108,6 @@ impl ToolRegistryFactory {
|
|||||||
registry.register(WebFetchTool::new(50_000, 30));
|
registry.register(WebFetchTool::new(50_000, 30));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注册 Task 工具(如果启用且有 subagent_runtime)
|
|
||||||
if self.is_enabled("task") && self.task_config.enabled {
|
|
||||||
if let Some(runtime) = &self.subagent_runtime {
|
|
||||||
registry.register(TaskTool::new(runtime.clone()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
registry
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 构建子代理专用工具集(不包含 task 工具防止递归)
|
|
||||||
pub(crate) fn build_subagent_tools(&self) -> ToolRegistry {
|
|
||||||
let mut registry = ToolRegistry::new();
|
|
||||||
|
|
||||||
// 基础工具
|
|
||||||
if self.is_enabled("calculator") {
|
|
||||||
registry.register(CalculatorTool::new());
|
|
||||||
}
|
|
||||||
if self.is_enabled("get_time") {
|
|
||||||
registry.register(TimeTool::new(self.default_timezone.clone()));
|
|
||||||
}
|
|
||||||
if self.is_enabled("file_read") {
|
|
||||||
registry.register(FileReadTool::new());
|
|
||||||
}
|
|
||||||
if self.is_enabled("file_write") {
|
|
||||||
registry.register(FileWriteTool::new());
|
|
||||||
}
|
|
||||||
if self.is_enabled("file_edit") {
|
|
||||||
registry.register(FileEditTool::new());
|
|
||||||
}
|
|
||||||
if self.is_enabled("bash") {
|
|
||||||
registry.register(BashTool::new());
|
|
||||||
}
|
|
||||||
if self.is_enabled("http_request") {
|
|
||||||
registry.register(HttpRequestTool::new(
|
|
||||||
vec!["*".to_string()],
|
|
||||||
1_000_000,
|
|
||||||
30,
|
|
||||||
false,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if self.is_enabled("web_fetch") {
|
|
||||||
registry.register(WebFetchTool::new(50_000, 30));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 记忆工具(只读)
|
|
||||||
if self.is_enabled("memory_search") {
|
|
||||||
registry.register(MemorySearchTool::new(self.memories.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skill 工具
|
|
||||||
if self.is_enabled("skill_activate") {
|
|
||||||
registry.register(SkillActivateTool::new(
|
|
||||||
self.skills.clone(),
|
|
||||||
self.skill_events.clone(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if self.is_enabled("skill_list") {
|
|
||||||
registry.register(SkillListTool::new(self.skills.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 进度通知工具
|
|
||||||
if self.is_enabled("session_send") {
|
|
||||||
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注意:不注册 task 工具,防止递归创建子代理
|
|
||||||
|
|
||||||
registry
|
registry
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,7 +12,6 @@ pub mod session_send;
|
|||||||
pub mod schema;
|
pub mod schema;
|
||||||
pub mod skill_activate;
|
pub mod skill_activate;
|
||||||
pub mod skill_manage;
|
pub mod skill_manage;
|
||||||
pub mod task;
|
|
||||||
pub mod time;
|
pub mod time;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
pub mod web_fetch;
|
pub mod web_fetch;
|
||||||
@ -34,10 +33,6 @@ pub use session_send::{
|
|||||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use skill_activate::SkillActivateTool;
|
pub use skill_activate::SkillActivateTool;
|
||||||
pub use skill_manage::{SkillListTool, SkillManageTool};
|
pub use skill_manage::{SkillListTool, SkillManageTool};
|
||||||
pub use task::{
|
|
||||||
DefaultSubAgentRuntime, InMemoryTaskRepository, SubAgentRuntime, SubAgentRuntimeConfig,
|
|
||||||
TaskError, TaskRepository, TaskTool,
|
|
||||||
};
|
|
||||||
pub use time::TimeTool;
|
pub use time::TimeTool;
|
||||||
pub use traits::{Tool, ToolContext, ToolResult};
|
pub use traits::{Tool, ToolContext, ToolResult};
|
||||||
pub use web_fetch::WebFetchTool;
|
pub use web_fetch::WebFetchTool;
|
||||||
|
|||||||
@ -1,48 +0,0 @@
|
|||||||
use crate::storage::StorageError;
|
|
||||||
|
|
||||||
/// 任务错误类型
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
|
||||||
pub enum TaskError {
|
|
||||||
#[error("Task session not found: {0}")]
|
|
||||||
SessionNotFound(String),
|
|
||||||
|
|
||||||
#[error("Invalid parent session")]
|
|
||||||
InvalidParentSession,
|
|
||||||
|
|
||||||
#[error("Failed to create subagent: {0}")]
|
|
||||||
AgentCreationFailed(String),
|
|
||||||
|
|
||||||
#[error("Execution failed: {0}")]
|
|
||||||
ExecutionFailed(String),
|
|
||||||
|
|
||||||
#[error("Task execution timed out")]
|
|
||||||
Timeout,
|
|
||||||
|
|
||||||
#[error("Repository error: {0}")]
|
|
||||||
RepositoryError(#[from] StorageError),
|
|
||||||
|
|
||||||
#[error("Serialization error: {0}")]
|
|
||||||
SerializationError(#[from] serde_json::Error),
|
|
||||||
|
|
||||||
#[error("Missing required context: {0}")]
|
|
||||||
MissingContext(String),
|
|
||||||
|
|
||||||
#[error("Invalid arguments: {0}")]
|
|
||||||
InvalidArguments(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TaskError {
|
|
||||||
pub fn as_status(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Timeout => "timeout",
|
|
||||||
Self::SessionNotFound(_) => "failed",
|
|
||||||
Self::InvalidParentSession => "failed",
|
|
||||||
Self::AgentCreationFailed(_) => "failed",
|
|
||||||
Self::ExecutionFailed(_) => "failed",
|
|
||||||
Self::RepositoryError(_) => "failed",
|
|
||||||
Self::SerializationError(_) => "failed",
|
|
||||||
Self::MissingContext(_) => "failed",
|
|
||||||
Self::InvalidArguments(_) => "failed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
pub mod error;
|
|
||||||
pub mod prompt;
|
|
||||||
pub mod repository;
|
|
||||||
pub mod runtime;
|
|
||||||
pub mod tool;
|
|
||||||
pub mod types;
|
|
||||||
|
|
||||||
pub use error::TaskError;
|
|
||||||
pub use prompt::SubagentPromptBuilder;
|
|
||||||
pub use repository::{InMemoryTaskRepository, TaskRepository};
|
|
||||||
pub use runtime::{DefaultSubAgentRuntime, SubAgentRuntime, SubAgentRuntimeConfig, StaticSystemPromptProvider};
|
|
||||||
pub use tool::TaskTool;
|
|
||||||
pub use types::{SubagentType, TaskDefinition, TaskHandle, TaskSession, TaskSessionState, TaskToolArgs, TaskToolResult};
|
|
||||||
@ -1,80 +0,0 @@
|
|||||||
use super::types::SubagentType;
|
|
||||||
|
|
||||||
/// 子代理系统提示词构建器
|
|
||||||
pub struct SubagentPromptBuilder;
|
|
||||||
|
|
||||||
impl SubagentPromptBuilder {
|
|
||||||
/// 构建子代理系统提示词
|
|
||||||
pub fn build(
|
|
||||||
subagent_type: SubagentType,
|
|
||||||
description: &str,
|
|
||||||
_prompt: &str,
|
|
||||||
) -> String {
|
|
||||||
match subagent_type {
|
|
||||||
SubagentType::General => Self::build_general_prompt(description),
|
|
||||||
SubagentType::Explore => Self::build_explore_prompt(description),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 构建恢复任务的提示词
|
|
||||||
pub fn build_resume_prompt(session_description: &str, additional_prompt: &str) -> String {
|
|
||||||
format!(
|
|
||||||
"你正在继续执行一个之前创建的子代理任务。\n\n\
|
|
||||||
任务描述: {}\n\n\
|
|
||||||
继续执行指令: {}\n\n\
|
|
||||||
你应该:\n\
|
|
||||||
1. 回顾之前的工作进度(如果已有历史)\n\
|
|
||||||
2. 继续完成任务,不要偏离目标\n\
|
|
||||||
3. 完成后给出简洁的总结\n\
|
|
||||||
4. 不要尝试创建新的子代理任务\n\n\
|
|
||||||
注意: 你在一个独立的执行上下文中,没有访问主对话历史的权限。",
|
|
||||||
session_description, additional_prompt
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_general_prompt(description: &str) -> String {
|
|
||||||
format!(
|
|
||||||
"你是一个专注的子代理,正在执行一个独立任务。\n\n\
|
|
||||||
任务描述: {}\n\n\
|
|
||||||
你应该:\n\
|
|
||||||
1. 专注于完成任务,不要偏离目标\n\
|
|
||||||
2. 使用可用的工具进行必要操作\n\
|
|
||||||
3. 完成后给出简洁的总结\n\
|
|
||||||
4. 不要尝试创建新的子代理任务\n\n\
|
|
||||||
注意: 你没有访问主对话历史的权限,这是一个独立的执行上下文。",
|
|
||||||
description
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_explore_prompt(description: &str) -> String {
|
|
||||||
format!(
|
|
||||||
"你是一个只读探索代理,用于代码库探索和信息收集。\n\n\
|
|
||||||
任务描述: {}\n\n\
|
|
||||||
你应该:\n\
|
|
||||||
1. 只使用只读工具进行探索\n\
|
|
||||||
2. 专注于理解和收集信息\n\
|
|
||||||
3. 不要进行任何写操作\n\
|
|
||||||
4. 给出简洁的发现总结\n\n\
|
|
||||||
注意: 你是一个只读代理,禁止执行任何修改操作。",
|
|
||||||
description
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 从子代理输出提取简洁摘要
|
|
||||||
pub fn extract_summary(content: &str) -> String {
|
|
||||||
// 取第一段或前 500 字符
|
|
||||||
let first_paragraph = content
|
|
||||||
.lines()
|
|
||||||
.take_while(|line| !line.trim().is_empty())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
if first_paragraph.len() > 500 {
|
|
||||||
first_paragraph.chars().take(500).collect()
|
|
||||||
} else if first_paragraph.is_empty() {
|
|
||||||
content.chars().take(200).collect()
|
|
||||||
} else {
|
|
||||||
first_paragraph
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::RwLock;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
use crate::storage::StorageError;
|
|
||||||
|
|
||||||
use super::types::TaskSession;
|
|
||||||
|
|
||||||
/// 任务持久化接口
|
|
||||||
#[async_trait]
|
|
||||||
pub trait TaskRepository: Send + Sync + 'static {
|
|
||||||
/// 保存任务会话
|
|
||||||
async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError>;
|
|
||||||
|
|
||||||
/// 加载任务会话
|
|
||||||
async fn load_task_session(&self, task_id: &str) -> Result<Option<TaskSession>, StorageError>;
|
|
||||||
|
|
||||||
/// 删除任务会话
|
|
||||||
async fn delete_task_session(&self, task_id: &str) -> Result<bool, StorageError>;
|
|
||||||
|
|
||||||
/// 列出父会话的所有任务
|
|
||||||
async fn list_tasks_for_session(
|
|
||||||
&self,
|
|
||||||
parent_session_id: &str,
|
|
||||||
) -> Result<Vec<TaskSession>, StorageError>;
|
|
||||||
|
|
||||||
/// 清理过期任务(超过指定小时)
|
|
||||||
async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result<usize, StorageError>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 内存实现(用于测试)
|
|
||||||
pub struct InMemoryTaskRepository {
|
|
||||||
sessions: RwLock<HashMap<String, TaskSession>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl InMemoryTaskRepository {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
sessions: RwLock::new(HashMap::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for InMemoryTaskRepository {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl TaskRepository for InMemoryTaskRepository {
|
|
||||||
async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError> {
|
|
||||||
self.sessions
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.insert(session.id.clone(), session.clone());
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn load_task_session(&self, task_id: &str) -> Result<Option<TaskSession>, StorageError> {
|
|
||||||
Ok(self.sessions.read().unwrap().get(task_id).cloned())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn delete_task_session(&self, task_id: &str) -> Result<bool, StorageError> {
|
|
||||||
Ok(self.sessions.write().unwrap().remove(task_id).is_some())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_tasks_for_session(
|
|
||||||
&self,
|
|
||||||
parent_session_id: &str,
|
|
||||||
) -> Result<Vec<TaskSession>, StorageError> {
|
|
||||||
Ok(self
|
|
||||||
.sessions
|
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.values()
|
|
||||||
.filter(|s| s.parent_session_id == parent_session_id)
|
|
||||||
.cloned()
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result<usize, StorageError> {
|
|
||||||
let now = current_timestamp();
|
|
||||||
let ttl_millis = ttl_hours * 3600 * 1000;
|
|
||||||
let mut sessions = self.sessions.write().unwrap();
|
|
||||||
let before = sessions.len();
|
|
||||||
sessions.retain(|_, s| now - s.updated_at < ttl_millis as i64);
|
|
||||||
Ok(before - sessions.len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn current_timestamp() -> i64 {
|
|
||||||
std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.expect("system clock before unix epoch")
|
|
||||||
.as_millis() as i64
|
|
||||||
}
|
|
||||||
@ -1,378 +0,0 @@
|
|||||||
use std::collections::HashSet;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
use crate::agent::{AgentLoop, AgentRuntimeConfig, SystemPrompt, SystemPromptContext, SystemPromptProvider};
|
|
||||||
use crate::bus::ChatMessage;
|
|
||||||
use crate::config::LLMProviderConfig;
|
|
||||||
use crate::storage::ConversationRepository;
|
|
||||||
use crate::tools::{ToolContext, ToolRegistry};
|
|
||||||
|
|
||||||
use super::error::TaskError;
|
|
||||||
use super::prompt::{extract_summary, SubagentPromptBuilder};
|
|
||||||
use super::repository::TaskRepository;
|
|
||||||
use super::types::{SubagentType, TaskDefinition, TaskHandle, TaskSession, TaskSessionState, TaskToolResult};
|
|
||||||
|
|
||||||
/// 子代理运行时配置
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct SubAgentRuntimeConfig {
|
|
||||||
/// 子代理可用的工具列表(白名单)
|
|
||||||
pub allowed_tools: HashSet<String>,
|
|
||||||
/// 最大执行时间(秒)
|
|
||||||
pub max_execution_secs: u64,
|
|
||||||
/// 探索类型的最大工具调用次数
|
|
||||||
pub explore_max_tool_calls: usize,
|
|
||||||
/// 任务 TTL(小时)
|
|
||||||
pub ttl_hours: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for SubAgentRuntimeConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
allowed_tools: HashSet::from([
|
|
||||||
"file_read".to_string(),
|
|
||||||
"file_edit".to_string(),
|
|
||||||
"file_write".to_string(),
|
|
||||||
"bash".to_string(),
|
|
||||||
"http_request".to_string(),
|
|
||||||
"web_fetch".to_string(),
|
|
||||||
"memory_search".to_string(),
|
|
||||||
"get_time".to_string(),
|
|
||||||
"calculator".to_string(),
|
|
||||||
"skill_activate".to_string(),
|
|
||||||
"skill_list".to_string(),
|
|
||||||
"send_session_message".to_string(), // 用于进度通知
|
|
||||||
]),
|
|
||||||
max_execution_secs: 300, // 5分钟
|
|
||||||
explore_max_tool_calls: 20,
|
|
||||||
ttl_hours: 24,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 子代理运行时抽象接口
|
|
||||||
#[async_trait]
|
|
||||||
pub trait SubAgentRuntime: Send + Sync + 'static {
|
|
||||||
/// 创建并执行子代理任务
|
|
||||||
async fn spawn(
|
|
||||||
&self,
|
|
||||||
parent_context: &ToolContext,
|
|
||||||
task: TaskDefinition,
|
|
||||||
) -> Result<TaskToolResult, TaskError>;
|
|
||||||
|
|
||||||
/// 恢复现有任务
|
|
||||||
async fn resume(
|
|
||||||
&self,
|
|
||||||
task_id: &str,
|
|
||||||
parent_context: &ToolContext,
|
|
||||||
additional_prompt: String,
|
|
||||||
) -> Result<TaskToolResult, TaskError>;
|
|
||||||
|
|
||||||
/// 发送消息给子代理(支持中断或补充指令)
|
|
||||||
async fn send_message(&self, task_id: &str, message: String) -> Result<(), TaskError>;
|
|
||||||
|
|
||||||
/// 清理过期任务
|
|
||||||
async fn cleanup_expired(&self) -> Result<usize, TaskError>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 静态系统提示词提供者(用于子代理)
|
|
||||||
pub struct StaticSystemPromptProvider {
|
|
||||||
prompt: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StaticSystemPromptProvider {
|
|
||||||
pub fn new(prompt: String) -> Self {
|
|
||||||
Self { prompt }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SystemPromptProvider for StaticSystemPromptProvider {
|
|
||||||
fn build(&self, _context: &SystemPromptContext) -> Option<SystemPrompt> {
|
|
||||||
Some(SystemPrompt {
|
|
||||||
content: self.prompt.clone(),
|
|
||||||
context: Some("subagent".to_string()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 默认子代理运行时实现
|
|
||||||
pub struct DefaultSubAgentRuntime {
|
|
||||||
config: SubAgentRuntimeConfig,
|
|
||||||
task_repository: Arc<dyn TaskRepository>,
|
|
||||||
conversation_repository: Arc<dyn ConversationRepository>,
|
|
||||||
subagent_tools: Arc<ToolRegistry>,
|
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DefaultSubAgentRuntime {
|
|
||||||
pub fn new(
|
|
||||||
config: SubAgentRuntimeConfig,
|
|
||||||
task_repository: Arc<dyn TaskRepository>,
|
|
||||||
conversation_repository: Arc<dyn ConversationRepository>,
|
|
||||||
subagent_tools: Arc<ToolRegistry>,
|
|
||||||
provider_config: LLMProviderConfig,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
config,
|
|
||||||
task_repository,
|
|
||||||
conversation_repository,
|
|
||||||
subagent_tools,
|
|
||||||
provider_config,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 创建子代理实例
|
|
||||||
fn create_subagent(
|
|
||||||
&self,
|
|
||||||
session: &TaskSession,
|
|
||||||
system_prompt: String,
|
|
||||||
) -> Result<AgentLoop, TaskError> {
|
|
||||||
let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt));
|
|
||||||
|
|
||||||
AgentLoop::with_tools_and_system_prompt_provider(
|
|
||||||
AgentRuntimeConfig::from(self.provider_config.clone()),
|
|
||||||
self.subagent_tools.clone(),
|
|
||||||
prompt_provider,
|
|
||||||
None, // 子代理不需要 skill provider
|
|
||||||
)
|
|
||||||
.map(|agent| {
|
|
||||||
agent.with_tool_context(ToolContext {
|
|
||||||
channel_name: Some(session.parent_channel_name.clone()),
|
|
||||||
sender_id: None,
|
|
||||||
chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id
|
|
||||||
session_id: Some(session.session_id.clone()), // 子代理自己的 session_id
|
|
||||||
message_id: None,
|
|
||||||
message_seq: None,
|
|
||||||
subagent_description: Some(session.description.clone()),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.map_err(|e| TaskError::AgentCreationFailed(e.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 执行任务(带超时控制)
|
|
||||||
async fn execute_task(
|
|
||||||
&self,
|
|
||||||
agent: AgentLoop,
|
|
||||||
session: &TaskSession,
|
|
||||||
prompt: String,
|
|
||||||
) -> Result<TaskToolResult, TaskError> {
|
|
||||||
// 构建初始消息
|
|
||||||
let history = vec![ChatMessage::user(prompt)];
|
|
||||||
let system_prompt_context = SystemPromptContext {
|
|
||||||
session_id: Some(session.session_id.clone()),
|
|
||||||
chat_id: session.session_id.clone(),
|
|
||||||
user_message_count: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
// 设置超时
|
|
||||||
let max_secs = if session.subagent_type == SubagentType::Explore {
|
|
||||||
self.config.max_execution_secs / 2 // Explore 类型时间更短
|
|
||||||
} else {
|
|
||||||
self.config.max_execution_secs
|
|
||||||
};
|
|
||||||
let timeout_duration = Duration::from_secs(max_secs);
|
|
||||||
|
|
||||||
let result = tokio::time::timeout(
|
|
||||||
timeout_duration,
|
|
||||||
agent.process(history, Some(&system_prompt_context)),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(Ok(process_result)) => {
|
|
||||||
let final_message = process_result.final_response;
|
|
||||||
Ok(TaskToolResult {
|
|
||||||
status: "success".to_string(),
|
|
||||||
summary: extract_summary(&final_message.content),
|
|
||||||
output: final_message.content,
|
|
||||||
task_id: session.id.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
|
||||||
Err(_) => Err(TaskError::Timeout),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 使用历史继续执行
|
|
||||||
async fn execute_task_with_history(
|
|
||||||
&self,
|
|
||||||
agent: AgentLoop,
|
|
||||||
session: &TaskSession,
|
|
||||||
additional_prompt: String,
|
|
||||||
) -> Result<TaskToolResult, TaskError> {
|
|
||||||
// 加载历史 + 新消息
|
|
||||||
let mut history = self
|
|
||||||
.conversation_repository
|
|
||||||
.load_messages(&session.session_id)
|
|
||||||
.map_err(TaskError::RepositoryError)?;
|
|
||||||
history.push(ChatMessage::user(additional_prompt));
|
|
||||||
|
|
||||||
let user_message_count = history.iter().filter(|m| m.role == "user").count();
|
|
||||||
let system_prompt_context = SystemPromptContext {
|
|
||||||
session_id: Some(session.session_id.clone()),
|
|
||||||
chat_id: session.session_id.clone(),
|
|
||||||
user_message_count,
|
|
||||||
};
|
|
||||||
|
|
||||||
let timeout_duration = Duration::from_secs(self.config.max_execution_secs);
|
|
||||||
|
|
||||||
let result = tokio::time::timeout(
|
|
||||||
timeout_duration,
|
|
||||||
agent.process(history, Some(&system_prompt_context)),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(Ok(process_result)) => {
|
|
||||||
let final_message = process_result.final_response;
|
|
||||||
Ok(TaskToolResult {
|
|
||||||
status: "success".to_string(),
|
|
||||||
summary: extract_summary(&final_message.content),
|
|
||||||
output: final_message.content,
|
|
||||||
task_id: session.id.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
|
|
||||||
Err(_) => Err(TaskError::Timeout),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl SubAgentRuntime for DefaultSubAgentRuntime {
|
|
||||||
async fn spawn(
|
|
||||||
&self,
|
|
||||||
parent_context: &ToolContext,
|
|
||||||
task: TaskDefinition,
|
|
||||||
) -> Result<TaskToolResult, TaskError> {
|
|
||||||
// 1. 验证上下文
|
|
||||||
let session_id = parent_context
|
|
||||||
.session_id
|
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
|
||||||
let chat_id = parent_context
|
|
||||||
.chat_id
|
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?;
|
|
||||||
let channel_name = parent_context
|
|
||||||
.channel_name
|
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?;
|
|
||||||
|
|
||||||
// 2. 创建任务会话
|
|
||||||
let session = TaskSession::new(
|
|
||||||
session_id,
|
|
||||||
chat_id,
|
|
||||||
channel_name,
|
|
||||||
task.description.clone(),
|
|
||||||
task.subagent_type,
|
|
||||||
);
|
|
||||||
|
|
||||||
// 3. 保存会话
|
|
||||||
self.task_repository.save_task_session(&session).await?;
|
|
||||||
|
|
||||||
// 4. 构建子代理系统提示词
|
|
||||||
let system_prompt = SubagentPromptBuilder::build(
|
|
||||||
task.subagent_type,
|
|
||||||
&task.description,
|
|
||||||
&task.prompt,
|
|
||||||
);
|
|
||||||
|
|
||||||
// 5. 创建子代理
|
|
||||||
let agent = self.create_subagent(&session, system_prompt)?;
|
|
||||||
|
|
||||||
// 6. 执行任务
|
|
||||||
let result = self
|
|
||||||
.execute_task(agent, &session, task.prompt.clone())
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// 7. 更新会话状态并保存
|
|
||||||
match result {
|
|
||||||
Ok(tool_result) => {
|
|
||||||
let mut session = session;
|
|
||||||
session.mark_completed(tool_result.summary.clone());
|
|
||||||
self.task_repository.save_task_session(&session).await?;
|
|
||||||
Ok(tool_result)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let mut session = session;
|
|
||||||
let status = e.as_status();
|
|
||||||
if status == "timeout" {
|
|
||||||
session.mark_timeout();
|
|
||||||
} else {
|
|
||||||
session.mark_failed(e.to_string());
|
|
||||||
}
|
|
||||||
self.task_repository.save_task_session(&session).await?;
|
|
||||||
Err(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn resume(
|
|
||||||
&self,
|
|
||||||
task_id: &str,
|
|
||||||
parent_context: &ToolContext,
|
|
||||||
additional_prompt: String,
|
|
||||||
) -> Result<TaskToolResult, TaskError> {
|
|
||||||
// 1. 加载现有会话
|
|
||||||
let session = self
|
|
||||||
.task_repository
|
|
||||||
.load_task_session(task_id)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?;
|
|
||||||
|
|
||||||
// 2. 验证父会话匹配
|
|
||||||
let parent_session_id = parent_context
|
|
||||||
.session_id
|
|
||||||
.clone()
|
|
||||||
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
|
|
||||||
if session.parent_session_id != parent_session_id {
|
|
||||||
return Err(TaskError::InvalidParentSession);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 构建恢复提示词
|
|
||||||
let system_prompt = SubagentPromptBuilder::build_resume_prompt(
|
|
||||||
&session.description,
|
|
||||||
&additional_prompt,
|
|
||||||
);
|
|
||||||
|
|
||||||
// 4. 创建子代理
|
|
||||||
let agent = self.create_subagent(&session, system_prompt)?;
|
|
||||||
|
|
||||||
// 5. 使用历史继续执行
|
|
||||||
let result = self
|
|
||||||
.execute_task_with_history(agent, &session, additional_prompt)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// 6. 更新会话状态
|
|
||||||
match result {
|
|
||||||
Ok(tool_result) => {
|
|
||||||
let mut session = session;
|
|
||||||
session.mark_completed(tool_result.summary.clone());
|
|
||||||
self.task_repository.save_task_session(&session).await?;
|
|
||||||
Ok(tool_result)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let mut session = session;
|
|
||||||
session.mark_failed(e.to_string());
|
|
||||||
self.task_repository.save_task_session(&session).await?;
|
|
||||||
Err(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_message(&self, _task_id: &str, _message: String) -> Result<(), TaskError> {
|
|
||||||
// TODO: 实现双向通信
|
|
||||||
// 需要在 TaskSession 中添加 pending_messages 队列
|
|
||||||
Err(TaskError::InvalidArguments("send_message not implemented yet".to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cleanup_expired(&self) -> Result<usize, TaskError> {
|
|
||||||
self.task_repository
|
|
||||||
.cleanup_expired_tasks(self.config.ttl_hours)
|
|
||||||
.await
|
|
||||||
.map_err(TaskError::from)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,153 +0,0 @@
|
|||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::json;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::tools::{Tool, ToolContext, ToolResult};
|
|
||||||
|
|
||||||
use super::runtime::SubAgentRuntime;
|
|
||||||
use super::types::{TaskDefinition, TaskToolArgs};
|
|
||||||
|
|
||||||
/// Task 工具 - 创建和管理子代理
|
|
||||||
pub struct TaskTool {
|
|
||||||
runtime: Arc<dyn SubAgentRuntime>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TaskTool {
|
|
||||||
pub fn new(runtime: Arc<dyn SubAgentRuntime>) -> Self {
|
|
||||||
Self { runtime }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for TaskTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"task"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Launch a specialized subagent to handle complex, multi-step tasks. \
|
|
||||||
Subagents run in isolated contexts and can work in parallel. \
|
|
||||||
Use 'general' type for complex tasks, 'explore' type for read-only exploration. \
|
|
||||||
You can resume a previous task by providing its task_id."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"description": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Short description (3-5 words) of what this task does",
|
|
||||||
"maxLength": 50
|
|
||||||
},
|
|
||||||
"prompt": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Detailed instructions for the subagent to execute"
|
|
||||||
},
|
|
||||||
"subagent_type": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["general", "explore"],
|
|
||||||
"default": "general",
|
|
||||||
"description": "Type of subagent: 'general' for complex multi-step tasks, 'explore' for read-only search/exploration"
|
|
||||||
},
|
|
||||||
"task_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Optional: Resume an existing task session by providing its task_id"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["description", "prompt"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_only(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn exclusive(&self) -> bool {
|
|
||||||
// Task 工具创建子代理,不应与其他工具并发执行
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
|
||||||
// Task 工具必须通过 execute_with_context 获取父会话信息
|
|
||||||
Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(
|
|
||||||
"task tool requires tool context with session_id, chat_id, and channel_name"
|
|
||||||
.to_string(),
|
|
||||||
),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute_with_context(
|
|
||||||
&self,
|
|
||||||
context: &ToolContext,
|
|
||||||
args: serde_json::Value,
|
|
||||||
) -> anyhow::Result<ToolResult> {
|
|
||||||
// 1. 解析参数
|
|
||||||
let task_args: TaskToolArgs = serde_json::from_value(args.clone())
|
|
||||||
.map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?;
|
|
||||||
|
|
||||||
// 2. 验证描述长度
|
|
||||||
let word_count = task_args.description.split_whitespace().count();
|
|
||||||
if task_args.description.len() > 50 || word_count > 7 || word_count < 1 {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(
|
|
||||||
"description should be 1-5 words, max 50 characters".to_string(),
|
|
||||||
),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 验证上下文
|
|
||||||
if context.session_id.is_none() {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("task tool requires session_id in context".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if context.chat_id.is_none() {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("task tool requires chat_id in context".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if context.channel_name.is_none() {
|
|
||||||
return Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some("task tool requires channel_name in context".to_string()),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 执行任务
|
|
||||||
let result = if let Some(task_id) = task_args.task_id {
|
|
||||||
// 恢复现有任务
|
|
||||||
self.runtime
|
|
||||||
.resume(&task_id, context, task_args.prompt)
|
|
||||||
.await
|
|
||||||
} else {
|
|
||||||
// 创建新任务
|
|
||||||
let task_def = TaskDefinition::from(task_args);
|
|
||||||
self.runtime.spawn(context, task_def).await
|
|
||||||
};
|
|
||||||
|
|
||||||
// 5. 构建返回结果
|
|
||||||
match result {
|
|
||||||
Ok(task_result) => Ok(ToolResult {
|
|
||||||
success: true,
|
|
||||||
output: serde_json::to_string(&task_result)?,
|
|
||||||
error: None,
|
|
||||||
}),
|
|
||||||
Err(e) => Ok(ToolResult {
|
|
||||||
success: false,
|
|
||||||
output: String::new(),
|
|
||||||
error: Some(e.to_string()),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,190 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
/// 子代理会话状态
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum TaskSessionState {
|
|
||||||
/// 正在执行
|
|
||||||
Running,
|
|
||||||
/// 已完成
|
|
||||||
Completed,
|
|
||||||
/// 已失败
|
|
||||||
Failed,
|
|
||||||
/// 已超时
|
|
||||||
Timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for TaskSessionState {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::Running
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 子代理类型
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
|
||||||
#[serde(rename_all = "lowercase")]
|
|
||||||
pub enum SubagentType {
|
|
||||||
/// 通用型 - 处理复杂多步骤任务
|
|
||||||
#[default]
|
|
||||||
General,
|
|
||||||
/// 探索型 - 只读搜索代理
|
|
||||||
Explore,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SubagentType {
|
|
||||||
pub fn as_str(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::General => "general",
|
|
||||||
Self::Explore => "explore",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_str(s: &str) -> Option<Self> {
|
|
||||||
match s {
|
|
||||||
"general" => Some(Self::General),
|
|
||||||
"explore" => Some(Self::Explore),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 任务会话记录
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct TaskSession {
|
|
||||||
/// 任务唯一 ID (UUID)
|
|
||||||
pub id: String,
|
|
||||||
/// 子代理独立的 session_id(存储在 message 表)
|
|
||||||
pub session_id: String,
|
|
||||||
/// 父会话 ID (用于关联)
|
|
||||||
pub parent_session_id: String,
|
|
||||||
/// 父 chat_id
|
|
||||||
pub parent_chat_id: String,
|
|
||||||
/// 父 channel_name
|
|
||||||
pub parent_channel_name: String,
|
|
||||||
/// 任务描述
|
|
||||||
pub description: String,
|
|
||||||
/// 子代理类型
|
|
||||||
pub subagent_type: SubagentType,
|
|
||||||
/// 当前状态
|
|
||||||
pub state: TaskSessionState,
|
|
||||||
/// 创建时间
|
|
||||||
pub created_at: i64,
|
|
||||||
/// 最后更新时间
|
|
||||||
pub updated_at: i64,
|
|
||||||
/// 执行摘要
|
|
||||||
pub summary: Option<String>,
|
|
||||||
/// 错误信息
|
|
||||||
pub error: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TaskSession {
|
|
||||||
pub fn new(
|
|
||||||
parent_session_id: String,
|
|
||||||
parent_chat_id: String,
|
|
||||||
parent_channel_name: String,
|
|
||||||
description: String,
|
|
||||||
subagent_type: SubagentType,
|
|
||||||
) -> Self {
|
|
||||||
let id = format!("task:{}", uuid::Uuid::new_v4());
|
|
||||||
let session_id = format!("sub:{}:{}", parent_session_id, id);
|
|
||||||
let now = current_timestamp();
|
|
||||||
Self {
|
|
||||||
id,
|
|
||||||
session_id,
|
|
||||||
parent_session_id,
|
|
||||||
parent_chat_id,
|
|
||||||
parent_channel_name,
|
|
||||||
description,
|
|
||||||
subagent_type,
|
|
||||||
state: TaskSessionState::Running,
|
|
||||||
created_at: now,
|
|
||||||
updated_at: now,
|
|
||||||
summary: None,
|
|
||||||
error: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 标记完成
|
|
||||||
pub fn mark_completed(&mut self, summary: String) {
|
|
||||||
self.state = TaskSessionState::Completed;
|
|
||||||
self.summary = Some(summary);
|
|
||||||
self.updated_at = current_timestamp();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 标记失败
|
|
||||||
pub fn mark_failed(&mut self, error: String) {
|
|
||||||
self.state = TaskSessionState::Failed;
|
|
||||||
self.error = Some(error);
|
|
||||||
self.updated_at = current_timestamp();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 标记超时
|
|
||||||
pub fn mark_timeout(&mut self) {
|
|
||||||
self.state = TaskSessionState::Timeout;
|
|
||||||
self.error = Some("Task execution timed out".to_string());
|
|
||||||
self.updated_at = current_timestamp();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 任务工具参数
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct TaskToolArgs {
|
|
||||||
/// 简短描述(3-5词)
|
|
||||||
pub description: String,
|
|
||||||
/// 详细指令
|
|
||||||
pub prompt: String,
|
|
||||||
/// 子代理类型
|
|
||||||
#[serde(default)]
|
|
||||||
pub subagent_type: SubagentType,
|
|
||||||
/// 恢复现有会话的 task_id
|
|
||||||
#[serde(default)]
|
|
||||||
pub task_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 任务定义(用于 SubAgentRuntime::spawn)
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TaskDefinition {
|
|
||||||
pub description: String,
|
|
||||||
pub prompt: String,
|
|
||||||
pub subagent_type: SubagentType,
|
|
||||||
pub max_execution_secs: Option<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<TaskToolArgs> for TaskDefinition {
|
|
||||||
fn from(args: TaskToolArgs) -> Self {
|
|
||||||
Self {
|
|
||||||
description: args.description,
|
|
||||||
prompt: args.prompt,
|
|
||||||
subagent_type: args.subagent_type,
|
|
||||||
max_execution_secs: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 任务句柄(运行中任务)
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct TaskHandle {
|
|
||||||
pub task_id: String,
|
|
||||||
pub session_id: String,
|
|
||||||
pub status: TaskSessionState,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 任务执行结果
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
pub struct TaskToolResult {
|
|
||||||
/// 状态: success/failed/timeout
|
|
||||||
pub status: String,
|
|
||||||
/// 任务完成总结
|
|
||||||
pub summary: String,
|
|
||||||
/// 详细输出
|
|
||||||
pub output: String,
|
|
||||||
/// 会话 ID(用于恢复)
|
|
||||||
pub task_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn current_timestamp() -> i64 {
|
|
||||||
std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.expect("system clock before unix epoch")
|
|
||||||
.as_millis() as i64
|
|
||||||
}
|
|
||||||
@ -15,8 +15,6 @@ pub struct ToolContext {
|
|||||||
pub session_id: Option<String>,
|
pub session_id: Option<String>,
|
||||||
pub message_id: Option<String>,
|
pub message_id: Option<String>,
|
||||||
pub message_seq: Option<i64>,
|
pub message_seq: Option<i64>,
|
||||||
/// 子代理标识,用于标注消息来源
|
|
||||||
pub subagent_description: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user