diff --git a/.gitignore b/.gitignore index 16751f8..b04e58c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,4 @@ reference/** .env *.env -AGENTS.md -CLAUDE.md Cargo.lock diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..d5ec9a9 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,73 @@ +# PicoBot + +## Maintenance + +- **Update this file on any architectural change** — module boundaries, data flow, key constraints, or build/test commands must be reflected here + +## Build & Run + +- `cargo build` — build the binary +- `cargo run -- gateway` — start gateway server (binds `127.0.0.1:19876` by default) +- `cargo run -- chat` — connect to gateway as CLI client (default `ws://127.0.0.1:19876/ws`) + +## Config + +- Config file: `~/.picobot/config.json` or `./config.json` (fallback order) +- `.env` is loaded and env var placeholders `` are substituted into config +- Config example: `config.example.json` + +## Tests + +- `cargo test --lib` — run unit tests (FAILS: `src/session/session.rs:657` missing `workspace_dir` field in test helper) +- `cargo test --test test_integration -- --ignored` — run integration tests (requires `tests/test.env` with API keys) + +## Reference + +- `reference/` — third-party reference implementations (nanobot, Mini-Agent, zeroclaw); not part of this project; use for similar functionality patterns + +## Architecture + +### Modes + +- **Gateway mode** (`cargo run -- gateway`): HTTP/WebSocket server; owns `GatewayState` which holds all services +- **Client mode** (`cargo run -- chat`): TUI chat client; connects to gateway via WebSocket, purely for user interaction + +### Core Data Flow + +``` +Channel → MessageBus → SessionManager → AgentLoop → (tools) → SessionManager → MessageBus → OutboundDispatcher → Channel + ↑ + ControlChannel ──→ SessionManager (dialog ops: create/switch/archive/delete) +``` + +### Modules + +| Module | Responsibility | Key Types | +|--------|---------------|-----------| +| `gateway` | Server lifecycle, HTTP/WS endpoints, owns `GatewayState` | `GatewayState`, `run()` | +| `client` | TUI rendering, WebSocket client for CLI chat | `App`, `run()` | +| `channels` | External integrations (Feishu, CLI chat) | `ChannelManager`, `Channel` trait | +| `bus` | Async message queue (inbound/outbound/control channels) | `MessageBus`, `InboundMessage`, `OutboundMessage`, `ControlMessage` | +| `session` | Conversation session lifecycle, dialog operations | `SessionManager`, `Session` | +| `agent` | LLM call loop, tool execution, context compression | `AgentLoop` | +| `providers` | LLM API clients (OpenAI-compatible, Anthropic) | `LLMProvider` trait, factory `create_provider()` | +| `tools` | Agent tools (bash, file operations, http, web) | `ToolRegistry`, `Tool` trait | + +### Functional Boundaries + +- **Channels** only send/receive messages via `MessageBus`; they know nothing about sessions or LLM +- **MessageBus** is a pure async queue; it routes nothing, just passes messages +- **SessionManager** owns session state and dialog operations; it does NOT call LLM directly +- **AgentLoop** receives dialog events from `SessionManager`, calls LLM via `providers`, executes tools, returns text responses +- **Providers** are pure HTTP clients; no bus/session/channel awareness +- **Tools** are executed by `AgentLoop`; they receive raw arguments and return string results + +### Key Constraints + +- Gateway **changes working directory** to workspace on startup (`src/gateway/mod.rs:31`) +- `ChannelManager` owns the `MessageBus` and all channel instances +- `OutboundDispatcher` routes outbound messages to the correct channel via `ChannelManager` + +## Known Issues + +- `src/session/session.rs:657` — `LLMProviderConfig` struct requires `workspace_dir` but test helper at line 656-669 doesn't provide it; test code needs `workspace_dir: PathBuf::new()` added diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..7457da4 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,3 @@ +# Claude Code + +Read `AGENTS.md` for project context, build commands, architecture, and conventions. diff --git a/Cargo.toml b/Cargo.toml index 9f104ba..f557e7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,3 +32,5 @@ ratatui = "0.27" crossterm = { version = "0.28", features = ["event-stream"] } termimad = "0.34" textwrap = "0.16" +chrono = "0.4" +hostname = "0.3" diff --git a/config.example.json b/config.example.json new file mode 100644 index 0000000..43ef441 --- /dev/null +++ b/config.example.json @@ -0,0 +1,67 @@ +{ + "providers": { + "aliyun": { + "type": "openai", + "base_url": "https://api.openai.com/v1", + "api_key": "", + "extra_headers": {} + }, + "openai": { + "type": "openai", + "base_url": "https://api.openai.com/v1", + "api_key": "", + "extra_headers": {} + }, + "anthropic": { + "type": "anthropic", + "base_url": "https://api.anthropic.com/v1", + "api_key": "", + "extra_headers": {} + } + }, + "models": { + "qwen-plus": { + "model_id": "qwen-plus", + "temperature": 0.0, + "max_tokens": 8192 + }, + "gpt-4o": { + "model_id": "gpt-4o", + "temperature": 0.7, + "max_tokens": 4096 + }, + "claude-sonnet-4-20250514": { + "model_id": "claude-sonnet-4-20250514", + "temperature": 0.7, + "max_tokens": 8192 + } + }, + "agents": { + "default": { + "provider": "aliyun", + "model": "qwen-plus", + "max_tool_iterations": 20, + "token_limit": 128000 + } + }, + "gateway": { + "host": "127.0.0.1", + "port": 19876, + "session_ttl_hours": 168 + }, + "client": { + "gateway_url": "ws://127.0.0.1:19876/ws" + }, + "channels": { + "feishu": { + "enabled": true, + "app_id": "", + "app_secret": "", + "allow_from": ["*"], + "agent": "default", + "media_dir": "~/.picobot/media/feishu", + "reaction_emoji": "Typing" + } + }, + "workspace_dir": "~/.picobot/workspace" +} diff --git a/src/agent/agent_loop.rs b/src/agent/agent_loop.rs index 8e0c9a8..8bcfa10 100644 --- a/src/agent/agent_loop.rs +++ b/src/agent/agent_loop.rs @@ -1,3 +1,4 @@ +use crate::agent::system_prompt::build_system_prompt; use crate::bus::message::ContentBlock; use crate::bus::ChatMessage; use crate::config::LLMProviderConfig; @@ -9,6 +10,7 @@ use crate::tools::ToolRegistry; use std::collections::VecDeque; use std::hash::{Hash, Hasher}; use std::io::Read; +use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; @@ -222,6 +224,8 @@ pub struct AgentLoop { tools: Arc, observer: Option>, max_iterations: usize, + workspace_dir: PathBuf, + model_name: String, } #[derive(Debug, Clone)] @@ -234,6 +238,8 @@ impl AgentLoop { /// Create a new AgentLoop with a provider created from config. pub fn new(provider_config: LLMProviderConfig) -> Result { let max_iterations = provider_config.max_tool_iterations; + let model_name = provider_config.model_id.clone(); + let workspace_dir = provider_config.workspace_dir.clone(); let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; @@ -242,12 +248,16 @@ impl AgentLoop { tools: Arc::new(ToolRegistry::new()), observer: None, max_iterations, + workspace_dir, + model_name, }) } /// Create a new AgentLoop with provider created from config and given tools. pub fn with_tools(provider_config: LLMProviderConfig, tools: Arc) -> Result { let max_iterations = provider_config.max_tool_iterations; + let model_name = provider_config.model_id.clone(); + let workspace_dir = provider_config.workspace_dir.clone(); let provider = create_provider(provider_config) .map_err(|e| AgentError::ProviderCreation(e.to_string()))?; @@ -256,16 +266,20 @@ impl AgentLoop { tools, observer: None, max_iterations, + workspace_dir, + model_name, }) } /// Create a new AgentLoop with an existing shared provider. - pub fn with_provider(provider: Arc, max_iterations: usize) -> Self { + pub fn with_provider(provider: Arc, max_iterations: usize, model_name: String, workspace_dir: PathBuf) -> Self { Self { provider, tools: Arc::new(ToolRegistry::new()), observer: None, max_iterations, + workspace_dir, + model_name, } } @@ -274,15 +288,25 @@ impl AgentLoop { provider: Arc, tools: Arc, max_iterations: usize, + model_name: String, + workspace_dir: PathBuf, ) -> Self { Self { provider, tools, observer: None, max_iterations, + workspace_dir, + model_name, } } + /// Set the workspace directory. + pub fn with_workspace_dir(mut self, dir: PathBuf) -> Self { + self.workspace_dir = dir; + self + } + /// Set an observer for tracking events. pub fn with_observer(mut self, observer: Arc) -> Self { self.observer = Some(observer); @@ -304,6 +328,15 @@ impl AgentLoop { #[cfg(debug_assertions)] tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process"); + // Build and inject system prompt if not present + let has_system = messages.first().map_or(false, |m| m.role == "system"); + if !has_system { + let system_prompt = build_system_prompt(&self.workspace_dir, &self.model_name, &self.tools); + #[cfg(debug_assertions)] + tracing::debug!("System prompt injected:\n{}", system_prompt); + messages.insert(0, ChatMessage::system(system_prompt)); + } + // Track tool calls for loop detection let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default()); let mut emitted_messages = Vec::new(); diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 4dd5762..1aace0c 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,7 @@ pub mod agent_loop; pub mod context_compressor; +pub mod system_prompt; pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult}; pub use context_compressor::ContextCompressor; +pub use system_prompt::{build_system_prompt, PromptContext, PromptSection, SystemPromptBuilder}; diff --git a/src/agent/system_prompt.rs b/src/agent/system_prompt.rs new file mode 100644 index 0000000..b2575bf --- /dev/null +++ b/src/agent/system_prompt.rs @@ -0,0 +1,353 @@ +//! System prompt construction for PicoBot agent. +//! +//! This module provides a modular framework for building system prompts +//! using the SystemPromptBuilder pattern. +//! +//! Configuration: +//! - USER.md is loaded from ~/.picobot/USER.md (user's personal configuration) + +use crate::tools::ToolRegistry; +use std::fmt::Write; +use std::path::Path; + +/// Maximum characters per injected workspace file. +pub const BOOTSTRAP_MAX_CHARS: usize = 16_000; + +/// Context for building system prompts. +pub struct PromptContext<'a> { + pub workspace_dir: &'a Path, + pub model_name: &'a str, + pub tools: &'a ToolRegistry, +} + +/// Trait for system prompt sections. +pub trait PromptSection: Send + Sync { + fn name(&self) -> &str; + fn build(&self, ctx: &PromptContext<'_>) -> String; +} + +/// Builder for constructing system prompts from modular sections. +#[derive(Default)] +pub struct SystemPromptBuilder { + sections: Vec>, +} + +impl SystemPromptBuilder { + /// Create a new builder with default sections. + pub fn with_defaults() -> Self { + Self { + sections: vec![ + Box::new(ToolHonestySection), + Box::new(NoToolNarrationSection), + Box::new(ToolsSection), + Box::new(YourTaskSection), + Box::new(SafetySection), + Box::new(WorkspaceSection), + Box::new(UserProfileSection), + Box::new(DateTimeSection), + Box::new(RuntimeSection), + ], + } + } + + /// Add a custom section to the builder. + pub fn add_section(mut self, section: Box) -> Self { + self.sections.push(section); + self + } + + /// Build the complete system prompt. + pub fn build(&self, ctx: &PromptContext<'_>) -> String { + let mut output = String::with_capacity(8192); + for section in &self.sections { + let part = section.build(ctx); + if part.trim().is_empty() { + continue; + } + output.push_str(part.trim_end()); + output.push_str("\n\n"); + } + output + } +} + +// === Prompt Section Implementations === + +/// Critical rule: never fabricate tool results. +pub struct ToolHonestySection; + +impl PromptSection for ToolHonestySection { + fn name(&self) -> &str { + "tool_honesty" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + "## CRITICAL: Tool Honesty + +- NEVER fabricate, invent, or guess tool results. If a tool returns empty results, say \"No results found.\" +- If a tool call fails, report the error - never make up data to fill the gap. +- When unsure whether a tool call succeeded, ask the user rather than guessing." + .to_string() + } +} + +/// Critical rule: never narrate tool usage. +pub struct NoToolNarrationSection; + +impl PromptSection for NoToolNarrationSection { + fn name(&self) -> &str { + "no_narration" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + "## CRITICAL: No Tool Narration + +NEVER narrate, announce, describe, or explain your tool usage to the user. +Do NOT say things like \"Let me check...\", \"I will use bash to...\", \"I'll fetch that for you\", \"Searching now...\", or similar. +The user must ONLY see the final answer. Tool calls are invisible infrastructure - never reference them. +If you catch yourself starting a sentence about what tool you are about to use or just used, DELETE it and give the answer directly." + .to_string() + } +} + +/// List of available tools. +pub struct ToolsSection; + +impl PromptSection for ToolsSection { + fn name(&self) -> &str { + "tools" + } + + fn build(&self, ctx: &PromptContext<'_>) -> String { + if !ctx.tools.has_tools() { + return String::new(); + } + + let mut output = String::from("## Tools\n\nYou have access to the following tools:\n\n"); + for (name, tool) in ctx.tools.iter() { + let _ = writeln!(output, "- **{}**: {}", name, tool.description()); + } + output + } +} + +/// Instructions for the task. +pub struct YourTaskSection; + +impl PromptSection for YourTaskSection { + fn name(&self) -> &str { + "your_task" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + "## Your Task + +When the user sends a message, ACT on it. Use the tools to fulfill their request. +Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions. +Instead: use tools directly when needed, and give the final answer when done." + .to_string() + } +} + +/// Safety guidelines. +pub struct SafetySection; + +impl PromptSection for SafetySection { + fn name(&self) -> &str { + "safety" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + "## Safety + +- Do not exfiltrate private data. +- Do not run destructive commands without asking. +- Do not bypass oversight or approval mechanisms. +- Prefer safe operations over risky ones. +- When in doubt, ask before acting externally." + .to_string() + } +} + +/// Workspace directory information and guidelines. +pub struct WorkspaceSection; + +impl PromptSection for WorkspaceSection { + fn name(&self) -> &str { + "workspace" + } + + fn build(&self, ctx: &PromptContext<'_>) -> String { + // Try to get absolute path + let abs_path = ctx + .workspace_dir + .canonicalize() + .unwrap_or_else(|_| ctx.workspace_dir.to_path_buf()); + format!( + "## Workspace\n\nWorking directory: `{}`\n\n### File Storage Guidelines\n\n- **Generated files**: Store all generated files (code, documents, artifacts) in the workspace directory or its subdirectories.\n- **Downloaded files**: Save downloaded files to the workspace directory, organized by task.\n- **One task, one folder**: Create a dedicated subfolder for each task or project (e.g., `task_2024_01_01/`).\n- **Temporary files**: If files are only needed during processing and won't be kept, use `/tmp/` or create a temp folder (e.g., `/tmp/picobot_task_xxx/`) instead of cluttering the workspace.\n\n### Working Directory Structure\n\nThe workspace is your home base for this session. Keep it organized by creating subdirectories for different tasks.", + abs_path.display() + ) + } +} + +/// User profile from ~/.picobot/USER.md. +pub struct UserProfileSection; + +impl PromptSection for UserProfileSection { + fn name(&self) -> &str { + "user_profile" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + let mut output = String::from("## User Profile\n\n"); + + // Load USER.md from ~/.picobot/USER.md + if let Some(user_config_dir) = get_user_config_dir() { + if let Some(content) = + load_file_from_dir(&user_config_dir, "USER.md", BOOTSTRAP_MAX_CHARS) + { + output.push_str(&content); + return output; + } + } + + // No USER.md found, return empty + String::new() + } +} + +/// Current date and time. +pub struct DateTimeSection; + +impl PromptSection for DateTimeSection { + fn name(&self) -> &str { + "datetime" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> String { + let now = chrono::Local::now(); + format!( + "## Current Date & Time\n\n{} ({})", + now.format("%Y-%m-%d %H:%M:%S"), + now.format("%Z") + ) + } +} + +/// Runtime environment information. +pub struct RuntimeSection; + +impl PromptSection for RuntimeSection { + fn name(&self) -> &str { + "runtime" + } + + fn build(&self, ctx: &PromptContext<'_>) -> String { + let host = hostname::get() + .map(|h| h.to_string_lossy().to_string()) + .unwrap_or_else(|_| "unknown".to_string()); + format!( + "## Runtime\n\nHost: {} | OS: {} | Model: {}", + host, + std::env::consts::OS, + ctx.model_name + ) + } +} + +// === Helper Functions === + +/// Get user config directory (~/.picobot/). +fn get_user_config_dir() -> Option { + dirs::home_dir().map(|home| home.join(".picobot")) +} + +/// Load a file from specified directory with truncation. +fn load_file_from_dir(dir: &Path, filename: &str, max_chars: usize) -> Option { + let path = dir.join(filename); + match std::fs::read_to_string(&path) { + Ok(content) => { + let trimmed = content.trim(); + if trimmed.is_empty() { + return None; + } + let truncated = if trimmed.chars().count() > max_chars { + trimmed + .char_indices() + .nth(max_chars) + .map(|(idx, _)| &trimmed[..idx]) + .unwrap_or(trimmed) + .to_string() + + &format!( + "\n\n[... truncated at {} characters - use file_read for full file]", + max_chars + ) + } else { + trimmed.to_string() + }; + Some(truncated) + } + Err(_) => None, + } +} + +/// Build a complete system prompt with default configuration. +pub fn build_system_prompt(workspace_dir: &Path, model_name: &str, tools: &ToolRegistry) -> String { + let ctx = PromptContext { + workspace_dir, + model_name, + tools, + }; + SystemPromptBuilder::with_defaults().build(&ctx) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_builder_creates_sections() { + let temp_dir = std::env::temp_dir(); + let tools = ToolRegistry::new(); + + let ctx = PromptContext { + workspace_dir: &temp_dir, + model_name: "test-model", + tools: &tools, + }; + + let prompt = SystemPromptBuilder::with_defaults().build(&ctx); + + assert!(prompt.contains("## CRITICAL: Tool Honesty")); + assert!(prompt.contains("## CRITICAL: No Tool Narration")); + assert!(prompt.contains("## Safety")); + assert!(prompt.contains("## Workspace")); + assert!(prompt.contains("## Current Date & Time")); + assert!(prompt.contains("## Runtime")); + } + + #[test] + fn test_load_file_from_dir() { + let temp_dir = tempfile::tempdir().unwrap(); + let test_file = temp_dir.path().join("TEST.md"); + std::fs::write(&test_file, "Hello, world!").unwrap(); + + let content = load_file_from_dir(temp_dir.path(), "TEST.md", 100); + assert_eq!(content, Some("Hello, world!".to_string())); + + let content = load_file_from_dir(temp_dir.path(), "NOT_EXIST.md", 100); + assert_eq!(content, None); + } + + #[test] + fn test_build_system_prompt() { + let temp_dir = std::env::temp_dir(); + let tools = ToolRegistry::new(); + + let prompt = build_system_prompt(&temp_dir, "test-model", &tools); + + assert!(!prompt.is_empty()); + assert!(prompt.contains("test-model")); + } +} diff --git a/src/channels/feishu.rs b/src/channels/feishu.rs index b33fb25..c95a883 100644 --- a/src/channels/feishu.rs +++ b/src/channels/feishu.rs @@ -173,9 +173,13 @@ struct ParsedMessage { impl FeishuChannel { pub fn new( - config: FeishuChannelConfig, - _provider_config: LLMProviderConfig, + mut config: FeishuChannelConfig, + workspace_dir: &Path, ) -> Result { + // Override media_dir to use workspace_dir/media/feishu + let media_dir = workspace_dir.join("media").join("feishu"); + config.media_dir = media_dir.to_string_lossy().to_string(); + Ok(Self { config, http_client: reqwest::Client::new(), diff --git a/src/channels/manager.rs b/src/channels/manager.rs index cf6a547..b144dd3 100644 --- a/src/channels/manager.rs +++ b/src/channels/manager.rs @@ -43,19 +43,19 @@ impl ChannelManager { pub async fn init( &self, config: &Config, - _provider_config: crate::config::LLMProviderConfig, + workspace_dir: std::path::PathBuf, ) -> Result<(), ChannelError> { // Initialize Feishu channel if enabled if let Some(feishu_config) = config.channels.get("feishu") { if feishu_config.enabled { - let channel = FeishuChannel::new(feishu_config.clone(), _provider_config) + let channel = FeishuChannel::new(feishu_config.clone(), &workspace_dir) .map_err(|e| ChannelError::Other(format!("Failed to create Feishu channel: {}", e)))?; self.channels .write() .await .insert("feishu".to_string(), Arc::new(channel)); - tracing::info!("Feishu channel registered"); + tracing::info!("Feishu channel registered (media_dir: {}/media/feishu)", workspace_dir.display()); } else { tracing::info!("Feishu channel disabled in config"); } diff --git a/src/config/mod.rs b/src/config/mod.rs index e2fecce..eed56ca 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,6 +5,39 @@ use std::env; use std::fs; use std::path::{Path, PathBuf}; +/// Get the user configuration directory (~/.picobot) +pub fn get_user_config_dir() -> PathBuf { + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".picobot") +} + +/// Get the default workspace directory (~/.picobot/workspace) +pub fn get_default_workspace_dir() -> PathBuf { + get_user_config_dir().join("workspace") +} + +/// Expand ~ in path to user home directory +pub fn expand_path(path: &str) -> PathBuf { + if path.starts_with("~/") { + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(&path[2..]) + } else { + PathBuf::from(path) + } +} + +/// Ensure workspace directory exists, create if needed +pub fn ensure_workspace_dir(path: &Path) -> Result { + if !path.exists() { + tracing::info!("Creating workspace directory: {}", path.display()); + fs::create_dir_all(path)?; + } + // Return canonical path + path.canonicalize().or_else(|_| Ok(path.to_path_buf())) +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Config { pub providers: HashMap, @@ -16,6 +49,12 @@ pub struct Config { pub client: ClientConfig, #[serde(default)] pub channels: HashMap, + #[serde(default = "default_workspace_dir")] + pub workspace_dir: String, +} + +fn default_workspace_dir() -> String { + get_default_workspace_dir().to_string_lossy().to_string() } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -40,8 +79,10 @@ fn default_allow_from() -> Vec { } fn default_media_dir() -> String { - let home = dirs::home_dir().unwrap_or_else(|| std::path::PathBuf::from(".")); - home.join(".picobot/media/feishu").to_string_lossy().to_string() + get_user_config_dir() + .join("media/feishu") + .to_string_lossy() + .to_string() } fn default_reaction_emoji() -> String { @@ -146,11 +187,11 @@ pub struct LLMProviderConfig { pub model_extra: HashMap, pub max_tool_iterations: usize, pub token_limit: usize, + pub workspace_dir: PathBuf, } fn get_default_config_path() -> PathBuf { - let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")); - home.join(".picobot").join("config.json") + get_user_config_dir().join("config.json") } impl Config { @@ -186,13 +227,19 @@ impl Config { } pub fn get_provider_config(&self, agent_name: &str) -> Result { - let agent = self.agents.get(agent_name) + let agent = self + .agents + .get(agent_name) .ok_or(ConfigError::AgentNotFound(agent_name.to_string()))?; - let provider = self.providers.get(&agent.provider) + let provider = self + .providers + .get(&agent.provider) .ok_or(ConfigError::ProviderNotFound(agent.provider.clone()))?; - let model = self.models.get(&agent.model) + let model = self + .models + .get(&agent.model) .ok_or(ConfigError::ModelNotFound(agent.model.clone()))?; Ok(LLMProviderConfig { @@ -207,6 +254,7 @@ impl Config { model_extra: model.extra.clone(), max_tool_iterations: agent.max_tool_iterations, token_limit: agent.token_limit, + workspace_dir: expand_path(&self.workspace_dir), }) } } @@ -260,18 +308,19 @@ fn resolve_env_placeholders(content: &str) -> String { re.replace_all(content, |caps: ®ex::Captures| { let var_name = &caps[1]; env::var(var_name).unwrap_or_else(|_| caps[0].to_string()) - }).to_string() + }) + .to_string() } #[cfg(test)] mod tests { use super::*; - fn write_test_config() -> tempfile::NamedTempFile { - let file = tempfile::NamedTempFile::new().unwrap(); - std::fs::write( - file.path(), - r#"{ + fn write_test_config() -> tempfile::NamedTempFile { + let file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + r#"{ "providers": { "aliyun": { "type": "openai", @@ -306,15 +355,15 @@ mod tests { "port": 19876 } }"#, - ) - .unwrap(); - file - } + ) + .unwrap(); + file + } #[test] fn test_config_load() { - let file = write_test_config(); - let config = Config::load(file.path().to_str().unwrap()).unwrap(); + let file = write_test_config(); + let config = Config::load(file.path().to_str().unwrap()).unwrap(); // Check providers assert!(config.providers.contains_key("volcengine")); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index c755085..52d6710 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -8,12 +8,13 @@ use tokio::net::TcpListener; use crate::bus::{ControlMessage, OutboundDispatcher}; use crate::channels::{ChannelManager, CliChatChannel}; use crate::channels::base::{Channel, ChannelError}; -use crate::config::Config; +use crate::config::{Config, expand_path, ensure_workspace_dir}; use crate::logging; use crate::session::SessionManager; pub struct GatewayState { pub config: Config, + pub workspace_dir: std::path::PathBuf, pub session_manager: SessionManager, pub channel_manager: ChannelManager, } @@ -22,8 +23,20 @@ impl GatewayState { pub fn new() -> Result> { let config = Config::load_default()?; + // Initialize workspace directory: expand path and ensure it exists + let workspace_path = expand_path(&config.workspace_dir); + let workspace_path = ensure_workspace_dir(&workspace_path)?; + + // Switch current working directory to workspace + std::env::set_current_dir(&workspace_path) + .map_err(|e| format!("Failed to switch to workspace directory {}: {}", workspace_path.display(), e))?; + + tracing::info!("Using workspace directory: {}", workspace_path.display()); + // Get provider config for SessionManager - let provider_config = config.get_provider_config("default")?; + let mut provider_config = config.get_provider_config("default")?; + // Override workspace_dir with the ensured path + provider_config.workspace_dir = workspace_path.clone(); // Session TTL from config (default 4 hours) let session_ttl_hours = config.gateway.session_ttl_hours.unwrap_or(4); @@ -36,6 +49,7 @@ impl GatewayState { Ok(Self { config, + workspace_dir: workspace_path, session_manager, channel_manager, }) @@ -188,13 +202,10 @@ pub async fn run(host: Option, port: Option) -> Result<(), Box Vec { self.tools.keys().cloned().collect() } + + pub fn iter(&self) -> impl Iterator)> { + self.tools.iter() + } } impl Default for ToolRegistry {