Compare commits
No commits in common. "881c73c79f31b7b53d37870e83a992df1506a135" and "90228a4d49efa4f5369012dae4e5ce627cb2b2d4" have entirely different histories.
881c73c79f
...
90228a4d49
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,5 +11,3 @@ PicoBot.code-workspace
|
|||||||
.picobot
|
.picobot
|
||||||
.claude
|
.claude
|
||||||
output
|
output
|
||||||
.python-version
|
|
||||||
pyproject.toml
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart", "stream"] }
|
reqwest = { version = "0.13.2", default-features = false, features = ["json", "rustls", "multipart"] }
|
||||||
dotenv = "0.15"
|
dotenv = "0.15"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
regex = "1.0"
|
regex = "1.0"
|
||||||
|
|||||||
@ -9,6 +9,5 @@ pub use agent_loop::{
|
|||||||
pub use context_compressor::ContextCompressor;
|
pub use context_compressor::ContextCompressor;
|
||||||
pub use runtime_config::AgentRuntimeConfig;
|
pub use runtime_config::AgentRuntimeConfig;
|
||||||
pub use system_prompt::{
|
pub use system_prompt::{
|
||||||
CompositeSystemPromptProvider, generate_system_env_prompt, SystemPrompt, SystemPromptContext,
|
CompositeSystemPromptProvider, SystemPrompt, SystemPromptContext, SystemPromptProvider,
|
||||||
SystemPromptProvider,
|
|
||||||
};
|
};
|
||||||
|
|||||||
@ -70,33 +70,6 @@ impl SystemPromptProvider for CompositeSystemPromptProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 生成系统环境信息提示词
|
|
||||||
/// 供主智能体和子智能体共享使用
|
|
||||||
pub fn generate_system_env_prompt(config: &crate::config::LLMProviderConfig) -> String {
|
|
||||||
use std::env;
|
|
||||||
use std::env::consts::{ARCH, OS};
|
|
||||||
|
|
||||||
let os_name = match OS {
|
|
||||||
"windows" => "Windows",
|
|
||||||
"linux" => "Linux",
|
|
||||||
"macos" => "macOS",
|
|
||||||
"freebsd" => "FreeBSD",
|
|
||||||
_ => OS,
|
|
||||||
};
|
|
||||||
|
|
||||||
// 使用 platform 模块获取 Shell 信息,与 bash 工具保持一致
|
|
||||||
let shell_info = crate::platform::ShellInfo::default();
|
|
||||||
let shell = shell_info.executable;
|
|
||||||
let cwd = env::current_dir()
|
|
||||||
.map(|p| p.display().to_string())
|
|
||||||
.unwrap_or_else(|_| "unknown".to_string());
|
|
||||||
|
|
||||||
format!(
|
|
||||||
"## 系统环境\n- 操作系统: {}\n- 架构: {}\n- Shell: {}\n- 当前工作目录: {}\n- 模型提供商: {}\n- 模型: {}",
|
|
||||||
os_name, ARCH, shell, cwd, config.name, config.model_id
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@ -257,24 +257,10 @@ pub fn generate_markdown(
|
|||||||
/// 格式化消息内容
|
/// 格式化消息内容
|
||||||
///
|
///
|
||||||
/// 如果内容包含特殊字符,使用代码块包装
|
/// 如果内容包含特殊字符,使用代码块包装
|
||||||
/// 使用比内容中最大连续反引号数量多1的反引号来包裹,避免嵌套冲突
|
|
||||||
pub fn format_message_content(content: &str) -> String {
|
pub fn format_message_content(content: &str) -> String {
|
||||||
// 如果内容包含表格标记或换行符,使用代码块包裹以保留格式
|
// 如果内容包含代码块标记或表格标记,使用原始格式
|
||||||
if content.contains("| ") || content.contains('\n') {
|
if content.contains("```") || content.contains("| ") {
|
||||||
// 计算内容中连续反引号的最大数量
|
format!("```\n{}\n```", content)
|
||||||
let max_backticks = content
|
|
||||||
.chars()
|
|
||||||
.fold((0, 0), |(max_count, current_count), c| {
|
|
||||||
if c == '`' {
|
|
||||||
(max_count, current_count + 1)
|
|
||||||
} else {
|
|
||||||
(max_count.max(current_count), 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.0;
|
|
||||||
// 使用比最大数量多1的反引号来包裹(至少3个)
|
|
||||||
let fence = "`".repeat(max_backticks.max(3) + 1);
|
|
||||||
format!("{}\n{}\n{}", fence, content, fence)
|
|
||||||
} else {
|
} else {
|
||||||
content.to_string()
|
content.to_string()
|
||||||
}
|
}
|
||||||
@ -567,40 +553,10 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_format_message_content() {
|
fn test_format_message_content() {
|
||||||
// 普通单行文本 - 原样返回
|
|
||||||
assert_eq!(format_message_content("hello"), "hello");
|
assert_eq!(format_message_content("hello"), "hello");
|
||||||
|
|
||||||
// 单行包含反引号 - 原样返回(单行不需要包裹)
|
|
||||||
assert_eq!(format_message_content("`code`"), "`code`");
|
|
||||||
|
|
||||||
// 包含换行符 - 使用4个反引号包裹(最小)
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
format_message_content("line1\nline2\nline3"),
|
format_message_content("```code```"),
|
||||||
"````\nline1\nline2\nline3\n````"
|
"```\n```code```\n```"
|
||||||
);
|
|
||||||
|
|
||||||
// 包含表格标记 - 使用4个反引号包裹
|
|
||||||
assert_eq!(
|
|
||||||
format_message_content("| col1 | col2 |"),
|
|
||||||
"````\n| col1 | col2 |\n````"
|
|
||||||
);
|
|
||||||
|
|
||||||
// 多行内容包含3个反引号(代码块标记)- 使用4个反引号包裹
|
|
||||||
assert_eq!(
|
|
||||||
format_message_content("```code```\nmore"),
|
|
||||||
"````\n```code```\nmore\n````"
|
|
||||||
);
|
|
||||||
|
|
||||||
// 多行内容包含多行代码块
|
|
||||||
assert_eq!(
|
|
||||||
format_message_content("```\ncode\n```\nmore"),
|
|
||||||
"````\n```\ncode\n```\nmore\n````"
|
|
||||||
);
|
|
||||||
|
|
||||||
// 多行内容包含4个反引号 - 使用5个反引号包裹
|
|
||||||
assert_eq!(
|
|
||||||
format_message_content("````code````\nmore"),
|
|
||||||
"`````\n````code````\nmore\n`````"
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,6 @@ pub(crate) struct FinalizeAgentResultRequest<'a> {
|
|||||||
pub(crate) metadata: &'a HashMap<String, String>,
|
pub(crate) metadata: &'a HashMap<String, String>,
|
||||||
pub(crate) suppress_live_tool_calls: bool,
|
pub(crate) suppress_live_tool_calls: bool,
|
||||||
pub(crate) execution_kind: &'a str,
|
pub(crate) execution_kind: &'a str,
|
||||||
pub(crate) original_topic_id: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct FinalizedAgentResult {
|
pub(crate) struct FinalizedAgentResult {
|
||||||
@ -79,14 +78,10 @@ impl AgentExecutionService {
|
|||||||
session: &mut Session,
|
session: &mut Session,
|
||||||
request: FinalizeAgentResultRequest<'_>,
|
request: FinalizeAgentResultRequest<'_>,
|
||||||
) -> Result<FinalizedAgentResult, AgentError> {
|
) -> Result<FinalizedAgentResult, AgentError> {
|
||||||
// 检查是否是最新的用户回合
|
if !session.matches_current_user_turn(request.chat_id, request.user_message) {
|
||||||
let is_current_turn =
|
|
||||||
session.matches_current_user_turn(request.chat_id, request.user_message);
|
|
||||||
|
|
||||||
if !is_current_turn {
|
|
||||||
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
let (latest_user_id, latest_user_preview, compression_in_flight, history_len) =
|
||||||
session.stale_result_diagnostics(request.chat_id);
|
session.stale_result_diagnostics(request.chat_id);
|
||||||
tracing::info!(
|
tracing::warn!(
|
||||||
channel = %request.channel_name,
|
channel = %request.channel_name,
|
||||||
chat_id = %request.chat_id,
|
chat_id = %request.chat_id,
|
||||||
user_message_id = %request.user_message.id,
|
user_message_id = %request.user_message.id,
|
||||||
@ -95,66 +90,41 @@ impl AgentExecutionService {
|
|||||||
compression_in_flight,
|
compression_in_flight,
|
||||||
history_len,
|
history_len,
|
||||||
execution_kind = %request.execution_kind,
|
execution_kind = %request.execution_kind,
|
||||||
original_topic_id = ?request.original_topic_id,
|
"Skipping stale agent result because a newer user message is already present"
|
||||||
"User switched topic during agent execution - saving result to original topic"
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
return Ok(FinalizedAgentResult {
|
||||||
|
outbound_messages: Vec::new(),
|
||||||
|
should_schedule_compaction: false,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 确定保存消息的话题 ID
|
session
|
||||||
// 如果是最新回合,使用当前话题;否则使用原始话题
|
.append_persisted_messages(request.chat_id, request.result.emitted_messages.clone())?;
|
||||||
let target_topic_id = if is_current_turn {
|
|
||||||
session.current_topic(request.chat_id)
|
|
||||||
} else {
|
|
||||||
request.original_topic_id.as_deref()
|
|
||||||
};
|
|
||||||
|
|
||||||
// 将结果消息保存到确定的话题
|
let outbound_messages = request
|
||||||
if let Some(topic_id) = target_topic_id {
|
.result
|
||||||
if let Err(err) = session.append_messages_to_topic(
|
.emitted_messages
|
||||||
request.chat_id,
|
.iter()
|
||||||
topic_id,
|
.filter(|message| {
|
||||||
&request.result.emitted_messages,
|
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
|
||||||
) {
|
&& should_display_message_to_user(self.show_tool_results, message)
|
||||||
tracing::error!(
|
})
|
||||||
error = %err,
|
.flat_map(|message| {
|
||||||
topic_id = %topic_id,
|
OutboundMessage::from_chat_message(
|
||||||
"Failed to append messages to topic"
|
request.channel_name,
|
||||||
);
|
request.chat_id,
|
||||||
}
|
None, // session_id
|
||||||
}
|
None,
|
||||||
|
request.metadata,
|
||||||
// 只有当是最新回合时才发送 outbound 消息给用户
|
message,
|
||||||
// 如果用户已经切换到其他话题,只保存结果,不发送消息(避免打扰)
|
)
|
||||||
let outbound_messages = if is_current_turn {
|
})
|
||||||
request
|
.collect();
|
||||||
.result
|
|
||||||
.emitted_messages
|
|
||||||
.iter()
|
|
||||||
.filter(|message| {
|
|
||||||
(!message.is_assistant_tool_call_message() || !request.suppress_live_tool_calls)
|
|
||||||
&& should_display_message_to_user(self.show_tool_results, message)
|
|
||||||
})
|
|
||||||
.flat_map(|message| {
|
|
||||||
OutboundMessage::from_chat_message(
|
|
||||||
request.channel_name,
|
|
||||||
request.chat_id,
|
|
||||||
None, // session_id
|
|
||||||
None,
|
|
||||||
request.metadata,
|
|
||||||
message,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
Vec::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// 只有当是最新回合时才触发历史压缩
|
|
||||||
let should_schedule_compaction = is_current_turn;
|
|
||||||
|
|
||||||
Ok(FinalizedAgentResult {
|
Ok(FinalizedAgentResult {
|
||||||
outbound_messages,
|
outbound_messages,
|
||||||
should_schedule_compaction,
|
should_schedule_compaction: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +132,7 @@ impl AgentExecutionService {
|
|||||||
&self,
|
&self,
|
||||||
request: MessageExecutionRequest<'_>,
|
request: MessageExecutionRequest<'_>,
|
||||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||||
let (history, agent, user_message, user_message_count, original_topic_id) = {
|
let (history, agent, user_message, user_message_count) = {
|
||||||
let mut session_guard = request.session.lock().await;
|
let mut session_guard = request.session.lock().await;
|
||||||
|
|
||||||
session_guard.ensure_persistent_session(request.chat_id)?;
|
session_guard.ensure_persistent_session(request.chat_id)?;
|
||||||
@ -186,11 +156,6 @@ impl AgentExecutionService {
|
|||||||
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
|
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
|
||||||
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
|
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
|
||||||
|
|
||||||
// 在添加用户消息前,记录当前话题 ID
|
|
||||||
let original_topic_id = session_guard
|
|
||||||
.current_topic(request.chat_id)
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
|
|
||||||
let user_message = session_guard.create_user_message(&enriched_content, media_refs);
|
let user_message = session_guard.create_user_message(&enriched_content, media_refs);
|
||||||
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
|
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
|
||||||
|
|
||||||
@ -207,7 +172,7 @@ impl AgentExecutionService {
|
|||||||
agent = agent.with_emitted_message_handler(handler);
|
agent = agent.with_emitted_message_handler(handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
(history, agent, user_message, user_message_count, original_topic_id)
|
(history, agent, user_message, user_message_count)
|
||||||
};
|
};
|
||||||
|
|
||||||
// 构建系统提示词上下文
|
// 构建系统提示词上下文
|
||||||
@ -230,7 +195,6 @@ impl AgentExecutionService {
|
|||||||
metadata: &metadata,
|
metadata: &metadata,
|
||||||
suppress_live_tool_calls: request.live_emitter.is_some(),
|
suppress_live_tool_calls: request.live_emitter.is_some(),
|
||||||
execution_kind: "message",
|
execution_kind: "message",
|
||||||
original_topic_id,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -240,7 +204,7 @@ impl AgentExecutionService {
|
|||||||
&self,
|
&self,
|
||||||
request: ScheduledExecutionRequest<'_>,
|
request: ScheduledExecutionRequest<'_>,
|
||||||
) -> Result<Vec<OutboundMessage>, AgentError> {
|
) -> Result<Vec<OutboundMessage>, AgentError> {
|
||||||
let (history, agent, user_message, user_message_count, original_topic_id) = {
|
let (history, agent, user_message, user_message_count) = {
|
||||||
let mut session_guard = request.session.lock().await;
|
let mut session_guard = request.session.lock().await;
|
||||||
|
|
||||||
session_guard.ensure_persistent_session(request.chat_id)?;
|
session_guard.ensure_persistent_session(request.chat_id)?;
|
||||||
@ -266,11 +230,6 @@ impl AgentExecutionService {
|
|||||||
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
|
let history_before = session_guard.get_or_create_history(request.chat_id).clone();
|
||||||
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
|
let user_message_count = history_before.iter().filter(|m| m.role == "user").count();
|
||||||
|
|
||||||
// 在添加用户消息前,记录当前话题 ID
|
|
||||||
let original_topic_id = session_guard
|
|
||||||
.current_topic(request.chat_id)
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
|
|
||||||
let user_message = session_guard.create_user_message(request.prompt, Vec::new());
|
let user_message = session_guard.create_user_message(request.prompt, Vec::new());
|
||||||
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
|
session_guard.append_persisted_message(request.chat_id, user_message.clone())?;
|
||||||
|
|
||||||
@ -280,13 +239,13 @@ impl AgentExecutionService {
|
|||||||
|
|
||||||
let agent = session_guard.create_agent_with_provider_config(
|
let agent = session_guard.create_agent_with_provider_config(
|
||||||
request.chat_id,
|
request.chat_id,
|
||||||
request.notification_chat_id, // 传入真实 chat_id
|
request.notification_chat_id, // 传入真实 chat_id
|
||||||
Some(request.sender_id),
|
Some(request.sender_id),
|
||||||
Some(&user_message.id),
|
Some(&user_message.id),
|
||||||
request.provider_config.clone(),
|
request.provider_config.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
(history, agent, user_message, user_message_count, original_topic_id)
|
(history, agent, user_message, user_message_count)
|
||||||
};
|
};
|
||||||
|
|
||||||
// 构建系统提示词上下文
|
// 构建系统提示词上下文
|
||||||
@ -308,7 +267,6 @@ impl AgentExecutionService {
|
|||||||
metadata: request.metadata,
|
metadata: request.metadata,
|
||||||
suppress_live_tool_calls: false,
|
suppress_live_tool_calls: false,
|
||||||
execution_kind: "scheduled_task",
|
execution_kind: "scheduled_task",
|
||||||
original_topic_id,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
use std::env;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
@ -103,9 +104,26 @@ fn strip_comments_and_whitespace(content: &str) -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 生成系统环境信息提示词
|
/// 生成系统环境信息提示词
|
||||||
/// 使用 agent 模块的共享实现
|
|
||||||
pub(crate) fn generate_system_environment_prompt(config: &LLMProviderConfig) -> String {
|
pub(crate) fn generate_system_environment_prompt(config: &LLMProviderConfig) -> String {
|
||||||
crate::agent::generate_system_env_prompt(config)
|
use std::env::consts::{ARCH, OS};
|
||||||
|
|
||||||
|
let os_name = match OS {
|
||||||
|
"windows" => "Windows",
|
||||||
|
"linux" => "Linux",
|
||||||
|
"macos" => "macOS",
|
||||||
|
"freebsd" => "FreeBSD",
|
||||||
|
_ => OS,
|
||||||
|
};
|
||||||
|
|
||||||
|
let shell = env::var("SHELL").unwrap_or_else(|_| "unknown".to_string());
|
||||||
|
let cwd = env::current_dir()
|
||||||
|
.map(|p| p.display().to_string())
|
||||||
|
.unwrap_or_else(|_| "unknown".to_string());
|
||||||
|
|
||||||
|
format!(
|
||||||
|
"## 系统环境\n- 操作系统: {}\n- 架构: {}\n- Shell: {}\n- 当前工作目录: {}\n- 模型提供商: {}\n- 模型: {}",
|
||||||
|
os_name, ARCH, shell, cwd, config.name, config.model_id
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn persist_memory_summary(path: &Path, markdown_body: &str) -> Result<(), AgentError> {
|
fn persist_memory_summary(path: &Path, markdown_body: &str) -> Result<(), AgentError> {
|
||||||
|
|||||||
@ -106,7 +106,6 @@ pub(crate) fn build_session_manager_with_sender(
|
|||||||
explore_max_execution_secs: task_config.explore_max_execution_secs,
|
explore_max_execution_secs: task_config.explore_max_execution_secs,
|
||||||
explore_max_tool_calls: 20,
|
explore_max_tool_calls: 20,
|
||||||
ttl_hours: task_config.ttl_hours,
|
ttl_hours: task_config.ttl_hours,
|
||||||
skills_index: skills.system_index_prompt(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
|
let subagent_runtime = Arc::new(DefaultSubAgentRuntime::new(
|
||||||
|
|||||||
@ -305,16 +305,6 @@ impl Session {
|
|||||||
self.history.append_persisted_messages(chat_id, messages)
|
self.history.append_persisted_messages(chat_id, messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 将消息保存到指定话题(直接写入数据库,不更新内存历史)
|
|
||||||
pub fn append_messages_to_topic(
|
|
||||||
&self,
|
|
||||||
chat_id: &str,
|
|
||||||
topic_id: &str,
|
|
||||||
messages: &[ChatMessage],
|
|
||||||
) -> Result<(), AgentError> {
|
|
||||||
self.history.append_to_topic(chat_id, topic_id, messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||||||
if media_refs.is_empty() {
|
if media_refs.is_empty() {
|
||||||
ChatMessage::user(content)
|
ChatMessage::user(content)
|
||||||
|
|||||||
@ -212,27 +212,6 @@ impl SessionHistory {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 将消息保存到指定话题(直接写入数据库,不更新内存历史)
|
|
||||||
/// 用于异步执行结果保存到原始话题的场景
|
|
||||||
pub(crate) fn append_to_topic(
|
|
||||||
&self,
|
|
||||||
chat_id: &str,
|
|
||||||
topic_id: &str,
|
|
||||||
messages: &[ChatMessage],
|
|
||||||
) -> Result<(), AgentError> {
|
|
||||||
let session_id = self.persistent_session_id(chat_id);
|
|
||||||
|
|
||||||
for message in messages {
|
|
||||||
self.conversations
|
|
||||||
.append_message_with_topic(&session_id, Some(topic_id), message)
|
|
||||||
.map_err(|err| {
|
|
||||||
AgentError::Other(format!("append message to topic error: {}", err))
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
|
pub(crate) fn latest_user_message(&self, chat_id: &str) -> Option<&ChatMessage> {
|
||||||
self.get_history(chat_id)
|
self.get_history(chat_id)
|
||||||
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
|
.and_then(|history| history.iter().rev().find(|message| message.role == "user"))
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures_util::StreamExt;
|
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
@ -12,98 +11,6 @@ use crate::domain::messages::ContentBlock;
|
|||||||
|
|
||||||
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
const INTERNAL_MODEL_EXTRA_KEYS: &[&str] = &["tool_call_arguments_json", "mock_response_content"];
|
||||||
|
|
||||||
/// 流式响应中的工具调用增量
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
struct StreamingToolCall {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
arguments: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 流式响应累积器
|
|
||||||
#[derive(Debug, Default)]
|
|
||||||
struct StreamingAccumulator {
|
|
||||||
content: String,
|
|
||||||
reasoning_content: Option<String>,
|
|
||||||
tool_calls: HashMap<usize, StreamingToolCall>,
|
|
||||||
response_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamingAccumulator {
|
|
||||||
fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 添加内容增量
|
|
||||||
fn add_content(&mut self, delta: &str) {
|
|
||||||
self.content.push_str(delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 添加推理内容增量
|
|
||||||
fn add_reasoning_content(&mut self, delta: &str) {
|
|
||||||
if self.reasoning_content.is_none() {
|
|
||||||
self.reasoning_content = Some(String::new());
|
|
||||||
}
|
|
||||||
self.reasoning_content.as_mut().unwrap().push_str(delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 添加工具调用增量
|
|
||||||
fn add_tool_call(&mut self, index: usize, id: Option<&str>, name: Option<&str>, arguments: Option<&str>) {
|
|
||||||
let entry = self.tool_calls.entry(index).or_insert_with(StreamingToolCall::default);
|
|
||||||
|
|
||||||
if let Some(id) = id {
|
|
||||||
entry.id = id.to_string();
|
|
||||||
}
|
|
||||||
if let Some(name) = name {
|
|
||||||
entry.name = name.to_string();
|
|
||||||
}
|
|
||||||
if let Some(args) = arguments {
|
|
||||||
entry.arguments.push_str(args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 设置响应 ID
|
|
||||||
fn set_response_id(&mut self, id: String) {
|
|
||||||
if self.response_id.is_empty() {
|
|
||||||
self.response_id = id;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 构建最终的 ChatCompletionResponse
|
|
||||||
fn build_response(self, model: String) -> ChatCompletionResponse {
|
|
||||||
let tool_calls: Vec<ToolCall> = self.tool_calls
|
|
||||||
.into_iter()
|
|
||||||
.filter(|(_, call)| !call.id.is_empty() && !call.name.is_empty())
|
|
||||||
.map(|(_, call)| {
|
|
||||||
let arguments = serde_json::from_str(&call.arguments)
|
|
||||||
.unwrap_or_else(|_| serde_json::Value::Null);
|
|
||||||
ToolCall {
|
|
||||||
id: call.id,
|
|
||||||
name: call.name,
|
|
||||||
arguments,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
ChatCompletionResponse {
|
|
||||||
id: if self.response_id.is_empty() {
|
|
||||||
format!("stream-{}", uuid::Uuid::new_v4())
|
|
||||||
} else {
|
|
||||||
self.response_id
|
|
||||||
},
|
|
||||||
model,
|
|
||||||
content: self.content,
|
|
||||||
reasoning_content: self.reasoning_content,
|
|
||||||
tool_calls,
|
|
||||||
usage: Usage {
|
|
||||||
prompt_tokens: 0,
|
|
||||||
completion_tokens: 0,
|
|
||||||
total_tokens: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
fn format_error_chain(error: &(dyn std::error::Error + 'static)) -> String {
|
||||||
let mut details = vec![error.to_string()];
|
let mut details = vec![error.to_string()];
|
||||||
let mut current = error.source();
|
let mut current = error.source();
|
||||||
@ -210,14 +117,6 @@ impl OpenAIProvider {
|
|||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 检查是否启用流式输出,默认启用
|
|
||||||
fn is_streaming_enabled(&self) -> bool {
|
|
||||||
self.model_extra
|
|
||||||
.get("enable_streaming")
|
|
||||||
.and_then(|value| value.as_bool())
|
|
||||||
.unwrap_or(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
fn normalize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||||
match arguments {
|
match arguments {
|
||||||
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
Value::String(raw) => serde_json::from_str(raw).unwrap_or_else(|_| arguments.clone()),
|
||||||
@ -260,199 +159,6 @@ impl OpenAIProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 内部流式聊天实现
|
|
||||||
async fn chat_streaming(
|
|
||||||
&self,
|
|
||||||
request: &ChatCompletionRequest,
|
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
||||||
tracing::debug!(provider = %self.name, model = %self.model_id, "Starting streaming chat");
|
|
||||||
|
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
|
||||||
|
|
||||||
let mut body = self.build_request_body(request);
|
|
||||||
// 启用流式输出
|
|
||||||
body["stream"] = json!(true);
|
|
||||||
|
|
||||||
let mut req_builder = self
|
|
||||||
.client
|
|
||||||
.post(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("Accept", "text/event-stream");
|
|
||||||
|
|
||||||
for (key, value) in &self.extra_headers {
|
|
||||||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp = req_builder.json(&body).send().await.map_err(|err| {
|
|
||||||
format_transport_error_context(
|
|
||||||
&self.name,
|
|
||||||
&self.model_id,
|
|
||||||
&url,
|
|
||||||
self.llm_timeout_secs,
|
|
||||||
&err,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let status = resp.status();
|
|
||||||
if !status.is_success() {
|
|
||||||
let text = resp.text().await.unwrap_or_default();
|
|
||||||
return Err(format!("API error {}: {}", status, text).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut accumulator = StreamingAccumulator::new();
|
|
||||||
|
|
||||||
// 读取 SSE 流
|
|
||||||
let mut stream = resp.bytes_stream();
|
|
||||||
let mut buffer = String::new();
|
|
||||||
let mut done_received = false;
|
|
||||||
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
|
||||||
let chunk = chunk_result?;
|
|
||||||
let text = String::from_utf8_lossy(&chunk);
|
|
||||||
buffer.push_str(&text);
|
|
||||||
|
|
||||||
// 处理缓冲区中的完整行
|
|
||||||
while let Some(newline_pos) = buffer.find('\n') {
|
|
||||||
let line = buffer[..newline_pos].to_string();
|
|
||||||
buffer = buffer[newline_pos + 1..].to_string();
|
|
||||||
|
|
||||||
let line = line.trim();
|
|
||||||
if line.is_empty() || line.starts_with(':') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSE 格式: data: {...}
|
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
|
||||||
if data == "[DONE]" {
|
|
||||||
// 流结束
|
|
||||||
done_received = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 JSON
|
|
||||||
match serde_json::from_str::<Value>(data) {
|
|
||||||
Ok(json) => {
|
|
||||||
// 提取响应 ID
|
|
||||||
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
|
|
||||||
accumulator.set_response_id(id.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取 choices
|
|
||||||
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
|
||||||
for choice in choices {
|
|
||||||
if let Some(delta) = choice.get("delta") {
|
|
||||||
// 提取内容增量
|
|
||||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
|
||||||
accumulator.add_content(content);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取推理内容增量
|
|
||||||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
|
||||||
accumulator.add_reasoning_content(reasoning);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取工具调用增量
|
|
||||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
|
||||||
tracing::debug!(tool_calls_count = tool_calls.len(), "Received tool_calls in delta");
|
|
||||||
for tool_call in tool_calls {
|
|
||||||
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
|
|
||||||
|
|
||||||
let id = tool_call.get("id").and_then(|v| v.as_str());
|
|
||||||
let name = tool_call.get("function")
|
|
||||||
.and_then(|f| f.get("name"))
|
|
||||||
.and_then(|n| n.as_str());
|
|
||||||
let arguments = tool_call.get("function")
|
|
||||||
.and_then(|f| f.get("arguments"))
|
|
||||||
.and_then(|a| a.as_str());
|
|
||||||
|
|
||||||
tracing::debug!(
|
|
||||||
index = index,
|
|
||||||
id = ?id,
|
|
||||||
name = ?name,
|
|
||||||
arguments = ?arguments,
|
|
||||||
"Tool call delta received"
|
|
||||||
);
|
|
||||||
|
|
||||||
accumulator.add_tool_call(index, id, name, arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::debug!(
|
|
||||||
error = %e,
|
|
||||||
data = %data,
|
|
||||||
"Failed to parse SSE data"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if done_received {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理缓冲区中剩余的内容
|
|
||||||
for line in buffer.lines() {
|
|
||||||
let line = line.trim();
|
|
||||||
if line.is_empty() || line.starts_with(':') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
|
||||||
if data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Ok(json) = serde_json::from_str::<Value>(data) {
|
|
||||||
if let Some(id) = json.get("id").and_then(|v| v.as_str()) {
|
|
||||||
accumulator.set_response_id(id.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
|
|
||||||
for choice in choices {
|
|
||||||
if let Some(delta) = choice.get("delta") {
|
|
||||||
if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
|
|
||||||
accumulator.add_content(content);
|
|
||||||
}
|
|
||||||
if let Some(reasoning) = delta.get("reasoning_content").and_then(|r| r.as_str()) {
|
|
||||||
accumulator.add_reasoning_content(reasoning);
|
|
||||||
}
|
|
||||||
if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
|
|
||||||
for tool_call in tool_calls {
|
|
||||||
let index = tool_call.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
|
|
||||||
let id = tool_call.get("id").and_then(|v| v.as_str());
|
|
||||||
let name = tool_call.get("function")
|
|
||||||
.and_then(|f| f.get("name"))
|
|
||||||
.and_then(|n| n.as_str());
|
|
||||||
let arguments = tool_call.get("function")
|
|
||||||
.and_then(|f| f.get("arguments"))
|
|
||||||
.and_then(|a| a.as_str());
|
|
||||||
accumulator.add_tool_call(index, id, name, arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = accumulator.build_response(self.model_id.clone());
|
|
||||||
tracing::debug!(
|
|
||||||
content_len = response.content.len(),
|
|
||||||
tool_calls_count = response.tool_calls.len(),
|
|
||||||
has_reasoning = response.reasoning_content.is_some(),
|
|
||||||
"Streaming response built"
|
|
||||||
);
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
@ -575,26 +281,6 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
// 检查是否启用流式输出
|
|
||||||
if self.is_streaming_enabled() {
|
|
||||||
// 优先尝试流式输出
|
|
||||||
match self.chat_streaming(&request).await {
|
|
||||||
Ok(response) => return Ok(response),
|
|
||||||
Err(e) => {
|
|
||||||
tracing::debug!(
|
|
||||||
provider = %self.name,
|
|
||||||
model = %self.model_id,
|
|
||||||
error = %e,
|
|
||||||
"Streaming failed, falling back to non-streaming"
|
|
||||||
);
|
|
||||||
// 流式失败,回退到非流式
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tracing::debug!(provider = %self.name, model = %self.model_id, "Streaming disabled, using non-streaming");
|
|
||||||
}
|
|
||||||
|
|
||||||
// 非流式回退实现
|
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
let body = self.build_request_body(&request);
|
let body = self.build_request_body(&request);
|
||||||
|
|||||||
@ -1,30 +1,18 @@
|
|||||||
use super::types::SubagentType;
|
use super::types::SubagentType;
|
||||||
use crate::config::LLMProviderConfig;
|
|
||||||
|
|
||||||
/// 子代理系统提示词构建器
|
/// 子代理系统提示词构建器
|
||||||
pub struct SubagentPromptBuilder;
|
pub struct SubagentPromptBuilder;
|
||||||
|
|
||||||
impl SubagentPromptBuilder {
|
impl SubagentPromptBuilder {
|
||||||
/// 构建子代理系统提示词(包含系统环境信息和技能索引)
|
/// 构建子代理系统提示词
|
||||||
pub fn build(
|
pub fn build(
|
||||||
subagent_type: SubagentType,
|
subagent_type: SubagentType,
|
||||||
description: &str,
|
description: &str,
|
||||||
_prompt: &str,
|
_prompt: &str,
|
||||||
config: &LLMProviderConfig,
|
|
||||||
skills_index: Option<&str>,
|
|
||||||
) -> String {
|
) -> String {
|
||||||
let base_prompt = match subagent_type {
|
match subagent_type {
|
||||||
SubagentType::General => Self::build_general_prompt(description),
|
SubagentType::General => Self::build_general_prompt(description),
|
||||||
SubagentType::Explore => Self::build_explore_prompt(description),
|
SubagentType::Explore => Self::build_explore_prompt(description),
|
||||||
};
|
|
||||||
let env_info = crate::agent::generate_system_env_prompt(config);
|
|
||||||
|
|
||||||
// 组合提示词:基础 + 环境 + 技能索引(可选)
|
|
||||||
match skills_index {
|
|
||||||
Some(index) if !index.is_empty() => {
|
|
||||||
format!("{}\n\n{}\n\n{}", base_prompt, env_info, index)
|
|
||||||
}
|
|
||||||
_ => format!("{}\n\n{}", base_prompt, env_info),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -28,8 +28,6 @@ pub struct SubAgentRuntimeConfig {
|
|||||||
pub explore_max_tool_calls: usize,
|
pub explore_max_tool_calls: usize,
|
||||||
/// 任务 TTL(小时)
|
/// 任务 TTL(小时)
|
||||||
pub ttl_hours: u64,
|
pub ttl_hours: u64,
|
||||||
/// 技能索引(可选,预生成的技能列表字符串)
|
|
||||||
pub skills_index: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SubAgentRuntimeConfig {
|
impl Default for SubAgentRuntimeConfig {
|
||||||
@ -53,7 +51,6 @@ impl Default for SubAgentRuntimeConfig {
|
|||||||
explore_max_execution_secs: 600, // 10分钟
|
explore_max_execution_secs: 600, // 10分钟
|
||||||
explore_max_tool_calls: 20,
|
explore_max_tool_calls: 20,
|
||||||
ttl_hours: 24,
|
ttl_hours: 24,
|
||||||
skills_index: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -284,8 +281,6 @@ impl SubAgentRuntime for DefaultSubAgentRuntime {
|
|||||||
task.subagent_type,
|
task.subagent_type,
|
||||||
&task.description,
|
&task.description,
|
||||||
&task.prompt,
|
&task.prompt,
|
||||||
&self.provider_config,
|
|
||||||
self.config.skills_index.as_deref(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// 5. 创建子代理
|
// 5. 创建子代理
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user