新增跨session消息发送能力

This commit is contained in:
xiaoski 2026-05-04 00:32:24 +08:00
parent 24d3407b05
commit 98eb7bea3d
17 changed files with 636 additions and 133 deletions

View File

@ -1,9 +1,5 @@
# PicoBot
## Maintenance
- **Update this file on any architectural change** — module boundaries, data flow, key constraints, or build/test commands must be reflected here
## Build & Run
- `cargo build` — build the binary
@ -12,18 +8,20 @@
## Config
- Config file: `~/.picobot/config.json` or `./config.json` (fallback order)
- `.env` is loaded and env var placeholders `<VAR_NAME>` are substituted into config
- Config file: `~/.picobot/config.json` or `./config.json` (fallback order, see `src/config/mod.rs:213`)
- `.env` is loaded manually (not via dotenv crate); env var placeholders `<VAR_NAME>` in config JSON are substituted
- Config example: `config.example.json`
## Tests
- `cargo test --lib` — run unit tests (FAILS: `src/session/session.rs:657` missing `workspace_dir` field in test helper)
- `cargo test --test test_integration -- --ignored` — run integration tests (requires `tests/test.env` with API keys)
- `cargo test --lib` — run unit tests (runs all `#[test]` in `src/`)
- `cargo test --test test_integration -- --ignored` — run integration tests (also `test_tool_calling`, `test_request_format`)
- **All** integration tests require `tests/test.env` with real API keys; copy from `tests/test.env.example` and fill in keys
- Integration tests are `#[ignore]` by default; use `-- --ignored` to run them
## Reference
- `reference/` — third-party reference implementations (nanobot, Mini-Agent, zeroclaw); not part of this project; use for similar functionality patterns
- `reference/` — third-party reference implementations (nanobot, Mini-Agent, zeroclaw); not part of this project; do not modify
## Architecture
@ -51,8 +49,13 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
| `session` | Conversation session lifecycle, dialog operations | `SessionManager`, `Session` |
| `agent` | LLM call loop, tool execution, context compression | `AgentLoop` |
| `providers` | LLM API clients (OpenAI-compatible, Anthropic) | `LLMProvider` trait, factory `create_provider()` |
| `tools` | Agent tools (bash, file operations, http, web, get_skill) | `ToolRegistry`, `Tool` trait |
| `tools` | Agent tools (bash, file ops, http, web, get_skill) | `ToolRegistry`, `Tool` trait |
| `skills` | Skills loading, management, and prompt building | `SkillsLoader`, `Skill` |
| `storage` | SQLite persistence for sessions and messages | `Storage`, `SessionMeta`, `MessageMeta` |
| `observability` | Observer pattern for agent/tool telemetry events | `Observer` trait, `ObserverEvent`, `MultiObserver` |
| `protocol` | WebSocket protocol message types | `WsInbound`, `WsOutbound`, `SessionSummary` |
| `config` | Config loading, env substitution, path resolution | `Config`, `LLMProviderConfig` |
| `logging` | Tracing initialization with file rotation | `init_logging()`, `init_logging_console_only()` |
### Functional Boundaries
@ -68,9 +71,7 @@ Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionM
### Key Constraints
- Gateway **changes working directory** to workspace on startup (`src/gateway/mod.rs:31`)
- Session/message persistence uses SQLite via `sqlx`; DB stored in workspace as `.picobot_sessions.db` by default
- `ChannelManager` owns the `MessageBus` and all channel instances
- `OutboundDispatcher` routes outbound messages to the correct channel via `ChannelManager`
## Known Issues
- (No known issues at this time)
- Config `.env` loading uses `unsafe { env::set_var(...) }` — don't refactor to safer patterns without understanding side effects

View File

@ -20,7 +20,7 @@ clap = { version = "4", features = ["derive"] }
dirs = "6.0.0"
prost = "0.14"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "local-time"] }
tracing-appender = "0.2"
anyhow = "1.0"
mime_guess = "2.0"

View File

@ -390,8 +390,16 @@ impl AgentLoop {
});
}
// Execute tool calls
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
// Execute tool calls — log tool names and args before execution
{
let tools_info: Vec<String> = response.tool_calls.iter()
.map(|tc| {
let args = serde_json::to_string(&tc.arguments).unwrap_or_default();
format!("{}:{}", tc.name, args)
})
.collect();
tracing::info!(iteration, count = response.tool_calls.len(), tools = %tools_info.join(", "), "Tool calls detected, executing tools");
}
// Add assistant message with tool calls
let assistant_message = ChatMessage::assistant_with_tool_calls(

View File

@ -45,6 +45,7 @@ impl SystemPromptBuilder {
Box::new(UserProfileSection),
Box::new(DateTimeSection),
Box::new(RuntimeSection),
Box::new(CrossChannelSection),
],
}
}
@ -233,6 +234,48 @@ impl PromptSection for DateTimeSection {
}
}
/// Cross-channel messaging and system notification guidance for LLM.
pub struct CrossChannelSection;
impl PromptSection for CrossChannelSection {
fn name(&self) -> &str {
"cross_channel"
}
fn build(&self, _ctx: &PromptContext<'_>) -> String {
r#"## 关于跨渠道消息和系统通知
`source`
### source.kind = "system_notification"
- `system_name`:
- `task_id`: ID
### source.kind = "cross_channel"
- `from_channel`: "feishu"
- `from_user_id`: ID
### send_message
使 `send_message`
- `target_chat_id`: ID
1. `<channel>:<chat_id>`
2. `<channel>:<chat_id>:<dialog_id>`
- `content`:
- `origin`: 使 session_id
`[message from X to Y]`
LLM /
###
-
- "#
.to_string()
}
}
/// Runtime environment information.
pub struct RuntimeSection;

View File

@ -73,6 +73,28 @@ pub struct ChatMessage {
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<MessageSource>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SourceKind {
#[serde(rename = "system_notification")]
SystemNotification,
#[serde(rename = "cross_channel")]
CrossChannel,
#[serde(rename = "external_trigger")]
ExternalTrigger,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageSource {
pub kind: SourceKind,
pub from_channel: Option<String>,
pub from_session: Option<String>,
pub from_user_id: Option<String>,
pub system_name: Option<String>,
pub task_id: Option<String>,
}
impl ChatMessage {
@ -86,6 +108,7 @@ impl ChatMessage {
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: None,
}
}
@ -99,6 +122,7 @@ impl ChatMessage {
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: None,
}
}
@ -112,6 +136,7 @@ impl ChatMessage {
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: None,
}
}
@ -125,6 +150,21 @@ impl ChatMessage {
tool_call_id: None,
tool_name: None,
tool_calls: Some(tool_calls),
source: None,
}
}
pub fn assistant_with_source(content: impl Into<String>, source: MessageSource) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
role: "assistant".to_string(),
content: content.into(),
media_refs: Vec::new(),
timestamp: current_timestamp(),
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: Some(source),
}
}
@ -138,6 +178,7 @@ impl ChatMessage {
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: None,
}
}
@ -151,6 +192,7 @@ impl ChatMessage {
tool_call_id: Some(tool_call_id.into()),
tool_name: Some(tool_name.into()),
tool_calls: None,
source: None,
}
}
}

View File

@ -2,7 +2,7 @@ pub mod dispatcher;
pub mod message;
pub use dispatcher::OutboundDispatcher;
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, OutboundMessage};
pub use message::{ChatMessage, ContentBlock, ControlMessage, InboundMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};

View File

@ -24,6 +24,14 @@ impl ChannelManager {
}
}
pub fn with_bus(cli_chat_channel: Arc<crate::channels::CliChatChannel>, bus: Arc<MessageBus>) -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
cli_chat_channel,
bus,
}
}
/// Get a reference to the MessageBus
pub fn bus(&self) -> Arc<MessageBus> {
self.bus.clone()
@ -99,6 +107,11 @@ impl ChannelManager {
self.channels.read().await.get(name).cloned()
}
/// Get list of registered channel names
pub async fn list_channel_names(&self) -> Vec<String> {
self.channels.read().await.keys().cloned().collect()
}
/// Dispatch an outbound message to the appropriate channel
pub async fn dispatch(&self, msg: OutboundMessage) -> Result<(), ChannelError> {
let channel_name = &msg.channel;

View File

@ -5,7 +5,7 @@ use std::sync::Arc;
use axum::{routing, Router};
use tokio::net::TcpListener;
use crate::bus::{ControlMessage, OutboundDispatcher};
use crate::bus::{ControlMessage, MessageBus, OutboundDispatcher};
use crate::channels::{ChannelManager, CliChatChannel};
use crate::channels::base::{Channel, ChannelError};
use crate::config::{Config, expand_path, ensure_workspace_dir};
@ -15,7 +15,7 @@ use crate::session::SessionManager;
pub struct GatewayState {
pub config: Config,
pub workspace_dir: std::path::PathBuf,
pub session_manager: SessionManager,
pub session_manager: Arc<SessionManager>,
pub channel_manager: ChannelManager,
}
@ -53,21 +53,32 @@ impl GatewayState {
);
tracing::info!("Session storage: {}", db_path.display());
let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone())?;
// Create MessageBus first (shared by SessionManager and ChannelManager)
let bus = MessageBus::new(100);
// Create SessionManager with bus injection
let session_manager = SessionManager::new(session_ttl_hours, provider_config.clone(), storage.clone(), bus.clone())?;
let session_manager = Arc::new(session_manager);
// Start background cleanup task (default 60 minutes)
let cleanup_interval = config.gateway.cleanup_interval_minutes.unwrap_or(60);
Arc::new(session_manager.clone()).start_cleanup_task(cleanup_interval);
session_manager.clone().start_cleanup_task(cleanup_interval);
tracing::info!("Session cleanup task started (interval: {} min)", cleanup_interval);
// Create CLI Chat Channel first (needed for ChannelManager)
// Create ChannelManager and init channels
let cli_chat_channel = Arc::new(CliChatChannel::new());
let channel_manager = ChannelManager::new(cli_chat_channel);
let channel_manager = ChannelManager::with_bus(cli_chat_channel, bus);
channel_manager.init(&config, workspace_path.clone()).await
.map_err(|e| format!("Failed to init channels: {}", e))?;
// Register send_message tool with available channel names
let available_channels = channel_manager.list_channel_names().await;
session_manager.register_outbound_tool(available_channels);
Ok(Self {
config,
workspace_dir: workspace_path,
session_manager,
session_manager: session_manager.clone(),
channel_manager,
})
}
@ -231,11 +242,7 @@ pub async fn run(host: Option<String>, port: Option<u16>) -> Result<(), Box<dyn
let state = Arc::new(GatewayState::new().await?);
// Initialize and start channels with workspace directory
state.channel_manager.init(
&state.config,
state.workspace_dir.clone(),
).await?;
// Start all channels (init already done in GatewayState::new)
state.channel_manager.start_all().await?;
// Start message processing (inbound processor + control processor + outbound dispatcher)

View File

@ -4,6 +4,7 @@ use tracing_subscriber::{
fmt,
layer::SubscriberExt,
util::SubscriberInitExt,
fmt::time::LocalTime,
EnvFilter,
};
@ -44,12 +45,14 @@ pub fn init_logging() {
let file_layer = fmt::layer()
.with_writer(file_appender)
.with_timer(LocalTime::rfc_3339())
.with_ansi(false)
.with_target(true)
.with_level(true)
.with_thread_ids(true);
let console_layer = fmt::layer()
.with_timer(LocalTime::rfc_3339())
.with_target(true)
.with_level(true);
@ -68,6 +71,7 @@ pub fn init_logging_console_only() {
.unwrap_or_else(|_| EnvFilter::new("info"));
let console_layer = fmt::layer()
.with_timer(LocalTime::rfc_3339())
.with_target(true)
.with_level(true);

View File

@ -117,10 +117,12 @@ struct AnthropicTool {
#[derive(Deserialize)]
struct AnthropicResponse {
id: String,
model: String,
id: Option<String>,
model: Option<String>,
#[serde(default)]
content: Vec<AnthropicContent>,
usage: AnthropicUsage,
#[serde(default)]
usage: Option<AnthropicUsage>,
}
#[derive(Deserialize)]
@ -138,7 +140,9 @@ enum AnthropicContent {
#[derive(Deserialize)]
struct AnthropicUsage {
#[serde(default)]
input_tokens: u32,
#[serde(default)]
output_tokens: u32,
}
@ -167,9 +171,28 @@ impl LLMProvider for AnthropicProvider {
messages: request
.messages
.iter()
.map(|m| AnthropicMessage {
role: m.role.clone(),
content: convert_content_blocks(&m.content),
.map(|m| {
let role = if m.role == "tool" {
// Anthropic uses "user" role for tool result messages
"user".to_string()
} else {
m.role.clone()
};
let content = if let Some(ref tc_id) = m.tool_call_id {
// Tool result: wrap as tool_result content block
let output = m.content.iter()
.filter_map(|b| match b { ContentBlock::Text { text } => Some(text.as_str()), _ => None })
.collect::<Vec<_>>()
.join("");
vec![serde_json::json!({
"type": "tool_result",
"tool_use_id": tc_id,
"content": output,
})]
} else {
convert_content_blocks(&m.content)
};
AnthropicMessage { role, content }
})
.collect(),
max_tokens,
@ -191,7 +214,24 @@ impl LLMProvider for AnthropicProvider {
let resp = req_builder.json(&body).send().await?;
let anthropic_resp: AnthropicResponse = resp.json().await?;
let status = resp.status();
let body_text = resp.text().await?;
if !status.is_success() {
let error_msg = serde_json::from_str::<serde_json::Value>(&body_text)
.ok()
.and_then(|v| {
v.get("error")
.and_then(|e| e.get("message"))
.and_then(|m| m.as_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| body_text.clone());
return Err(format!("API error ({}): {}", status.as_u16(), error_msg).into());
}
let anthropic_resp: AnthropicResponse = serde_json::from_str(&body_text)
.map_err(|e| format!("decode error: {} | body: {}", e, &body_text))?;
let mut content = String::new();
let mut tool_calls = Vec::new();
@ -218,15 +258,14 @@ impl LLMProvider for AnthropicProvider {
}
Ok(ChatCompletionResponse {
id: anthropic_resp.id,
model: anthropic_resp.model,
id: anthropic_resp.id.unwrap_or_default(),
model: anthropic_resp.model.unwrap_or_default(),
content,
tool_calls,
usage: Usage {
prompt_tokens: anthropic_resp.usage.input_tokens,
completion_tokens: anthropic_resp.usage.output_tokens,
total_tokens: anthropic_resp.usage.input_tokens
+ anthropic_resp.usage.output_tokens,
prompt_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
completion_tokens: anthropic_resp.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
total_tokens: anthropic_resp.usage.as_ref().map(|u| u.input_tokens + u.output_tokens).unwrap_or(0),
},
})
}

View File

@ -5,7 +5,7 @@ use std::time::{Duration, Instant};
use tokio::sync::{Mutex, mpsc};
use uuid::Uuid;
use crate::bus::ChatMessage;
use crate::bus::{ChatMessage, MediaItem, MessageSource, OutboundMessage, SourceKind};
use crate::storage::{Storage, StorageError};
use std::sync::Arc as StdArc;
@ -26,10 +26,10 @@ use crate::providers::{create_provider, LLMProvider};
use crate::session::session_id::{UnifiedSessionId, DEFAULT_DIALOG_ID};
use crate::session::events::DialogInfo;
use crate::skills::SkillsLoader;
use crate::tools::{
BashTool, CalculatorTool, FileEditTool, FileReadTool, FileWriteTool,
GetSkillTool, HttpRequestTool, ToolRegistry, WebFetchTool,
};
use crate::tools::{ToolRegistry, create_default_tools};
use crate::bus::MessageBus;
use crate::tools::OutboundMessenger;
use crate::tools::SendMessageTool;
/// Generate a short ID (8 characters) from a UUID
fn short_id() -> String {
@ -133,6 +133,7 @@ impl Session {
tool_call_id: m.tool_call_id,
tool_name: m.tool_name,
tool_calls: m.tool_calls.map(|tc| serde_json::from_str(&tc).unwrap_or_default()),
source: m.source.and_then(|s| serde_json::from_str(&s).ok()),
}
}).collect();
@ -190,6 +191,7 @@ impl Session {
tool_call_id: message.tool_call_id.clone(),
tool_name: message.tool_name.clone(),
tool_calls: message.tool_calls.as_ref().map(|tc| serde_json::to_string(tc).unwrap_or_default()),
source: message.source.as_ref().map(|s| serde_json::to_string(s).unwrap_or_default()),
created_at: now,
};
storage.append_message_with_retry(&self.id.to_string(), &msg_meta).await?;
@ -547,6 +549,8 @@ pub struct SessionManager {
tools: Arc<ToolRegistry>,
skills_loader: Arc<SkillsLoader>,
storage: Arc<Storage>,
bus: Arc<MessageBus>,
current_source_session: Arc<Mutex<Option<String>>>,
}
struct SessionManagerInner {
@ -558,23 +562,7 @@ struct SessionManagerInner {
current_sessions: HashMap<String, String>,
}
fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
let mut registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
registry.register(WebFetchTool::new(50_000, 30));
registry.register(GetSkillTool::new(skills_loader));
registry
}
/// 斜杠命令定义
#[derive(Debug, Clone)]
@ -649,6 +637,7 @@ impl SessionManager {
session_ttl_hours: u64,
provider_config: LLMProviderConfig,
storage: Arc<Storage>,
bus: Arc<MessageBus>,
) -> Result<Self, AgentError> {
let skills_loader = SkillsLoader::new();
skills_loader.load_skills();
@ -667,9 +656,17 @@ impl SessionManager {
tools,
skills_loader,
storage,
bus,
current_source_session: Arc::new(Mutex::new(None)),
})
}
/// Register the send_message tool (requires self in Arc)
pub fn register_outbound_tool(self: &Arc<Self>, available_channels: Vec<String>) {
let messenger: Arc<dyn OutboundMessenger> = self.clone();
self.tools.register(SendMessageTool::new(messenger, available_channels));
}
pub fn tools(&self) -> Arc<ToolRegistry> {
self.tools.clone()
}
@ -1047,65 +1044,111 @@ impl SessionManager {
Err(AgentError::Other("clear_dialog_history not available".to_string()))
}
/// Get or activate a specific session by its full UnifiedSessionId.
/// Returns an error if the session does not exist in storage.
/// If the session was expired from memory but still in storage,
/// it will be restored (reactivated).
pub async fn get_or_activate_session(
&self,
unified_id: &UnifiedSessionId,
) -> Result<Arc<Mutex<Session>>, AgentError> {
let session_id_str = unified_id.to_string();
match self.storage.get_session(&session_id_str).await {
Ok(_) => self.get_or_create_session(unified_id).await,
Err(StorageError::NotFound(_)) => {
Err(AgentError::Other(format!("session not found: {}", unified_id)))
}
Err(e) => Err(AgentError::Other(format!("storage error: {}", e))),
}
}
async fn resolve_dialog_id(
&self,
channel: &str,
chat_id: &str,
) -> Result<UnifiedSessionId, AgentError> {
let chat_scope = format!("{}:{}", channel, chat_id);
let current_id = {
self.inner.lock().await.current_sessions.get(&chat_scope).cloned()
};
if let Some(ref current_id) = current_id {
match self.storage.get_session(current_id).await {
Ok(_) => {
let parts: Vec<&str> = current_id.split(':').collect();
if parts.len() == 3 {
return Ok(UnifiedSessionId::new(channel, chat_id, parts[2]));
}
}
Err(_) => {}
}
}
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
Ok(Some(meta)) => Ok(UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)),
_ => {
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
Ok(new_id)
}
}
}
/// Send a system notification (no LLM triggered).
///
/// Flow:
/// 1. Resolve target session (resolve_dialog_id)
/// 2. Write assistant message with source tag to history
/// 3. Publish OutboundMessage via bus to target channel
pub async fn send_notification(
&self,
channel: &str,
chat_id: &str,
content: &str,
system_name: &str,
task_id: Option<&str>,
) -> Result<(), AgentError> {
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
let session = self.get_or_create_session(&unified_id).await?;
{
let mut guard = session.lock().await;
let source = MessageSource {
kind: SourceKind::SystemNotification,
from_channel: None,
from_session: None,
from_user_id: None,
system_name: Some(system_name.to_string()),
task_id: task_id.map(|s| s.to_string()),
};
let msg = ChatMessage::assistant_with_source(content, source);
guard.add_message(msg, true).await
.map_err(|e| AgentError::Other(format!("persist error: {}", e)))?;
}
let outbound = OutboundMessage {
channel: channel.to_string(),
chat_id: chat_id.to_string(),
content: content.to_string(),
reply_to: None,
media: vec![],
metadata: HashMap::new(),
};
self.bus.publish_outbound(outbound).await
.map_err(|e| AgentError::Other(format!("bus publish error: {}", e)))?;
Ok(())
}
pub async fn handle_message(
&self,
channel: &str,
_sender_id: &str,
chat_id: &str,
content: &str,
media: Vec<crate::bus::MediaItem>,
media: Vec<MediaItem>,
) -> Result<HandleResult, AgentError> {
// Channel messages never carry dialog_id — routing is entirely via current_sessions
let unified_id = {
let chat_scope = format!("{}:{}", channel, chat_id);
let current_session_id = {
let inner = self.inner.lock().await;
inner.current_sessions.get(&chat_scope).cloned()
};
if let Some(current_id) = current_session_id {
// Verify current session still exists in Storage
match self.storage.get_session(&current_id).await {
Ok(_) => {
// Current session still valid, extract dialog_id
let parts: Vec<&str> = current_id.split(':').collect();
if parts.len() == 3 {
UnifiedSessionId::new(channel, chat_id, parts[2])
} else {
// Malformed, fallback to find or create
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
Ok(Some(m)) => UnifiedSessionId::new(channel, chat_id, &m.dialog_id),
_ => {
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
new_id
}
}
}
}
Err(_) => {
// Current session no longer exists, create new
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
new_id
}
}
} else {
// No current session tracked, find active or create new
let ttl_millis = self.inner.lock().await.session_ttl.as_millis() as i64;
tracing::debug!(channel, chat_id, ttl_millis, "No current_sessions entry, searching Storage for active session");
match self.storage.find_active_session(channel, chat_id, ttl_millis).await {
Ok(Some(meta)) => {
tracing::debug!(session_id = %meta.id, dialog_id = %meta.dialog_id, last_active_at = %meta.last_active_at, "Found active session in Storage");
UnifiedSessionId::new(channel, chat_id, &meta.dialog_id)
}
Ok(None) | Err(_) => {
tracing::debug!("No active session found in Storage, creating new session");
// Create new session
let (new_id, _) = self.create_session(channel, chat_id, None, String::new()).await?;
new_id
}
}
}
};
let unified_id = self.resolve_dialog_id(channel, chat_id).await?;
*self.current_source_session.lock().await = Some(unified_id.to_string());
tracing::debug!(unified_id = %unified_id, "handle_message resolved unified_id");
let session = self.get_or_create_session(&unified_id).await?;
@ -1121,9 +1164,11 @@ impl SessionManager {
match result {
Ok((_new_session_id, response)) => {
*self.current_source_session.lock().await = None;
return Ok(HandleResult::CommandOutput(response));
}
Err(e) => {
*self.current_source_session.lock().await = None;
return Ok(HandleResult::CommandOutput(e.to_string()));
}
}
@ -1183,6 +1228,8 @@ impl SessionManager {
"Agent response received"
);
*self.current_source_session.lock().await = None;
Ok(HandleResult::AgentResponse(response))
}
@ -1203,6 +1250,74 @@ impl SessionManager {
}
}
#[async_trait::async_trait]
impl OutboundMessenger for SessionManager {
async fn send_message(
&self,
channel: &str,
chat_id: &str,
dialog_id: Option<&str>,
content: &str,
mut source: MessageSource,
) -> Result<(), String> {
// Fill origin from current source session if not provided
if source.from_session.is_none() {
source.from_session = self.current_source_session.lock().await.clone();
}
let (target_sid, session) = if let Some(did) = dialog_id {
let sid = UnifiedSessionId::new(channel, chat_id, did);
let session = self.get_or_activate_session(&sid).await
.map_err(|e| e.to_string())?;
(sid, session)
} else {
let sid = self.resolve_dialog_id(channel, chat_id).await
.map_err(|e| e.to_string())?;
let session = self.get_or_create_session(&sid).await
.map_err(|e| e.to_string())?;
(sid, session)
};
// Build message prefix: [message from <origin> to <channel:chat_id:dialog_id>]
let target_id = target_sid.to_string();
let origin = source.from_session.as_deref().unwrap_or("unknown");
let origin_id = source.from_session.clone();
let prefix = format!("[message from {} to {}] ", origin, target_id);
let marked_content = format!("{}\n{}", prefix, content);
// Write source-tagged assistant message to target session history
{
let mut guard = session.lock().await;
let msg = ChatMessage::assistant_with_source(marked_content.clone(), source);
guard.add_message(msg, true).await
.map_err(|e| e.to_string())?;
}
// Restore active dialog if source and target share channel:chat_id but differ in dialog_id
if let Some(ref origin_id) = origin_id {
let parts: Vec<&str> = origin_id.split(':').collect();
if parts.len() == 3 && parts[0] == channel && parts[1] == chat_id && parts[2] != target_sid.dialog_id {
let scope = format!("{}:{}", channel, chat_id);
self.inner.lock().await.current_sessions.insert(scope, origin_id.clone());
}
}
// Publish OutboundMessage via bus to target channel
let outbound = OutboundMessage {
channel: channel.to_string(),
chat_id: chat_id.to_string(),
content: marked_content,
reply_to: None,
media: vec![],
metadata: HashMap::new(),
};
self.bus.publish_outbound(outbound).await
.map_err(|e| e.to_string())?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@ -11,5 +11,6 @@ pub struct MessageMeta {
pub tool_call_id: Option<String>,
pub tool_name: Option<String>,
pub tool_calls: Option<String>,
pub source: Option<String>,
pub created_at: i64,
}

View File

@ -66,6 +66,7 @@ impl Storage {
tool_call_id TEXT,
tool_name TEXT,
tool_calls TEXT,
source TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE
)
@ -83,6 +84,14 @@ impl Storage {
.execute(&self.pool)
.await?;
// Migration: add source column if upgrading from older schema
sqlx::query(
r#"ALTER TABLE messages ADD COLUMN source TEXT"#,
)
.execute(&self.pool)
.await
.ok();
Ok(())
}
@ -260,8 +269,8 @@ impl Storage {
pub async fn append_message(&self, session_id: &str, msg: &crate::storage::message::MessageMeta) -> Result<i64, StorageError> {
sqlx::query(
r#"
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO messages (id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(&msg.id)
@ -273,6 +282,7 @@ impl Storage {
.bind(&msg.tool_call_id)
.bind(&msg.tool_name)
.bind(&msg.tool_calls)
.bind(&msg.source)
.bind(msg.created_at)
.execute(self.pool())
.await?;
@ -300,7 +310,7 @@ impl Storage {
) -> Result<Vec<crate::storage::message::MessageMeta>, StorageError> {
let rows = sqlx::query(
r#"
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, created_at
SELECT id, session_id, seq, role, content, media_refs, tool_call_id, tool_name, tool_calls, source, created_at
FROM messages
WHERE session_id = ? AND seq >= ?
ORDER BY seq ASC
@ -323,6 +333,7 @@ impl Storage {
tool_call_id: row.get("tool_call_id"),
tool_name: row.get("tool_name"),
tool_calls: row.get("tool_calls"),
source: row.get("source"),
created_at: row.get("created_at"),
})
.collect())
@ -486,6 +497,7 @@ mod tests {
tool_call_id: None,
tool_name: None,
tool_calls: None,
source: None,
created_at: 1000,
};

View File

@ -7,6 +7,7 @@ pub mod get_skill;
pub mod http_request;
pub mod registry;
pub mod schema;
pub mod send_message;
pub mod traits;
pub mod web_fetch;
@ -19,5 +20,30 @@ pub use get_skill::GetSkillTool;
pub use http_request::HttpRequestTool;
pub use registry::ToolRegistry;
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use traits::{Tool, ToolResult};
pub use send_message::SendMessageTool;
pub use traits::{OutboundMessenger, Tool, ToolResult};
pub use web_fetch::WebFetchTool;
use std::sync::Arc;
use crate::skills::SkillsLoader;
/// Create the base tool registry (without send_message).
/// `send_message` tool is registered later via `SessionManager::register_outbound_tool()`
/// once the available channel names are known.
pub fn create_default_tools(skills_loader: Arc<SkillsLoader>) -> ToolRegistry {
let registry = ToolRegistry::new();
registry.register(CalculatorTool::new());
registry.register(FileReadTool::new());
registry.register(FileWriteTool::new());
registry.register(FileEditTool::new());
registry.register(BashTool::new());
registry.register(HttpRequestTool::new(
vec!["*".to_string()],
1_000_000,
30,
false,
));
registry.register(WebFetchTool::new(50_000, 30));
registry.register(GetSkillTool::new(skills_loader));
registry
}

View File

@ -1,36 +1,39 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::providers::{Tool, ToolFunction};
use super::traits::Tool as ToolTrait;
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn ToolTrait>>,
tools: Mutex<HashMap<String, Arc<dyn ToolTrait>>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
tools: Mutex::new(HashMap::new()),
}
}
pub fn register<T: ToolTrait + 'static>(&mut self, tool: T) {
self.tools.insert(tool.name().to_string(), Box::new(tool));
pub fn register<T: ToolTrait + 'static>(&self, tool: T) {
self.tools.lock().unwrap().insert(tool.name().to_string(), Arc::new(tool));
}
pub fn get(&self, name: &str) -> Option<&Box<dyn ToolTrait>> {
self.tools.get(name)
pub fn get(&self, name: &str) -> Option<Arc<dyn ToolTrait>> {
self.tools.lock().unwrap().get(name).cloned()
}
/// Get all registered tools.
/// Used for concurrent tool execution when we need to look up tools by name.
pub fn get_all(&self) -> Vec<&Box<dyn ToolTrait>> {
self.tools.values().collect()
pub fn get_all(&self) -> Vec<Arc<dyn ToolTrait>> {
self.tools.lock().unwrap().values().cloned().collect()
}
pub fn get_definitions(&self) -> Vec<Tool> {
self.tools
.lock()
.unwrap()
.values()
.map(|tool| Tool {
tool_type: "function".to_string(),
@ -44,15 +47,20 @@ impl ToolRegistry {
}
pub fn has_tools(&self) -> bool {
!self.tools.is_empty()
!self.tools.lock().unwrap().is_empty()
}
pub fn tool_names(&self) -> Vec<String> {
self.tools.keys().cloned().collect()
self.tools.lock().unwrap().keys().cloned().collect()
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &Box<dyn ToolTrait>)> {
self.tools.iter()
pub fn iter(&self) -> Vec<(String, Arc<dyn ToolTrait>)> {
self.tools
.lock()
.unwrap()
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
}

171
src/tools/send_message.rs Normal file
View File

@ -0,0 +1,171 @@
use std::sync::Arc;
use std::collections::HashSet;
use async_trait::async_trait;
use crate::bus::{MessageSource, SourceKind};
use super::traits::{OutboundMessenger, Tool, ToolResult};
pub struct SendMessageTool {
messenger: Arc<dyn OutboundMessenger>,
available_channels: HashSet<String>,
}
impl SendMessageTool {
pub fn new(messenger: Arc<dyn OutboundMessenger>, available_channels: Vec<String>) -> Self {
Self {
messenger,
available_channels: available_channels.into_iter().collect(),
}
}
}
/// Parse target_chat_id into (channel, chat_id, optional dialog_id).
/// Accepts two formats:
/// - Two-part: `<channel>:<chat_id>` → sends to latest active session for that chat
/// - Three-part: `<channel>:<chat_id>:<dialog_id>` → sends to specific session
fn parse_target_chat_id(raw: &str) -> Result<(&str, &str, Option<&str>), String> {
let parts: Vec<&str> = raw.split(':').collect();
match parts.len() {
2 => {
if parts[0].is_empty() || parts[1].is_empty() {
Err(format!("Invalid target_chat_id format '{}': channel and chat_id must not be empty", raw))
} else {
Ok((parts[0], parts[1], None))
}
}
3 => {
if parts[0].is_empty() || parts[1].is_empty() || parts[2].is_empty() {
Err(format!("Invalid target_chat_id format '{}': all three parts must not be empty", raw))
} else {
Ok((parts[0], parts[1], Some(parts[2])))
}
}
_ => Err(format!(
"Invalid target_chat_id format '{}'. Expected <channel>:<chat_id> or <channel>:<chat_id>:<dialog_id>",
raw
)),
}
}
#[async_trait]
impl Tool for SendMessageTool {
fn name(&self) -> &str {
"send_message"
}
fn description(&self) -> &str {
"向指定渠道的会话发送消息。用于在用户请求下向其他渠道发送内容。\
target_chat_id <channel>:<chat_id>\
<channel>:<chat_id>:<dialog_id>"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"target_chat_id": {
"type": "string",
"description": "目标会话ID。支持两种格式: 1) <channel>:<chat_id> 发送到该聊天下最新活跃会话, 无则自动创建; 2) <channel>:<chat_id>:<dialog_id> 发送到指定会话, 过期则自动激活。channel 可选值: feishu, cli_chat"
},
"content": {
"type": "string",
"description": "要发送的消息内容"
},
"origin": {
"type": "string",
"description": "可选。消息来源标识。不填则自动使用当前会话的完整 session_id (<channel>:<chat_id>:<dialog_id>)"
}
},
"required": ["target_chat_id", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let raw_id = args["target_chat_id"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing target_chat_id"))?;
let content = args["content"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("missing content"))?;
// 1. Parse target_chat_id
let (channel, chat_id, dialog_id) = parse_target_chat_id(raw_id)
.map_err(|e| anyhow::anyhow!(e))?;
// 2. Validate channel
if !self.available_channels.contains(channel) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Channel '{}' is not available. Available channels: {}",
channel,
self.available_channels.iter().cloned().collect::<Vec<_>>().join(", ")
)),
});
}
let from_session = args["origin"].as_str().map(|s| s.to_string());
let source = MessageSource {
kind: SourceKind::CrossChannel,
from_channel: Some("tool".to_string()),
from_session,
from_user_id: None,
system_name: None,
task_id: None,
};
// 3. Send via messenger
match self.messenger
.send_message(channel, chat_id, dialog_id, content, source)
.await
{
Ok(()) => Ok(ToolResult {
success: true,
output: "消息已发送".to_string(),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_target_chat_id_two_part() {
let (ch, cid, did) = parse_target_chat_id("feishu:oc_abc123").unwrap();
assert_eq!(ch, "feishu");
assert_eq!(cid, "oc_abc123");
assert!(did.is_none());
}
#[test]
fn test_parse_target_chat_id_three_part() {
let (ch, cid, did) = parse_target_chat_id("feishu:oc_abc123:dialog1").unwrap();
assert_eq!(ch, "feishu");
assert_eq!(cid, "oc_abc123");
assert_eq!(did, Some("dialog1"));
}
#[test]
fn test_parse_target_chat_id_invalid_one_part() {
assert!(parse_target_chat_id("feishu").is_err());
}
#[test]
fn test_parse_target_chat_id_empty_parts() {
assert!(parse_target_chat_id("feishu:").is_err());
assert!(parse_target_chat_id(":chat_id").is_err());
assert!(parse_target_chat_id("feishu::dialog").is_err());
}
}

View File

@ -1,4 +1,5 @@
use async_trait::async_trait;
use crate::bus::MessageSource;
#[derive(Debug, Clone)]
pub struct ToolResult {
@ -29,3 +30,15 @@ pub trait Tool: Send + Sync + 'static {
false
}
}
#[async_trait]
pub trait OutboundMessenger: Send + Sync {
async fn send_message(
&self,
channel: &str,
chat_id: &str,
dialog_id: Option<&str>,
content: &str,
source: MessageSource,
) -> Result<(), String>;
}