feat: 移除任务管理相关功能,简化工具配置和依赖

This commit is contained in:
ooodc 2026-05-16 09:08:40 +08:00
parent 020b7aa77a
commit 9bf57c1132
12 changed files with 4 additions and 1162 deletions

View File

@ -104,67 +104,11 @@ impl Default for SkillsConfig {
pub struct ToolsConfig {
#[serde(default)]
pub disabled: Vec<String>,
#[serde(default)]
pub task: TaskConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TaskConfig {
#[serde(default = "default_task_enabled")]
pub enabled: bool,
#[serde(default = "default_task_max_execution_secs")]
pub max_execution_secs: u64,
#[serde(default = "default_task_ttl_hours")]
pub ttl_hours: u64,
#[serde(default = "default_task_allowed_tools")]
pub allowed_tools: Vec<String>,
}
fn default_task_enabled() -> bool {
true
}
fn default_task_max_execution_secs() -> u64 {
300
}
fn default_task_ttl_hours() -> u64 {
24
}
fn default_task_allowed_tools() -> Vec<String> {
vec![
"file_read".to_string(),
"file_edit".to_string(),
"file_write".to_string(),
"bash".to_string(),
"http_request".to_string(),
"web_fetch".to_string(),
"memory_search".to_string(),
"get_time".to_string(),
"calculator".to_string(),
"skill_activate".to_string(),
"skill_list".to_string(),
]
}
impl Default for ToolsConfig {
fn default() -> Self {
Self {
disabled: Vec::new(),
task: TaskConfig::default(),
}
}
}
impl Default for TaskConfig {
fn default() -> Self {
Self {
enabled: default_task_enabled(),
max_execution_secs: default_task_max_execution_secs(),
ttl_hours: default_task_ttl_hours(),
allowed_tools: default_task_allowed_tools(),
}
Self { disabled: Vec::new() }
}
}

View File

@ -9,7 +9,7 @@ use crate::storage::{
ConversationRepository, MemoryRepository, PromptInjectionRepository, SchedulerJobRepository,
SessionStore, SkillEventRepository,
};
use crate::tools::{InMemoryTaskRepository, NoopSessionMessageSender, SessionMessageSender, ToolRegistry};
use crate::tools::{NoopSessionMessageSender, SessionMessageSender, ToolRegistry};
use super::agent_factory::AgentFactory;
use super::cli_session::CliSessionService;
@ -74,10 +74,6 @@ pub(crate) fn build_session_manager_with_sender(
let memories: Arc<dyn MemoryRepository> = store.clone();
let scheduler_jobs: Arc<dyn SchedulerJobRepository> = store.clone();
let skill_events: Arc<dyn SkillEventRepository> = store.clone();
// TaskRepository 使用内存存储(后续可以改为 SQLite
let task_repository = Arc::new(InMemoryTaskRepository::new());
let tools = Arc::new(
ToolRegistryFactory::new(
skills.clone(),
@ -89,7 +85,6 @@ pub(crate) fn build_session_manager_with_sender(
default_timezone,
disabled_tools,
)
.with_task_deps(task_repository, provider_config.clone())
.build(),
);

View File

@ -1,14 +1,13 @@
use std::collections::HashSet;
use std::sync::Arc;
use crate::config::LLMProviderConfig;
use crate::skills::SkillRuntime;
use crate::storage::{MemoryRepository, SchedulerJobRepository, SkillEventRepository};
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool, HttpRequestTool,
MemoryManageTool, MemorySearchTool, SchedulerManageTool, SessionMessageSender,
SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TaskRuntime,
TaskTool, TimeTool, ToolRegistry, WebFetchTool,
SessionSendTool, SkillActivateTool, SkillListTool, SkillManageTool, TimeTool, ToolRegistry,
WebFetchTool,
};
pub(crate) struct ToolRegistryFactory {
@ -20,9 +19,6 @@ pub(crate) struct ToolRegistryFactory {
known_agents: HashSet<String>,
default_timezone: String,
disabled_tools: HashSet<String>,
// Task 工具需要的依赖
task_repository: Option<Arc<dyn crate::tools::TaskRepository>>,
provider_config: Option<LLMProviderConfig>,
}
impl ToolRegistryFactory {
@ -45,22 +41,9 @@ impl ToolRegistryFactory {
known_agents,
default_timezone,
disabled_tools,
task_repository: None,
provider_config: None,
}
}
/// 设置 Task 工具的依赖
pub(crate) fn with_task_deps(
mut self,
task_repository: Arc<dyn crate::tools::TaskRepository>,
provider_config: LLMProviderConfig,
) -> Self {
self.task_repository = Some(task_repository);
self.provider_config = Some(provider_config);
self
}
fn is_enabled(&self, tool_name: &str) -> bool {
!self.disabled_tools.contains(tool_name)
}
@ -125,89 +108,6 @@ impl ToolRegistryFactory {
registry.register(WebFetchTool::new(50_000, 30));
}
// 注册 Task 工具(需要 TaskRuntime
if self.is_enabled("task") {
if let (Some(task_repository), Some(provider_config)) =
(&self.task_repository, &self.provider_config)
{
// 先创建一个临时的 ToolRegistry不含 task 工具)用于子代理
let subagent_tools = Arc::new(self.build_without_task());
let task_runtime = Arc::new(TaskRuntime::new(
crate::tools::TaskRuntimeConfig::default(),
task_repository.clone(),
subagent_tools,
provider_config.clone(),
));
registry.register(TaskTool::new(task_runtime));
}
}
registry
}
/// 构建不含 task 工具的注册表(用于子代理)
fn build_without_task(&self) -> ToolRegistry {
let mut registry = ToolRegistry::new();
if self.is_enabled("calculator") {
registry.register(CalculatorTool::new());
}
if self.is_enabled("get_time") {
registry.register(TimeTool::new(self.default_timezone.clone()));
}
if self.is_enabled("file_read") {
registry.register(FileReadTool::new());
}
if self.is_enabled("file_write") {
registry.register(FileWriteTool::new());
}
if self.is_enabled("file_edit") {
registry.register(FileEditTool::new());
}
if self.is_enabled("memory_search") {
registry.register(MemorySearchTool::new(self.memories.clone()));
}
if self.is_enabled("memory_manage") {
registry.register(MemoryManageTool::new(self.memories.clone()));
}
if self.is_enabled("session_send") {
registry.register(SessionSendTool::new(self.session_message_sender.clone()));
}
if self.is_enabled("scheduler_manage") {
registry.register(SchedulerManageTool::new(
self.scheduler_jobs.clone(),
self.known_agents.clone(),
));
}
if self.is_enabled("skill_activate") {
registry.register(SkillActivateTool::new(
self.skills.clone(),
self.skill_events.clone(),
));
}
if self.is_enabled("skill_list") {
registry.register(SkillListTool::new(self.skills.clone()));
}
if self.is_enabled("skill_manage") {
registry.register(SkillManageTool::new(self.skills.clone()));
}
if self.is_enabled("bash") {
registry.register(BashTool::new());
}
if self.is_enabled("http_request") {
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
}
if self.is_enabled("web_fetch") {
registry.register(WebFetchTool::new(50_000, 30));
}
// 注意:不注册 task 工具,防止子代理递归创建子代理
registry
}
}

View File

@ -12,7 +12,6 @@ pub mod session_send;
pub mod schema;
pub mod skill_activate;
pub mod skill_manage;
pub mod task;
pub mod time;
pub mod traits;
pub mod web_fetch;
@ -34,9 +33,6 @@ pub use session_send::{
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use skill_activate::SkillActivateTool;
pub use skill_manage::{SkillListTool, SkillManageTool};
pub use task::TaskTool;
pub use time::TimeTool;
pub use traits::{Tool, ToolContext, ToolResult};
pub use web_fetch::WebFetchTool;
pub(crate) use task::{TaskRepository, TaskRuntime, TaskRuntimeConfig, InMemoryTaskRepository};

View File

@ -1,48 +0,0 @@
use crate::storage::StorageError;
/// 任务错误类型
#[derive(Debug, thiserror::Error)]
pub enum TaskError {
#[error("Task session not found: {0}")]
SessionNotFound(String),
#[error("Invalid parent session")]
InvalidParentSession,
#[error("Failed to create subagent: {0}")]
AgentCreationFailed(String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
#[error("Task execution timed out")]
Timeout,
#[error("Repository error: {0}")]
RepositoryError(#[from] StorageError),
#[error("Serialization error: {0}")]
SerializationError(#[from] serde_json::Error),
#[error("Missing required context: {0}")]
MissingContext(String),
#[error("Invalid arguments: {0}")]
InvalidArguments(String),
}
impl TaskError {
pub fn as_status(&self) -> &'static str {
match self {
Self::Timeout => "timeout",
Self::SessionNotFound(_) => "failed",
Self::InvalidParentSession => "failed",
Self::AgentCreationFailed(_) => "failed",
Self::ExecutionFailed(_) => "failed",
Self::RepositoryError(_) => "failed",
Self::SerializationError(_) => "failed",
Self::MissingContext(_) => "failed",
Self::InvalidArguments(_) => "failed",
}
}
}

View File

@ -1,11 +0,0 @@
mod error;
mod prompt;
mod repository;
mod runtime;
mod tool;
mod types;
pub(crate) use repository::TaskRepository;
pub(crate) use repository::InMemoryTaskRepository;
pub(crate) use runtime::{TaskRuntime, TaskRuntimeConfig};
pub use tool::TaskTool;

View File

@ -1,125 +0,0 @@
use crate::bus::ChatMessage;
use super::types::SubagentType;
/// 子代理系统提示词构建器
pub struct SubagentPromptBuilder;
impl SubagentPromptBuilder {
/// 构建子代理系统提示词
pub fn build(
subagent_type: SubagentType,
description: &str,
_prompt: &str,
) -> String {
match subagent_type {
SubagentType::General => Self::build_general_prompt(description),
SubagentType::Explore => Self::build_explore_prompt(description),
}
}
/// 构建恢复任务的提示词
pub fn build_resume_prompt(session_description: &str, additional_prompt: &str) -> String {
format!(
"你正在继续执行一个之前创建的子代理任务。\n\n\
: {}\n\n\
: {}\n\n\
:\n\
1. \n\
2. \n\
3. \n\
4. \n\n\
: 访",
session_description, additional_prompt
)
}
/// 构建带上下文摘要的探索提示词
/// 预留功能:用于 Explore 类型子代理继承主代理上下文
#[allow(dead_code)]
pub fn build_explore_prompt_with_context(
description: &str,
parent_history: &[ChatMessage],
) -> String {
let context_summary = Self::extract_context_summary(parent_history);
format!(
"你是一个只读探索代理,用于代码库探索和信息收集。\n\n\
: {}\n\n\
:\n{}\n\n\
:\n\
1. 使file_read, bash \n\
2. \n\
3. \n\
4. \n\n\
: ",
description, context_summary
)
}
fn build_general_prompt(description: &str) -> String {
format!(
"你是一个专注的子代理,正在执行一个独立任务。\n\n\
: {}\n\n\
:\n\
1. \n\
2. 使\n\
3. \n\
4. \n\n\
: 访",
description
)
}
fn build_explore_prompt(description: &str) -> String {
format!(
"你是一个只读探索代理,用于代码库探索和信息收集。\n\n\
: {}\n\n\
:\n\
1. 使\n\
2. \n\
3. \n\
4. \n\n\
: ",
description
)
}
/// 提取最近用户消息的上下文摘要
/// 预留功能:用于 Explore 类型子代理继承主代理上下文
#[allow(dead_code)]
fn extract_context_summary(history: &[ChatMessage]) -> String {
history
.iter()
.filter(|m| m.role == "user")
.rev()
.take(5)
.map(|m| {
let content = &m.content;
if content.len() > 100 {
format!("- {}", content.chars().take(100).collect::<String>())
} else {
format!("- {}", content)
}
})
.collect::<Vec<_>>()
.join("\n")
}
}
/// 从子代理输出提取简洁摘要
pub fn extract_summary(content: &str) -> String {
// 取第一段或前 500 字符
let first_paragraph = content
.lines()
.take_while(|line| !line.trim().is_empty())
.collect::<Vec<_>>()
.join("\n");
if first_paragraph.len() > 500 {
first_paragraph.chars().take(500).collect()
} else if first_paragraph.is_empty() {
content.chars().take(200).collect()
} else {
first_paragraph
}
}

View File

@ -1,98 +0,0 @@
use std::collections::HashMap;
use std::sync::RwLock;
use async_trait::async_trait;
use crate::storage::StorageError;
use super::types::TaskSession;
/// 任务持久化接口
#[async_trait]
pub trait TaskRepository: Send + Sync + 'static {
/// 保存任务会话
async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError>;
/// 加载任务会话
async fn load_task_session(&self, task_id: &str) -> Result<Option<TaskSession>, StorageError>;
/// 删除任务会话
async fn delete_task_session(&self, task_id: &str) -> Result<bool, StorageError>;
/// 列出父会话的所有任务
async fn list_tasks_for_session(
&self,
parent_session_id: &str,
) -> Result<Vec<TaskSession>, StorageError>;
/// 清理过期任务(超过指定小时)
async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result<usize, StorageError>;
}
/// 内存实现(用于测试)
pub struct InMemoryTaskRepository {
sessions: RwLock<HashMap<String, TaskSession>>,
}
impl InMemoryTaskRepository {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryTaskRepository {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TaskRepository for InMemoryTaskRepository {
async fn save_task_session(&self, session: &TaskSession) -> Result<(), StorageError> {
self.sessions
.write()
.unwrap()
.insert(session.id.clone(), session.clone());
Ok(())
}
async fn load_task_session(&self, task_id: &str) -> Result<Option<TaskSession>, StorageError> {
Ok(self.sessions.read().unwrap().get(task_id).cloned())
}
async fn delete_task_session(&self, task_id: &str) -> Result<bool, StorageError> {
Ok(self.sessions.write().unwrap().remove(task_id).is_some())
}
async fn list_tasks_for_session(
&self,
parent_session_id: &str,
) -> Result<Vec<TaskSession>, StorageError> {
Ok(self
.sessions
.read()
.unwrap()
.values()
.filter(|s| s.parent_session_id == parent_session_id)
.cloned()
.collect())
}
async fn cleanup_expired_tasks(&self, ttl_hours: u64) -> Result<usize, StorageError> {
let now = current_timestamp();
let ttl_millis = ttl_hours * 3600 * 1000;
let mut sessions = self.sessions.write().unwrap();
let before = sessions.len();
sessions.retain(|_, s| now - s.updated_at < ttl_millis as i64);
Ok(before - sessions.len())
}
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_millis() as i64
}

View File

@ -1,370 +0,0 @@
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use crate::agent::{AgentLoop, SystemPromptContext, SystemPromptProvider};
use crate::bus::ChatMessage;
use crate::config::LLMProviderConfig;
use crate::tools::{ToolContext, ToolRegistry};
use super::error::TaskError;
use super::prompt::{extract_summary, SubagentPromptBuilder};
use super::repository::TaskRepository;
use super::types::{SubagentType, TaskSession, TaskToolResult};
/// 子代理运行时配置
#[derive(Debug, Clone)]
pub struct TaskRuntimeConfig {
/// 子代理可用的工具列表(白名单)
pub allowed_tools: HashSet<String>,
/// 最大执行时间(秒)
pub max_execution_secs: u64,
/// 探索类型的最大工具调用次数
pub explore_max_tool_calls: usize,
/// 任务 TTL小时
pub ttl_hours: u64,
}
impl Default for TaskRuntimeConfig {
fn default() -> Self {
Self {
allowed_tools: HashSet::from([
"file_read".to_string(),
"file_edit".to_string(),
"file_write".to_string(),
"bash".to_string(),
"http_request".to_string(),
"web_fetch".to_string(),
"memory_search".to_string(),
"get_time".to_string(),
"calculator".to_string(),
"skill_activate".to_string(),
"skill_list".to_string(),
]),
max_execution_secs: 300, // 5分钟
explore_max_tool_calls: 20,
ttl_hours: 24,
}
}
}
/// 静态系统提示词提供者
pub struct StaticSystemPromptProvider {
prompt: String,
}
impl StaticSystemPromptProvider {
pub fn new(prompt: String) -> Self {
Self { prompt }
}
}
impl SystemPromptProvider for StaticSystemPromptProvider {
fn build(&self, _context: &SystemPromptContext) -> Option<crate::agent::SystemPrompt> {
Some(crate::agent::SystemPrompt {
content: self.prompt.clone(),
context: Some("subagent".to_string()),
})
}
}
/// 子代理运行时管理器
pub struct TaskRuntime {
config: TaskRuntimeConfig,
repository: Arc<dyn TaskRepository>,
tools: Arc<ToolRegistry>,
provider_config: LLMProviderConfig,
}
impl TaskRuntime {
pub fn new(
config: TaskRuntimeConfig,
repository: Arc<dyn TaskRepository>,
tools: Arc<ToolRegistry>,
provider_config: LLMProviderConfig,
) -> Self {
Self {
config,
repository,
tools,
provider_config,
}
}
/// 获取子代理工具注册表(排除 task 工具防止递归)
fn get_subagent_tools(&self) -> Arc<ToolRegistry> {
// 创建一个新的工具注册表,只包含允许的工具
// 这里简化处理,直接使用传入的 tools假设 task 工具不会注册进去)
// 实际实现中需要过滤 allowed_tools
self.tools.clone()
}
/// 创建新的子代理会话并执行
pub async fn spawn(
&self,
parent_context: &ToolContext,
description: String,
prompt: String,
subagent_type: SubagentType,
) -> Result<TaskToolResult, TaskError> {
// 1. 验证上下文
let session_id = parent_context
.session_id
.clone()
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
let chat_id = parent_context
.chat_id
.clone()
.ok_or_else(|| TaskError::MissingContext("chat_id".to_string()))?;
let channel_name = parent_context
.channel_name
.clone()
.ok_or_else(|| TaskError::MissingContext("channel_name".to_string()))?;
// 2. 创建任务会话
let session = TaskSession::new(
session_id,
chat_id,
channel_name,
description.clone(),
subagent_type,
);
// 3. 保存会话
self.repository.save_task_session(&session).await?;
// 4. 构建子代理系统提示词
let system_prompt = SubagentPromptBuilder::build(subagent_type, &description, &prompt);
// 5. 创建子代理
let agent = self.create_subagent(&session, system_prompt)?;
// 6. 执行任务
let result = self.execute_task(agent, &session, prompt).await;
// 7. 更新会话状态并保存
match result {
Ok(tool_result) => {
let mut session = session;
session.mark_completed(tool_result.summary.clone());
self.repository.save_task_session(&session).await?;
Ok(tool_result)
}
Err(e) => {
let mut session = session;
let status = e.as_status();
if status == "timeout" {
session.mark_timeout();
} else {
session.mark_failed(e.to_string());
}
self.repository.save_task_session(&session).await?;
Err(e)
}
}
}
/// 恢复现有任务会话
pub async fn resume(
&self,
task_id: &str,
parent_context: &ToolContext,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 1. 加载现有会话
let session = self
.repository
.load_task_session(task_id)
.await?
.ok_or_else(|| TaskError::SessionNotFound(task_id.to_string()))?;
// 2. 验证父会话匹配
let parent_session_id = parent_context
.session_id
.clone()
.ok_or_else(|| TaskError::MissingContext("session_id".to_string()))?;
if session.parent_session_id != parent_session_id {
return Err(TaskError::InvalidParentSession);
}
// 3. 构建恢复提示词
let system_prompt = SubagentPromptBuilder::build_resume_prompt(
&session.description,
&additional_prompt,
);
// 4. 创建子代理
let agent = self.create_subagent(&session, system_prompt)?;
// 5. 使用历史继续执行
let result = self
.execute_task_with_history(agent, &session, additional_prompt)
.await;
// 6. 更新会话状态
match result {
Ok(tool_result) => {
let mut session = session;
session.mark_completed(tool_result.summary.clone());
self.repository.save_task_session(&session).await?;
Ok(tool_result)
}
Err(e) => {
let mut session = session;
session.mark_failed(e.to_string());
self.repository.save_task_session(&session).await?;
Err(e)
}
}
}
/// 创建子代理实例
fn create_subagent(
&self,
session: &TaskSession,
system_prompt: String,
) -> Result<AgentLoop, TaskError> {
let prompt_provider = Arc::new(StaticSystemPromptProvider::new(system_prompt));
// 获取子代理工具注册表
let subagent_tools = self.get_subagent_tools();
// 直接创建 AgentLoop使用自定义的提示词提供者
AgentLoop::with_tools_and_system_prompt_provider(
self.provider_config.clone(),
subagent_tools,
prompt_provider,
None, // 子代理不需要 skill provider
)
.map(|agent| {
agent.with_tool_context(ToolContext {
channel_name: Some(session.parent_channel_name.clone()),
sender_id: None,
chat_id: Some(session.parent_chat_id.clone()), // 使用父会话 chat_id发送到飞书用户
session_id: Some(session.id.clone()),
message_id: None,
message_seq: None,
subagent_description: Some(session.description.clone()), // 子代理标识
})
})
.map_err(|e| TaskError::AgentCreationFailed(e.to_string()))
}
/// 执行任务(带超时控制)
async fn execute_task(
&self,
agent: AgentLoop,
session: &TaskSession,
prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 构建初始消息
let history = vec![ChatMessage::user(prompt)];
let system_prompt_context = SystemPromptContext {
session_id: Some(session.id.clone()),
chat_id: session.id.clone(),
user_message_count: 1,
};
// 设置超时
let timeout_duration = Duration::from_secs(self.config.max_execution_secs);
let result = tokio::time::timeout(
timeout_duration,
agent.process(history, Some(&system_prompt_context)),
)
.await;
match result {
Ok(Ok(process_result)) => {
let final_message = process_result.final_response;
Ok(TaskToolResult {
status: "success".to_string(),
summary: extract_summary(&final_message.content),
output: final_message.content,
task_id: session.id.clone(),
})
}
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
Err(_) => Err(TaskError::Timeout),
}
}
/// 使用历史继续执行
async fn execute_task_with_history(
&self,
agent: AgentLoop,
session: &TaskSession,
additional_prompt: String,
) -> Result<TaskToolResult, TaskError> {
// 构建历史 + 新消息
let mut history = session.history.clone();
history.push(ChatMessage::user(additional_prompt));
let user_message_count = history.iter().filter(|m| m.role == "user").count();
let system_prompt_context = SystemPromptContext {
session_id: Some(session.id.clone()),
chat_id: session.id.clone(),
user_message_count,
};
let timeout_duration = Duration::from_secs(self.config.max_execution_secs);
let result = tokio::time::timeout(
timeout_duration,
agent.process(history, Some(&system_prompt_context)),
)
.await;
match result {
Ok(Ok(process_result)) => {
let final_message = process_result.final_response;
Ok(TaskToolResult {
status: "success".to_string(),
summary: extract_summary(&final_message.content),
output: final_message.content,
task_id: session.id.clone(),
})
}
Ok(Err(e)) => Err(TaskError::ExecutionFailed(e.to_string())),
Err(_) => Err(TaskError::Timeout),
}
}
/// 清理过期任务
pub async fn cleanup_expired(&self) -> Result<usize, TaskError> {
self.repository
.cleanup_expired_tasks(self.config.ttl_hours)
.await
.map_err(TaskError::from)
}
/// 创建用于测试的实例
#[cfg(test)]
pub fn new_for_test() -> Self {
use super::repository::InMemoryTaskRepository;
use std::collections::HashMap;
Self {
config: TaskRuntimeConfig::default(),
repository: Arc::new(InMemoryTaskRepository::new()),
tools: Arc::new(ToolRegistry::new()),
provider_config: LLMProviderConfig {
provider_type: "openai".to_string(),
name: "test".to_string(),
base_url: "https://test.local/v1".to_string(),
api_key: "test-key".to_string(),
extra_headers: HashMap::new(),
llm_timeout_secs: 120,
memory_maintenance_timeout_secs: 600,
model_id: "test-model".to_string(),
temperature: None,
max_tokens: None,
context_window_tokens: None,
model_extra: HashMap::new(),
max_tool_iterations: 100,
tool_result_max_chars: 20_000,
context_tool_result_trim_chars: 2_000,
},
}
}
}

View File

@ -1,170 +0,0 @@
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
use crate::tools::{Tool, ToolContext, ToolResult};
use super::runtime::TaskRuntime;
use super::types::TaskToolArgs;
/// Task 工具 - 创建和管理子代理
pub struct TaskTool {
runtime: Arc<TaskRuntime>,
}
impl TaskTool {
pub fn new(runtime: Arc<TaskRuntime>) -> Self {
Self { runtime }
}
}
#[async_trait]
impl Tool for TaskTool {
fn name(&self) -> &str {
"task"
}
fn description(&self) -> &str {
"Launch a specialized subagent to handle complex, multi-step tasks. \
Subagents run in isolated contexts and can work in parallel. \
Use 'general' type for complex tasks, 'explore' type for read-only exploration. \
You can resume a previous task by providing its task_id."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"description": {
"type": "string",
"description": "Short description (3-5 words) of what this task does",
"maxLength": 50
},
"prompt": {
"type": "string",
"description": "Detailed instructions for the subagent to execute"
},
"subagent_type": {
"type": "string",
"enum": ["general", "explore"],
"default": "general",
"description": "Type of subagent: 'general' for complex multi-step tasks, 'explore' for read-only search/exploration"
},
"task_id": {
"type": "string",
"description": "Optional: Resume an existing task session by providing its task_id"
}
},
"required": ["description", "prompt"]
})
}
fn read_only(&self) -> bool {
false
}
fn exclusive(&self) -> bool {
// Task 工具创建子代理,不应与其他工具并发执行
true
}
async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
// Task 工具必须通过 execute_with_context 获取父会话信息
Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"task tool requires tool context with session_id, chat_id, and channel_name"
.to_string(),
),
})
}
async fn execute_with_context(
&self,
context: &ToolContext,
args: serde_json::Value,
) -> anyhow::Result<ToolResult> {
// 1. 解析参数
let task_args: TaskToolArgs = serde_json::from_value(args.clone())
.map_err(|e| anyhow::anyhow!("invalid task arguments: {}", e))?;
// 2. 验证描述长度
let word_count = task_args.description.split_whitespace().count();
if task_args.description.len() > 50 || word_count > 7 || word_count < 1 {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"description should be 1-5 words, max 50 characters".to_string(),
),
});
}
// 3. 验证上下文
if context.session_id.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("task tool requires session_id in context".to_string()),
});
}
if context.chat_id.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("task tool requires chat_id in context".to_string()),
});
}
if context.channel_name.is_none() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("task tool requires channel_name in context".to_string()),
});
}
// 4. 执行任务
let result = if let Some(task_id) = task_args.task_id {
// 恢复现有任务
self.runtime
.resume(&task_id, context, task_args.prompt)
.await
} else {
// 创建新任务
self.runtime
.spawn(
context,
task_args.description,
task_args.prompt,
task_args.subagent_type,
)
.await
};
// 5. 构建返回结果
match result {
Ok(task_result) => Ok(ToolResult {
success: true,
output: serde_json::to_string(&task_result)?,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_tool_name_and_description() {
// 简单验证工具名称
assert!(!TaskTool::name(&TaskTool::new(Arc::new(TaskRuntime::new_for_test()))).is_empty());
}
}

View File

@ -1,169 +0,0 @@
use serde::{Deserialize, Serialize};
use crate::bus::ChatMessage;
/// 子代理会话状态
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TaskSessionState {
/// 正在执行
Running,
/// 已完成
Completed,
/// 已失败
Failed,
/// 已超时
Timeout,
}
impl Default for TaskSessionState {
fn default() -> Self {
Self::Running
}
}
/// 子代理类型
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SubagentType {
/// 通用型 - 处理复杂多步骤任务
#[default]
General,
/// 探索型 - 只读搜索代理
Explore,
}
impl SubagentType {
pub fn as_str(&self) -> &'static str {
match self {
Self::General => "general",
Self::Explore => "explore",
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s {
"general" => Some(Self::General),
"explore" => Some(Self::Explore),
_ => None,
}
}
}
/// 任务会话记录
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskSession {
/// 任务唯一 ID (UUID)
pub id: String,
/// 父会话 ID (用于关联)
pub parent_session_id: String,
/// 父 chat_id
pub parent_chat_id: String,
/// 父 channel_name
pub parent_channel_name: String,
/// 任务描述
pub description: String,
/// 子代理类型
pub subagent_type: SubagentType,
/// 当前状态
pub state: TaskSessionState,
/// 会话历史(子代理的对话)
#[serde(default)]
pub history: Vec<ChatMessage>,
/// 创建时间
pub created_at: i64,
/// 最后更新时间
pub updated_at: i64,
/// 执行摘要
pub summary: Option<String>,
/// 错误信息
pub error: Option<String>,
}
impl TaskSession {
pub fn new(
parent_session_id: String,
parent_chat_id: String,
parent_channel_name: String,
description: String,
subagent_type: SubagentType,
) -> Self {
let now = current_timestamp();
Self {
id: format!("task:{}", uuid::Uuid::new_v4()),
parent_session_id,
parent_chat_id,
parent_channel_name,
description,
subagent_type,
state: TaskSessionState::Running,
history: Vec::new(),
created_at: now,
updated_at: now,
summary: None,
error: None,
}
}
/// 添加消息到历史
pub fn add_message(&mut self, message: ChatMessage) {
self.history.push(message);
self.updated_at = current_timestamp();
}
/// 标记完成
pub fn mark_completed(&mut self, summary: String) {
self.state = TaskSessionState::Completed;
self.summary = Some(summary);
self.updated_at = current_timestamp();
}
/// 标记失败
pub fn mark_failed(&mut self, error: String) {
self.state = TaskSessionState::Failed;
self.error = Some(error);
self.updated_at = current_timestamp();
}
/// 标记超时
pub fn mark_timeout(&mut self) {
self.state = TaskSessionState::Timeout;
self.error = Some("Task execution timed out".to_string());
self.updated_at = current_timestamp();
}
}
/// 任务工具参数
#[derive(Debug, Clone, Deserialize)]
pub struct TaskToolArgs {
/// 简短描述3-5词
pub description: String,
/// 详细指令
pub prompt: String,
/// 子代理类型
#[serde(default)]
pub subagent_type: SubagentType,
/// 恢复现有会话的 task_id
#[serde(default)]
pub task_id: Option<String>,
}
/// 任务执行结果
#[derive(Debug, Clone, Serialize)]
pub struct TaskToolResult {
/// 状态: success/failed/timeout
pub status: String,
/// 任务完成总结
pub summary: String,
/// 详细输出
pub output: String,
/// 会话 ID用于恢复
pub task_id: String,
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system clock before unix epoch")
.as_millis() as i64
}

View File

@ -15,8 +15,6 @@ pub struct ToolContext {
pub session_id: Option<String>,
pub message_id: Option<String>,
pub message_seq: Option<i64>,
/// 子代理标识,用于标注消息来源
pub subagent_description: Option<String>,
}
#[async_trait]