新增跨session消息发送能力
This commit is contained in:
parent
24d3407b05
commit
98eb7bea3d
29
AGENTS.md
29
AGENTS.md
@ -1,9 +1,5 @@
|
||||
# PicoBot
|
||||
|
||||
## Maintenance
|
||||
|
||||
- **Update this file on any architectural change** — module boundaries, data flow, key constraints, or build/test commands must be reflected here
|
||||
|
||||
## Build & Run
|
||||
|
||||
- `cargo build` — build the binary
|
||||
@ -12,18 +8,20 @@
|
||||
|
||||
## Config
|
||||
|
||||
- Config file: `~/.picobot/config.json` or `./config.json` (fallback order)
|
||||
- `.env` is loaded and env var placeholders `<VAR_NAME>` are substituted into config
|
||||
- Config file: `~/.picobot/config.json` or `./config.json` (fallback order, see `src/config/mod.rs:213`)
|
||||
- `.env` is loaded manually (not via dotenv crate); env var placeholders `<VAR_NAME>` in config JSON are substituted
|
||||
- Config example: `config.example.json`
|
||||
|
||||
## Tests
|
||||
|
||||
- `cargo test --lib` — run unit tests (FAILS: `src/session/session.rs:657` missing `workspace_dir` field in test helper)
|
||||
- `cargo test --test test_integration -- --ignored` — run integration tests (requires `tests/test.env` with API keys)
|
||||
- `cargo test --lib` — run unit tests (runs all `#[test]` in `src/`)
|
||||
- `cargo test --test test_integration -- --ignored` — run integration tests (also `test_tool_calling`, `test_request_format`)
|
||||
- **All** integration tests require `tests/test.env` with real API keys; copy from `tests/test.env.example` and fill in keys
|
||||
- Integration tests are `#[ignore]` by default; use `-- --ignored` to run them
|
||||
|
||||
## Reference
|
||||
|
||||
- `reference/` — third-party reference implementations (nanobot, Mini-Agent, zeroclaw); not part of this project; use for similar functionality patterns
|
||||
- `reference/` — third-party reference implementations (nanobot, Mini-Agent, zeroclaw); not part of this project; do not modify
|
||||
|
||||
## Architecture
|
||||
|
||||
@ -51,8 +49,13 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
||||
| `session` | Conversation session lifecycle, dialog operations | `SessionManager`, `Session` |
|
||||
| `agent` | LLM call loop, tool execution, context compression | `AgentLoop` |
|
||||
| `providers` | LLM API clients (OpenAI-compatible, Anthropic) | `LLMProvider` trait, factory `create_provider()` |
|
||||
| `tools` | Agent tools (bash, file operations, http, web, get_skill) | `ToolRegistry`, `Tool` trait |
|
||||
| `tools` | Agent tools (bash, file ops, http, web, get_skill) | `ToolRegistry`, `Tool` trait |
|
||||
| `skills` | Skills loading, management, and prompt building | `SkillsLoader`, `Skill` |
|
||||
| `storage` | SQLite persistence for sessions and messages | `Storage`, `SessionMeta`, `MessageMeta` |
|
||||
| `observability` | Observer pattern for agent/tool telemetry events | `Observer` trait, `ObserverEvent`, `MultiObserver` |
|
||||
| `protocol` | WebSocket protocol message types | `WsInbound`, `WsOutbound`, `SessionSummary` |
|
||||
| `config` | Config loading, env substitution, path resolution | `Config`, `LLMProviderConfig` |
|
||||
| `logging` | Tracing initialization with file rotation | `init_logging()`, `init_logging_console_only()` |
|
||||
|
||||
### Functional Boundaries
|
||||
|
||||
@ -68,9 +71,7 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
|
||||
### Key Constraints
|
||||
|
||||
- Gateway **changes working directory** to workspace on startup (`src/gateway/mod.rs:31`)
|
||||
- Session/message persistence uses SQLite via `sqlx`; DB stored in workspace as `.picobot_sessions.db` by default
|
||||
- `ChannelManager` owns the `MessageBus` and all channel instances
|
||||
- `OutboundDispatcher` routes outbound messages to the correct channel via `ChannelManager`
|
||||
|
||||
## Known Issues
|
||||
|
||||
- (No known issues at this time)
|
||||
- Config `.env` loading uses `unsafe { env::set_var(...) }` — don't refactor to safer patterns without understanding side effects
|
||||
|
||||
@ -20,7 +20,7 @@ clap = { version = "4", features = ["derive"] }
|
||||
dirs = "6.0.0"
|
||||
prost = "0.14"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "local-time"] }
|
||||
tracing-appender = "0.2"
|
||||
anyhow = "1.0"
|
||||
mime_guess = "2.0"
|
||||
|
||||
@ -390,8 +390,16 @@ impl AgentLoop {
|
||||
});
|
||||
}
|
||||
|
||||
// Execute tool calls
|
||||
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
||||
// Execute tool calls — log tool names and args before execution
|
||||
{
|
||||
let tools_info: Vec<String> = response.tool_calls.iter()
|
||||
.map(|tc| {
|
||||
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
|
||||
format!("{}:{}", tc.name, args)
|
||||
})
|
||||
.collect();
|
||||
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");
|
||||
}
|
||||
|
||||
// Add assistant message with tool calls
|
||||
let assistant_message = ChatMessage::assistant_with_tool_calls(
|
||||
|
||||
@ -45,6 +45,7 @@ impl SystemPromptBuilder {
|
||||
Box::new(UserProfileSection),
|
||||
Box::new(DateTimeSection),
|
||||
Box::new(RuntimeSection),
|
||||
Box::new(CrossChannelSection),
|
||||
],
|
||||
}
|
||||
}
|
||||
@ -233,6 +234,48 @@ impl PromptSection for DateTimeSection {
|
||||
}
|
||||
}
|
||||
|
||||
/// Cross-channel messaging and system notification guidance for LLM.
|
||||
pub struct CrossChannelSection;
|
||||
|
||||
impl PromptSection for CrossChannelSection {
|
||||
fn name(&self) -> &str {
|
||||
"cross_channel"
|
||||
}
|
||||
|
||||
fn build(&self, _ctx: &PromptContext<'_>) -> String {
|
||||
r#"## 关于跨渠道消息和系统通知
|
||||
|
||||
当前对话中可能出现带有 `source` 标记的消息,这些消息不是用户直接输入:
|
||||
|
||||
### 系统通知(source.kind = "system_notification")
|
||||
来自机器人内部系统(如定时任务、后台任务)的通知。
|
||||
- `system_name`: 发出通知的系统名称
|
||||
- `task_id`: 关联的任务 ID
|
||||
|
||||
### 跨渠道消息(source.kind = "cross_channel")
|
||||
来自其他渠道的消息被写入当前对话。
|
||||
- `from_channel`: 来源渠道(如 "feishu")
|
||||
- `from_user_id`: 来源用户 ID
|
||||
|
||||
### send_message 工具
|
||||
|
||||
使用 `send_message` 向其他渠道发送消息。参数:
|
||||
- `target_chat_id`: 目标会话ID,支持两种格式:
|
||||
1. `<channel>:<chat_id>` — 发送到该聊天下最新活跃的会话,若没有活跃会话则自动创建
|
||||
2. `<channel>:<chat_id>:<dialog_id>` — 发送到指定会话,若会话已过期则自动激活
|
||||
- `content`: 要发送的消息内容
|
||||
- `origin`(可选): 消息来源标识,不填则自动使用当前会话的完整 session_id
|
||||
|
||||
跨渠道消息到达目标会话时,内容前会带有 `[message from X to Y]` 标记,
|
||||
表示该消息的来源和目标。目标会话的 LLM 应将此理解为来自其他渠道/会话的消息。
|
||||
|
||||
### 处理建议
|
||||
- 系统通知:可以提及但不建议以此为由改变对话主题
|
||||
- 跨渠道消息:当用户提及相关事务时可关联这些消息"#
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Runtime environment information.
|
||||
pub struct RuntimeSection;
|
||||
|
||||
|
||||
@ -73,6 +73,28 @@ pub struct ChatMessage {
|
||||
pub tool_name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub source: Option<MessageSource>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum SourceKind {
|
||||
#[serde(rename = "system_notification")]
|
||||
SystemNotification,
|
||||
#[serde(rename = "cross_channel")]
|
||||
CrossChannel,
|
||||
#[serde(rename = "external_trigger")]
|
||||
ExternalTrigger,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessageSource {
|
||||
pub kind: SourceKind,
|
||||
pub from_channel: Option<String>,
|
||||
pub from_session: Option<String>,
|
||||
pub from_user_id: Option<String>,
|
||||
pub system_name: Option<String>,
|
||||
pub task_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
@ -86,6 +108,7 @@ impl ChatMessage {
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,6 +122,7 @@ impl ChatMessage {
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,6 +136,7 @@ impl ChatMessage {
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,6 +150,21 @@ impl ChatMessage {
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant_with_source(content: impl Into<String>, source: MessageSource) -> Self {
|
||||
Self {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: content.into(),
|
||||
media_refs: Vec::new(),
|
||||
timestamp: current_timestamp(),
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: Some(source),
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,6 +178,7 @@ impl ChatMessage {
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,6 +192,7 @@ impl ChatMessage {
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
tool_name: Some(tool_name.into()),
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ pub mod dispatcher;
|
||||
pub mod message;
|
||||
|
||||
pub use dispatcher::OutboundDispatcher;
|
||||
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, OutboundMessage};
|
||||
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
|
||||
@ -24,6 +24,14 @@ impl ChannelManager {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
|
||||
Self {
|
||||
channels: Arc::new(RwLock::new(HashMap::new())),
|
||||
cli_chat_channel,
|
||||
bus,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the MessageBus
|
||||
pub fn bus(&self) -> Arc<MessageBus> {
|
||||
self.bus.clone()
|
||||
@ -99,6 +107,11 @@ impl ChannelManager {
|
||||
self.channels.read().await.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get list of registered channel names
|
||||
pub async fn list_channel_names(&self) -> Vec<String> {
|
||||
self.channels.read().await.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Dispatch an outbound message to the appropriate channel
|
||||
pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
|
||||
let channel_name = &msg.channel;
|
||||
|
||||
@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||
use axum::{routing, Router};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
use crate::bus::{ControlMessage, OutboundDispatcher};
|
||||
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
|
||||
use crate::channels::{ChannelManager, CliChatChannel};
|
||||
use crate::channels::base::{Channel, ChannelError};
|
||||
use crate::config::{Config, expand_path, ensure_workspace_dir};
|
||||
@ -15,7 +15,7 @@ use crate::session::SessionManager;
|
||||
pub struct GatewayState {
|
||||
pub config: Config,
|
||||
pub workspace_dir: std::path::PathBuf,
|
||||
pub session_manager: SessionManager,
|
||||
pub session_manager: Arc<SessionManager>,
|
||||
pub channel_manager: ChannelManager,
|
||||
}
|
||||
|
||||
@ -53,21 +53,32 @@ impl GatewayState {
|
||||
);
|
||||
tracing::info!("Session storage: {}", db_path.display());
|
||||
|
||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone())?;
|
||||
// Create MessageBus first (shared by SessionManager and ChannelManager)
|
||||
let bus = MessageBus::new(100);
|
||||
|
||||
// Create SessionManager with bus injection
|
||||
let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone(), bus.clone())?;
|
||||
let session_manager = Arc::new(session_manager);
|
||||
|
||||
// Start background cleanup task (default 60 minutes)
|
||||
let cleanup_interval = config.gateway.cleanup_interval_minutes.unwrap_or(60);
|
||||
Arc::new(session_manager.clone()).start_cleanup_task(cleanup_interval);
|
||||
session_manager.clone().start_cleanup_task(cleanup_interval);
|
||||
tracing::info!("Session cleanup task started (interval: {} min)", cleanup_interval);
|
||||
|
||||
// Create CLI Chat Channel first (needed for ChannelManager)
|
||||
// Create ChannelManager and init channels
|
||||
let cli_chat_channel = Arc::new(CliChatChannel::new());
|
||||
let channel_manager = ChannelManager::new(cli_chat_channel);
|
||||
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
|
||||
channel_manager.init(&config, workspace_path.clone()).await
|
||||
.map_err(|e| format!("Failed to init channels: {}", e))?;
|
||||
|
||||
// Register send_message tool with available channel names
|
||||
let available_channels = channel_manager.list_channel_names().await;
|
||||
session_manager.register_outbound_tool(available_channels);
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
workspace_dir: workspace_path,
|
||||
session_manager,
|
||||
session_manager: session_manager.clone(),
|
||||
channel_manager,
|
||||
})
|
||||
}
|
||||
@ -231,11 +242,7 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
|
||||
|
||||
let state = Arc::new(GatewayState::new().await?);
|
||||
|
||||
// Initialize and start channels with workspace directory
|
||||
state.channel_manager.init(
|
||||
&state.config,
|
||||
state.workspace_dir.clone(),
|
||||
).await?;
|
||||
// Start all channels (init already done in GatewayState::new)
|
||||
state.channel_manager.start_all().await?;
|
||||
|
||||
// Start message processing (inbound processor + control processor + outbound dispatcher)
|
||||
|
||||
@ -4,6 +4,7 @@ use tracing_subscriber::{
|
||||
fmt,
|
||||
layer::SubscriberExt,
|
||||
util::SubscriberInitExt,
|
||||
fmt::time::LocalTime,
|
||||
EnvFilter,
|
||||
};
|
||||
|
||||
@ -44,12 +45,14 @@ pub fn init_logging() {
|
||||
|
||||
let file_layer = fmt::layer()
|
||||
.with_writer(file_appender)
|
||||
.with_timer(LocalTime::rfc_3339())
|
||||
.with_ansi(false)
|
||||
.with_target(true)
|
||||
.with_level(true)
|
||||
.with_thread_ids(true);
|
||||
|
||||
let console_layer = fmt::layer()
|
||||
.with_timer(LocalTime::rfc_3339())
|
||||
.with_target(true)
|
||||
.with_level(true);
|
||||
|
||||
@ -68,6 +71,7 @@ pub fn init_logging_console_only() {
|
||||
.unwrap_or_else(|_| EnvFilter::new("info"));
|
||||
|
||||
let console_layer = fmt::layer()
|
||||
.with_timer(LocalTime::rfc_3339())
|
||||
.with_target(true)
|
||||
.with_level(true);
|
||||
|
||||
|
||||
@ -117,10 +117,12 @@ struct AnthropicTool {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AnthropicResponse {
|
||||
id: String,
|
||||
model: String,
|
||||
id: Option<String>,
|
||||
model: Option<String>,
|
||||
#[serde(default)]
|
||||
content: Vec<AnthropicContent>,
|
||||
usage: AnthropicUsage,
|
||||
#[serde(default)]
|
||||
usage: Option<AnthropicUsage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -138,7 +140,9 @@ enum AnthropicContent {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AnthropicUsage {
|
||||
#[serde(default)]
|
||||
input_tokens: u32,
|
||||
#[serde(default)]
|
||||
output_tokens: u32,
|
||||
}
|
||||
|
||||
@ -167,9 +171,28 @@ impl LLMProvider for AnthropicProvider {
|
||||
messages: request
|
||||
.messages
|
||||
.iter()
|
||||
.map(|m| AnthropicMessage {
|
||||
role: m.role.clone(),
|
||||
content: convert_content_blocks(&m.content),
|
||||
.map(|m| {
|
||||
let role = if m.role == "tool" {
|
||||
// Anthropic uses "user" role for tool result messages
|
||||
"user".to_string()
|
||||
} else {
|
||||
m.role.clone()
|
||||
};
|
||||
let content = if let Some(ref tc_id) = m.tool_call_id {
|
||||
// Tool result: wrap as tool_result content block
|
||||
let output = m.content.iter()
|
||||
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
vec![serde_json::json!({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tc_id,
|
||||
"content": output,
|
||||
})]
|
||||
} else {
|
||||
convert_content_blocks(&m.content)
|
||||
};
|
||||
AnthropicMessage { role, content }
|
||||
})
|
||||
.collect(),
|
||||
max_tokens,
|
||||
@ -191,7 +214,24 @@ impl LLMProvider for AnthropicProvider {
|
||||
|
||||
let resp = req_builder.json(&body).send().await?;
|
||||
|
||||
let anthropic_resp: AnthropicResponse = resp.json().await?;
|
||||
let status = resp.status();
|
||||
let body_text = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
v.get("error")
|
||||
.and_then(|e| e.get("message"))
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|s| s.to_string())
|
||||
})
|
||||
.unwrap_or_else(|| body_text.clone());
|
||||
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
|
||||
}
|
||||
|
||||
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
|
||||
.map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?;
|
||||
|
||||
let mut content = String::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
@ -218,15 +258,14 @@ impl LLMProvider for AnthropicProvider {
|
||||
}
|
||||
|
||||
Ok(ChatCompletionResponse {
|
||||
id: anthropic_resp.id,
|
||||
model: anthropic_resp.model,
|
||||
id: anthropic_resp.id.unwrap_or_default(),
|
||||
model: anthropic_resp.model.unwrap_or_default(),
|
||||
content,
|
||||
tool_calls,
|
||||
usage: Usage {
|
||||
prompt_tokens: anthropic_resp.usage.input_tokens,
|
||||
completion_tokens: anthropic_resp.usage.output_tokens,
|
||||
total_tokens: anthropic_resp.usage.input_tokens
|
||||
+ anthropic_resp.usage.output_tokens,
|
||||
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
||||
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ use std::time::{Duration, Instant};
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::bus::ChatMessage;
|
||||
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
|
||||
use crate::storage::{Storage, StorageError};
|
||||
use std::sync::Arc as StdArc;
|
||||
|
||||
@ -26,10 +26,10 @@ use crate::providers::{create_provider, LLMProvider};
|
||||
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
|
||||
use crate::session::events::DialogInfo;
|
||||
use crate::skills::SkillsLoader;
|
||||
use crate::tools::{
|
||||
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
|
||||
GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool,
|
||||
};
|
||||
use crate::tools::{ToolRegistry, create_default_tools};
|
||||
use crate::bus::MessageBus;
|
||||
use crate::tools::OutboundMessenger;
|
||||
use crate::tools::SendMessageTool;
|
||||
|
||||
/// Generate a short ID (8 characters) from a UUID
|
||||
fn short_id() -> String {
|
||||
@ -133,6 +133,7 @@ impl Session {
|
||||
tool_call_id: m.tool_call_id,
|
||||
tool_name: m.tool_name,
|
||||
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
|
||||
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
@ -190,6 +191,7 @@ impl Session {
|
||||
tool_call_id: message.tool_call_id.clone(),
|
||||
tool_name: message.tool_name.clone(),
|
||||
tool_calls: message.tool_calls.as_ref().map(|tc| serde_json::to_string(tc).unwrap_or_default()),
|
||||
source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()),
|
||||
created_at: now,
|
||||
};
|
||||
storage.append_message_with_retry(&self.id.to_string(), &msg_meta).await?;
|
||||
@ -547,6 +549,8 @@ pub struct SessionManager {
|
||||
tools: Arc<ToolRegistry>,
|
||||
skills_loader: Arc<SkillsLoader>,
|
||||
storage: Arc<Storage>,
|
||||
bus: Arc<MessageBus>,
|
||||
current_source_session: Arc<Mutex<Option<String>>>,
|
||||
}
|
||||
|
||||
struct SessionManagerInner {
|
||||
@ -558,23 +562,7 @@ struct SessionManagerInner {
|
||||
current_sessions: HashMap<String, String>,
|
||||
}
|
||||
|
||||
fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
|
||||
let mut registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
registry.register(FileReadTool::new());
|
||||
registry.register(FileWriteTool::new());
|
||||
registry.register(FileEditTool::new());
|
||||
registry.register(BashTool::new());
|
||||
registry.register(HttpRequestTool::new(
|
||||
vec!["*".to_string()],
|
||||
1_000_000,
|
||||
30,
|
||||
false,
|
||||
));
|
||||
registry.register(WebFetchTool::new(50_000, 30));
|
||||
registry.register(GetSkillTool::new(skills_loader));
|
||||
registry
|
||||
}
|
||||
|
||||
|
||||
/// 斜杠命令定义
|
||||
#[derive(Debug, Clone)]
|
||||
@ -649,6 +637,7 @@ impl SessionManager {
|
||||
session_ttl_hours: u64,
|
||||
provider_config: LLMProviderConfig,
|
||||
storage: Arc<Storage>,
|
||||
bus: Arc<MessageBus>,
|
||||
) -> Result<Self, AgentError> {
|
||||
let skills_loader = SkillsLoader::new();
|
||||
skills_loader.load_skills();
|
||||
@ -667,9 +656,17 @@ impl SessionManager {
|
||||
tools,
|
||||
skills_loader,
|
||||
storage,
|
||||
bus,
|
||||
current_source_session: Arc::new(Mutex::new(None)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Register the send_message tool (requires self in Arc)
|
||||
pub fn register_outbound_tool(self: &Arc<Self>, available_channels: Vec<String>) {
|
||||
let messenger: Arc<dyn OutboundMessenger> = self.clone();
|
||||
self.tools.register(SendMessageTool::new(messenger, available_channels));
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> Arc<ToolRegistry> {
|
||||
self.tools.clone()
|
||||
}
|
||||
@ -1047,65 +1044,111 @@ impl SessionManager {
|
||||
Err(AgentError::Other("clear_dialog_history not available".to_string()))
|
||||
}
|
||||
|
||||
/// Get or activate a specific session by its full UnifiedSessionId.
|
||||
/// Returns an error if the session does not exist in storage.
|
||||
/// If the session was expired from memory but still in storage,
|
||||
/// it will be restored (reactivated).
|
||||
pub async fn get_or_activate_session(
|
||||
&self,
|
||||
unified_id: &UnifiedSessionId,
|
||||
) -> Result<Arc<Mutex<Session>>, AgentError> {
|
||||
let session_id_str = unified_id.to_string();
|
||||
match self.storage.get_session(&session_id_str).await {
|
||||
Ok(_) => self.get_or_create_session(unified_id).await,
|
||||
Err(StorageError::NotFound(_)) => {
|
||||
Err(AgentError::Other(format!("session not found: {}", unified_id)))
|
||||
}
|
||||
Err(e) => Err(AgentError::Other(format!("storage error: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_dialog_id(
|
||||
&self,
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
) -> Result<UnifiedSessionId, AgentError> {
|
||||
let chat_scope = format!("{}:{}", channel, chat_id);
|
||||
let current_id = {
|
||||
self.inner.lock().await.current_sessions.get(&chat_scope).cloned()
|
||||
};
|
||||
|
||||
if let Some(ref current_id) = current_id {
|
||||
match self.storage.get_session(current_id).await {
|
||||
Ok(_) => {
|
||||
let parts: Vec<&str> = current_id.split(':').collect();
|
||||
if parts.len() == 3 {
|
||||
return Ok(UnifiedSessionId::new(channel, chat_id, parts[2]));
|
||||
}
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
|
||||
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
|
||||
Ok(Some(meta)) => Ok(UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)),
|
||||
_ => {
|
||||
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
|
||||
Ok(new_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a system notification (no LLM triggered).
|
||||
///
|
||||
/// Flow:
|
||||
/// 1. Resolve target session (resolve_dialog_id)
|
||||
/// 2. Write assistant message with source tag to history
|
||||
/// 3. Publish OutboundMessage via bus to target channel
|
||||
pub async fn send_notification(
|
||||
&self,
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
system_name: &str,
|
||||
task_id: Option<&str>,
|
||||
) -> Result<(), AgentError> {
|
||||
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
|
||||
let session = self.get_or_create_session(&unified_id).await?;
|
||||
{
|
||||
let mut guard = session.lock().await;
|
||||
let source = MessageSource {
|
||||
kind: SourceKind::SystemNotification,
|
||||
from_channel: None,
|
||||
from_session: None,
|
||||
from_user_id: None,
|
||||
system_name: Some(system_name.to_string()),
|
||||
task_id: task_id.map(|s| s.to_string()),
|
||||
};
|
||||
let msg = ChatMessage::assistant_with_source(content, source);
|
||||
guard.add_message(msg, true).await
|
||||
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
|
||||
}
|
||||
|
||||
let outbound = OutboundMessage {
|
||||
channel: channel.to_string(),
|
||||
chat_id: chat_id.to_string(),
|
||||
content: content.to_string(),
|
||||
reply_to: None,
|
||||
media: vec![],
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
self.bus.publish_outbound(outbound).await
|
||||
.map_err(|e| AgentError::Other(format!("bus publish error: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_message(
|
||||
&self,
|
||||
channel: &str,
|
||||
_sender_id: &str,
|
||||
chat_id: &str,
|
||||
content: &str,
|
||||
media: Vec<crate::bus::MediaItem>,
|
||||
media: Vec<MediaItem>,
|
||||
) -> Result<HandleResult, AgentError> {
|
||||
// Channel messages never carry dialog_id — routing is entirely via current_sessions
|
||||
let unified_id = {
|
||||
let chat_scope = format!("{}:{}", channel, chat_id);
|
||||
let current_session_id = {
|
||||
let inner = self.inner.lock().await;
|
||||
inner.current_sessions.get(&chat_scope).cloned()
|
||||
};
|
||||
if let Some(current_id) = current_session_id {
|
||||
// Verify current session still exists in Storage
|
||||
match self.storage.get_session(¤t_id).await {
|
||||
Ok(_) => {
|
||||
// Current session still valid, extract dialog_id
|
||||
let parts: Vec<&str> = current_id.split(':').collect();
|
||||
if parts.len() == 3 {
|
||||
UnifiedSessionId::new(channel, chat_id, parts[2])
|
||||
} else {
|
||||
// Malformed, fallback to find or create
|
||||
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
|
||||
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
|
||||
Ok(Some(m)) => UnifiedSessionId::new(channel, chat_id, &m.dialog_id),
|
||||
_ => {
|
||||
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
|
||||
new_id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Current session no longer exists, create new
|
||||
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
|
||||
new_id
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No current session tracked, find active or create new
|
||||
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
|
||||
tracing::debug!(channel, chat_id, ttl_millis, "No current_sessions entry, searching Storage for active session");
|
||||
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
|
||||
Ok(Some(meta)) => {
|
||||
tracing::debug!(session_id = %meta.id, dialog_id = %meta.dialog_id, last_active_at = %meta.last_active_at, "Found active session in Storage");
|
||||
UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)
|
||||
}
|
||||
Ok(None) | Err(_) => {
|
||||
tracing::debug!("No active session found in Storage, creating new session");
|
||||
// Create new session
|
||||
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
|
||||
new_id
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
|
||||
*self.current_source_session.lock().await = Some(unified_id.to_string());
|
||||
tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id");
|
||||
let session = self.get_or_create_session(&unified_id).await?;
|
||||
|
||||
@ -1121,9 +1164,11 @@ impl SessionManager {
|
||||
|
||||
match result {
|
||||
Ok((_new_session_id, response)) => {
|
||||
*self.current_source_session.lock().await = None;
|
||||
return Ok(HandleResult::CommandOutput(response));
|
||||
}
|
||||
Err(e) => {
|
||||
*self.current_source_session.lock().await = None;
|
||||
return Ok(HandleResult::CommandOutput(e.to_string()));
|
||||
}
|
||||
}
|
||||
@ -1183,6 +1228,8 @@ impl SessionManager {
|
||||
"Agent response received"
|
||||
);
|
||||
|
||||
*self.current_source_session.lock().await = None;
|
||||
|
||||
Ok(HandleResult::AgentResponse(response))
|
||||
}
|
||||
|
||||
@ -1203,6 +1250,74 @@ impl SessionManager {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OutboundMessenger for SessionManager {
|
||||
async fn send_message(
|
||||
&self,
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
dialog_id: Option<&str>,
|
||||
content: &str,
|
||||
mut source: MessageSource,
|
||||
) -> Result<(), String> {
|
||||
// Fill origin from current source session if not provided
|
||||
if source.from_session.is_none() {
|
||||
source.from_session = self.current_source_session.lock().await.clone();
|
||||
}
|
||||
|
||||
let (target_sid, session) = if let Some(did) = dialog_id {
|
||||
let sid = UnifiedSessionId::new(channel, chat_id, did);
|
||||
let session = self.get_or_activate_session(&sid).await
|
||||
.map_err(|e| e.to_string())?;
|
||||
(sid, session)
|
||||
} else {
|
||||
let sid = self.resolve_dialog_id(channel, chat_id).await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let session = self.get_or_create_session(&sid).await
|
||||
.map_err(|e| e.to_string())?;
|
||||
(sid, session)
|
||||
};
|
||||
|
||||
// Build message prefix: [message from <origin> to <channel:chat_id:dialog_id>]
|
||||
let target_id = target_sid.to_string();
|
||||
let origin = source.from_session.as_deref().unwrap_or("unknown");
|
||||
let origin_id = source.from_session.clone();
|
||||
let prefix = format!("[message from {} to {}] ", origin, target_id);
|
||||
let marked_content = format!("{}\n{}", prefix, content);
|
||||
|
||||
// Write source-tagged assistant message to target session history
|
||||
{
|
||||
let mut guard = session.lock().await;
|
||||
let msg = ChatMessage::assistant_with_source(marked_content.clone(), source);
|
||||
guard.add_message(msg, true).await
|
||||
.map_err(|e| e.to_string())?;
|
||||
}
|
||||
|
||||
// Restore active dialog if source and target share channel:chat_id but differ in dialog_id
|
||||
if let Some(ref origin_id) = origin_id {
|
||||
let parts: Vec<&str> = origin_id.split(':').collect();
|
||||
if parts.len() == 3 && parts[0] == channel && parts[1] == chat_id && parts[2] != target_sid.dialog_id {
|
||||
let scope = format!("{}:{}", channel, chat_id);
|
||||
self.inner.lock().await.current_sessions.insert(scope, origin_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Publish OutboundMessage via bus to target channel
|
||||
let outbound = OutboundMessage {
|
||||
channel: channel.to_string(),
|
||||
chat_id: chat_id.to_string(),
|
||||
content: marked_content,
|
||||
reply_to: None,
|
||||
media: vec![],
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
self.bus.publish_outbound(outbound).await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@ -11,5 +11,6 @@ pub struct MessageMeta {
|
||||
pub tool_call_id: Option<String>,
|
||||
pub tool_name: Option<String>,
|
||||
pub tool_calls: Option<String>,
|
||||
pub source: Option<String>,
|
||||
pub created_at: i64,
|
||||
}
|
||||
|
||||
@ -66,6 +66,7 @@ impl Storage {
|
||||
tool_call_id TEXT,
|
||||
tool_name TEXT,
|
||||
tool_calls TEXT,
|
||||
source TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
|
||||
)
|
||||
@ -83,6 +84,14 @@ impl Storage {
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
// Migration: add source column if upgrading from older schema
|
||||
sqlx::query(
|
||||
r#"ALTER TABLE messages ADD COLUMN source TEXT"#,
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -260,8 +269,8 @@ impl Storage {
|
||||
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&msg.id)
|
||||
@ -273,6 +282,7 @@ impl Storage {
|
||||
.bind(&msg.tool_call_id)
|
||||
.bind(&msg.tool_name)
|
||||
.bind(&msg.tool_calls)
|
||||
.bind(&msg.source)
|
||||
.bind(msg.created_at)
|
||||
.execute(self.pool())
|
||||
.await?;
|
||||
@ -300,7 +310,7 @@ impl Storage {
|
||||
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at
|
||||
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
|
||||
FROM messages
|
||||
WHERE session_id = ? AND seq >= ?
|
||||
ORDER BY seq ASC
|
||||
@ -323,6 +333,7 @@ impl Storage {
|
||||
tool_call_id: row.get("tool_call_id"),
|
||||
tool_name: row.get("tool_name"),
|
||||
tool_calls: row.get("tool_calls"),
|
||||
source: row.get("source"),
|
||||
created_at: row.get("created_at"),
|
||||
})
|
||||
.collect())
|
||||
@ -486,6 +497,7 @@ mod tests {
|
||||
tool_call_id: None,
|
||||
tool_name: None,
|
||||
tool_calls: None,
|
||||
source: None,
|
||||
created_at: 1000,
|
||||
};
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ pub mod get_skill;
|
||||
pub mod http_request;
|
||||
pub mod registry;
|
||||
pub mod schema;
|
||||
pub mod send_message;
|
||||
pub mod traits;
|
||||
pub mod web_fetch;
|
||||
|
||||
@ -19,5 +20,30 @@ pub use get_skill::GetSkillTool;
|
||||
pub use http_request::HttpRequestTool;
|
||||
pub use registry::ToolRegistry;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use traits::{Tool, ToolResult};
|
||||
pub use send_message::SendMessageTool;
|
||||
pub use traits::{OutboundMessenger, Tool, ToolResult};
|
||||
pub use web_fetch::WebFetchTool;
|
||||
|
||||
use std::sync::Arc;
|
||||
use crate::skills::SkillsLoader;
|
||||
|
||||
/// Create the base tool registry (without send_message).
|
||||
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
|
||||
/// once the available channel names are known.
|
||||
pub fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
|
||||
let registry = ToolRegistry::new();
|
||||
registry.register(CalculatorTool::new());
|
||||
registry.register(FileReadTool::new());
|
||||
registry.register(FileWriteTool::new());
|
||||
registry.register(FileEditTool::new());
|
||||
registry.register(BashTool::new());
|
||||
registry.register(HttpRequestTool::new(
|
||||
vec!["*".to_string()],
|
||||
1_000_000,
|
||||
30,
|
||||
false,
|
||||
));
|
||||
registry.register(WebFetchTool::new(50_000, 30));
|
||||
registry.register(GetSkillTool::new(skills_loader));
|
||||
registry
|
||||
}
|
||||
|
||||
@ -1,36 +1,39 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::providers::{Tool, ToolFunction};
|
||||
|
||||
use super::traits::Tool as ToolTrait;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn ToolTrait>>,
|
||||
tools: Mutex<HashMap<String, Arc<dyn ToolTrait>>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tools: HashMap::new(),
|
||||
tools: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<T: ToolTrait + 'static>(&mut self, tool: T) {
|
||||
self.tools.insert(tool.name().to_string(), Box::new(tool));
|
||||
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
|
||||
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<&Box<dyn ToolTrait>> {
|
||||
self.tools.get(name)
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
|
||||
self.tools.lock().unwrap().get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get all registered tools.
|
||||
/// Used for concurrent tool execution when we need to look up tools by name.
|
||||
pub fn get_all(&self) -> Vec<&Box<dyn ToolTrait>> {
|
||||
self.tools.values().collect()
|
||||
pub fn get_all(&self) -> Vec<Arc<dyn ToolTrait>> {
|
||||
self.tools.lock().unwrap().values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn get_definitions(&self) -> Vec<Tool> {
|
||||
self.tools
|
||||
.lock()
|
||||
.unwrap()
|
||||
.values()
|
||||
.map(|tool| Tool {
|
||||
tool_type: "function".to_string(),
|
||||
@ -44,15 +47,20 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
pub fn has_tools(&self) -> bool {
|
||||
!self.tools.is_empty()
|
||||
!self.tools.lock().unwrap().is_empty()
|
||||
}
|
||||
|
||||
pub fn tool_names(&self) -> Vec<String> {
|
||||
self.tools.keys().cloned().collect()
|
||||
self.tools.lock().unwrap().keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&String, &Box<dyn ToolTrait>)> {
|
||||
self.tools.iter()
|
||||
pub fn iter(&self) -> Vec<(String, Arc<dyn ToolTrait>)> {
|
||||
self.tools
|
||||
.lock()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
171
src/tools/send_message.rs
Normal file
171
src/tools/send_message.rs
Normal file
@ -0,0 +1,171 @@
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::bus::{MessageSource, SourceKind};
|
||||
|
||||
use super::traits::{OutboundMessenger, Tool, ToolResult};
|
||||
|
||||
pub struct SendMessageTool {
|
||||
messenger: Arc<dyn OutboundMessenger>,
|
||||
available_channels: HashSet<String>,
|
||||
}
|
||||
|
||||
impl SendMessageTool {
|
||||
pub fn new(messenger: Arc<dyn OutboundMessenger>, available_channels: Vec<String>) -> Self {
|
||||
Self {
|
||||
messenger,
|
||||
available_channels: available_channels.into_iter().collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse target_chat_id into (channel, chat_id, optional dialog_id).
|
||||
/// Accepts two formats:
|
||||
/// - Two-part: `<channel>:<chat_id>` → sends to latest active session for that chat
|
||||
/// - Three-part: `<channel>:<chat_id>:<dialog_id>` → sends to specific session
|
||||
fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String> {
|
||||
let parts: Vec<&str> = raw.split(':').collect();
|
||||
match parts.len() {
|
||||
2 => {
|
||||
if parts[0].is_empty() || parts[1].is_empty() {
|
||||
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
|
||||
} else {
|
||||
Ok((parts[0], parts[1], None))
|
||||
}
|
||||
}
|
||||
3 => {
|
||||
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
|
||||
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
|
||||
} else {
|
||||
Ok((parts[0], parts[1], Some(parts[2])))
|
||||
}
|
||||
}
|
||||
_ => Err(format!(
|
||||
"Invalid target_chat_id format '{}'. Expected <channel>:<chat_id> or <channel>:<chat_id>:<dialog_id>",
|
||||
raw
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for SendMessageTool {
|
||||
fn name(&self) -> &str {
|
||||
"send_message"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"向指定渠道的会话发送消息。用于在用户请求下向其他渠道发送内容。\
|
||||
target_chat_id 支持两种格式:<channel>:<chat_id>(发送到该聊天下最新活跃会话)\
|
||||
或 <channel>:<chat_id>:<dialog_id>(发送到指定会话,过期则自动激活)"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"target_chat_id": {
|
||||
"type": "string",
|
||||
"description": "目标会话ID。支持两种格式: 1) <channel>:<chat_id> 发送到该聊天下最新活跃会话, 无则自动创建; 2) <channel>:<chat_id>:<dialog_id> 发送到指定会话, 过期则自动激活。channel 可选值: feishu, cli_chat"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "要发送的消息内容"
|
||||
},
|
||||
"origin": {
|
||||
"type": "string",
|
||||
"description": "可选。消息来源标识。不填则自动使用当前会话的完整 session_id (<channel>:<chat_id>:<dialog_id>)"
|
||||
}
|
||||
},
|
||||
"required": ["target_chat_id", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let raw_id = args["target_chat_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing target_chat_id"))?;
|
||||
let content = args["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
|
||||
|
||||
// 1. Parse target_chat_id
|
||||
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
|
||||
// 2. Validate channel
|
||||
if !self.available_channels.contains(channel) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Channel '{}' is not available. Available channels: {}",
|
||||
channel,
|
||||
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
let from_session = args["origin"].as_str().map(|s| s.to_string());
|
||||
|
||||
let source = MessageSource {
|
||||
kind: SourceKind::CrossChannel,
|
||||
from_channel: Some("tool".to_string()),
|
||||
from_session,
|
||||
from_user_id: None,
|
||||
system_name: None,
|
||||
task_id: None,
|
||||
};
|
||||
|
||||
// 3. Send via messenger
|
||||
match self.messenger
|
||||
.send_message(channel, chat_id, dialog_id, content, source)
|
||||
.await
|
||||
{
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "消息已发送".to_string(),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(e),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_target_chat_id_two_part() {
|
||||
let (ch, cid, did) = parse_target_chat_id("feishu:oc_abc123").unwrap();
|
||||
assert_eq!(ch, "feishu");
|
||||
assert_eq!(cid, "oc_abc123");
|
||||
assert!(did.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_target_chat_id_three_part() {
|
||||
let (ch, cid, did) = parse_target_chat_id("feishu:oc_abc123:dialog1").unwrap();
|
||||
assert_eq!(ch, "feishu");
|
||||
assert_eq!(cid, "oc_abc123");
|
||||
assert_eq!(did, Some("dialog1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_target_chat_id_invalid_one_part() {
|
||||
assert!(parse_target_chat_id("feishu").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_target_chat_id_empty_parts() {
|
||||
assert!(parse_target_chat_id("feishu:").is_err());
|
||||
assert!(parse_target_chat_id(":chat_id").is_err());
|
||||
assert!(parse_target_chat_id("feishu::dialog").is_err());
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,5 @@
|
||||
use async_trait::async_trait;
|
||||
use crate::bus::MessageSource;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolResult {
|
||||
@ -29,3 +30,15 @@ pub trait Tool: Send + Sync + 'static {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait OutboundMessenger: Send + Sync {
|
||||
async fn send_message(
|
||||
&self,
|
||||
channel: &str,
|
||||
chat_id: &str,
|
||||
dialog_id: Option<&str>,
|
||||
content: &str,
|
||||
source: MessageSource,
|
||||
) -> Result<(), String>;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user