可观测性改善,llm api兼容性改善
This commit is contained in:
parent
98eb7bea3d
commit
0e146a8f2a
@ -226,6 +226,7 @@ pub struct AgentLoop {
|
||||
max_iterations: usize,
|
||||
workspace_dir: PathBuf,
|
||||
model_name: String,
|
||||
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -247,6 +248,7 @@ impl AgentLoop {
|
||||
provider: Arc::from(provider),
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
observer: None,
|
||||
notify_tx: None,
|
||||
max_iterations,
|
||||
workspace_dir,
|
||||
model_name,
|
||||
@ -265,6 +267,7 @@ impl AgentLoop {
|
||||
provider: Arc::from(provider),
|
||||
tools,
|
||||
observer: None,
|
||||
notify_tx: None,
|
||||
max_iterations,
|
||||
workspace_dir,
|
||||
model_name,
|
||||
@ -277,6 +280,7 @@ impl AgentLoop {
|
||||
provider,
|
||||
tools: Arc::new(ToolRegistry::new()),
|
||||
observer: None,
|
||||
notify_tx: None,
|
||||
max_iterations,
|
||||
workspace_dir,
|
||||
model_name,
|
||||
@ -295,6 +299,7 @@ impl AgentLoop {
|
||||
provider,
|
||||
tools,
|
||||
observer: None,
|
||||
notify_tx: None,
|
||||
max_iterations,
|
||||
workspace_dir,
|
||||
model_name,
|
||||
@ -313,6 +318,11 @@ impl AgentLoop {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_notify(mut self, tx: tokio::sync::mpsc::UnboundedSender<String>) -> Self {
|
||||
self.notify_tx = Some(tx);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> &Arc<ToolRegistry> {
|
||||
&self.tools
|
||||
}
|
||||
@ -390,12 +400,16 @@ impl AgentLoop {
|
||||
});
|
||||
}
|
||||
|
||||
// Execute tool calls — log tool names and args before execution
|
||||
// Execute tool calls — log and notify immediately
|
||||
{
|
||||
let tools_info: Vec<String> = response.tool_calls.iter()
|
||||
.map(|tc| {
|
||||
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
||||
format!("{}:{}", tc.name, args)
|
||||
let s = format!("{}:{}", tc.name, args);
|
||||
if let Some(ref tx) = self.notify_tx {
|
||||
let _ = tx.send(format!("调用工具 {}", s));
|
||||
}
|
||||
s
|
||||
})
|
||||
.collect();
|
||||
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");
|
||||
|
||||
@ -38,7 +38,6 @@ impl SystemPromptBuilder {
|
||||
Self {
|
||||
sections: vec![
|
||||
Box::new(ToolHonestySection),
|
||||
Box::new(NoToolNarrationSection),
|
||||
Box::new(YourTaskSection),
|
||||
Box::new(SafetySection),
|
||||
Box::new(WorkspaceSection),
|
||||
@ -82,30 +81,11 @@ impl PromptSection for ToolHonestySection {
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
"## CRITICAL: Tool Honesty
|
||||
"## 关键规则:工具诚实性
|
||||
|
||||
- NEVER fabricate, invent, or guess tool results. If a tool returns empty results, say \"No results found.\"
|
||||
- If a tool call fails, report the error - never make up data to fill the gap.
|
||||
- When unsure whether a tool call succeeded, ask the user rather than guessing."
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Critical rule: never narrate tool usage.
|
||||
pub struct NoToolNarrationSection;
|
||||
|
||||
impl PromptSection for NoToolNarrationSection {
|
||||
fn name(&self) -> &str {
|
||||
"no_narration"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
"## CRITICAL: No Tool Narration
|
||||
|
||||
NEVER narrate, announce, describe, or explain your tool usage to the user.
|
||||
Do NOT say things like \"Let me check...\", \"I will use bash to...\", \"I'll fetch that for you\", \"Searching now...\", or similar.
|
||||
The user must ONLY see the final answer. Tool calls are invisible infrastructure - never reference them.
|
||||
If you catch yourself starting a sentence about what tool you are about to use or just used, DELETE it and give the answer directly."
|
||||
- 绝对不要编造、虚构或猜测工具结果。如果工具返回空结果,说\"没有找到结果\"。
|
||||
- 如果工具调用失败,报告错误——绝不要编造数据来填补空白。
|
||||
- 当不确定工具调用是否成功时,询问用户而不是猜测。"
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
@ -123,7 +103,7 @@ impl PromptSection for ToolsSection {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let mut output = String::from("## Tools\n\nYou have access to the following tools:\n\n");
|
||||
let mut output = String::from("## 工具\n\n你可以使用以下工具:\n\n");
|
||||
for (name, tool) in ctx.tools.iter() {
|
||||
let _ = writeln!(output, "- **{}**: {}", name, tool.description());
|
||||
}
|
||||
@ -140,11 +120,11 @@ impl PromptSection for YourTaskSection {
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
"## Your Task
|
||||
"## 你的任务
|
||||
|
||||
When the user sends a message, ACT on it. Use the tools to fulfill their request.
|
||||
Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions.
|
||||
Instead: use tools directly when needed, and give the final answer when done."
|
||||
当用户发送消息时,立即行动。使用工具来完成他们的请求。
|
||||
不要:总结此配置、描述你的能力、用元评论回复、或输出逐步指令。
|
||||
而是:在需要时直接使用工具,完成后给出最终答案。"
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
@ -158,13 +138,13 @@ impl PromptSection for SafetySection {
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
"## Safety
|
||||
"## 安全规则
|
||||
|
||||
- Do not exfiltrate private data.
|
||||
- Do not run destructive commands without asking.
|
||||
- Do not bypass oversight or approval mechanisms.
|
||||
- Prefer safe operations over risky ones.
|
||||
- When in doubt, ask before acting externally."
|
||||
- 不要泄露隐私数据。
|
||||
- 未经询问不要执行破坏性命令。
|
||||
- 不要绕过监督或审批机制。
|
||||
- 优先选择安全操作而非风险操作。
|
||||
- 不确定时,在外部操作前先询问。"
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
@ -184,7 +164,7 @@ impl PromptSection for WorkspaceSection {
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| ctx.workspace_dir.to_path_buf());
|
||||
format!(
|
||||
"## Workspace\n\nWorking directory: `{}`\n\n### File Storage Guidelines\n\n- **Generated files**: Store all generated files (code, documents, artifacts) in the workspace directory or its subdirectories.\n- **Downloaded files**: Save downloaded files to the workspace directory, organized by task.\n- **One task, one folder**: Create a dedicated subfolder for each task or project (e.g., `task_2024_01_01/`).\n- **Temporary files**: If files are only needed during processing and won't be kept, use `/tmp/` or create a temp folder (e.g., `/tmp/picobot_task_xxx/`) instead of cluttering the workspace.\n\n### Working Directory Structure\n\nThe workspace is your home base for this session. Keep it organized by creating subdirectories for different tasks.",
|
||||
"## 工作目录\n\n工作目录:`{}`\n\n### 文件存储规范\n\n- **生成的文件**:将所有生成的文件(代码、文档、制品)存放在工作目录或其子目录中。\n- **下载的文件**:将下载的文件保存到工作目录,按任务整理。\n- **一个任务一个文件夹**:为每个任务或项目创建专用的子文件夹(如 `task_2024_01_01/`)。\n- **临时文件**:如果文件仅在处理期间需要且不保留,使用 `/tmp/` 或创建临时文件夹(如 `/tmp/picobot_task_xxx/`),以免弄乱工作目录。\n\n### 目录结构\n\n工作目录是你在本会话中的操作大本营。通过为不同任务创建子目录来保持整洁。",
|
||||
abs_path.display()
|
||||
)
|
||||
}
|
||||
@ -199,7 +179,7 @@ impl PromptSection for UserProfileSection {
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
let mut output = String::from("## User Profile\n\n");
|
||||
let mut output = String::from("## 用户配置\n\n");
|
||||
|
||||
// Load USER.md from ~/.picobot/USER.md
|
||||
if let Some(user_config_dir) = get_user_config_dir() {
|
||||
@ -227,7 +207,7 @@ impl PromptSection for DateTimeSection {
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
let now = chrono::Local::now();
|
||||
format!(
|
||||
"## Current Date & Time\n\n{} ({})",
|
||||
"## 当前日期与时间\n\n{} ({})",
|
||||
now.format("%Y-%m-%d %H:%M:%S"),
|
||||
now.format("%Z")
|
||||
)
|
||||
@ -289,7 +269,7 @@ impl PromptSection for RuntimeSection {
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
format!(
|
||||
"## Runtime\n\nHost: {} | OS: {} | Model: {}",
|
||||
"## 运行环境\n\n主机: {} | 操作系统: {} | 模型: {}",
|
||||
host,
|
||||
std::env::consts::OS,
|
||||
ctx.model_name
|
||||
@ -321,7 +301,7 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
|
||||
.unwrap_or(trimmed)
|
||||
.to_string()
|
||||
+ &format!(
|
||||
"\n\n[... truncated at {} characters - use file_read for full file]",
|
||||
"\n\n[... 已截断至 {} 字符 - 使用 file_read 获取完整文件]",
|
||||
max_chars
|
||||
)
|
||||
} else {
|
||||
@ -361,12 +341,11 @@ mod tests {
|
||||
|
||||
let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
|
||||
|
||||
assert!(prompt.contains("## CRITICAL: Tool Honesty"));
|
||||
assert!(prompt.contains("## CRITICAL: No Tool Narration"));
|
||||
assert!(prompt.contains("## Safety"));
|
||||
assert!(prompt.contains("## Workspace"));
|
||||
assert!(prompt.contains("## Current Date & Time"));
|
||||
assert!(prompt.contains("## Runtime"));
|
||||
assert!(prompt.contains("## 关键规则:工具诚实性"));
|
||||
assert!(prompt.contains("## 安全规则"));
|
||||
assert!(prompt.contains("## 工作目录"));
|
||||
assert!(prompt.contains("## 当前日期与时间"));
|
||||
assert!(prompt.contains("## 运行环境"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@ -479,10 +479,16 @@ impl Channel for CliChatChannel {
|
||||
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
let clients = self.clients.lock().await.clone();
|
||||
for client in clients {
|
||||
let outbound = WsOutbound::AssistantResponse {
|
||||
id: short_id(),
|
||||
content: msg.content.clone(),
|
||||
role: "assistant".to_string(),
|
||||
let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
|
||||
WsOutbound::SystemNotification {
|
||||
content: msg.content.clone(),
|
||||
}
|
||||
} else {
|
||||
WsOutbound::AssistantResponse {
|
||||
id: short_id(),
|
||||
content: msg.content.clone(),
|
||||
role: "assistant".to_string(),
|
||||
}
|
||||
};
|
||||
let _ = client.sender.send(outbound).await;
|
||||
}
|
||||
|
||||
@ -6,13 +6,8 @@ use std::collections::HashMap;
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
|
||||
use super::traits::Usage;
|
||||
|
||||
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
|
||||
}
|
||||
use std::sync::Arc;
|
||||
use crate::storage::Storage;
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
|
||||
blocks.iter().map(|b| match b {
|
||||
@ -62,6 +57,7 @@ pub struct AnthropicProvider {
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
model_extra: HashMap<String, serde_json::Value>,
|
||||
storage: Option<Arc<Storage>>,
|
||||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
@ -85,8 +81,13 @@ impl AnthropicProvider {
|
||||
temperature,
|
||||
max_tokens,
|
||||
model_extra,
|
||||
storage: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_storage(&mut self, storage: Arc<Storage>) {
|
||||
self.storage = Some(storage);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@ -104,7 +105,6 @@ struct AnthropicRequest {
|
||||
#[derive(Serialize)]
|
||||
struct AnthropicMessage {
|
||||
role: String,
|
||||
#[serde(serialize_with = "serialize_content_blocks")]
|
||||
content: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
@ -128,14 +128,23 @@ struct AnthropicResponse {
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum AnthropicContent {
|
||||
Text { text: String },
|
||||
Thinking { thinking: String },
|
||||
Text {
|
||||
#[serde(alias = "content")]
|
||||
text: String,
|
||||
},
|
||||
Thinking {
|
||||
#[serde(alias = "content")]
|
||||
thinking: String,
|
||||
},
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
#[serde(alias = "arguments")]
|
||||
input: serde_json::Value,
|
||||
},
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -152,6 +161,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let start = std::time::Instant::now();
|
||||
let url = format!("{}/v1/messages", self.base_url);
|
||||
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
|
||||
|
||||
@ -190,7 +200,19 @@ impl LLMProvider for AnthropicProvider {
|
||||
"content": output,
|
||||
})]
|
||||
} else {
|
||||
convert_content_blocks(&m.content)
|
||||
let mut blocks = convert_content_blocks(&m.content);
|
||||
// Append tool_use blocks from assistant messages with tool calls
|
||||
if let Some(ref tool_calls) = m.tool_calls {
|
||||
for tc in tool_calls {
|
||||
blocks.push(serde_json::json!({
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.name,
|
||||
"input": tc.arguments,
|
||||
}));
|
||||
}
|
||||
}
|
||||
blocks
|
||||
};
|
||||
AnthropicMessage { role, content }
|
||||
})
|
||||
@ -212,10 +234,14 @@ impl LLMProvider for AnthropicProvider {
|
||||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||
|
||||
let resp = req_builder.json(&body).send().await?;
|
||||
|
||||
let status = resp.status();
|
||||
let body_text = resp.text().await?;
|
||||
tracing::debug!(status = %status, resp_body = %body_text, "LLM response");
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
|
||||
@ -227,11 +253,33 @@ impl LLMProvider for AnthropicProvider {
|
||||
.map(|s| s.to_string())
|
||||
})
|
||||
.unwrap_or_else(|| body_text.clone());
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&body_text), Some(&error_msg),
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
}
|
||||
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
||||
}
|
||||
|
||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
||||
.map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?;
|
||||
.map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
let model = self.model_id.clone();
|
||||
let req = req_body_str.clone();
|
||||
let resp_body = body_text.clone();
|
||||
let dur = start.elapsed().as_millis() as u64;
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
|
||||
});
|
||||
}
|
||||
err_msg
|
||||
})?;
|
||||
|
||||
let mut content = String::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
@ -247,6 +295,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
}
|
||||
}
|
||||
AnthropicContent::Thinking { .. } => {}
|
||||
AnthropicContent::Unknown => {}
|
||||
AnthropicContent::ToolUse { id, name, input } => {
|
||||
tool_calls.push(ToolCall {
|
||||
id: id.clone(),
|
||||
@ -257,7 +306,7 @@ impl LLMProvider for AnthropicProvider {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ChatCompletionResponse {
|
||||
let response = ChatCompletionResponse {
|
||||
id: anthropic_resp.id.unwrap_or_default(),
|
||||
model: anthropic_resp.model.unwrap_or_default(),
|
||||
content,
|
||||
@ -267,7 +316,20 @@ impl LLMProvider for AnthropicProvider {
|
||||
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
||||
},
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
&self.name,
|
||||
&self.model_id,
|
||||
&req_body_str,
|
||||
Some(&body_text),
|
||||
None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn ptype(&self) -> &str {
|
||||
|
||||
@ -7,6 +7,8 @@ use std::collections::HashMap;
|
||||
use crate::bus::message::ContentBlock;
|
||||
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
|
||||
use super::traits::Usage;
|
||||
use std::sync::Arc;
|
||||
use crate::storage::Storage;
|
||||
|
||||
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
|
||||
if blocks.len() == 1 {
|
||||
@ -32,6 +34,7 @@ pub struct OpenAIProvider {
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
model_extra: HashMap<String, serde_json::Value>,
|
||||
storage: Option<Arc<Storage>>,
|
||||
}
|
||||
|
||||
impl OpenAIProvider {
|
||||
@ -55,9 +58,14 @@ impl OpenAIProvider {
|
||||
temperature,
|
||||
max_tokens,
|
||||
model_extra,
|
||||
storage: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_storage(&mut self, storage: Arc<Storage>) {
|
||||
self.storage = Some(storage);
|
||||
}
|
||||
|
||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||
let mut body = json!({
|
||||
"model": self.model_id,
|
||||
@ -162,6 +170,7 @@ impl LLMProvider for OpenAIProvider {
|
||||
&self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let start = std::time::Instant::now();
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let body = self.build_request_body(&request);
|
||||
@ -200,24 +209,44 @@ impl LLMProvider for OpenAIProvider {
|
||||
req_builder = req_builder.header(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
|
||||
tracing::debug!(req_body = %req_body_str, "LLM request");
|
||||
|
||||
let resp = req_builder.json(&body).send().await?;
|
||||
|
||||
let status = resp.status();
|
||||
let text = resp.text().await?;
|
||||
|
||||
// Debug: Log LLM response (only in debug builds)
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let resp_preview: String = text.chars().take(100).collect();
|
||||
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
|
||||
}
|
||||
tracing::debug!(status = %status, resp_body = %text, "LLM response");
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(format!("API error {}: {}", status, text).into());
|
||||
let error = format!("API error {}: {}", status, text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&text), Some(&error),
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
}
|
||||
return Err(error.into());
|
||||
}
|
||||
|
||||
let openai_resp: OpenAIResponse = serde_json::from_str(&text)
|
||||
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?;
|
||||
.map_err(|e| {
|
||||
let err_msg = format!("decode error: {} | body: {}", e, &text);
|
||||
if let Some(ref storage) = self.storage {
|
||||
let name = self.name.clone();
|
||||
let model = self.model_id.clone();
|
||||
let req = req_body_str.clone();
|
||||
let resp = text.clone();
|
||||
let dur = start.elapsed().as_millis() as u64;
|
||||
let err = err_msg.clone();
|
||||
let s = storage.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await;
|
||||
});
|
||||
}
|
||||
err_msg
|
||||
})?;
|
||||
|
||||
let content = openai_resp.choices[0]
|
||||
.message
|
||||
@ -237,7 +266,7 @@ impl LLMProvider for OpenAIProvider {
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ChatCompletionResponse {
|
||||
let response = ChatCompletionResponse {
|
||||
id: openai_resp.id,
|
||||
model: openai_resp.model,
|
||||
content,
|
||||
@ -247,7 +276,17 @@ impl LLMProvider for OpenAIProvider {
|
||||
completion_tokens: openai_resp.usage.completion_tokens,
|
||||
total_tokens: openai_resp.usage.total_tokens,
|
||||
},
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(ref storage) = self.storage {
|
||||
let _ = storage.append_llm_call(
|
||||
&self.name, &self.model_id, &req_body_str,
|
||||
Some(&text), None,
|
||||
start.elapsed().as_millis() as u64,
|
||||
).await;
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
fn ptype(&self) -> &str {
|
||||
|
||||
@ -123,4 +123,6 @@ pub trait LLMProvider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
fn model_id(&self) -> &str;
|
||||
|
||||
fn set_storage(&mut self, _storage: std::sync::Arc<crate::storage::Storage>) {}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
|
||||
@ -21,7 +21,6 @@ use crate::config::LLMProviderConfig;
|
||||
use crate::agent::{AgentLoop, AgentError, ContextCompressor};
|
||||
use crate::agent::system_prompt::build_system_prompt;
|
||||
use crate::agent::context_compressor::ContextCompressionConfig;
|
||||
use crate::protocol::WsOutbound;
|
||||
use crate::providers::{create_provider, LLMProvider};
|
||||
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
|
||||
use crate::session::events::DialogInfo;
|
||||
@ -49,7 +48,6 @@ pub struct Session {
|
||||
messages: Vec<ChatMessage>,
|
||||
seq_counter: i64,
|
||||
|
||||
pub user_tx: mpsc::Sender<WsOutbound>,
|
||||
provider_config: LLMProviderConfig,
|
||||
provider: Arc<dyn LLMProvider>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
@ -63,14 +61,16 @@ impl Session {
|
||||
pub async fn new(
|
||||
id: UnifiedSessionId,
|
||||
provider_config: LLMProviderConfig,
|
||||
user_tx: mpsc::Sender<WsOutbound>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
storage: Option<StdArc<Storage>>,
|
||||
routing_info: String,
|
||||
title: String,
|
||||
) -> Result<Self, AgentError> {
|
||||
let provider_box = create_provider(provider_config.clone())
|
||||
let mut provider_box = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
if let Some(ref s) = storage {
|
||||
provider_box.set_storage(s.clone());
|
||||
}
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||
|
||||
let compressor_config = ContextCompressionConfig {
|
||||
@ -89,7 +89,6 @@ impl Session {
|
||||
total_message_count: 0,
|
||||
messages: Vec::new(),
|
||||
seq_counter: 1,
|
||||
user_tx,
|
||||
provider_config: provider_config.clone(),
|
||||
provider: provider.clone(),
|
||||
tools,
|
||||
@ -103,7 +102,6 @@ impl Session {
|
||||
pub async fn from_storage(
|
||||
id: UnifiedSessionId,
|
||||
provider_config: LLMProviderConfig,
|
||||
user_tx: mpsc::Sender<WsOutbound>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
storage: StdArc<Storage>,
|
||||
) -> Result<Self, AgentError> {
|
||||
@ -113,8 +111,9 @@ impl Session {
|
||||
let messages = storage.load_messages(&id.to_string(), 0).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
|
||||
|
||||
let provider_box = create_provider(provider_config.clone())
|
||||
let mut provider_box = create_provider(provider_config.clone())
|
||||
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
|
||||
provider_box.set_storage(storage.clone());
|
||||
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
|
||||
|
||||
let compressor_config = ContextCompressionConfig {
|
||||
@ -123,6 +122,7 @@ impl Session {
|
||||
};
|
||||
|
||||
// Convert MessageMeta to ChatMessage
|
||||
// Clear tool_call_id/tool_name — they're not valid across API sessions
|
||||
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
|
||||
ChatMessage {
|
||||
id: m.id,
|
||||
@ -130,8 +130,8 @@ impl Session {
|
||||
content: m.content,
|
||||
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
|
||||
timestamp: m.created_at,
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
@ -149,7 +149,6 @@ impl Session {
|
||||
total_message_count,
|
||||
messages: chat_messages,
|
||||
seq_counter,
|
||||
user_tx,
|
||||
provider_config: provider_config.clone(),
|
||||
provider: provider.clone(),
|
||||
tools,
|
||||
@ -252,18 +251,6 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&self, msg: WsOutbound) {
|
||||
let _ = self.user_tx.send(msg).await;
|
||||
}
|
||||
|
||||
/// 发送系统通知(不记录进 session 历史)
|
||||
pub async fn send_system_notification(&self, content: &str) {
|
||||
let msg = WsOutbound::SystemNotification {
|
||||
content: content.to_string(),
|
||||
};
|
||||
let _ = self.user_tx.send(msg).await;
|
||||
}
|
||||
|
||||
/// 将 session 元数据写回 Storage
|
||||
pub async fn persist_session_meta(&self) -> Result<(), StorageError> {
|
||||
if let Some(ref storage) = self.storage {
|
||||
@ -364,6 +351,14 @@ impl Session {
|
||||
))
|
||||
}
|
||||
|
||||
/// 创建一个附通知通道的 AgentLoop 实例
|
||||
pub fn create_agent_with_notify(
|
||||
&self,
|
||||
notify_tx: tokio::sync::mpsc::UnboundedSender<String>,
|
||||
) -> Result<AgentLoop, AgentError> {
|
||||
Ok(self.create_agent()?.with_notify(notify_tx))
|
||||
}
|
||||
|
||||
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills)
|
||||
pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
|
||||
let base_prompt = build_system_prompt(
|
||||
@ -874,11 +869,9 @@ impl SessionManager {
|
||||
self.storage.upsert_session(&meta).await
|
||||
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
|
||||
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::new(
|
||||
unified_id.clone(),
|
||||
self.provider_config.clone(),
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
Some(self.storage.clone()),
|
||||
routing_info,
|
||||
@ -909,11 +902,9 @@ impl SessionManager {
|
||||
match self.storage.get_session(&session_id_str).await {
|
||||
Ok(meta) => {
|
||||
tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage");
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::from_storage(
|
||||
unified_id.clone(),
|
||||
self.provider_config.clone(),
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
self.storage.clone(),
|
||||
).await?;
|
||||
@ -932,11 +923,9 @@ impl SessionManager {
|
||||
}
|
||||
|
||||
// Create new session
|
||||
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
|
||||
let session = Session::new(
|
||||
unified_id.clone(),
|
||||
self.provider_config.clone(),
|
||||
user_tx,
|
||||
self.tools.clone(),
|
||||
Some(self.storage.clone()),
|
||||
String::new(),
|
||||
@ -1175,6 +1164,30 @@ impl SessionManager {
|
||||
}
|
||||
|
||||
// Normal message handling through LLM
|
||||
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
// Spawn notification publisher — sends immediately when tools are detected
|
||||
{
|
||||
let bus = self.bus.clone();
|
||||
let ch = channel.to_string();
|
||||
let cid = chat_id.to_string();
|
||||
tokio::spawn(async move {
|
||||
while let Some(notif) = notify_rx.recv().await {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("_type".to_string(), "notification".to_string());
|
||||
let outbound = OutboundMessage {
|
||||
channel: ch.clone(),
|
||||
chat_id: cid.clone(),
|
||||
content: notif,
|
||||
reply_to: None,
|
||||
media: vec![],
|
||||
metadata,
|
||||
};
|
||||
let _ = bus.publish_outbound(outbound).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let response: String = {
|
||||
let mut session_guard = session.lock().await;
|
||||
|
||||
@ -1202,7 +1215,7 @@ impl SessionManager {
|
||||
.compress_if_needed(history)
|
||||
.await?;
|
||||
|
||||
let agent = session_guard.create_agent()?;
|
||||
let agent = session_guard.create_agent_with_notify(notify_tx)?;
|
||||
let result = agent.process(history).await?;
|
||||
|
||||
for msg in result.emitted_messages {
|
||||
@ -1322,7 +1335,6 @@ impl OutboundMessenger for SessionManager {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
fn test_provider_config() -> LLMProviderConfig {
|
||||
LLMProviderConfig {
|
||||
|
||||
@ -92,6 +92,70 @@ impl Storage {
|
||||
.await
|
||||
.ok();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS llm_calls (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
created_at INTEGER NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
model TEXT NOT NULL,
|
||||
request_body TEXT NOT NULL,
|
||||
response_body TEXT,
|
||||
error TEXT,
|
||||
duration_ms INTEGER
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn append_llm_call(
|
||||
&self,
|
||||
provider: &str,
|
||||
model: &str,
|
||||
request_body: &str,
|
||||
response_body: Option<&str>,
|
||||
error: Option<&str>,
|
||||
duration_ms: u64,
|
||||
) -> Result<(), StorageError> {
|
||||
let now = chrono::Utc::now().timestamp_millis();
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO llm_calls (created_at, provider, model, request_body, response_body, error, duration_ms)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(now)
|
||||
.bind(provider)
|
||||
.bind(model)
|
||||
.bind(request_body)
|
||||
.bind(response_body)
|
||||
.bind(error)
|
||||
.bind(duration_ms as i64)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
// Prune to keep last 1000 records
|
||||
self.prune_llm_calls(1000).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn prune_llm_calls(&self, max_records: i64) -> Result<(), StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
DELETE FROM llm_calls WHERE id <= (
|
||||
SELECT COALESCE(MAX(id), 0) - ? FROM llm_calls
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(max_records)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user