可观测性改善,llm api兼容性改善

This commit is contained in:
xiaoski 2026-05-04 23:04:28 +08:00
parent 98eb7bea3d
commit 0e146a8f2a
8 changed files with 286 additions and 108 deletions

View File

@ -226,6 +226,7 @@ pub struct AgentLoop {
max_iterations: usize, max_iterations: usize,
workspace_dir: PathBuf, workspace_dir: PathBuf,
model_name: String, model_name: String,
notify_tx: Option<tokio::sync::mpsc::UnboundedSender<String>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -247,6 +248,7 @@ impl AgentLoop {
provider: Arc::from(provider), provider: Arc::from(provider),
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -265,6 +267,7 @@ impl AgentLoop {
provider: Arc::from(provider), provider: Arc::from(provider),
tools, tools,
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -277,6 +280,7 @@ impl AgentLoop {
provider, provider,
tools: Arc::new(ToolRegistry::new()), tools: Arc::new(ToolRegistry::new()),
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -295,6 +299,7 @@ impl AgentLoop {
provider, provider,
tools, tools,
observer: None, observer: None,
notify_tx: None,
max_iterations, max_iterations,
workspace_dir, workspace_dir,
model_name, model_name,
@ -313,6 +318,11 @@ impl AgentLoop {
self self
} }
pub fn with_notify(mut self, tx: tokio::sync::mpsc::UnboundedSender<String>) -> Self {
self.notify_tx = Some(tx);
self
}
pub fn tools(&self) -> &Arc<ToolRegistry> { pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools &self.tools
} }
@ -390,12 +400,16 @@ impl AgentLoop {
}); });
} }
// Execute tool calls — log tool names and args before execution // Execute tool calls — log and notify immediately
{ {
let tools_info: Vec<String> = response.tool_calls.iter() let tools_info: Vec<String> = response.tool_calls.iter()
.map(|tc| { .map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default(); let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
format!("{}:{}", tc.name, args) let s = format!("{}:{}", tc.name, args);
if let Some(ref tx) = self.notify_tx {
let _ = tx.send(format!("调用工具 {}", s));
}
s
}) })
.collect(); .collect();
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools"); tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");

View File

@ -38,7 +38,6 @@ impl SystemPromptBuilder {
Self { Self {
sections: vec![ sections: vec![
Box::new(ToolHonestySection), Box::new(ToolHonestySection),
Box::new(NoToolNarrationSection),
Box::new(YourTaskSection), Box::new(YourTaskSection),
Box::new(SafetySection), Box::new(SafetySection),
Box::new(WorkspaceSection), Box::new(WorkspaceSection),
@ -82,30 +81,11 @@ impl PromptSection for ToolHonestySection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## CRITICAL: Tool Honesty "## 关键规则:工具诚实性
- NEVER fabricate, invent, or guess tool results. If a tool returns empty results, say \"No results found.\" - \"没有找到结果\"
- If a tool call fails, report the error - never make up data to fill the gap. -
- When unsure whether a tool call succeeded, ask the user rather than guessing." - "
.to_string()
}
}
/// Critical rule: never narrate tool usage.
pub struct NoToolNarrationSection;
impl PromptSection for NoToolNarrationSection {
fn name(&self) -> &str {
"no_narration"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## CRITICAL: No Tool Narration
NEVER narrate, announce, describe, or explain your tool usage to the user.
Do NOT say things like \"Let me check...\", \"I will use bash to...\", \"I'll fetch that for you\", \"Searching now...\", or similar.
The user must ONLY see the final answer. Tool calls are invisible infrastructure - never reference them.
If you catch yourself starting a sentence about what tool you are about to use or just used, DELETE it and give the answer directly."
.to_string() .to_string()
} }
} }
@ -123,7 +103,7 @@ impl PromptSection for ToolsSection {
return String::new(); return String::new();
} }
let mut output = String::from("## Tools\n\nYou have access to the following tools:\n\n"); let mut output = String::from("## 工具\n\n你可以使用以下工具:\n\n");
for (name, tool) in ctx.tools.iter() { for (name, tool) in ctx.tools.iter() {
let _ = writeln!(output, "- **{}**: {}", name, tool.description()); let _ = writeln!(output, "- **{}**: {}", name, tool.description());
} }
@ -140,11 +120,11 @@ impl PromptSection for YourTaskSection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## Your Task "## 你的任务
When the user sends a message, ACT on it. Use the tools to fulfill their request. 使
Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions.
Instead: use tools directly when needed, and give the final answer when done." 使"
.to_string() .to_string()
} }
} }
@ -158,13 +138,13 @@ impl PromptSection for SafetySection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
"## Safety "## 安全规则
- Do not exfiltrate private data. -
- Do not run destructive commands without asking. -
- Do not bypass oversight or approval mechanisms. -
- Prefer safe operations over risky ones. -
- When in doubt, ask before acting externally." - "
.to_string() .to_string()
} }
} }
@ -184,7 +164,7 @@ impl PromptSection for WorkspaceSection {
.canonicalize() .canonicalize()
.unwrap_or_else(|_| ctx.workspace_dir.to_path_buf()); .unwrap_or_else(|_| ctx.workspace_dir.to_path_buf());
format!( format!(
"## Workspace\n\nWorking directory: `{}`\n\n### File Storage Guidelines\n\n- **Generated files**: Store all generated files (code, documents, artifacts) in the workspace directory or its subdirectories.\n- **Downloaded files**: Save downloaded files to the workspace directory, organized by task.\n- **One task, one folder**: Create a dedicated subfolder for each task or project (e.g., `task_2024_01_01/`).\n- **Temporary files**: If files are only needed during processing and won't be kept, use `/tmp/` or create a temp folder (e.g., `/tmp/picobot_task_xxx/`) instead of cluttering the workspace.\n\n### Working Directory Structure\n\nThe workspace is your home base for this session. Keep it organized by creating subdirectories for different tasks.", "## 工作目录\n\n工作目录:`{}`\n\n### 文件存储规范\n\n- **生成的文件**:将所有生成的文件(代码、文档、制品)存放在工作目录或其子目录中。\n- **下载的文件**:将下载的文件保存到工作目录,按任务整理。\n- **一个任务一个文件夹**:为每个任务或项目创建专用的子文件夹(如 `task_2024_01_01/`)。\n- **临时文件**:如果文件仅在处理期间需要且不保留,使用 `/tmp/` 或创建临时文件夹(如 `/tmp/picobot_task_xxx/`),以免弄乱工作目录。\n\n### 目录结构\n\n工作目录是你在本会话中的操作大本营。通过为不同任务创建子目录来保持整洁。",
abs_path.display() abs_path.display()
) )
} }
@ -199,7 +179,7 @@ impl PromptSection for UserProfileSection {
} }
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
let mut output = String::from("## User Profile\n\n"); let mut output = String::from("## 用户配置\n\n");
// Load USER.md from ~/.picobot/USER.md // Load USER.md from ~/.picobot/USER.md
if let Some(user_config_dir) = get_user_config_dir() { if let Some(user_config_dir) = get_user_config_dir() {
@ -227,7 +207,7 @@ impl PromptSection for DateTimeSection {
fn build(&self, _ctx: &PromptContext<'_>) -> String { fn build(&self, _ctx: &PromptContext<'_>) -> String {
let now = chrono::Local::now(); let now = chrono::Local::now();
format!( format!(
"## Current Date & Time\n\n{} ({})", "## 当前日期与时间\n\n{} ({})",
now.format("%Y-%m-%d %H:%M:%S"), now.format("%Y-%m-%d %H:%M:%S"),
now.format("%Z") now.format("%Z")
) )
@ -289,7 +269,7 @@ impl PromptSection for RuntimeSection {
.map(|h| h.to_string_lossy().to_string()) .map(|h| h.to_string_lossy().to_string())
.unwrap_or_else(|_| "unknown".to_string()); .unwrap_or_else(|_| "unknown".to_string());
format!( format!(
"## Runtime\n\nHost: {} | OS: {} | Model: {}", "## 运行环境\n\n主机: {} | 操作系统: {} | 模型: {}",
host, host,
std::env::consts::OS, std::env::consts::OS,
ctx.model_name ctx.model_name
@ -321,7 +301,7 @@ fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option<St
.unwrap_or(trimmed) .unwrap_or(trimmed)
.to_string() .to_string()
+ &format!( + &format!(
"\n\n[... truncated at {} characters - use file_read for full file]", "\n\n[... 已截断至 {} 字符 - 使用 file_read 获取完整文件]",
max_chars max_chars
) )
} else { } else {
@ -361,12 +341,11 @@ mod tests {
let prompt = SystemPromptBuilder::with_defaults().build(&ctx); let prompt = SystemPromptBuilder::with_defaults().build(&ctx);
assert!(prompt.contains("## CRITICAL: Tool Honesty")); assert!(prompt.contains("## 关键规则:工具诚实性"));
assert!(prompt.contains("## CRITICAL: No Tool Narration")); assert!(prompt.contains("## 安全规则"));
assert!(prompt.contains("## Safety")); assert!(prompt.contains("## 工作目录"));
assert!(prompt.contains("## Workspace")); assert!(prompt.contains("## 当前日期与时间"));
assert!(prompt.contains("## Current Date & Time")); assert!(prompt.contains("## 运行环境"));
assert!(prompt.contains("## Runtime"));
} }
#[test] #[test]

View File

@ -479,10 +479,16 @@ impl Channel for CliChatChannel {
async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> { async fn send(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let clients = self.clients.lock().await.clone(); let clients = self.clients.lock().await.clone();
for client in clients { for client in clients {
let outbound = WsOutbound::AssistantResponse { let outbound = if msg.metadata.get("_type").map(|v| v.as_str()) == Some("notification") {
WsOutbound::SystemNotification {
content: msg.content.clone(),
}
} else {
WsOutbound::AssistantResponse {
id: short_id(), id: short_id(),
content: msg.content.clone(), content: msg.content.clone(),
role: "assistant".to_string(), role: "assistant".to_string(),
}
}; };
let _ = client.sender.send(outbound).await; let _ = client.sender.send(outbound).await;
} }

View File

@ -6,13 +6,8 @@ use std::collections::HashMap;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, Tool, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc;
fn serialize_content_blocks<S>(blocks: &[serde_json::Value], serializer: S) -> Result<S::Ok, S::Error> use crate::storage::Storage;
where
S: serde::Serializer,
{
serializer.serialize_str(&serde_json::to_string(blocks).unwrap_or_else(|_| "[]".to_string()))
}
fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> { fn convert_content_blocks(blocks: &[ContentBlock]) -> Vec<serde_json::Value> {
blocks.iter().map(|b| match b { blocks.iter().map(|b| match b {
@ -62,6 +57,7 @@ pub struct AnthropicProvider {
temperature: Option<f32>, temperature: Option<f32>,
max_tokens: Option<u32>, max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>, model_extra: HashMap<String, serde_json::Value>,
storage: Option<Arc<Storage>>,
} }
impl AnthropicProvider { impl AnthropicProvider {
@ -85,8 +81,13 @@ impl AnthropicProvider {
temperature, temperature,
max_tokens, max_tokens,
model_extra, model_extra,
storage: None,
} }
} }
pub fn set_storage(&mut self, storage: Arc<Storage>) {
self.storage = Some(storage);
}
} }
#[derive(Serialize)] #[derive(Serialize)]
@ -104,7 +105,6 @@ struct AnthropicRequest {
#[derive(Serialize)] #[derive(Serialize)]
struct AnthropicMessage { struct AnthropicMessage {
role: String, role: String,
#[serde(serialize_with = "serialize_content_blocks")]
content: Vec<serde_json::Value>, content: Vec<serde_json::Value>,
} }
@ -128,14 +128,23 @@ struct AnthropicResponse {
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
enum AnthropicContent { enum AnthropicContent {
Text { text: String }, Text {
Thinking { thinking: String }, #[serde(alias = "content")]
text: String,
},
Thinking {
#[serde(alias = "content")]
thinking: String,
},
#[serde(rename = "tool_use")] #[serde(rename = "tool_use")]
ToolUse { ToolUse {
id: String, id: String,
name: String, name: String,
#[serde(alias = "arguments")]
input: serde_json::Value, input: serde_json::Value,
}, },
#[serde(other)]
Unknown,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -152,6 +161,7 @@ impl LLMProvider for AnthropicProvider {
&self, &self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let start = std::time::Instant::now();
let url = format!("{}/v1/messages", self.base_url); let url = format!("{}/v1/messages", self.base_url);
let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024); let max_tokens = request.max_tokens.or(self.max_tokens).unwrap_or(1024);
@ -190,7 +200,19 @@ impl LLMProvider for AnthropicProvider {
"content": output, "content": output,
})] })]
} else { } else {
convert_content_blocks(&m.content) let mut blocks = convert_content_blocks(&m.content);
// Append tool_use blocks from assistant messages with tool calls
if let Some(ref tool_calls) = m.tool_calls {
for tc in tool_calls {
blocks.push(serde_json::json!({
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments,
}));
}
}
blocks
}; };
AnthropicMessage { role, content } AnthropicMessage { role, content }
}) })
@ -212,10 +234,14 @@ impl LLMProvider for AnthropicProvider {
req_builder = req_builder.header(key.as_str(), value.as_str()); req_builder = req_builder.header(key.as_str(), value.as_str());
} }
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await?; let resp = req_builder.json(&body).send().await?;
let status = resp.status(); let status = resp.status();
let body_text = resp.text().await?; let body_text = resp.text().await?;
tracing::debug!(status = %status, resp_body = %body_text, "LLM response");
if !status.is_success() { if !status.is_success() {
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text) let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
@ -227,11 +253,33 @@ impl LLMProvider for AnthropicProvider {
.map(|s| s.to_string()) .map(|s| s.to_string())
}) })
.unwrap_or_else(|| body_text.clone()); .unwrap_or_else(|| body_text.clone());
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&body_text), Some(&error_msg),
start.elapsed().as_millis() as u64,
).await;
}
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into()); return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
} }
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text) let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
.map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?; .map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &body_text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp_body = body_text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp_body), Some(&err), dur).await;
});
}
err_msg
})?;
let mut content = String::new(); let mut content = String::new();
let mut tool_calls = Vec::new(); let mut tool_calls = Vec::new();
@ -247,6 +295,7 @@ impl LLMProvider for AnthropicProvider {
} }
} }
AnthropicContent::Thinking { .. } => {} AnthropicContent::Thinking { .. } => {}
AnthropicContent::Unknown => {}
AnthropicContent::ToolUse { id, name, input } => { AnthropicContent::ToolUse { id, name, input } => {
tool_calls.push(ToolCall { tool_calls.push(ToolCall {
id: id.clone(), id: id.clone(),
@ -257,7 +306,7 @@ impl LLMProvider for AnthropicProvider {
} }
} }
Ok(ChatCompletionResponse { let response = ChatCompletionResponse {
id: anthropic_resp.id.unwrap_or_default(), id: anthropic_resp.id.unwrap_or_default(),
model: anthropic_resp.model.unwrap_or_default(), model: anthropic_resp.model.unwrap_or_default(),
content, content,
@ -267,7 +316,20 @@ impl LLMProvider for AnthropicProvider {
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0), completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0), total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
}, },
}) };
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name,
&self.model_id,
&req_body_str,
Some(&body_text),
None,
start.elapsed().as_millis() as u64,
).await;
}
Ok(response)
} }
fn ptype(&self) -> &str { fn ptype(&self) -> &str {

View File

@ -7,6 +7,8 @@ use std::collections::HashMap;
use crate::bus::message::ContentBlock; use crate::bus::message::ContentBlock;
use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall}; use super::{ChatCompletionRequest, ChatCompletionResponse, LLMProvider, ToolCall};
use super::traits::Usage; use super::traits::Usage;
use std::sync::Arc;
use crate::storage::Storage;
fn convert_content_blocks(blocks: &[ContentBlock]) -> Value { fn convert_content_blocks(blocks: &[ContentBlock]) -> Value {
if blocks.len() == 1 { if blocks.len() == 1 {
@ -32,6 +34,7 @@ pub struct OpenAIProvider {
temperature: Option<f32>, temperature: Option<f32>,
max_tokens: Option<u32>, max_tokens: Option<u32>,
model_extra: HashMap<String, serde_json::Value>, model_extra: HashMap<String, serde_json::Value>,
storage: Option<Arc<Storage>>,
} }
impl OpenAIProvider { impl OpenAIProvider {
@ -55,9 +58,14 @@ impl OpenAIProvider {
temperature, temperature,
max_tokens, max_tokens,
model_extra, model_extra,
storage: None,
} }
} }
pub fn set_storage(&mut self, storage: Arc<Storage>) {
self.storage = Some(storage);
}
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
let mut body = json!({ let mut body = json!({
"model": self.model_id, "model": self.model_id,
@ -162,6 +170,7 @@ impl LLMProvider for OpenAIProvider {
&self, &self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
let start = std::time::Instant::now();
let url = format!("{}/chat/completions", self.base_url); let url = format!("{}/chat/completions", self.base_url);
let body = self.build_request_body(&request); let body = self.build_request_body(&request);
@ -200,24 +209,44 @@ impl LLMProvider for OpenAIProvider {
req_builder = req_builder.header(key.as_str(), value.as_str()); req_builder = req_builder.header(key.as_str(), value.as_str());
} }
let req_body_str = serde_json::to_string_pretty(&body).unwrap_or_default();
tracing::debug!(req_body = %req_body_str, "LLM request");
let resp = req_builder.json(&body).send().await?; let resp = req_builder.json(&body).send().await?;
let status = resp.status(); let status = resp.status();
let text = resp.text().await?; let text = resp.text().await?;
tracing::debug!(status = %status, resp_body = %text, "LLM response");
// Debug: Log LLM response (only in debug builds)
#[cfg(debug_assertions)]
{
let resp_preview: String = text.chars().take(100).collect();
tracing::debug!(status = %status, response_preview = %resp_preview, response_len = %text.len(), "LLM response (first 100 chars shown)");
}
if !status.is_success() { if !status.is_success() {
return Err(format!("API error {}: {}", status, text).into()); let error = format!("API error {}: {}", status, text);
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), Some(&error),
start.elapsed().as_millis() as u64,
).await;
}
return Err(error.into());
} }
let openai_resp: OpenAIResponse = serde_json::from_str(&text) let openai_resp: OpenAIResponse = serde_json::from_str(&text)
.map_err(|e| format!("decode error: {} | body: {}", e, &text))?; .map_err(|e| {
let err_msg = format!("decode error: {} | body: {}", e, &text);
if let Some(ref storage) = self.storage {
let name = self.name.clone();
let model = self.model_id.clone();
let req = req_body_str.clone();
let resp = text.clone();
let dur = start.elapsed().as_millis() as u64;
let err = err_msg.clone();
let s = storage.clone();
tokio::spawn(async move {
let _ = s.append_llm_call(&name, &model, &req, Some(&resp), Some(&err), dur).await;
});
}
err_msg
})?;
let content = openai_resp.choices[0] let content = openai_resp.choices[0]
.message .message
@ -237,7 +266,7 @@ impl LLMProvider for OpenAIProvider {
}) })
.collect(); .collect();
Ok(ChatCompletionResponse { let response = ChatCompletionResponse {
id: openai_resp.id, id: openai_resp.id,
model: openai_resp.model, model: openai_resp.model,
content, content,
@ -247,7 +276,17 @@ impl LLMProvider for OpenAIProvider {
completion_tokens: openai_resp.usage.completion_tokens, completion_tokens: openai_resp.usage.completion_tokens,
total_tokens: openai_resp.usage.total_tokens, total_tokens: openai_resp.usage.total_tokens,
}, },
}) };
if let Some(ref storage) = self.storage {
let _ = storage.append_llm_call(
&self.name, &self.model_id, &req_body_str,
Some(&text), None,
start.elapsed().as_millis() as u64,
).await;
}
Ok(response)
} }
fn ptype(&self) -> &str { fn ptype(&self) -> &str {

View File

@ -123,4 +123,6 @@ pub trait LLMProvider: Send + Sync {
fn name(&self) -> &str; fn name(&self) -> &str;
fn model_id(&self) -> &str; fn model_id(&self) -> &str;
fn set_storage(&mut self, _storage: std::sync::Arc<crate::storage::Storage>) {}
} }

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc}; use tokio::sync::Mutex;
use uuid::Uuid; use uuid::Uuid;
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind}; use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
@ -21,7 +21,6 @@ use crate::config::LLMProviderConfig;
use crate::agent::{AgentLoop, AgentError, ContextCompressor}; use crate::agent::{AgentLoop, AgentError, ContextCompressor};
use crate::agent::system_prompt::build_system_prompt; use crate::agent::system_prompt::build_system_prompt;
use crate::agent::context_compressor::ContextCompressionConfig; use crate::agent::context_compressor::ContextCompressionConfig;
use crate::protocol::WsOutbound;
use crate::providers::{create_provider, LLMProvider}; use crate::providers::{create_provider, LLMProvider};
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID}; use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
use crate::session::events::DialogInfo; use crate::session::events::DialogInfo;
@ -49,7 +48,6 @@ pub struct Session {
messages: Vec<ChatMessage>, messages: Vec<ChatMessage>,
seq_counter: i64, seq_counter: i64,
pub user_tx: mpsc::Sender<WsOutbound>,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
provider: Arc<dyn LLMProvider>, provider: Arc<dyn LLMProvider>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
@ -63,14 +61,16 @@ impl Session {
pub async fn new( pub async fn new(
id: UnifiedSessionId, id: UnifiedSessionId,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
storage: Option<StdArc<Storage>>, storage: Option<StdArc<Storage>>,
routing_info: String, routing_info: String,
title: String, title: String,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
let provider_box = create_provider(provider_config.clone()) let mut provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
if let Some(ref s) = storage {
provider_box.set_storage(s.clone());
}
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box); let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig { let compressor_config = ContextCompressionConfig {
@ -89,7 +89,6 @@ impl Session {
total_message_count: 0, total_message_count: 0,
messages: Vec::new(), messages: Vec::new(),
seq_counter: 1, seq_counter: 1,
user_tx,
provider_config: provider_config.clone(), provider_config: provider_config.clone(),
provider: provider.clone(), provider: provider.clone(),
tools, tools,
@ -103,7 +102,6 @@ impl Session {
pub async fn from_storage( pub async fn from_storage(
id: UnifiedSessionId, id: UnifiedSessionId,
provider_config: LLMProviderConfig, provider_config: LLMProviderConfig,
user_tx: mpsc::Sender<WsOutbound>,
tools: Arc<ToolRegistry>, tools: Arc<ToolRegistry>,
storage: StdArc<Storage>, storage: StdArc<Storage>,
) -> Result<Self, AgentError> { ) -> Result<Self, AgentError> {
@ -113,8 +111,9 @@ impl Session {
let messages = storage.load_messages(&id.to_string(), 0).await let messages = storage.load_messages(&id.to_string(), 0).await
.map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?; .map_err(|e| AgentError::Other(format!("failed to load messages from storage: {}", e)))?;
let provider_box = create_provider(provider_config.clone()) let mut provider_box = create_provider(provider_config.clone())
.map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?; .map_err(|e| AgentError::Other(format!("provider creation error: {}", e)))?;
provider_box.set_storage(storage.clone());
let provider: Arc<dyn LLMProvider> = Arc::from(provider_box); let provider: Arc<dyn LLMProvider> = Arc::from(provider_box);
let compressor_config = ContextCompressionConfig { let compressor_config = ContextCompressionConfig {
@ -123,6 +122,7 @@ impl Session {
}; };
// Convert MessageMeta to ChatMessage // Convert MessageMeta to ChatMessage
// Clear tool_call_id/tool_name — they're not valid across API sessions
let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| { let chat_messages: Vec<ChatMessage> = messages.into_iter().map(|m| {
ChatMessage { ChatMessage {
id: m.id, id: m.id,
@ -130,8 +130,8 @@ impl Session {
content: m.content, content: m.content,
media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(), media_refs: m.media_refs.map(|refs| serde_json::from_str(&refs).unwrap_or_default()).unwrap_or_default(),
timestamp: m.created_at, timestamp: m.created_at,
tool_call_id: m.tool_call_id, tool_call_id: None,
tool_name: m.tool_name, tool_name: None,
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()), tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()), source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
} }
@ -149,7 +149,6 @@ impl Session {
total_message_count, total_message_count,
messages: chat_messages, messages: chat_messages,
seq_counter, seq_counter,
user_tx,
provider_config: provider_config.clone(), provider_config: provider_config.clone(),
provider: provider.clone(), provider: provider.clone(),
tools, tools,
@ -252,18 +251,6 @@ impl Session {
} }
} }
pub async fn send(&self, msg: WsOutbound) {
let _ = self.user_tx.send(msg).await;
}
/// 发送系统通知(不记录进 session 历史)
pub async fn send_system_notification(&self, content: &str) {
let msg = WsOutbound::SystemNotification {
content: content.to_string(),
};
let _ = self.user_tx.send(msg).await;
}
/// 将 session 元数据写回 Storage /// 将 session 元数据写回 Storage
pub async fn persist_session_meta(&self) -> Result<(), StorageError> { pub async fn persist_session_meta(&self) -> Result<(), StorageError> {
if let Some(ref storage) = self.storage { if let Some(ref storage) = self.storage {
@ -364,6 +351,14 @@ impl Session {
)) ))
} }
/// 创建一个附通知通道的 AgentLoop 实例
pub fn create_agent_with_notify(
&self,
notify_tx: tokio::sync::mpsc::UnboundedSender<String>,
) -> Result<AgentLoop, AgentError> {
Ok(self.create_agent()?.with_notify(notify_tx))
}
/// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills /// 构建系统提示词(包含 AgentLoop 的基础提示词 + skills
pub fn build_system_prompt(&self, skills_prompt: &str) -> String { pub fn build_system_prompt(&self, skills_prompt: &str) -> String {
let base_prompt = build_system_prompt( let base_prompt = build_system_prompt(
@ -874,11 +869,9 @@ impl SessionManager {
self.storage.upsert_session(&meta).await self.storage.upsert_session(&meta).await
.map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?; .map_err(|e| AgentError::Other(format!("failed to create session in storage: {}", e)))?;
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new( let session = Session::new(
unified_id.clone(), unified_id.clone(),
self.provider_config.clone(), self.provider_config.clone(),
user_tx,
self.tools.clone(), self.tools.clone(),
Some(self.storage.clone()), Some(self.storage.clone()),
routing_info, routing_info,
@ -909,11 +902,9 @@ impl SessionManager {
match self.storage.get_session(&session_id_str).await { match self.storage.get_session(&session_id_str).await {
Ok(meta) => { Ok(meta) => {
tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage"); tracing::debug!(session_id = %session_id_str, last_active_at = %meta.last_active_at, message_count = %meta.message_count, "Restoring session from Storage");
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::from_storage( let session = Session::from_storage(
unified_id.clone(), unified_id.clone(),
self.provider_config.clone(), self.provider_config.clone(),
user_tx,
self.tools.clone(), self.tools.clone(),
self.storage.clone(), self.storage.clone(),
).await?; ).await?;
@ -932,11 +923,9 @@ impl SessionManager {
} }
// Create new session // Create new session
let (user_tx, _rx) = mpsc::channel::<WsOutbound>(100);
let session = Session::new( let session = Session::new(
unified_id.clone(), unified_id.clone(),
self.provider_config.clone(), self.provider_config.clone(),
user_tx,
self.tools.clone(), self.tools.clone(),
Some(self.storage.clone()), Some(self.storage.clone()),
String::new(), String::new(),
@ -1175,6 +1164,30 @@ impl SessionManager {
} }
// Normal message handling through LLM // Normal message handling through LLM
let (notify_tx, mut notify_rx) = tokio::sync::mpsc::unbounded_channel();
// Spawn notification publisher — sends immediately when tools are detected
{
let bus = self.bus.clone();
let ch = channel.to_string();
let cid = chat_id.to_string();
tokio::spawn(async move {
while let Some(notif) = notify_rx.recv().await {
let mut metadata = HashMap::new();
metadata.insert("_type".to_string(), "notification".to_string());
let outbound = OutboundMessage {
channel: ch.clone(),
chat_id: cid.clone(),
content: notif,
reply_to: None,
media: vec![],
metadata,
};
let _ = bus.publish_outbound(outbound).await;
}
});
}
let response: String = { let response: String = {
let mut session_guard = session.lock().await; let mut session_guard = session.lock().await;
@ -1202,7 +1215,7 @@ impl SessionManager {
.compress_if_needed(history) .compress_if_needed(history)
.await?; .await?;
let agent = session_guard.create_agent()?; let agent = session_guard.create_agent_with_notify(notify_tx)?;
let result = agent.process(history).await?; let result = agent.process(history).await?;
for msg in result.emitted_messages { for msg in result.emitted_messages {
@ -1322,7 +1335,6 @@ impl OutboundMessenger for SessionManager {
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::mpsc;
fn test_provider_config() -> LLMProviderConfig { fn test_provider_config() -> LLMProviderConfig {
LLMProviderConfig { LLMProviderConfig {

View File

@ -92,6 +92,70 @@ impl Storage {
.await .await
.ok(); .ok();
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS llm_calls (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at INTEGER NOT NULL,
provider TEXT NOT NULL,
model TEXT NOT NULL,
request_body TEXT NOT NULL,
response_body TEXT,
error TEXT,
duration_ms INTEGER
)
"#,
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn append_llm_call(
&self,
provider: &str,
model: &str,
request_body: &str,
response_body: Option<&str>,
error: Option<&str>,
duration_ms: u64,
) -> Result<(), StorageError> {
let now = chrono::Utc::now().timestamp_millis();
sqlx::query(
r#"
INSERT INTO llm_calls (created_at, provider, model, request_body, response_body, error, duration_ms)
VALUES (?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(now)
.bind(provider)
.bind(model)
.bind(request_body)
.bind(response_body)
.bind(error)
.bind(duration_ms as i64)
.execute(self.pool())
.await?;
// Prune to keep last 1000 records
self.prune_llm_calls(1000).await?;
Ok(())
}
async fn prune_llm_calls(&self, max_records: i64) -> Result<(), StorageError> {
sqlx::query(
r#"
DELETE FROM llm_calls WHERE id <= (
SELECT COALESCE(MAX(id), 0) - ? FROM llm_calls
)
"#,
)
.bind(max_records)
.execute(self.pool())
.await?;
Ok(()) Ok(())
} }