feat: 添加子智能体支持到保存话题和会话功能,优化数据持久化

This commit is contained in:
oudecheng 2026-05-20 17:52:46 +08:00
parent 8d530dcd6b
commit 49475783a2
11 changed files with 365 additions and 57 deletions

View File

@ -44,13 +44,30 @@ impl InputAdapter for ChannelInputAdapter {
if trimmed == "/save" {
return Ok(Some(Command::SaveTopic {
filepath: None,
include_subagents: false,
}));
}
if let Some(filepath) = trimmed.strip_prefix("/save ") {
let filepath = filepath.trim();
if let Some(args) = trimmed.strip_prefix("/save ") {
let args = args.trim();
let parts: Vec<&str> = args.split_whitespace().collect();
// 解析参数
let mut include_subagents = false;
let mut filepath = None;
for part in parts {
if part == "+sub" {
include_subagents = true;
} else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveTopic {
filepath: Some(filepath.to_string()),
filepath,
include_subagents,
}));
}
@ -59,24 +76,35 @@ impl InputAdapter for ChannelInputAdapter {
return Ok(Some(Command::SaveSession {
filepath: None,
include_all: false,
include_subagents: false,
}));
}
if let Some(args) = trimmed.strip_prefix("/save-session ") {
let args = args.trim();
// 解析参数:可能是 "all"、路径、或 "all 路径"
let (include_all, filepath) = if args == "all" {
// /save-session all - 保存全部消息
(true, None)
} else if args.starts_with("all ") {
// /save-session all <filepath> - 保存全部消息到指定路径
let path = args[4..].trim();
(true, Some(path.to_string()))
} else {
// /save-session <filepath> - 保存活跃消息到指定路径
(false, Some(args.to_string()))
};
return Ok(Some(Command::SaveSession { filepath, include_all }));
let parts: Vec<&str> = args.split_whitespace().collect();
// 解析参数
let mut include_all = false;
let mut include_subagents = false;
let mut filepath = None;
for part in parts {
if part == "all" {
include_all = true;
} else if part == "+sub" {
include_subagents = true;
} else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveSession {
filepath,
include_all,
include_subagents,
}));
}
// 解析 /list 命令

View File

@ -45,13 +45,30 @@ impl InputAdapter for CliInputAdapter {
if trimmed == "/save" {
return Ok(Some(Command::SaveTopic {
filepath: None,
include_subagents: false,
}));
}
if let Some(filepath) = trimmed.strip_prefix("/save ") {
let filepath = filepath.trim();
if let Some(args) = trimmed.strip_prefix("/save ") {
let args = args.trim();
let parts: Vec<&str> = args.split_whitespace().collect();
// 解析参数
let mut include_subagents = false;
let mut filepath = None;
for part in parts {
if part == "+sub" {
include_subagents = true;
} else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveTopic {
filepath: Some(filepath.to_string()),
filepath,
include_subagents,
}));
}
@ -60,24 +77,35 @@ impl InputAdapter for CliInputAdapter {
return Ok(Some(Command::SaveSession {
filepath: None,
include_all: false,
include_subagents: false,
}));
}
if let Some(args) = trimmed.strip_prefix("/save-session ") {
let args = args.trim();
// 解析参数:可能是 "all"、路径、或 "all 路径"
let (include_all, filepath) = if args == "all" {
// /save-session all - 保存全部消息
(true, None)
} else if args.starts_with("all ") {
// /save-session all <filepath> - 保存全部消息到指定路径
let path = args[4..].trim();
(true, Some(path.to_string()))
} else {
// /save-session <filepath> - 保存活跃消息到指定路径
(false, Some(args.to_string()))
};
return Ok(Some(Command::SaveSession { filepath, include_all }));
let parts: Vec<&str> = args.split_whitespace().collect();
// 解析参数
let mut include_all = false;
let mut include_subagents = false;
let mut filepath = None;
for part in parts {
if part == "all" {
include_all = true;
} else if part == "+sub" {
include_subagents = true;
} else if !part.is_empty() {
// 非特殊参数视为文件路径
filepath = Some(part.to_string());
}
}
return Ok(Some(Command::SaveSession {
filepath,
include_all,
include_subagents,
}));
}
// 解析 /list 命令

View File

@ -11,6 +11,7 @@ pub mod switch_session;
pub use save_session::{
escape_yaml_string, format_message_content, format_timestamp,
generate_messages_markdown, generate_system_prompt_markdown,
generate_subagent_tasks_markdown, load_subagent_data, SubagentTaskData,
};
use crate::bus::ChatMessage;

View File

@ -5,6 +5,7 @@ use crate::command::handler::{CommandHandler, CommandMetadata, InChatCommandHand
use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command;
use crate::storage::{SessionRecord, SessionStore};
use crate::tools::task::repository::TaskRepository;
use crate::agent::AgentError;
use async_trait::async_trait;
use chrono::{Local, TimeZone};
@ -17,8 +18,10 @@ use std::sync::Arc;
/// * `session_id` - 会话ID
/// * `filepath` - 可选的文件路径
/// * `include_all` - 是否包含 cutoff 之前的所有消息
/// * `include_subagents` - 是否包含子智能体消息
/// * `store` - 会话存储
/// * `provider_config` - LLM提供者配置
/// * `task_repository` - 任务存储(可选,用于查询子智能体)
/// * `system_prompt_provider` - 系统提示词提供者
///
/// # Returns
/// 返回保存的文件路径
@ -26,7 +29,9 @@ pub async fn save_session_to_file(
session_id: &str,
filepath: Option<String>,
include_all: bool,
include_subagents: bool,
store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
system_prompt_provider: &dyn SystemPromptProvider,
) -> Result<PathBuf, String> {
// 获取会话记录
@ -46,6 +51,13 @@ pub async fn save_session_to_file(
.map_err(|e| format!("Failed to load messages: {}", e))?
};
// 加载子智能体消息(如果启用)
let subagent_data = if include_subagents {
load_subagent_data(session_id, store, task_repository).await
} else {
Vec::new()
};
// 计算用户消息数(用于系统提示词构建)
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
@ -53,7 +65,7 @@ pub async fn save_session_to_file(
let system_prompt = build_system_prompt(system_prompt_provider, &record, user_message_count);
// 生成 Markdown 内容
let markdown = generate_markdown(&record, &system_prompt, &messages);
let markdown = generate_markdown_with_subagents(&record, &system_prompt, &messages, &subagent_data);
// 确定输出路径
let output_path = resolve_filepath(filepath, &record);
@ -78,6 +90,7 @@ pub async fn save_session_to_file(
/// 将当前会话内容(系统提示词和消息历史)保存到 Markdown 文件
pub struct SaveSessionCommandHandler {
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
}
@ -86,10 +99,16 @@ impl SaveSessionCommandHandler {
///
/// # Arguments
/// * `store` - 会话存储
/// * `system_prompt_provider` - 系统提示词提供者(负责构建完整的系统提示词)
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self {
/// * `task_repository` - 任务存储(用于查询子智能体)
/// * `system_prompt_provider` - 系统提示词提供者
pub fn new(
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self {
Self {
store,
task_repository,
system_prompt_provider,
}
}
@ -121,8 +140,8 @@ impl CommandHandler for SaveSessionCommandHandler {
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
match cmd {
Command::SaveSession { filepath, include_all } => {
handle_save_session(self, filepath, include_all, ctx).await
Command::SaveSession { filepath, include_all, include_subagents } => {
handle_save_session(self, filepath, include_all, include_subagents, ctx).await
}
_ => unreachable!(),
}
@ -134,12 +153,14 @@ async fn handle_save_session(
handler: &SaveSessionCommandHandler,
filepath: Option<String>,
include_all: bool,
include_subagents: bool,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
tracing::debug!(
ctx_session_id = ?ctx.session_id,
ctx_chat_id = ?ctx.chat_id,
channel = %ctx.channel_name,
include_subagents = include_subagents,
"SaveSession command received"
);
@ -174,7 +195,9 @@ async fn handle_save_session(
session_id,
filepath,
include_all,
include_subagents,
&*handler.store,
Some(handler.task_repository.as_ref()),
&*handler.system_prompt_provider,
)
.await
@ -202,6 +225,174 @@ async fn handle_save_session(
.with_metadata("message_count", &message_count.to_string()))
}
/// 子智能体任务数据
#[derive(Debug)]
pub struct SubagentTaskData {
pub task_id: String,
pub session_id: String,
pub description: String,
pub subagent_type: String,
pub state: String,
pub created_at: i64,
pub messages: Vec<crate::bus::ChatMessage>,
}
/// 加载子智能体数据
pub async fn load_subagent_data(
parent_session_id: &str,
store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
) -> Vec<SubagentTaskData> {
let Some(repo) = task_repository else {
return Vec::new();
};
// 获取所有子任务
let tasks = match repo.list_tasks_for_session(parent_session_id).await {
Ok(tasks) => tasks,
Err(e) => {
tracing::warn!(error = %e, "Failed to list tasks for session");
return Vec::new();
}
};
let mut result = Vec::new();
for task in tasks {
// 加载子智能体的消息
let messages = match store.load_all_messages(&task.session_id) {
Ok(msgs) => msgs,
Err(e) => {
tracing::warn!(error = %e, task_id = %task.id, "Failed to load subagent messages");
Vec::new()
}
};
result.push(SubagentTaskData {
task_id: task.id,
session_id: task.session_id,
description: task.description,
subagent_type: task.subagent_type.as_str().to_string(),
state: format!("{:?}", task.state),
created_at: task.created_at,
messages,
});
}
result
}
/// 生成 Markdown 内容(包含子智能体)
pub fn generate_markdown_with_subagents(
record: &SessionRecord,
system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage],
subagent_data: &[SubagentTaskData],
) -> String {
let mut output = String::new();
// YAML frontmatter
output.push_str("---\n");
output.push_str(&format!("title: {}\n", escape_yaml_string(&record.title)));
output.push_str(&format!("session_id: {}\n", record.id));
output.push_str(&format!("channel: {}\n", record.channel_name));
output.push_str(&format!("chat_id: {}\n", record.chat_id));
output.push_str(&format!(
"created_at: {}\n",
format_timestamp(record.created_at)
));
output.push_str(&format!(
"updated_at: {}\n",
format_timestamp(record.updated_at)
));
output.push_str(&format!(
"last_active_at: {}\n",
format_timestamp(record.last_active_at)
));
output.push_str(&format!("message_count: {}\n", messages.len()));
if !subagent_data.is_empty() {
output.push_str(&format!("subagent_count: {}\n", subagent_data.len()));
}
output.push_str("---\n\n");
// 系统提示词
output.push_str(&generate_system_prompt_markdown(system_prompt));
// 子智能体任务(如果有)
if !subagent_data.is_empty() {
output.push_str(&generate_subagent_tasks_markdown(subagent_data));
}
// 主会话消息历史
output.push_str(&generate_messages_markdown(messages));
output
}
/// 生成子智能体任务 Markdown
pub fn generate_subagent_tasks_markdown(subagent_data: &[SubagentTaskData]) -> String {
let mut output = String::new();
output.push_str("# Subagent Tasks\n\n");
for task in subagent_data {
output.push_str(&format!("## Task: {} ({})", task.description, task.subagent_type));
output.push('\n');
output.push_str(&format!("**Task ID:** `{}`\n\n", task.task_id));
output.push_str(&format!("**Session ID:** `{}`\n\n", task.session_id));
output.push_str(&format!("**Status:** {}\n\n", task.state));
output.push_str(&format!("**Created:** {}\n\n", format_timestamp(task.created_at)));
output.push_str(&format!("**Message Count:** {}\n\n", task.messages.len()));
// 子智能体消息
if !task.messages.is_empty() {
output.push_str("### Messages\n\n");
for (idx, msg) in task.messages.iter().enumerate() {
output.push_str(&format!("#### Message {}\n\n", idx + 1));
output.push_str(&format!("**Role:** {}\n\n", msg.role));
output.push_str(&format!("**Time:** {}\n\n", format_timestamp(msg.timestamp)));
if let Some(ref reasoning) = msg.reasoning_content {
output.push_str("**Reasoning:**\n");
output.push_str("```\n");
output.push_str(reasoning);
output.push_str("\n```\n\n");
}
output.push_str("**Content:**\n\n");
if msg.content.is_empty() {
output.push_str("*empty*\n\n");
} else {
output.push_str(&format!("{}\n\n", format_message_content(&msg.content)));
}
// 工具调用
if let Some(ref calls) = msg.tool_calls {
if !calls.is_empty() {
output.push_str("**Tool Calls:**\n\n");
for call in calls {
output.push_str(&format!("- **{}** (`{}`)\n", call.name, call.id));
output.push_str(" ```json\n");
let args_json = serde_json::to_string_pretty(&call.arguments)
.unwrap_or_else(|_| call.arguments.to_string());
for line in args_json.lines() {
output.push_str(&format!(" {}\n", line));
}
output.push_str(" ```\n");
}
output.push('\n');
}
}
output.push_str("---\n\n");
}
}
output.push_str("---\n\n");
}
output
}
/// 构建系统提示词
fn build_system_prompt(
provider: &dyn SystemPromptProvider,
@ -440,14 +631,20 @@ pub fn resolve_filepath(filepath: Option<String>, record: &SessionRecord) -> Pat
/// 用于处理 Feishu/WeChat 等通道中直接输入的 /save 命令
pub struct SaveSessionInChatHandler {
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
}
impl SaveSessionInChatHandler {
/// 创建新的 InChat 保存会话命令处理器
pub fn new(store: Arc<SessionStore>, system_prompt_provider: Arc<dyn SystemPromptProvider>) -> Self {
pub fn new(
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self {
Self {
store,
task_repository,
system_prompt_provider,
}
}
@ -465,7 +662,7 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
inbound: &InboundMessage,
session_manager: &crate::gateway::session::SessionManager,
) -> Result<Option<String>, AgentError> {
let Command::SaveSession { filepath, include_all } = cmd else {
let Command::SaveSession { filepath, include_all, include_subagents } = cmd else {
return Ok(None);
};
@ -486,7 +683,9 @@ impl InChatCommandHandler for SaveSessionInChatHandler {
&session_id,
filepath,
include_all,
include_subagents,
&*self.store,
Some(self.task_repository.as_ref()),
&*self.system_prompt_provider,
)
.await;
@ -623,11 +822,12 @@ mod tests {
#[test]
fn test_can_handle() {
let store = Arc::new(SessionStore::in_memory().unwrap());
let task_repository = Arc::new(crate::tools::task::repository::InMemoryTaskRepository::new());
let provider = Arc::new(TestSystemPromptProvider);
let handler = SaveSessionCommandHandler::new(store, provider);
let handler = SaveSessionCommandHandler::new(store, task_repository, provider);
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: false, include_subagents: false }));
assert!(handler.can_handle(&Command::SaveSession { filepath: None, include_all: true, include_subagents: false }));
assert!(!handler.can_handle(&Command::CreateSession { title: None }));
assert!(!handler.can_handle(&Command::SaveTopic { filepath: None }));
}

View File

@ -4,12 +4,14 @@ use crate::command::context::CommandContext;
use crate::command::handler::{CommandHandler, CommandMetadata};
use crate::command::handlers::{
escape_yaml_string, format_timestamp, generate_messages_markdown,
generate_system_prompt_markdown, get_messages_from_session,
generate_subagent_tasks_markdown, generate_system_prompt_markdown,
get_messages_from_session, load_subagent_data, SubagentTaskData,
};
use crate::command::response::{CommandError, CommandResponse, MessageKind};
use crate::command::Command;
use crate::gateway::session::SessionManager;
use crate::storage::{SessionStore, TopicRecord};
use crate::tools::task::repository::TaskRepository;
use async_trait::async_trait;
use chrono::Local;
use std::path::PathBuf;
@ -19,9 +21,11 @@ use std::sync::Arc;
pub async fn save_topic_to_file(
topic_id: &str,
filepath: Option<String>,
include_subagents: bool,
store: &SessionStore,
task_repository: Option<&dyn TaskRepository>,
system_prompt_provider: &dyn SystemPromptProvider,
messages: &[ChatMessage], // ← 从外部传入的消息(已压缩的 active history
messages: &[ChatMessage],
) -> Result<PathBuf, String> {
// 获取话题记录
let topic = store
@ -38,8 +42,15 @@ pub async fn save_topic_to_file(
let user_message_count = messages.iter().filter(|m| m.role == "user").count();
let system_prompt = build_system_prompt(system_prompt_provider, &session, user_message_count);
// 加载子智能体消息(如果启用)
let subagent_data = if include_subagents {
load_subagent_data(&topic.session_id, store, task_repository).await
} else {
Vec::new()
};
// 生成 Markdown 内容
let markdown = generate_topic_markdown(&topic, &system_prompt, messages);
let markdown = generate_topic_markdown(&topic, &system_prompt, messages, &subagent_data);
// 确定输出路径
let output_path = resolve_topic_filepath(filepath, &topic);
@ -79,6 +90,7 @@ fn generate_topic_markdown(
topic: &TopicRecord,
system_prompt: &Option<SystemPrompt>,
messages: &[crate::bus::ChatMessage],
subagent_data: &[SubagentTaskData],
) -> String {
let mut output = String::new();
@ -103,11 +115,19 @@ fn generate_topic_markdown(
format_timestamp(topic.last_active_at)
));
output.push_str(&format!("message_count: {}\n", messages.len()));
if !subagent_data.is_empty() {
output.push_str(&format!("subagent_count: {}\n", subagent_data.len()));
}
output.push_str("---\n\n");
// 系统提示词(复用公共函数)
output.push_str(&generate_system_prompt_markdown(system_prompt));
// 子智能体任务(如果有)
if !subagent_data.is_empty() {
output.push_str(&generate_subagent_tasks_markdown(subagent_data));
}
// 消息历史(复用公共函数)
output.push_str(&generate_messages_markdown(messages));
@ -153,6 +173,7 @@ fn resolve_topic_filepath(filepath: Option<String>, topic: &TopicRecord) -> Path
/// 保存话题命令处理器
pub struct SaveTopicCommandHandler {
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
session_manager: Option<SessionManager>,
}
@ -160,10 +181,12 @@ pub struct SaveTopicCommandHandler {
impl SaveTopicCommandHandler {
pub fn new(
store: Arc<SessionStore>,
task_repository: Arc<dyn TaskRepository>,
system_prompt_provider: Arc<dyn SystemPromptProvider>,
) -> Self {
Self {
store,
task_repository,
system_prompt_provider,
session_manager: None,
}
@ -195,7 +218,9 @@ impl CommandHandler for SaveTopicCommandHandler {
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
match cmd {
Command::SaveTopic { filepath } => handle_save_topic(self, filepath, ctx).await,
Command::SaveTopic { filepath, include_subagents } => {
handle_save_topic(self, filepath, include_subagents, ctx).await
}
_ => unreachable!(),
}
}
@ -204,12 +229,14 @@ impl CommandHandler for SaveTopicCommandHandler {
async fn handle_save_topic(
handler: &SaveTopicCommandHandler,
filepath: Option<String>,
include_subagents: bool,
ctx: CommandContext,
) -> Result<CommandResponse, CommandError> {
tracing::debug!(
ctx_topic_id = ?ctx.topic_id,
ctx_session_id = ?ctx.session_id,
channel = %ctx.channel_name,
include_subagents = include_subagents,
"SaveTopic command received"
);
@ -238,7 +265,9 @@ async fn handle_save_topic(
let output_path = save_topic_to_file(
topic_id,
filepath,
include_subagents,
&*handler.store,
Some(handler.task_repository.as_ref()),
&*handler.system_prompt_provider,
&messages,
)

View File

@ -14,11 +14,15 @@ pub enum Command {
/// 创建新话题(在同一个 Session 内)
CreateSession { title: Option<String> },
/// 保存当前话题内容到 Markdown 文件
SaveTopic { filepath: Option<String> },
SaveTopic {
filepath: Option<String>,
include_subagents: bool,
},
/// 保存会话内容到 Markdown 文件
SaveSession {
filepath: Option<String>,
include_all: bool,
include_subagents: bool,
},
/// 列出当前 Session 的所有话题
ListSessions { include_archived: bool },

View File

@ -38,6 +38,7 @@ use crate::config::LLMProviderConfig;
use crate::logging;
use crate::scheduler::Scheduler;
use crate::skills::SkillRuntime;
use crate::tools::task::repository::TaskRepository;
use agent_task_executor::{AgentTaskExecutor, SchedulerMaintenanceService};
use outbound_dispatcher::OutboundDispatcher;
use processor::InboundProcessor;
@ -50,6 +51,7 @@ pub struct GatewayState {
pub session_manager: SessionManager,
pub channel_manager: ChannelManager,
pub bus: Arc<MessageBus>,
pub task_repository: Arc<dyn TaskRepository>,
}
impl GatewayState {
@ -72,7 +74,7 @@ impl GatewayState {
let channel_manager = ChannelManager::new();
let bus = channel_manager.bus();
let session_manager = build_session_manager_with_sender(
let (session_manager, task_repository) = build_session_manager_with_sender(
agent_prompt_reinject_every,
show_tool_results,
config.time.timezone.clone(),
@ -91,6 +93,7 @@ impl GatewayState {
session_manager,
channel_manager,
bus,
task_repository,
})
}

View File

@ -80,12 +80,14 @@ impl InboundProcessor {
// 注册 save_session 处理器
command_router.register(Box::new(SaveSessionCommandHandler::new(
store.clone(),
session_manager.task_repository(),
system_prompt_provider.clone(),
)));
// 注册 save_topic 处理器
command_router.register(Box::new(SaveTopicCommandHandler::new(
store.clone(),
session_manager.task_repository(),
system_prompt_provider,
).with_session_manager(session_manager.clone())));

View File

@ -13,6 +13,7 @@ use crate::tools::{
DefaultSubAgentRuntime, InMemoryTaskRepository, NoopSessionMessageSender,
SessionMessageSender, SubAgentRuntimeConfig, ToolRegistry,
};
use crate::tools::task::repository::TaskRepository;
use super::agent_factory::AgentFactory;
use super::cli_session::CliSessionService;
@ -35,7 +36,7 @@ pub(crate) fn build_session_manager(
task_config: TaskConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<SessionManager, AgentError> {
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
build_session_manager_with_sender(
agent_prompt_reinject_every,
show_tool_results,
@ -63,7 +64,7 @@ pub(crate) fn build_session_manager_with_sender(
task_config: TaskConfig,
chat_history_ttl_hours: Option<u64>,
session_ttl_hours: Option<u64>,
) -> Result<SessionManager, AgentError> {
) -> Result<(SessionManager, Arc<dyn TaskRepository>), AgentError> {
let store = Arc::new(
SessionStore::new()
.map_err(|err| AgentError::Other(format!("session store init error: {}", err)))?,
@ -96,7 +97,7 @@ pub(crate) fn build_session_manager_with_sender(
);
// 创建 SubAgentRuntime如果 task 工具启用)
let factory = if task_config.enabled {
let (factory, task_repository): (_, Arc<dyn TaskRepository>) = if task_config.enabled {
let task_repository = Arc::new(InMemoryTaskRepository::new());
let subagent_tools = Arc::new(factory.build_subagent_tools());
@ -111,15 +112,16 @@ pub(crate) fn build_session_manager_with_sender(
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
runtime_config,
task_repository,
task_repository.clone(),
conversations.clone(),
subagent_tools,
provider_config.clone(),
));
factory.with_subagent_runtime(subagent_runtime)
(factory.with_subagent_runtime(subagent_runtime), task_repository)
} else {
factory
// 如果 task 工具未启用,创建一个空的内存仓库
(factory, Arc::new(InMemoryTaskRepository::new()))
};
let tools = Arc::new(factory.build());
@ -151,7 +153,7 @@ pub(crate) fn build_session_manager_with_sender(
let memory_maintenance =
MemoryMaintenanceCoordinator::new(store.clone(), provider_configs.clone());
Ok(SessionManager::from_services(SessionManagerServices {
Ok((SessionManager::from_services(SessionManagerServices {
tools: tools as Arc<ToolRegistry>,
skills,
store,
@ -161,5 +163,6 @@ pub(crate) fn build_session_manager_with_sender(
messages,
scheduled_tasks,
memory_maintenance,
}))
task_repository: task_repository.clone(),
}), task_repository))
}

View File

@ -8,6 +8,7 @@ use crate::scheduler::ScheduledAgentTaskOptions;
use crate::skills::SkillRuntime;
use crate::storage::{ConversationRepository, PromptInjectionRepository, SessionRecord, SessionStore, SkillEventRepository};
use crate::tools::ToolRegistry;
use crate::tools::task::repository::TaskRepository;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
@ -458,6 +459,7 @@ pub struct SessionManager {
messages: SessionMessageService,
scheduled_tasks: ScheduledAgentTaskService,
memory_maintenance: MemoryMaintenanceCoordinator,
task_repository: Arc<dyn TaskRepository>,
}
pub(crate) struct SessionManagerServices {
@ -470,6 +472,7 @@ pub(crate) struct SessionManagerServices {
pub(crate) messages: SessionMessageService,
pub(crate) scheduled_tasks: ScheduledAgentTaskService,
pub(crate) memory_maintenance: MemoryMaintenanceCoordinator,
pub(crate) task_repository: Arc<dyn TaskRepository>,
}
impl SessionManager {
@ -484,6 +487,7 @@ impl SessionManager {
messages: services.messages,
scheduled_tasks: services.scheduled_tasks,
memory_maintenance: services.memory_maintenance,
task_repository: services.task_repository,
}
}
@ -511,6 +515,7 @@ impl SessionManager {
chat_history_ttl_hours,
session_ttl_hours,
)
.map(|(session_manager, _)| session_manager)
}
pub fn tools(&self) -> Arc<ToolRegistry> {
@ -525,6 +530,10 @@ impl SessionManager {
self.show_tool_results
}
pub fn task_repository(&self) -> Arc<dyn TaskRepository> {
self.task_repository.clone()
}
pub fn skills(&self) -> Arc<SkillRuntime> {
self.skills.clone()
}

View File

@ -244,6 +244,7 @@ async fn handle_inbound(
router.register(Box::new(LoadSessionCommandHandler::new(store.clone())));
router.register(Box::new(SaveSessionCommandHandler::new(
store.clone(),
state.task_repository.clone(),
system_prompt_provider.clone(),
)));
// 注册 help 处理器