PicoBot/src/gateway/agent_factory.rs
oudecheng 631c61fea2 feat(agent): 支持子代理最大嵌套深度控制
- 在配置结构体中新增 max_nesting_depth 字段,设置子代理最大嵌套深度
- 在 AgentFactory、todo_read、todo_write 等处初始化 nesting_depth 字段为 0
- 允许 Task 工具注册,使用 max_nesting_depth 控制子代理嵌套层数
- SubAgentRuntimeConfig 新增 max_nesting_depth 配置项,默认值为 1
- TaskTool 新增 max_nesting_depth 字段和带深度限制的构造函数
- 任务执行时增加嵌套深度校验,超过最大深度返回错误提示,防止无限递归创建子代理
2026-06-17 14:43:55 +08:00

92 lines
3.4 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use crate::agent::{AgentError, AgentLoop, CompositeSystemPromptProvider};
use crate::config::LLMProviderConfig;
use crate::gateway::agent_prompt_provider::AgentPromptProvider;
use crate::gateway::todo_prompt_provider::TodoPromptProvider;
use crate::skills::{SkillPromptProvider, SkillRuntime};
use crate::storage::persistent_session_id;
use crate::storage::PromptInjectionRepository;
use crate::tools::{ToolContext, ToolRegistry};
#[derive(Clone)]
pub(crate) struct AgentFactory {
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>,
}
pub(crate) struct AgentBuildRequest<'a> {
pub(crate) channel_name: &'a str,
pub(crate) session_chat_id: &'a str,
pub(crate) notification_chat_id: Option<&'a str>,
pub(crate) sender_id: Option<&'a str>,
pub(crate) message_id: Option<&'a str>,
pub(crate) provider_config: LLMProviderConfig,
/// 当前话题 ID可选用于 todo 等按 topic 隔离的工具
pub(crate) topic_id: Option<String>,
/// 取消信号接收端可选Agent 在每次迭代时检查是否被取消
pub(crate) cancel_token: Option<tokio::sync::watch::Receiver<()>>,
}
impl AgentFactory {
pub(crate) fn new(
tools: Arc<ToolRegistry>,
skills: Arc<SkillRuntime>,
reinject_every: usize,
prompt_repository: Arc<dyn PromptInjectionRepository>,
) -> Self {
Self {
tools,
skills,
reinject_every,
prompt_repository,
}
}
pub(crate) fn create(&self, request: AgentBuildRequest<'_>) -> Result<AgentLoop, AgentError> {
let session_id = persistent_session_id(request.channel_name, request.session_chat_id);
// 创建组合的系统提示词提供者
let system_prompt_provider = Arc::new(CompositeSystemPromptProvider::new(vec![
Box::new(AgentPromptProvider::new(
self.reinject_every,
request.provider_config.clone(),
self.prompt_repository.clone(),
)),
Box::new(SkillPromptProvider::new(self.skills.clone())),
Box::new(TodoPromptProvider::new()),
]));
AgentLoop::with_tools_and_system_prompt_provider(
request.provider_config,
self.tools.clone(),
system_prompt_provider,
Some(self.skills.clone()),
)
.map(|agent| {
// notification_chat_id 优先,否则使用 session_chat_id
let tool_chat_id = request
.notification_chat_id
.unwrap_or(request.session_chat_id);
let mut agent = agent.with_tool_context(ToolContext {
channel_name: Some(request.channel_name.to_string()),
sender_id: request.sender_id.map(str::to_string),
chat_id: Some(tool_chat_id.to_string()),
session_id: Some(session_id),
topic_id: request.topic_id.clone(),
message_id: request.message_id.map(str::to_string),
message_seq: None,
subagent_description: None,
nesting_depth: 0,
});
// 如果有取消信号接收端,注入 Agent
if let Some(token) = request.cancel_token {
agent = agent.with_cancel_token(token);
}
agent
})
}
}