From d8d77824f472678f7ef47b99c41a8521db61e92b Mon Sep 17 00:00:00 2001 From: fengmengqi Date: Mon, 13 Apr 2026 14:39:17 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=88=E5=B9=B6=E4=B8=8A=E6=B8=B8=20?= =?UTF-8?q?Rust=20=E5=AE=9E=E7=8E=B0=EF=BC=8C=E6=89=A9=E5=B1=95=20API/?= =?UTF-8?q?=E8=BF=90=E8=A1=8C=E6=97=B6/=E5=B7=A5=E5=85=B7=E9=93=BE?= =?UTF-8?q?=E8=83=BD=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 claw-code/rust/crates 的完整实现合并到主 workspace,涵盖 9 个 crate 的更新与 2 个新 crate 的引入。 API 层: - 用原生 Anthropic 客户端(anthropic.rs)替换 claw_provider, 新增 prompt cache 减少重复请求开销 - 新增 HTTP 客户端构建器统一代理配置,OpenAI 兼容端增加 DashScope/Qwen 支持与抖动重试 - MessageRequest 扩展 temperature/top_p 等模型调参字段 - SSE 解析器增加 provider 上下文感知的错误信息 运行时(~11,000 行新增): - 新增 bash 命令安全校验、分支锁碰撞检测、配置文件校验 - 新增会话存储与控制面、MCP 生命周期状态机与服务端实现 - 新增权限执行引擎、策略引擎、插件生命周期管理 - 新增 worker 启动编排、任务/定时任务注册表、信任解析器 - 保留 Windows cmd /C fallback 命令/插件/工具: - commands 大幅重写,扩展 sandbox、doctor、plan 等 slash 命令 - plugins 新增 PostToolUseFailure hook 与宽容加载机制 - tools 新增 PDF 提取与 lane 补全工具 新增 crate:mock-anthropic-service(测试)、telemetry(遥测) 适配 claw-cli/server:ClawApiClient→AnthropicClient 重命名, SlashCommand::parse 返回 Result,移除 session 级 Thinking 变体, TokenUsage/ConversationMessage 补充序列化支持 --- Cargo.lock | 56 +- Cargo.toml | 32 + crates/api/Cargo.toml | 1 + crates/api/src/client.rs | 176 +- crates/api/src/error.rs | 463 +- crates/api/src/http_client.rs | 344 + crates/api/src/lib.rs | 20 +- crates/api/src/prompt_cache.rs | 735 ++ .../{claw_provider.rs => anthropic.rs} | 844 +- crates/api/src/providers/mod.rs | 947 +- crates/api/src/providers/openai_compat.rs | 897 +- crates/api/src/sse.rs | 149 +- crates/api/src/types.rs | 101 +- crates/api/tests/client_integration.rs | 483 +- crates/api/tests/openai_compat_integration.rs | 124 +- .../api/tests/provider_client_integration.rs | 20 +- crates/api/tests/proxy_integration.rs | 173 + crates/claw-cli/src/main.rs | 110 +- crates/commands/src/lib.rs | 4872 ++++++-- crates/compat-harness/src/lib.rs | 10 +- crates/lsp/src/lib.rs | 1 + crates/mock-anthropic-service/Cargo.toml | 18 + crates/mock-anthropic-service/src/lib.rs | 1123 ++ crates/mock-anthropic-service/src/main.rs | 34 + .../plugin.json | 0 .../plugin.json | 0 crates/plugins/src/hooks.rs | 229 +- crates/plugins/src/lib.rs | 636 +- crates/runtime/Cargo.toml | 5 +- crates/runtime/src/bash.rs | 107 +- crates/runtime/src/bash_validation.rs | 1004 ++ crates/runtime/src/bootstrap.rs | 57 +- crates/runtime/src/branch_lock.rs | 144 + crates/runtime/src/compact.rs | 277 +- crates/runtime/src/config.rs | 857 +- crates/runtime/src/config_validate.rs | 901 ++ crates/runtime/src/conversation.rs | 1038 +- crates/runtime/src/file_ops.rs | 303 +- crates/runtime/src/git_context.rs | 324 + crates/runtime/src/green_contract.rs | 152 + crates/runtime/src/hooks.rs | 818 +- crates/runtime/src/json.rs | 3 +- crates/runtime/src/lane_events.rs | 383 + crates/runtime/src/lib.rs | 125 +- crates/runtime/src/lsp_client.rs | 747 ++ crates/runtime/src/mcp.rs | 8 +- crates/runtime/src/mcp_client.rs | 14 + crates/runtime/src/mcp_lifecycle_hardened.rs | 843 ++ crates/runtime/src/mcp_server.rs | 440 + crates/runtime/src/mcp_stdio.rs | 1533 ++- crates/runtime/src/mcp_tool_bridge.rs | 920 ++ crates/runtime/src/oauth.rs | 28 +- crates/runtime/src/permission_enforcer.rs | 551 + crates/runtime/src/permissions.rs | 503 +- crates/runtime/src/plugin_lifecycle.rs | 533 + crates/runtime/src/policy_engine.rs | 581 + crates/runtime/src/prompt.rs | 184 +- crates/runtime/src/recovery_recipes.rs | 631 + crates/runtime/src/remote.rs | 14 +- crates/runtime/src/sandbox.rs | 41 +- crates/runtime/src/session.rs | 1182 +- crates/runtime/src/session_control.rs | 873 ++ crates/runtime/src/sse.rs | 32 +- crates/runtime/src/stale_base.rs | 429 + crates/runtime/src/stale_branch.rs | 417 + crates/runtime/src/summary_compression.rs | 300 + crates/runtime/src/task_packet.rs | 158 + crates/runtime/src/task_registry.rs | 503 + crates/runtime/src/team_cron_registry.rs | 509 + crates/runtime/src/trust_resolver.rs | 299 + crates/runtime/src/usage.rs | 41 +- crates/runtime/src/worker_boot.rs | 1180 ++ crates/runtime/tests/integration_tests.rs | 386 + crates/rusty-claude-cli/Cargo.toml | 14 +- crates/rusty-claude-cli/build.rs | 38 + crates/rusty-claude-cli/src/app.rs | 398 - crates/rusty-claude-cli/src/args.rs | 102 - crates/rusty-claude-cli/src/init.rs | 53 +- crates/rusty-claude-cli/src/input.rs | 818 +- crates/rusty-claude-cli/src/main.rs | 10228 ++++++++++++++-- crates/rusty-claude-cli/src/render.rs | 549 +- .../tests/cli_flags_and_config_defaults.rs | 298 + .../rusty-claude-cli/tests/compact_output.rs | 159 + .../tests/mock_parity_harness.rs | 884 ++ .../tests/output_format_contract.rs | 429 + .../tests/resume_slash_commands.rs | 555 + crates/server/src/lib.rs | 28 +- crates/telemetry/Cargo.toml | 13 + crates/telemetry/src/lib.rs | 526 + crates/tools/Cargo.toml | 2 + crates/tools/src/lane_completion.rs | 181 + crates/tools/src/lib.rs | 4669 ++++++- crates/tools/src/pdf_extract.rs | 548 + frontend/src/components/ChatView.tsx | 10 +- 94 files changed, 49049 insertions(+), 4429 deletions(-) create mode 100644 crates/api/src/http_client.rs create mode 100644 crates/api/src/prompt_cache.rs rename crates/api/src/providers/{claw_provider.rs => anthropic.rs} (52%) create mode 100644 crates/api/tests/proxy_integration.rs create mode 100644 crates/mock-anthropic-service/Cargo.toml create mode 100644 crates/mock-anthropic-service/src/lib.rs create mode 100644 crates/mock-anthropic-service/src/main.rs rename crates/plugins/bundled/example-bundled/{.claw-plugin => .claude-plugin}/plugin.json (100%) rename crates/plugins/bundled/sample-hooks/{.claw-plugin => .claude-plugin}/plugin.json (100%) create mode 100644 crates/runtime/src/bash_validation.rs create mode 100644 crates/runtime/src/branch_lock.rs create mode 100644 crates/runtime/src/config_validate.rs create mode 100644 crates/runtime/src/git_context.rs create mode 100644 crates/runtime/src/green_contract.rs create mode 100644 crates/runtime/src/lane_events.rs create mode 100644 crates/runtime/src/lsp_client.rs create mode 100644 crates/runtime/src/mcp_lifecycle_hardened.rs create mode 100644 crates/runtime/src/mcp_server.rs create mode 100644 crates/runtime/src/mcp_tool_bridge.rs create mode 100644 crates/runtime/src/permission_enforcer.rs create mode 100644 crates/runtime/src/plugin_lifecycle.rs create mode 100644 crates/runtime/src/policy_engine.rs create mode 100644 crates/runtime/src/recovery_recipes.rs create mode 100644 crates/runtime/src/session_control.rs create mode 100644 crates/runtime/src/stale_base.rs create mode 100644 crates/runtime/src/stale_branch.rs create mode 100644 crates/runtime/src/summary_compression.rs create mode 100644 crates/runtime/src/task_packet.rs create mode 100644 crates/runtime/src/task_registry.rs create mode 100644 crates/runtime/src/team_cron_registry.rs create mode 100644 crates/runtime/src/trust_resolver.rs create mode 100644 crates/runtime/src/worker_boot.rs create mode 100644 crates/runtime/tests/integration_tests.rs create mode 100644 crates/rusty-claude-cli/build.rs delete mode 100644 crates/rusty-claude-cli/src/app.rs delete mode 100644 crates/rusty-claude-cli/src/args.rs create mode 100644 crates/rusty-claude-cli/tests/cli_flags_and_config_defaults.rs create mode 100644 crates/rusty-claude-cli/tests/compact_output.rs create mode 100644 crates/rusty-claude-cli/tests/mock_parity_harness.rs create mode 100644 crates/rusty-claude-cli/tests/output_format_contract.rs create mode 100644 crates/rusty-claude-cli/tests/resume_slash_commands.rs create mode 100644 crates/telemetry/Cargo.toml create mode 100644 crates/telemetry/src/lib.rs create mode 100644 crates/tools/src/lane_completion.rs create mode 100644 crates/tools/src/pdf_extract.rs diff --git a/Cargo.lock b/Cargo.lock index 8526ba0..1c94994 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -34,6 +34,7 @@ dependencies = [ "runtime", "serde", "serde_json", + "telemetry", "tokio", ] @@ -347,12 +348,24 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "endian-type" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "env_home" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" + [[package]] name = "equivalent" version = "1.0.2" @@ -945,6 +958,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "mock-anthropic-service" +version = "0.1.0" +dependencies = [ + "api", + "serde_json", + "tokio", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -1340,14 +1362,15 @@ name = "runtime" version = "0.1.0" dependencies = [ "glob", - "lsp", "plugins", "regex", "serde", "serde_json", "sha2", + "telemetry", "tokio", "walkdir", + "which", ] [[package]] @@ -1431,9 +1454,12 @@ dependencies = [ "commands", "compat-harness", "crossterm", + "mock-anthropic-service", "plugins", "pulldown-cmark", "runtime", + "rustyline", + "serde", "serde_json", "syntect", "tokio", @@ -1721,6 +1747,14 @@ dependencies = [ "yaml-rust", ] +[[package]] +name = "telemetry" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "thiserror" version = "2.0.18" @@ -1852,6 +1886,8 @@ name = "tools" version = "0.1.0" dependencies = [ "api", + "commands", + "flate2", "plugins", "reqwest", "runtime", @@ -2129,6 +2165,18 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "7.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" +dependencies = [ + "either", + "env_home", + "rustix 1.1.4", + "winsafe", +] + [[package]] name = "winapi" version = "0.3.9" @@ -2384,6 +2432,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "winsafe" +version = "0.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" + [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/Cargo.toml b/Cargo.toml index aa2f4ea..e8e249e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,35 @@ pedantic = { level = "warn", priority = -1 } module_name_repetitions = "allow" missing_panics_doc = "allow" missing_errors_doc = "allow" +uninlined_format_args = "allow" +map_unwrap_or = "allow" +doc_markdown = "allow" +redundant_pattern_matching = "allow" +items_after_statements = "allow" +must_use_candidate = "allow" +io_other_error = "allow" +implicit_clone = "allow" +redundant_closure_for_method_calls = "allow" +unnecessary_map_or = "allow" +manual_let_else = "allow" +format_push_string = "allow" +match_same_arms = "allow" +similar_names = "allow" +needless_continue = "allow" +assigning_clones = "allow" +cast_possible_truncation = "allow" +cast_possible_wrap = "allow" +cast_sign_loss = "allow" +cmp_owned = "allow" +collapsible_if = "allow" +too_many_lines = "allow" +wildcard_in_or_patterns = "allow" +explicit_counter_loop = "allow" +manual_repeat_n = "allow" +manual_str_repeat = "allow" +needless_borrow = "allow" +needless_pass_by_value = "allow" +single_match_else = "allow" +too_many_arguments = "allow" +unnested_or_patterns = "allow" +unused_self = "allow" diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index b9923a8..d2e009c 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -10,6 +10,7 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "rus runtime = { path = "../runtime" } serde = { version = "1", features = ["derive"] } serde_json.workspace = true +telemetry = { path = "../telemetry" } tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } [lints] diff --git a/crates/api/src/client.rs b/crates/api/src/client.rs index b596777..59ccdaf 100644 --- a/crates/api/src/client.rs +++ b/crates/api/src/client.rs @@ -1,70 +1,91 @@ use crate::error::ApiError; -use crate::providers::claw_provider::{self, AuthSource, ClawApiClient}; +use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; +use crate::providers::anthropic::{self, AnthropicClient, AuthSource}; use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; -use crate::providers::{self, Provider, ProviderKind}; +use crate::providers::{self, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; -async fn send_via_provider( - provider: &P, - request: &MessageRequest, -) -> Result { - provider.send_message(request).await -} - -async fn stream_via_provider( - provider: &P, - request: &MessageRequest, -) -> Result { - provider.stream_message(request).await -} - +#[allow(clippy::large_enum_variant)] #[derive(Debug, Clone)] pub enum ProviderClient { - ClawApi(ClawApiClient), + Anthropic(AnthropicClient), Xai(OpenAiCompatClient), OpenAi(OpenAiCompatClient), } impl ProviderClient { pub fn from_model(model: &str) -> Result { - Self::from_model_with_default_auth(model, None) + Self::from_model_with_anthropic_auth(model, None) } - pub fn from_model_with_default_auth( + pub fn from_model_with_anthropic_auth( model: &str, - default_auth: Option, + anthropic_auth: Option, ) -> Result { let resolved_model = providers::resolve_model_alias(model); match providers::detect_provider_kind(&resolved_model) { - ProviderKind::ClawApi => Ok(Self::ClawApi(match default_auth { - Some(auth) => ClawApiClient::from_auth(auth), - None => ClawApiClient::from_env()?, + ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth { + Some(auth) => AnthropicClient::from_auth(auth), + None => AnthropicClient::from_env()?, })), ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( OpenAiCompatConfig::xai(), )?)), - ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env( - OpenAiCompatConfig::openai(), - )?)), + ProviderKind::OpenAi => { + // DashScope models (qwen-*) also return ProviderKind::OpenAi because they + // speak the OpenAI wire format, but they need the DashScope config which + // reads DASHSCOPE_API_KEY and points at dashscope.aliyuncs.com. + let config = match providers::metadata_for_model(&resolved_model) { + Some(meta) if meta.auth_env == "DASHSCOPE_API_KEY" => { + OpenAiCompatConfig::dashscope() + } + _ => OpenAiCompatConfig::openai(), + }; + Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?)) + } } } #[must_use] pub const fn provider_kind(&self) -> ProviderKind { match self { - Self::ClawApi(_) => ProviderKind::ClawApi, + Self::Anthropic(_) => ProviderKind::Anthropic, Self::Xai(_) => ProviderKind::Xai, Self::OpenAi(_) => ProviderKind::OpenAi, } } + #[must_use] + pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self { + match self { + Self::Anthropic(client) => Self::Anthropic(client.with_prompt_cache(prompt_cache)), + other => other, + } + } + + #[must_use] + pub fn prompt_cache_stats(&self) -> Option { + match self { + Self::Anthropic(client) => client.prompt_cache_stats(), + Self::Xai(_) | Self::OpenAi(_) => None, + } + } + + #[must_use] + pub fn take_last_prompt_cache_record(&self) -> Option { + match self { + Self::Anthropic(client) => client.take_last_prompt_cache_record(), + Self::Xai(_) | Self::OpenAi(_) => None, + } + } + pub async fn send_message( &self, request: &MessageRequest, ) -> Result { match self { - Self::ClawApi(client) => send_via_provider(client, request).await, - Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await, + Self::Anthropic(client) => client.send_message(request).await, + Self::Xai(client) | Self::OpenAi(client) => client.send_message(request).await, } } @@ -73,10 +94,12 @@ impl ProviderClient { request: &MessageRequest, ) -> Result { match self { - Self::ClawApi(client) => stream_via_provider(client, request) + Self::Anthropic(client) => client + .stream_message(request) .await - .map(MessageStream::ClawApi), - Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request) + .map(MessageStream::Anthropic), + Self::Xai(client) | Self::OpenAi(client) => client + .stream_message(request) .await .map(MessageStream::OpenAiCompat), } @@ -85,7 +108,7 @@ impl ProviderClient { #[derive(Debug)] pub enum MessageStream { - ClawApi(claw_provider::MessageStream), + Anthropic(anthropic::MessageStream), OpenAiCompat(openai_compat::MessageStream), } @@ -93,25 +116,25 @@ impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { match self { - Self::ClawApi(stream) => stream.request_id(), + Self::Anthropic(stream) => stream.request_id(), Self::OpenAiCompat(stream) => stream.request_id(), } } pub async fn next_event(&mut self) -> Result, ApiError> { match self { - Self::ClawApi(stream) => stream.next_event().await, + Self::Anthropic(stream) => stream.next_event().await, Self::OpenAiCompat(stream) => stream.next_event().await, } } } -pub use claw_provider::{ +pub use anthropic::{ oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet, }; #[must_use] pub fn read_base_url() -> String { - claw_provider::read_base_url() + anthropic::read_base_url() } #[must_use] @@ -121,8 +144,21 @@ pub fn read_xai_base_url() -> String { #[cfg(test)] mod tests { + use std::sync::{Mutex, OnceLock}; + + use super::ProviderClient; use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind}; + /// Serializes every test in this module that mutates process-wide + /// environment variables so concurrent test threads cannot observe + /// each other's partially-applied state. + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + #[test] fn resolves_existing_and_grok_aliases() { assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); @@ -135,7 +171,71 @@ mod tests { assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai); assert_eq!( detect_provider_kind("claude-sonnet-4-6"), - ProviderKind::ClawApi + ProviderKind::Anthropic ); } + + /// Snapshot-restore guard for a single environment variable. Mirrors + /// the pattern used in `providers/mod.rs` tests: captures the original + /// value on construction, applies the override, and restores on drop so + /// tests leave the process env untouched even when they panic. + struct EnvVarGuard { + key: &'static str, + original: Option, + } + + impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + match self.original.take() { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } + } + + #[test] + fn dashscope_model_uses_dashscope_config_not_openai() { + // Regression: qwen-plus was being routed to OpenAiCompatConfig::openai() + // which reads OPENAI_API_KEY and points at api.openai.com, when it should + // use OpenAiCompatConfig::dashscope() which reads DASHSCOPE_API_KEY and + // points at dashscope.aliyuncs.com. + let _lock = env_lock(); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", Some("test-dashscope-key")); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", None); + + let client = ProviderClient::from_model("qwen-plus"); + + // Must succeed (not fail with "missing OPENAI_API_KEY") + assert!( + client.is_ok(), + "qwen-plus with DASHSCOPE_API_KEY set should build successfully, got: {:?}", + client.err() + ); + + // Verify it's the OpenAi variant pointed at the DashScope base URL. + match client.unwrap() { + ProviderClient::OpenAi(openai_client) => { + assert!( + openai_client.base_url().contains("dashscope.aliyuncs.com"), + "qwen-plus should route to DashScope base URL (contains 'dashscope.aliyuncs.com'), got: {}", + openai_client.base_url() + ); + } + other => panic!( + "Expected ProviderClient::OpenAi for qwen-plus, got: {:?}", + other + ), + } + } } diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index 7649889..3fa4995 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -2,22 +2,55 @@ use std::env::VarError; use std::fmt::{Display, Formatter}; use std::time::Duration; +const GENERIC_FATAL_WRAPPER_MARKERS: &[&str] = &[ + "something went wrong while processing your request", + "please try again, or use /new to start a fresh session", +]; + +const CONTEXT_WINDOW_ERROR_MARKERS: &[&str] = &[ + "maximum context length", + "context window", + "context length", + "too many tokens", + "prompt is too long", + "input is too long", + "request is too large", +]; + #[derive(Debug)] pub enum ApiError { MissingCredentials { provider: &'static str, env_vars: &'static [&'static str], + /// Optional, runtime-computed hint appended to the error Display + /// output. Populated when the provider resolver can infer what the + /// user probably intended (e.g. an OpenAI key is set but Anthropic + /// was selected because no Anthropic credentials exist). + hint: Option, + }, + ContextWindowExceeded { + model: String, + estimated_input_tokens: u32, + requested_output_tokens: u32, + estimated_total_tokens: u32, + context_window_tokens: u32, }, ExpiredOAuthToken, Auth(String), InvalidApiKeyEnv(VarError), Http(reqwest::Error), Io(std::io::Error), - Json(serde_json::Error), + Json { + provider: String, + model: String, + body_snippet: String, + source: serde_json::Error, + }, Api { status: reqwest::StatusCode, error_type: Option, message: Option, + request_id: Option, body: String, retryable: bool, }, @@ -38,7 +71,48 @@ impl ApiError { provider: &'static str, env_vars: &'static [&'static str], ) -> Self { - Self::MissingCredentials { provider, env_vars } + Self::MissingCredentials { + provider, + env_vars, + hint: None, + } + } + + /// Build a `MissingCredentials` error carrying an extra, runtime-computed + /// hint string that the Display impl appends after the canonical "missing + /// credentials" message. Used by the provider resolver to + /// suggest the likely fix when the user has credentials for a different + /// provider already in the environment. + #[must_use] + pub fn missing_credentials_with_hint( + provider: &'static str, + env_vars: &'static [&'static str], + hint: impl Into, + ) -> Self { + Self::MissingCredentials { + provider, + env_vars, + hint: Some(hint.into()), + } + } + + /// Build a `Self::Json` enriched with the provider name, the model that + /// was requested, and the first 200 characters of the raw response body so + /// that callers can diagnose deserialization failures without re-running + /// the request. + #[must_use] + pub fn json_deserialize( + provider: impl Into, + model: impl Into, + body: &str, + source: serde_json::Error, + ) -> Self { + Self::Json { + provider: provider.into(), + model: model.into(), + body_snippet: truncate_body_snippet(body, 200), + source, + } } #[must_use] @@ -48,11 +122,106 @@ impl ApiError { Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), Self::MissingCredentials { .. } + | Self::ContextWindowExceeded { .. } | Self::ExpiredOAuthToken | Self::Auth(_) | Self::InvalidApiKeyEnv(_) | Self::Io(_) - | Self::Json(_) + | Self::Json { .. } + | Self::InvalidSseFrame(_) + | Self::BackoffOverflow { .. } => false, + } + } + + #[must_use] + pub fn request_id(&self) -> Option<&str> { + match self { + Self::Api { request_id, .. } => request_id.as_deref(), + Self::RetriesExhausted { last_error, .. } => last_error.request_id(), + Self::MissingCredentials { .. } + | Self::ContextWindowExceeded { .. } + | Self::ExpiredOAuthToken + | Self::Auth(_) + | Self::InvalidApiKeyEnv(_) + | Self::Http(_) + | Self::Io(_) + | Self::Json { .. } + | Self::InvalidSseFrame(_) + | Self::BackoffOverflow { .. } => None, + } + } + + #[must_use] + pub fn safe_failure_class(&self) -> &'static str { + match self { + Self::RetriesExhausted { .. } if self.is_context_window_failure() => "context_window", + Self::RetriesExhausted { .. } if self.is_generic_fatal_wrapper() => { + "provider_retry_exhausted" + } + Self::RetriesExhausted { last_error, .. } => last_error.safe_failure_class(), + Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) => { + "provider_auth" + } + Self::Api { status, .. } if matches!(status.as_u16(), 401 | 403) => "provider_auth", + Self::ContextWindowExceeded { .. } => "context_window", + Self::Api { .. } if self.is_context_window_failure() => "context_window", + Self::Api { status, .. } if status.as_u16() == 429 => "provider_rate_limit", + Self::Api { .. } if self.is_generic_fatal_wrapper() => "provider_internal", + Self::Api { .. } => "provider_error", + Self::Http(_) | Self::InvalidSseFrame(_) | Self::BackoffOverflow { .. } => { + "provider_transport" + } + Self::InvalidApiKeyEnv(_) | Self::Io(_) | Self::Json { .. } => "runtime_io", + } + } + + #[must_use] + pub fn is_generic_fatal_wrapper(&self) -> bool { + match self { + Self::Api { message, body, .. } => { + message + .as_deref() + .is_some_and(looks_like_generic_fatal_wrapper) + || looks_like_generic_fatal_wrapper(body) + } + Self::RetriesExhausted { last_error, .. } => last_error.is_generic_fatal_wrapper(), + Self::MissingCredentials { .. } + | Self::ContextWindowExceeded { .. } + | Self::ExpiredOAuthToken + | Self::Auth(_) + | Self::InvalidApiKeyEnv(_) + | Self::Http(_) + | Self::Io(_) + | Self::Json { .. } + | Self::InvalidSseFrame(_) + | Self::BackoffOverflow { .. } => false, + } + } + + #[must_use] + pub fn is_context_window_failure(&self) -> bool { + match self { + Self::ContextWindowExceeded { .. } => true, + Self::Api { + status, + message, + body, + .. + } => { + matches!(status.as_u16(), 400 | 413 | 422) + && (message + .as_deref() + .is_some_and(looks_like_context_window_error) + || looks_like_context_window_error(body)) + } + Self::RetriesExhausted { last_error, .. } => last_error.is_context_window_failure(), + Self::MissingCredentials { .. } + | Self::ExpiredOAuthToken + | Self::Auth(_) + | Self::InvalidApiKeyEnv(_) + | Self::Http(_) + | Self::Io(_) + | Self::Json { .. } | Self::InvalidSseFrame(_) | Self::BackoffOverflow { .. } => false, } @@ -62,10 +231,43 @@ impl ApiError { impl Display for ApiError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::MissingCredentials { provider, env_vars } => write!( + Self::MissingCredentials { + provider, + env_vars, + hint, + } => { + write!( + f, + "missing {provider} credentials; export {} before calling the {provider} API", + env_vars.join(" or ") + )?; + if cfg!(target_os = "windows") { + if let Some(primary) = env_vars.first() { + write!( + f, + " (on Windows, environment variables set in PowerShell only persist for the current session; use `setx {primary} ` to make it permanent, then open a new terminal, or place a `.env` file containing `{primary}=` in the current working directory)" + )?; + } else { + write!( + f, + " (on Windows, environment variables set in PowerShell only persist for the current session; use `setx` to make them permanent, then open a new terminal, or place a `.env` file in the current working directory)" + )?; + } + } + if let Some(hint) = hint { + write!(f, " — hint: {hint}")?; + } + Ok(()) + } + Self::ContextWindowExceeded { + model, + estimated_input_tokens, + requested_output_tokens, + estimated_total_tokens, + context_window_tokens, + } => write!( f, - "missing {provider} credentials; export {} before calling the {provider} API", - env_vars.join(" or ") + "context_window_blocked for {model}: estimated input {estimated_input_tokens} + requested output {requested_output_tokens} = {estimated_total_tokens} tokens exceeds the {context_window_tokens}-token context window; compact the session or reduce request size before retrying" ), Self::ExpiredOAuthToken => { write!( @@ -79,19 +281,37 @@ impl Display for ApiError { } Self::Http(error) => write!(f, "http error: {error}"), Self::Io(error) => write!(f, "io error: {error}"), - Self::Json(error) => write!(f, "json error: {error}"), + Self::Json { + provider, + model, + body_snippet, + source, + } => write!( + f, + "failed to parse {provider} response for model {model}: {source}; first 200 chars of body: {body_snippet}" + ), Self::Api { status, error_type, message, + request_id, body, .. - } => match (error_type, message) { - (Some(error_type), Some(message)) => { - write!(f, "api returned {status} ({error_type}): {message}") + } => { + if let (Some(error_type), Some(message)) = (error_type, message) { + write!(f, "api returned {status} ({error_type})")?; + if let Some(request_id) = request_id { + write!(f, " [trace {request_id}]")?; + } + write!(f, ": {message}") + } else { + write!(f, "api returned {status}")?; + if let Some(request_id) = request_id { + write!(f, " [trace {request_id}]")?; + } + write!(f, ": {body}") } - _ => write!(f, "api returned {status}: {body}"), - }, + } Self::RetriesExhausted { attempts, last_error, @@ -124,7 +344,12 @@ impl From for ApiError { impl From for ApiError { fn from(value: serde_json::Error) -> Self { - Self::Json(value) + Self::Json { + provider: "unknown".to_string(), + model: "unknown".to_string(), + body_snippet: String::new(), + source: value, + } } } @@ -133,3 +358,215 @@ impl From for ApiError { Self::InvalidApiKeyEnv(value) } } + +fn looks_like_generic_fatal_wrapper(text: &str) -> bool { + let lowered = text.to_ascii_lowercase(); + GENERIC_FATAL_WRAPPER_MARKERS + .iter() + .any(|marker| lowered.contains(marker)) +} + +fn looks_like_context_window_error(text: &str) -> bool { + let lowered = text.to_ascii_lowercase(); + CONTEXT_WINDOW_ERROR_MARKERS + .iter() + .any(|marker| lowered.contains(marker)) +} + +/// Truncate `body` so the resulting snippet contains at most `max_chars` +/// characters (counted by Unicode scalar values, not bytes), preserving the +/// leading slice of the body that the caller most often needs to inspect. +fn truncate_body_snippet(body: &str, max_chars: usize) -> String { + let mut taken_chars = 0; + let mut byte_end = 0; + for (offset, character) in body.char_indices() { + if taken_chars >= max_chars { + break; + } + taken_chars += 1; + byte_end = offset + character.len_utf8(); + } + if taken_chars >= max_chars && byte_end < body.len() { + format!("{}…", &body[..byte_end]) + } else { + body[..byte_end].to_string() + } +} + +#[cfg(test)] +mod tests { + use super::{truncate_body_snippet, ApiError}; + + #[test] + fn json_deserialize_error_includes_provider_model_and_truncated_body_snippet() { + let raw_body = format!("{}{}", "x".repeat(190), "_TAIL_PAST_200_CHARS_MARKER_"); + let source = serde_json::from_str::("{not json") + .expect_err("invalid json should fail to parse"); + + let error = ApiError::json_deserialize("Anthropic", "claude-opus-4-6", &raw_body, source); + let rendered = error.to_string(); + + assert!( + rendered.starts_with("failed to parse Anthropic response for model claude-opus-4-6: "), + "rendered error should lead with provider and model: {rendered}" + ); + assert!( + rendered.contains("first 200 chars of body: "), + "rendered error should label the body snippet: {rendered}" + ); + let snippet = rendered + .split("first 200 chars of body: ") + .nth(1) + .expect("snippet section should be present"); + assert!( + snippet.starts_with(&"x".repeat(190)), + "snippet should preserve the leading characters of the body: {snippet}" + ); + assert!( + snippet.ends_with('…'), + "snippet should signal truncation with an ellipsis: {snippet}" + ); + assert!( + !snippet.contains("_TAIL_PAST_200_CHARS_MARKER_"), + "snippet should drop characters past the 200-char cap: {snippet}" + ); + assert_eq!(error.safe_failure_class(), "runtime_io"); + assert_eq!(error.request_id(), None); + assert!(!error.is_retryable()); + } + + #[test] + fn truncate_body_snippet_keeps_short_bodies_intact() { + assert_eq!(truncate_body_snippet("hello", 200), "hello"); + assert_eq!(truncate_body_snippet("", 200), ""); + } + + #[test] + fn truncate_body_snippet_caps_long_bodies_at_max_chars() { + let body = "a".repeat(250); + let snippet = truncate_body_snippet(&body, 200); + assert_eq!(snippet.chars().count(), 201, "200 chars + ellipsis"); + assert!(snippet.ends_with('…')); + assert!(snippet.starts_with(&"a".repeat(200))); + } + + #[test] + fn truncate_body_snippet_does_not_split_multibyte_characters() { + let body = "한글한글한글한글한글한글"; + let snippet = truncate_body_snippet(body, 4); + assert_eq!(snippet, "한글한글…"); + } + + #[test] + fn detects_generic_fatal_wrapper_and_classifies_it_as_provider_internal() { + let error = ApiError::Api { + status: reqwest::StatusCode::INTERNAL_SERVER_ERROR, + error_type: Some("api_error".to_string()), + message: Some( + "Something went wrong while processing your request. Please try again, or use /new to start a fresh session." + .to_string(), + ), + request_id: Some("req_jobdori_123".to_string()), + body: String::new(), + retryable: true, + }; + + assert!(error.is_generic_fatal_wrapper()); + assert_eq!(error.safe_failure_class(), "provider_internal"); + assert_eq!(error.request_id(), Some("req_jobdori_123")); + assert!(error.to_string().contains("[trace req_jobdori_123]")); + } + + #[test] + fn retries_exhausted_preserves_nested_request_id_and_failure_class() { + let error = ApiError::RetriesExhausted { + attempts: 3, + last_error: Box::new(ApiError::Api { + status: reqwest::StatusCode::BAD_GATEWAY, + error_type: Some("api_error".to_string()), + message: Some( + "Something went wrong while processing your request. Please try again, or use /new to start a fresh session." + .to_string(), + ), + request_id: Some("req_nested_456".to_string()), + body: String::new(), + retryable: true, + }), + }; + + assert!(error.is_generic_fatal_wrapper()); + assert_eq!(error.safe_failure_class(), "provider_retry_exhausted"); + assert_eq!(error.request_id(), Some("req_nested_456")); + } + + #[test] + fn classifies_provider_context_window_errors() { + let error = ApiError::Api { + status: reqwest::StatusCode::BAD_REQUEST, + error_type: Some("invalid_request_error".to_string()), + message: Some( + "This model's maximum context length is 200000 tokens, but your request used 230000 tokens." + .to_string(), + ), + request_id: Some("req_ctx_123".to_string()), + body: String::new(), + retryable: false, + }; + + assert!(error.is_context_window_failure()); + assert_eq!(error.safe_failure_class(), "context_window"); + assert_eq!(error.request_id(), Some("req_ctx_123")); + } + + #[test] + fn missing_credentials_without_hint_renders_the_canonical_message() { + // given + let error = ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + ); + + // when + let rendered = error.to_string(); + + // then + assert!( + rendered.starts_with( + "missing Anthropic credentials; export ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY before calling the Anthropic API" + ), + "rendered error should lead with the canonical missing-credential message: {rendered}" + ); + assert!( + !rendered.contains(" — hint: "), + "no hint should be appended when none is supplied: {rendered}" + ); + } + + #[test] + fn missing_credentials_with_hint_appends_the_hint_after_base_message() { + // given + let error = ApiError::missing_credentials_with_hint( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + "I see OPENAI_API_KEY is set — if you meant to use the OpenAI-compat provider, prefix your model name with `openai/` so prefix routing selects it.", + ); + + // when + let rendered = error.to_string(); + + // then + assert!( + rendered.starts_with("missing Anthropic credentials;"), + "hint should be appended, not replace the base message: {rendered}" + ); + let hint_marker = " — hint: I see OPENAI_API_KEY is set — if you meant to use the OpenAI-compat provider, prefix your model name with `openai/` so prefix routing selects it."; + assert!( + rendered.ends_with(hint_marker), + "rendered error should end with the hint: {rendered}" + ); + // Classification semantics are unaffected by the presence of a hint. + assert_eq!(error.safe_failure_class(), "provider_auth"); + assert!(!error.is_retryable()); + assert_eq!(error.request_id(), None); + } +} diff --git a/crates/api/src/http_client.rs b/crates/api/src/http_client.rs new file mode 100644 index 0000000..508a577 --- /dev/null +++ b/crates/api/src/http_client.rs @@ -0,0 +1,344 @@ +use crate::error::ApiError; + +const HTTP_PROXY_KEYS: [&str; 2] = ["HTTP_PROXY", "http_proxy"]; +const HTTPS_PROXY_KEYS: [&str; 2] = ["HTTPS_PROXY", "https_proxy"]; +const NO_PROXY_KEYS: [&str; 2] = ["NO_PROXY", "no_proxy"]; + +/// Snapshot of the proxy-related environment variables that influence the +/// outbound HTTP client. Captured up front so callers can inspect, log, and +/// test the resolved configuration without re-reading the process environment. +/// +/// When `proxy_url` is set it acts as a single catch-all proxy for both +/// HTTP and HTTPS traffic, taking precedence over the per-scheme fields. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ProxyConfig { + pub http_proxy: Option, + pub https_proxy: Option, + pub no_proxy: Option, + /// Optional unified proxy URL that applies to both HTTP and HTTPS. + /// When set, this takes precedence over `http_proxy` and `https_proxy`. + pub proxy_url: Option, +} + +impl ProxyConfig { + /// Read proxy settings from the live process environment, honouring both + /// the upper- and lower-case spellings used by curl, git, and friends. + #[must_use] + pub fn from_env() -> Self { + Self::from_lookup(|key| std::env::var(key).ok()) + } + + /// Create a proxy configuration from a single URL that applies to both + /// HTTP and HTTPS traffic. This is the config-file alternative to setting + /// `HTTP_PROXY` and `HTTPS_PROXY` environment variables separately. + #[must_use] + pub fn from_proxy_url(url: impl Into) -> Self { + Self { + proxy_url: Some(url.into()), + ..Self::default() + } + } + + fn from_lookup(mut lookup: F) -> Self + where + F: FnMut(&str) -> Option, + { + Self { + http_proxy: first_non_empty(&HTTP_PROXY_KEYS, &mut lookup), + https_proxy: first_non_empty(&HTTPS_PROXY_KEYS, &mut lookup), + no_proxy: first_non_empty(&NO_PROXY_KEYS, &mut lookup), + proxy_url: None, + } + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.proxy_url.is_none() && self.http_proxy.is_none() && self.https_proxy.is_none() + } +} + +/// Build a `reqwest::Client` that honours the standard `HTTP_PROXY`, +/// `HTTPS_PROXY`, and `NO_PROXY` environment variables. When no proxy is +/// configured the client behaves identically to `reqwest::Client::new()`. +pub fn build_http_client() -> Result { + build_http_client_with(&ProxyConfig::from_env()) +} + +/// Infallible counterpart to [`build_http_client`] for constructors that +/// historically returned `Self` rather than `Result`. When the proxy +/// configuration is malformed we fall back to a default client so that +/// callers retain the previous behaviour and the failure surfaces on the +/// first outbound request instead of at construction time. +#[must_use] +pub fn build_http_client_or_default() -> reqwest::Client { + build_http_client().unwrap_or_else(|_| reqwest::Client::new()) +} + +/// Build a `reqwest::Client` from an explicit [`ProxyConfig`]. Used by tests +/// and by callers that want to override process-level environment lookups. +/// +/// When `config.proxy_url` is set it overrides the per-scheme `http_proxy` +/// and `https_proxy` fields and is registered as both an HTTP and HTTPS +/// proxy so a single value can route every outbound request. +pub fn build_http_client_with(config: &ProxyConfig) -> Result { + let mut builder = reqwest::Client::builder().no_proxy(); + + let no_proxy = config + .no_proxy + .as_deref() + .and_then(reqwest::NoProxy::from_string); + + let (http_proxy_url, https_proxy_url) = match config.proxy_url.as_deref() { + Some(unified) => (Some(unified), Some(unified)), + None => (config.http_proxy.as_deref(), config.https_proxy.as_deref()), + }; + + if let Some(url) = https_proxy_url { + let mut proxy = reqwest::Proxy::https(url)?; + if let Some(filter) = no_proxy.clone() { + proxy = proxy.no_proxy(Some(filter)); + } + builder = builder.proxy(proxy); + } + + if let Some(url) = http_proxy_url { + let mut proxy = reqwest::Proxy::http(url)?; + if let Some(filter) = no_proxy.clone() { + proxy = proxy.no_proxy(Some(filter)); + } + builder = builder.proxy(proxy); + } + + Ok(builder.build()?) +} + +fn first_non_empty(keys: &[&str], lookup: &mut F) -> Option +where + F: FnMut(&str) -> Option, +{ + keys.iter() + .find_map(|key| lookup(key).filter(|value| !value.is_empty())) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::{build_http_client_with, ProxyConfig}; + + fn config_from_map(pairs: &[(&str, &str)]) -> ProxyConfig { + let map: HashMap = pairs + .iter() + .map(|(key, value)| ((*key).to_string(), (*value).to_string())) + .collect(); + ProxyConfig::from_lookup(|key| map.get(key).cloned()) + } + + #[test] + fn proxy_config_is_empty_when_no_env_vars_are_set() { + // given + let config = config_from_map(&[]); + + // when + let empty = config.is_empty(); + + // then + assert!(empty); + assert_eq!(config, ProxyConfig::default()); + } + + #[test] + fn proxy_config_reads_uppercase_http_https_and_no_proxy() { + // given + let pairs = [ + ("HTTP_PROXY", "http://proxy.internal:3128"), + ("HTTPS_PROXY", "http://secure.internal:3129"), + ("NO_PROXY", "localhost,127.0.0.1,.corp"), + ]; + + // when + let config = config_from_map(&pairs); + + // then + assert_eq!( + config.http_proxy.as_deref(), + Some("http://proxy.internal:3128") + ); + assert_eq!( + config.https_proxy.as_deref(), + Some("http://secure.internal:3129") + ); + assert_eq!( + config.no_proxy.as_deref(), + Some("localhost,127.0.0.1,.corp") + ); + assert!(!config.is_empty()); + } + + #[test] + fn proxy_config_falls_back_to_lowercase_keys() { + // given + let pairs = [ + ("http_proxy", "http://lower.internal:3128"), + ("https_proxy", "http://lower-secure.internal:3129"), + ("no_proxy", ".lower"), + ]; + + // when + let config = config_from_map(&pairs); + + // then + assert_eq!( + config.http_proxy.as_deref(), + Some("http://lower.internal:3128") + ); + assert_eq!( + config.https_proxy.as_deref(), + Some("http://lower-secure.internal:3129") + ); + assert_eq!(config.no_proxy.as_deref(), Some(".lower")); + } + + #[test] + fn proxy_config_prefers_uppercase_over_lowercase_when_both_set() { + // given + let pairs = [ + ("HTTP_PROXY", "http://upper.internal:3128"), + ("http_proxy", "http://lower.internal:3128"), + ]; + + // when + let config = config_from_map(&pairs); + + // then + assert_eq!( + config.http_proxy.as_deref(), + Some("http://upper.internal:3128") + ); + } + + #[test] + fn proxy_config_treats_empty_strings_as_unset() { + // given + let pairs = [("HTTP_PROXY", ""), ("http_proxy", "")]; + + // when + let config = config_from_map(&pairs); + + // then + assert!(config.http_proxy.is_none()); + } + + #[test] + fn build_http_client_succeeds_when_no_proxy_is_configured() { + // given + let config = ProxyConfig::default(); + + // when + let result = build_http_client_with(&config); + + // then + assert!(result.is_ok()); + } + + #[test] + fn build_http_client_succeeds_with_valid_http_and_https_proxies() { + // given + let config = ProxyConfig { + http_proxy: Some("http://proxy.internal:3128".to_string()), + https_proxy: Some("http://secure.internal:3129".to_string()), + no_proxy: Some("localhost,127.0.0.1".to_string()), + proxy_url: None, + }; + + // when + let result = build_http_client_with(&config); + + // then + assert!(result.is_ok()); + } + + #[test] + fn build_http_client_returns_http_error_for_invalid_proxy_url() { + // given + let config = ProxyConfig { + http_proxy: None, + https_proxy: Some("not a url".to_string()), + no_proxy: None, + proxy_url: None, + }; + + // when + let result = build_http_client_with(&config); + + // then + let error = result.expect_err("invalid proxy URL must be reported as a build failure"); + assert!( + matches!(error, crate::error::ApiError::Http(_)), + "expected ApiError::Http for invalid proxy URL, got: {error:?}" + ); + } + + #[test] + fn from_proxy_url_sets_unified_field_and_leaves_per_scheme_empty() { + // given / when + let config = ProxyConfig::from_proxy_url("http://unified.internal:3128"); + + // then + assert_eq!( + config.proxy_url.as_deref(), + Some("http://unified.internal:3128") + ); + assert!(config.http_proxy.is_none()); + assert!(config.https_proxy.is_none()); + assert!(!config.is_empty()); + } + + #[test] + fn build_http_client_succeeds_with_unified_proxy_url() { + // given + let config = ProxyConfig { + proxy_url: Some("http://unified.internal:3128".to_string()), + no_proxy: Some("localhost".to_string()), + ..ProxyConfig::default() + }; + + // when + let result = build_http_client_with(&config); + + // then + assert!(result.is_ok()); + } + + #[test] + fn proxy_url_takes_precedence_over_per_scheme_fields() { + // given – both per-scheme and unified are set + let config = ProxyConfig { + http_proxy: Some("http://per-scheme.internal:1111".to_string()), + https_proxy: Some("http://per-scheme.internal:2222".to_string()), + no_proxy: None, + proxy_url: Some("http://unified.internal:3128".to_string()), + }; + + // when – building succeeds (the unified URL is valid) + let result = build_http_client_with(&config); + + // then + assert!(result.is_ok()); + } + + #[test] + fn build_http_client_returns_error_for_invalid_unified_proxy_url() { + // given + let config = ProxyConfig::from_proxy_url("not a url"); + + // when + let result = build_http_client_with(&config); + + // then + assert!( + matches!(result, Err(crate::error::ApiError::Http(_))), + "invalid unified proxy URL should fail: {result:?}" + ); + } +} diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 3306f53..bcf3e1b 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -1,5 +1,7 @@ mod client; mod error; +mod http_client; +mod prompt_cache; mod providers; mod sse; mod types; @@ -9,10 +11,18 @@ pub use client::{ resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient, }; pub use error::ApiError; -pub use providers::claw_provider::{AuthSource, ClawApiClient, ClawApiClient as ApiClient}; +pub use http_client::{ + build_http_client, build_http_client_or_default, build_http_client_with, ProxyConfig, +}; +pub use prompt_cache::{ + CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord, + PromptCacheStats, +}; +pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource}; pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; pub use providers::{ - detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind, + detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override, + resolve_model_alias, ProviderKind, }; pub use sse::{parse_frame, SseParser}; pub use types::{ @@ -21,3 +31,9 @@ pub use types::{ MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; + +pub use telemetry::{ + AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, JsonlTelemetrySink, + MemoryTelemetrySink, SessionTraceRecord, SessionTracer, TelemetryEvent, TelemetrySink, + DEFAULT_ANTHROPIC_VERSION, +}; diff --git a/crates/api/src/prompt_cache.rs b/crates/api/src/prompt_cache.rs new file mode 100644 index 0000000..0ee8663 --- /dev/null +++ b/crates/api/src/prompt_cache.rs @@ -0,0 +1,735 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::types::{MessageRequest, MessageResponse, Usage}; + +const DEFAULT_COMPLETION_TTL_SECS: u64 = 30; +const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60; +const DEFAULT_BREAK_MIN_DROP: u32 = 2_000; +const MAX_SANITIZED_LENGTH: usize = 80; +const REQUEST_FINGERPRINT_VERSION: u32 = 1; +const REQUEST_FINGERPRINT_PREFIX: &str = "v1"; +const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325; +const FNV_PRIME: u64 = 0x0000_0100_0000_01b3; + +#[derive(Debug, Clone)] +pub struct PromptCacheConfig { + pub session_id: String, + pub completion_ttl: Duration, + pub prompt_ttl: Duration, + pub cache_break_min_drop: u32, +} + +impl PromptCacheConfig { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self { + session_id: session_id.into(), + completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS), + prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS), + cache_break_min_drop: DEFAULT_BREAK_MIN_DROP, + } + } +} + +impl Default for PromptCacheConfig { + fn default() -> Self { + Self::new("default") + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCachePaths { + pub root: PathBuf, + pub session_dir: PathBuf, + pub completion_dir: PathBuf, + pub session_state_path: PathBuf, + pub stats_path: PathBuf, +} + +impl PromptCachePaths { + #[must_use] + pub fn for_session(session_id: &str) -> Self { + let root = base_cache_root(); + let session_dir = root.join(sanitize_path_segment(session_id)); + let completion_dir = session_dir.join("completions"); + Self { + root, + session_state_path: session_dir.join("session-state.json"), + stats_path: session_dir.join("stats.json"), + session_dir, + completion_dir, + } + } + + #[must_use] + pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf { + self.completion_dir.join(format!("{request_hash}.json")) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCacheStats { + pub tracked_requests: u64, + pub completion_cache_hits: u64, + pub completion_cache_misses: u64, + pub completion_cache_writes: u64, + pub expected_invalidations: u64, + pub unexpected_cache_breaks: u64, + pub total_cache_creation_input_tokens: u64, + pub total_cache_read_input_tokens: u64, + pub last_cache_creation_input_tokens: Option, + pub last_cache_read_input_tokens: Option, + pub last_request_hash: Option, + pub last_completion_cache_key: Option, + pub last_break_reason: Option, + pub last_cache_source: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CacheBreakEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheRecord { + pub cache_break: Option, + pub stats: PromptCacheStats, +} + +#[derive(Debug, Clone)] +pub struct PromptCache { + inner: Arc>, +} + +impl PromptCache { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self::with_config(PromptCacheConfig::new(session_id)) + } + + #[must_use] + pub fn with_config(config: PromptCacheConfig) -> Self { + let paths = PromptCachePaths::for_session(&config.session_id); + let stats = read_json::(&paths.stats_path).unwrap_or_default(); + let previous = read_json::(&paths.session_state_path); + Self { + inner: Arc::new(Mutex::new(PromptCacheInner { + config, + paths, + stats, + previous, + })), + } + } + + #[must_use] + pub fn paths(&self) -> PromptCachePaths { + self.lock().paths.clone() + } + + #[must_use] + pub fn stats(&self) -> PromptCacheStats { + self.lock().stats.clone() + } + + #[must_use] + pub fn lookup_completion(&self, request: &MessageRequest) -> Option { + let request_hash = request_hash_hex(request); + let (paths, ttl) = { + let inner = self.lock(); + (inner.paths.clone(), inner.config.completion_ttl) + }; + let entry_path = paths.completion_entry_path(&request_hash); + let entry = read_json::(&entry_path); + let Some(entry) = entry else { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash); + persist_state(&inner); + return None; + }; + + if entry.fingerprint_version != current_fingerprint_version() { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs(); + let mut inner = self.lock(); + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + if expired { + inner.stats.completion_cache_misses += 1; + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + inner.stats.completion_cache_hits += 1; + apply_usage_to_stats( + &mut inner.stats, + &entry.response.usage, + &request_hash, + "completion-cache", + ); + inner.previous = Some(TrackedPromptState::from_usage( + request, + &entry.response.usage, + )); + persist_state(&inner); + Some(entry.response) + } + + #[must_use] + pub fn record_response( + &self, + request: &MessageRequest, + response: &MessageResponse, + ) -> PromptCacheRecord { + self.record_usage_internal(request, &response.usage, Some(response)) + } + + #[must_use] + pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { + self.record_usage_internal(request, usage, None) + } + + fn record_usage_internal( + &self, + request: &MessageRequest, + usage: &Usage, + response: Option<&MessageResponse>, + ) -> PromptCacheRecord { + let request_hash = request_hash_hex(request); + let mut inner = self.lock(); + let previous = inner.previous.clone(); + let current = TrackedPromptState::from_usage(request, usage); + let cache_break = detect_cache_break(&inner.config, previous.as_ref(), ¤t); + + inner.stats.tracked_requests += 1; + apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response"); + if let Some(event) = &cache_break { + if event.unexpected { + inner.stats.unexpected_cache_breaks += 1; + } else { + inner.stats.expected_invalidations += 1; + } + inner.stats.last_break_reason = Some(event.reason.clone()); + } + + inner.previous = Some(current); + if let Some(response) = response { + write_completion_entry(&inner.paths, &request_hash, response); + inner.stats.completion_cache_writes += 1; + } + persist_state(&inner); + + PromptCacheRecord { + cache_break, + stats: inner.stats.clone(), + } + } + + fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> { + self.inner + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } +} + +#[derive(Debug)] +struct PromptCacheInner { + config: PromptCacheConfig, + paths: PromptCachePaths, + stats: PromptCacheStats, + previous: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CompletionCacheEntry { + cached_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + response: MessageResponse, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct TrackedPromptState { + observed_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + model_hash: u64, + system_hash: u64, + tools_hash: u64, + messages_hash: u64, + cache_read_input_tokens: u32, +} + +impl TrackedPromptState { + fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { + let hashes = RequestFingerprints::from_request(request); + Self { + observed_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + model_hash: hashes.model, + system_hash: hashes.system, + tools_hash: hashes.tools, + messages_hash: hashes.messages, + cache_read_input_tokens: usage.cache_read_input_tokens, + } + } +} + +#[derive(Debug, Clone, Copy)] +struct RequestFingerprints { + model: u64, + system: u64, + tools: u64, + messages: u64, +} + +impl RequestFingerprints { + fn from_request(request: &MessageRequest) -> Self { + Self { + model: hash_serializable(&request.model), + system: hash_serializable(&request.system), + tools: hash_serializable(&request.tools), + messages: hash_serializable(&request.messages), + } + } +} + +fn detect_cache_break( + config: &PromptCacheConfig, + previous: Option<&TrackedPromptState>, + current: &TrackedPromptState, +) -> Option { + let previous = previous?; + if previous.fingerprint_version != current.fingerprint_version { + return Some(CacheBreakEvent { + unexpected: false, + reason: format!( + "fingerprint version changed (v{} -> v{})", + previous.fingerprint_version, current.fingerprint_version + ), + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop: previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens), + }); + } + let token_drop = previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens); + if token_drop < config.cache_break_min_drop { + return None; + } + + let mut reasons = Vec::new(); + if previous.model_hash != current.model_hash { + reasons.push("model changed"); + } + if previous.system_hash != current.system_hash { + reasons.push("system prompt changed"); + } + if previous.tools_hash != current.tools_hash { + reasons.push("tool definitions changed"); + } + if previous.messages_hash != current.messages_hash { + reasons.push("message payload changed"); + } + + let elapsed = current + .observed_at_unix_secs + .saturating_sub(previous.observed_at_unix_secs); + + let (unexpected, reason) = if reasons.is_empty() { + if elapsed > config.prompt_ttl.as_secs() { + ( + false, + format!("possible prompt cache TTL expiry after {elapsed}s"), + ) + } else { + ( + true, + "cache read tokens dropped while prompt fingerprint remained stable".to_string(), + ) + } + } else { + (false, reasons.join(", ")) + }; + + Some(CacheBreakEvent { + unexpected, + reason, + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop, + }) +} + +fn apply_usage_to_stats( + stats: &mut PromptCacheStats, + usage: &Usage, + request_hash: &str, + source: &str, +) { + stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens); + stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens); + stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens); + stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens); + stats.last_request_hash = Some(request_hash.to_string()); + stats.last_cache_source = Some(source.to_string()); +} + +fn persist_state(inner: &PromptCacheInner) { + let _ = ensure_cache_dirs(&inner.paths); + let _ = write_json(&inner.paths.stats_path, &inner.stats); + if let Some(previous) = &inner.previous { + let _ = write_json(&inner.paths.session_state_path, previous); + } +} + +fn write_completion_entry( + paths: &PromptCachePaths, + request_hash: &str, + response: &MessageResponse, +) { + let _ = ensure_cache_dirs(paths); + let entry = CompletionCacheEntry { + cached_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + response: response.clone(), + }; + let _ = write_json(&paths.completion_entry_path(request_hash), &entry); +} + +fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> { + fs::create_dir_all(&paths.completion_dir) +} + +fn write_json(path: &Path, value: &T) -> std::io::Result<()> { + let json = serde_json::to_vec_pretty(value) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + fs::write(path, json) +} + +fn read_json Deserialize<'de>>(path: &Path) -> Option { + let bytes = fs::read(path).ok()?; + serde_json::from_slice(&bytes).ok() +} + +fn request_hash_hex(request: &MessageRequest) -> String { + format!( + "{REQUEST_FINGERPRINT_PREFIX}-{:016x}", + hash_serializable(request) + ) +} + +fn hash_serializable(value: &T) -> u64 { + let json = serde_json::to_vec(value).unwrap_or_default(); + stable_hash_bytes(&json) +} + +fn sanitize_path_segment(value: &str) -> String { + let sanitized: String = value + .chars() + .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' }) + .collect(); + if sanitized.len() <= MAX_SANITIZED_LENGTH { + return sanitized; + } + let suffix = format!("-{:x}", hash_string(value)); + format!( + "{}{}", + &sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())], + suffix + ) +} + +fn hash_string(value: &str) -> u64 { + stable_hash_bytes(value.as_bytes()) +} + +fn base_cache_root() -> PathBuf { + if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") { + return PathBuf::from(config_home) + .join("cache") + .join("prompt-cache"); + } + if let Some(home) = std::env::var_os("HOME") { + return PathBuf::from(home) + .join(".claude") + .join("cache") + .join("prompt-cache"); + } + std::env::temp_dir().join("claude-prompt-cache") +} + +fn now_unix_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +const fn current_fingerprint_version() -> u32 { + REQUEST_FINGERPRINT_VERSION +} + +fn stable_hash_bytes(bytes: &[u8]) -> u64 { + let mut hash = FNV_OFFSET_BASIS; + for byte in bytes { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + hash +} + +#[cfg(test)] +mod tests { + use std::sync::{Mutex, OnceLock}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use super::{ + detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, + PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, + }; + use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; + + fn test_env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + #[test] + fn path_builder_sanitizes_session_identifier() { + let paths = PromptCachePaths::for_session("session:/with spaces"); + let session_dir = paths + .session_dir + .file_name() + .and_then(|value| value.to_str()) + .expect("session dir name"); + assert_eq!(session_dir, "session--with-spaces"); + assert!(paths.completion_dir.ends_with("completions")); + assert!(paths.stats_path.ends_with("stats.json")); + assert!(paths.session_state_path.ends_with("session-state.json")); + } + + #[test] + fn request_fingerprint_drives_unexpected_break_detection() { + let request = sample_request("same"); + let previous = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(event.unexpected); + assert!(event.reason.contains("stable")); + } + + #[test] + fn changed_prompt_marks_break_as_expected() { + let previous_request = sample_request("first"); + let current_request = sample_request("second"); + let previous = TrackedPromptState::from_usage( + &previous_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + ¤t_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(!event.unexpected); + assert!(event.reason.contains("message payload changed")); + } + + #[test] + fn completion_cache_round_trip_persists_recent_response() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("unit-test-session"); + let request = sample_request("cache me"); + let response = sample_response(42, 12, "cached"); + + assert!(cache.lookup_completion(&request).is_none()); + let record = cache.record_response(&request, &response); + assert!(record.cache_break.is_none()); + + let cached = cache + .lookup_completion(&request) + .expect("cached response should load"); + assert_eq!(cached.content, response.content); + + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 1); + assert_eq!(stats.completion_cache_misses, 1); + assert_eq!(stats.completion_cache_writes, 1); + + let persisted = read_json::(&cache.paths().stats_path) + .expect("stats should persist"); + assert_eq!(persisted.completion_cache_hits, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn distinct_requests_do_not_collide_in_completion_cache() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-distinct-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("distinct-request-session"); + let first_request = sample_request("first"); + let second_request = sample_request("second"); + + let response = sample_response(42, 12, "cached"); + let _ = cache.record_response(&first_request, &response); + + assert!(cache.lookup_completion(&second_request).is_none()); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn expired_completion_entries_are_not_reused() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-expired-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::with_config(PromptCacheConfig { + session_id: "expired-session".to_string(), + completion_ttl: Duration::ZERO, + ..PromptCacheConfig::default() + }); + let request = sample_request("expire me"); + let response = sample_response(7, 3, "stale"); + + let _ = cache.record_response(&request, &response); + + assert!(cache.lookup_completion(&request).is_none()); + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 0); + assert_eq!(stats.completion_cache_misses, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn sanitize_path_caps_long_values() { + let long_value = "x".repeat(200); + let sanitized = sanitize_path_segment(&long_value); + assert!(sanitized.len() <= 80); + } + + #[test] + fn request_hashes_are_versioned_and_stable() { + let request = sample_request("stable"); + let first = request_hash_hex(&request); + let second = request_hash_hex(&request); + assert_eq!(first, second); + assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX)); + } + + fn sample_request(text: &str) -> MessageRequest { + MessageRequest { + model: "claude-3-7-sonnet-latest".to_string(), + max_tokens: 64, + messages: vec![InputMessage::user_text(text)], + system: Some("system".to_string()), + tools: None, + tool_choice: None, + stream: false, + ..Default::default() + } + } + + fn sample_response( + cache_read_input_tokens: u32, + output_tokens: u32, + text: &str, + ) -> MessageResponse { + MessageResponse { + id: "msg_test".to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::Text { + text: text.to_string(), + }], + model: "claude-3-7-sonnet-latest".to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_creation_input_tokens: 5, + cache_read_input_tokens, + output_tokens, + }, + request_id: Some("req_test".to_string()), + } + } +} diff --git a/crates/api/src/providers/claw_provider.rs b/crates/api/src/providers/anthropic.rs similarity index 52% rename from crates/api/src/providers/claw_provider.rs rename to crates/api/src/providers/anthropic.rs index 38c523f..6e62b7d 100644 --- a/crates/api/src/providers/claw_provider.rs +++ b/crates/api/src/providers/anthropic.rs @@ -1,25 +1,33 @@ use std::collections::VecDeque; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use runtime::format_usd; use runtime::{ load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, }; use serde::Deserialize; +use serde_json::{Map, Value}; +use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, SessionTracer}; use crate::error::ApiError; +use crate::http_client::build_http_client_or_default; +use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; -use super::{Provider, ProviderFuture}; +use super::{ + anthropic_missing_credentials, model_token_limit, resolve_model_alias, Provider, ProviderFuture, +}; use crate::sse::SseParser; -use crate::types::{MessageRequest, MessageResponse, StreamEvent}; +use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage}; pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; -const ANTHROPIC_VERSION: &str = "2023-06-01"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; -const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); -const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); -const DEFAULT_MAX_RETRIES: u32 = 2; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(128); +const DEFAULT_MAX_RETRIES: u32 = 8; #[derive(Debug, Clone, PartialEq, Eq)] pub enum AuthSource { @@ -43,10 +51,7 @@ impl AuthSource { }), (Some(api_key), None) => Ok(Self::ApiKey(api_key)), (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), - (None, None) => Err(ApiError::missing_credentials( - "Claw", - &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], - )), + (None, None) => Err(anthropic_missing_credentials()), } } @@ -106,37 +111,49 @@ impl From for AuthSource { } #[derive(Debug, Clone)] -pub struct ClawApiClient { +pub struct AnthropicClient { http: reqwest::Client, auth: AuthSource, base_url: String, max_retries: u32, initial_backoff: Duration, max_backoff: Duration, + request_profile: AnthropicRequestProfile, + session_tracer: Option, + prompt_cache: Option, + last_prompt_cache_record: Arc>>, } -impl ClawApiClient { +impl AnthropicClient { #[must_use] pub fn new(api_key: impl Into) -> Self { Self { - http: reqwest::Client::new(), + http: build_http_client_or_default(), auth: AuthSource::ApiKey(api_key.into()), base_url: DEFAULT_BASE_URL.to_string(), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + request_profile: AnthropicRequestProfile::default(), + session_tracer: None, + prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } #[must_use] pub fn from_auth(auth: AuthSource) -> Self { Self { - http: reqwest::Client::new(), + http: build_http_client_or_default(), auth, base_url: DEFAULT_BASE_URL.to_string(), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + request_profile: AnthropicRequestProfile::default(), + session_tracer: None, + prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } @@ -194,6 +211,70 @@ impl ClawApiClient { self } + #[must_use] + pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self { + self.session_tracer = Some(session_tracer); + self + } + + #[must_use] + pub fn with_client_identity(mut self, client_identity: ClientIdentity) -> Self { + self.request_profile.client_identity = client_identity; + self + } + + #[must_use] + pub fn with_beta(mut self, beta: impl Into) -> Self { + self.request_profile = self.request_profile.with_beta(beta); + self + } + + #[must_use] + pub fn with_extra_body_param(mut self, key: impl Into, value: Value) -> Self { + self.request_profile = self.request_profile.with_extra_body(key, value); + self + } + + #[must_use] + pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self { + self.prompt_cache = Some(prompt_cache); + self + } + + #[must_use] + pub fn prompt_cache_stats(&self) -> Option { + self.prompt_cache.as_ref().map(PromptCache::stats) + } + + #[must_use] + pub fn request_profile(&self) -> &AnthropicRequestProfile { + &self.request_profile + } + + #[must_use] + pub fn session_tracer(&self) -> Option<&SessionTracer> { + self.session_tracer.as_ref() + } + + #[must_use] + pub fn prompt_cache(&self) -> Option<&PromptCache> { + self.prompt_cache.as_ref() + } + + #[must_use] + pub fn take_last_prompt_cache_record(&self) -> Option { + self.last_prompt_cache_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take() + } + + #[must_use] + pub fn with_request_profile(mut self, request_profile: AnthropicRequestProfile) -> Self { + self.request_profile = request_profile; + self + } + #[must_use] pub fn auth_source(&self) -> &AuthSource { &self.auth @@ -207,15 +288,51 @@ impl ClawApiClient { stream: false, ..request.clone() }; - let response = self.send_with_retry(&request).await?; - let request_id = request_id_from_headers(response.headers()); - let mut response = response - .json::() - .await - .map_err(ApiError::from)?; + + if let Some(prompt_cache) = &self.prompt_cache { + if let Some(response) = prompt_cache.lookup_completion(&request) { + return Ok(response); + } + } + + self.preflight_message_request(&request).await?; + + let http_response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(http_response.headers()); + let body = http_response.text().await.map_err(ApiError::from)?; + let mut response = serde_json::from_str::(&body).map_err(|error| { + ApiError::json_deserialize("Anthropic", &request.model, &body, error) + })?; if response.request_id.is_none() { response.request_id = request_id; } + + if let Some(prompt_cache) = &self.prompt_cache { + let record = prompt_cache.record_response(&request, &response); + self.store_last_prompt_cache_record(record); + } + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_analytics( + AnalyticsEvent::new("api", "message_usage") + .with_property( + "request_id", + response + .request_id + .clone() + .map_or(Value::Null, Value::String), + ) + .with_property("total_tokens", Value::from(response.total_tokens())) + .with_property( + "estimated_cost_usd", + Value::String(format_usd( + response + .usage + .estimated_cost_usd(&response.model) + .total_cost_usd(), + )), + ), + ); + } Ok(response) } @@ -223,15 +340,21 @@ impl ClawApiClient { &self, request: &MessageRequest, ) -> Result { + self.preflight_message_request(request).await?; let response = self .send_with_retry(&request.clone().with_streaming()) .await?; Ok(MessageStream { request_id: request_id_from_headers(response.headers()), response, - parser: SseParser::new(), + parser: SseParser::new().with_context("Anthropic", request.model.clone()), pending: VecDeque::new(), done: false, + request: request.clone(), + prompt_cache: self.prompt_cache.clone(), + latest_usage: None, + usage_recorded: false, + last_prompt_cache_record: Arc::clone(&self.last_prompt_cache_record), }) } @@ -249,10 +372,10 @@ impl ClawApiClient { .await .map_err(ApiError::from)?; let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) + let body = response.text().await.map_err(ApiError::from)?; + serde_json::from_str::(&body).map_err(|error| { + ApiError::json_deserialize("Anthropic OAuth (exchange)", "n/a", &body, error) + }) } pub async fn refresh_oauth_token( @@ -269,10 +392,10 @@ impl ClawApiClient { .await .map_err(ApiError::from)?; let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) + let body = response.text().await.map_err(ApiError::from)?; + serde_json::from_str::(&body).map_err(|error| { + ApiError::json_deserialize("Anthropic OAuth (refresh)", "n/a", &body, error) + }) } async fn send_with_retry( @@ -284,25 +407,54 @@ impl ClawApiClient { loop { attempts += 1; + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_http_request_started( + attempts, + "POST", + "/v1/messages", + Map::new(), + ); + } match self.send_raw_request(request).await { Ok(response) => match expect_success(response).await { - Ok(response) => return Ok(response), + Ok(response) => { + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_http_request_succeeded( + attempts, + "POST", + "/v1/messages", + response.status().as_u16(), + request_id_from_headers(response.headers()), + Map::new(), + ); + } + return Ok(response); + } Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + self.record_request_failure(attempts, &error); last_error = Some(error); } - Err(error) => return Err(error), + Err(error) => { + let error = enrich_bearer_auth_error(error, &self.auth); + self.record_request_failure(attempts, &error); + return Err(error); + } }, Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + self.record_request_failure(attempts, &error); last_error = Some(error); } - Err(error) => return Err(error), + Err(error) => { + self.record_request_failure(attempts, &error); + return Err(error); + } } if attempts > self.max_retries { break; } - tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + tokio::time::sleep(self.jittered_backoff_for_attempt(attempts)?).await; } Err(ApiError::RetriesExhausted { @@ -316,15 +468,103 @@ impl ClawApiClient { request: &MessageRequest, ) -> Result { let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let mut request_body = self.request_profile.render_json_body(request)?; + strip_unsupported_beta_body_fields(&mut request_body); + let request_builder = self.build_request(&request_url).json(&request_body); + request_builder.send().await.map_err(ApiError::from) + } + + fn build_request(&self, request_url: &str) -> reqwest::RequestBuilder { let request_builder = self .http - .post(&request_url) - .header("anthropic-version", ANTHROPIC_VERSION) + .post(request_url) .header("content-type", "application/json"); let mut request_builder = self.auth.apply(request_builder); + for (header_name, header_value) in self.request_profile.header_pairs() { + request_builder = request_builder.header(header_name, header_value); + } + request_builder + } - request_builder = request_builder.json(request); - request_builder.send().await.map_err(ApiError::from) + async fn preflight_message_request(&self, request: &MessageRequest) -> Result<(), ApiError> { + // Always run the local byte-estimate guard first. This catches + // oversized requests even if the remote count_tokens endpoint is + // unreachable, misconfigured, or unimplemented (e.g., third-party + // Anthropic-compatible gateways). If byte estimation already flags + // the request as oversized, reject immediately without a network + // round trip. + super::preflight_message_request(request)?; + + let Some(limit) = model_token_limit(&request.model) else { + return Ok(()); + }; + + // Best-effort refinement using the Anthropic count_tokens endpoint. + // On any failure (network, parse, auth), fall back to the local + // byte-estimate result which already passed above. + let counted_input_tokens = match self.count_tokens(request).await { + Ok(count) => count, + Err(_) => return Ok(()), + }; + let estimated_total_tokens = counted_input_tokens.saturating_add(request.max_tokens); + if estimated_total_tokens > limit.context_window_tokens { + return Err(ApiError::ContextWindowExceeded { + model: resolve_model_alias(&request.model), + estimated_input_tokens: counted_input_tokens, + requested_output_tokens: request.max_tokens, + estimated_total_tokens, + context_window_tokens: limit.context_window_tokens, + }); + } + + Ok(()) + } + + async fn count_tokens(&self, request: &MessageRequest) -> Result { + #[derive(serde::Deserialize)] + struct CountTokensResponse { + input_tokens: u32, + } + + let request_url = format!( + "{}/v1/messages/count_tokens", + self.base_url.trim_end_matches('/') + ); + let mut request_body = self.request_profile.render_json_body(request)?; + strip_unsupported_beta_body_fields(&mut request_body); + let response = self + .build_request(&request_url) + .json(&request_body) + .send() + .await + .map_err(ApiError::from)?; + + let response = expect_success(response).await?; + let body = response.text().await.map_err(ApiError::from)?; + let parsed = serde_json::from_str::(&body).map_err(|error| { + ApiError::json_deserialize("Anthropic count_tokens", &request.model, &body, error) + })?; + Ok(parsed.input_tokens) + } + + fn record_request_failure(&self, attempt: u32, error: &ApiError) { + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_http_request_failed( + attempt, + "POST", + "/v1/messages", + error.to_string(), + error.is_retryable(), + Map::new(), + ); + } + } + + fn store_last_prompt_cache_record(&self, record: PromptCacheRecord) { + *self + .last_prompt_cache_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record); } fn backoff_for_attempt(&self, attempt: u32) -> Result { @@ -339,6 +579,42 @@ impl ClawApiClient { .checked_mul(multiplier) .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) } + + fn jittered_backoff_for_attempt(&self, attempt: u32) -> Result { + let base = self.backoff_for_attempt(attempt)?; + Ok(base + jitter_for_base(base)) + } +} + +/// Process-wide counter that guarantees distinct jitter samples even when +/// the system clock resolution is coarser than consecutive retry sleeps. +static JITTER_COUNTER: AtomicU64 = AtomicU64::new(0); + +/// Returns a random additive jitter in `[0, base]` to decorrelate retries +/// from multiple concurrent clients. Entropy is drawn from the nanosecond +/// wall clock mixed with a monotonic counter and run through a splitmix64 +/// finalizer; adequate for retry jitter (no cryptographic requirement). +fn jitter_for_base(base: Duration) -> Duration { + let base_nanos = u64::try_from(base.as_nanos()).unwrap_or(u64::MAX); + if base_nanos == 0 { + return Duration::ZERO; + } + let raw_nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|elapsed| u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX)) + .unwrap_or(0); + let tick = JITTER_COUNTER.fetch_add(1, Ordering::Relaxed); + // splitmix64 finalizer — mixes the low bits so large bases still see + // jitter across their full range instead of being clamped to subsec nanos. + let mut mixed = raw_nanos + .wrapping_add(tick) + .wrapping_add(0x9E37_79B9_7F4A_7C15); + mixed = (mixed ^ (mixed >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + mixed = (mixed ^ (mixed >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + mixed ^= mixed >> 31; + // Inclusive upper bound: jitter may equal `base`, matching "up to base". + let jitter_nanos = mixed % base_nanos.saturating_add(1); + Duration::from_nanos(jitter_nanos) } impl AuthSource { @@ -367,10 +643,7 @@ impl AuthSource { } } Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), - Ok(None) => Err(ApiError::missing_credentials( - "Claw", - &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], - )), + Ok(None) => Err(anthropic_missing_credentials()), Err(error) => Err(error), } } @@ -414,10 +687,7 @@ where } let Some(token_set) = load_saved_oauth_token()? else { - return Err(ApiError::missing_credentials( - "Claw", - &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], - )); + return Err(anthropic_missing_credentials()); }; if !oauth_token_is_expired(&token_set) { return Ok(AuthSource::BearerToken(token_set.access_token)); @@ -446,7 +716,7 @@ fn resolve_saved_oauth_token_set( let Some(refresh_token) = token_set.refresh_token.clone() else { return Err(ApiError::ExpiredOAuthToken); }; - let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); let refreshed = client_runtime_block_on(async { client .refresh_oauth_token( @@ -503,7 +773,7 @@ fn now_unix_timestamp() -> u64 { fn read_env_non_empty(key: &str) -> Result, ApiError> { match std::env::var(key) { Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(super::dotenv_value(key)), Err(error) => Err(ApiError::from(error)), } } @@ -514,10 +784,7 @@ fn read_api_key() -> Result { auth.api_key() .or_else(|| auth.bearer_token()) .map(ToOwned::to_owned) - .ok_or(ApiError::missing_credentials( - "Claw", - &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], - )) + .ok_or_else(anthropic_missing_credentials) } #[cfg(test)] @@ -540,7 +807,7 @@ fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option( @@ -565,6 +832,11 @@ pub struct MessageStream { parser: SseParser, pending: VecDeque, done: bool, + request: MessageRequest, + prompt_cache: Option, + latest_usage: Option, + usage_recorded: bool, + last_prompt_cache_record: Arc>>, } impl MessageStream { @@ -576,6 +848,7 @@ impl MessageStream { pub async fn next_event(&mut self) -> Result, ApiError> { loop { if let Some(event) = self.pending.pop_front() { + self.observe_event(&event); return Ok(Some(event)); } @@ -598,6 +871,29 @@ impl MessageStream { } } } + + fn observe_event(&mut self, event: &StreamEvent) { + match event { + StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => { + self.latest_usage = Some(usage.clone()); + } + StreamEvent::MessageStop(_) => { + if !self.usage_recorded { + if let (Some(prompt_cache), Some(usage)) = + (&self.prompt_cache, self.latest_usage.as_ref()) + { + let record = prompt_cache.record_usage(&self.request, usage); + *self + .last_prompt_cache_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record); + } + self.usage_recorded = true; + } + } + _ => {} + } + } } async fn expect_success(response: reqwest::Response) -> Result { @@ -606,8 +902,9 @@ async fn expect_success(response: reqwest::Response) -> Result(&body).ok(); + let parsed_error = serde_json::from_str::(&body).ok(); let retryable = is_retryable_status(status); Err(ApiError::Api { @@ -618,6 +915,7 @@ async fn expect_success(response: reqwest::Response) -> Result bool { matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) } -#[derive(Debug, Deserialize)] -struct ApiErrorEnvelope { - error: ApiErrorBody, +/// Anthropic API keys (`sk-ant-*`) are accepted over the `x-api-key` header +/// and rejected with HTTP 401 "Invalid bearer token" when sent as a Bearer +/// token via `ANTHROPIC_AUTH_TOKEN`. This happens often enough in the wild +/// (users copy-paste an `sk-ant-...` key into `ANTHROPIC_AUTH_TOKEN` because +/// the env var name sounds auth-related) that a bare 401 error is useless. +/// When we detect this exact shape, append a hint to the error message that +/// points the user at the one-line fix. +const SK_ANT_BEARER_HINT: &str = "sk-ant-* keys go in ANTHROPIC_API_KEY (x-api-key header), not ANTHROPIC_AUTH_TOKEN (Bearer header). Move your key to ANTHROPIC_API_KEY."; + +fn enrich_bearer_auth_error(error: ApiError, auth: &AuthSource) -> ApiError { + let ApiError::Api { + status, + error_type, + message, + request_id, + body, + retryable, + } = error + else { + return error; + }; + if status.as_u16() != 401 { + return ApiError::Api { + status, + error_type, + message, + request_id, + body, + retryable, + }; + } + let Some(bearer_token) = auth.bearer_token() else { + return ApiError::Api { + status, + error_type, + message, + request_id, + body, + retryable, + }; + }; + if !bearer_token.starts_with("sk-ant-") { + return ApiError::Api { + status, + error_type, + message, + request_id, + body, + retryable, + }; + } + // Only append the hint when the AuthSource is pure BearerToken. If both + // api_key and bearer_token are present (`ApiKeyAndBearer`), the x-api-key + // header is already being sent alongside the Bearer header and the 401 + // is coming from a different cause — adding the hint would be misleading. + if auth.api_key().is_some() { + return ApiError::Api { + status, + error_type, + message, + request_id, + body, + retryable, + }; + } + let enriched_message = match message { + Some(existing) => Some(format!("{existing} — hint: {SK_ANT_BEARER_HINT}")), + None => Some(format!("hint: {SK_ANT_BEARER_HINT}")), + }; + ApiError::Api { + status, + error_type, + message: enriched_message, + request_id, + body, + retryable, + } +} + +/// Remove beta-only body fields that the standard `/v1/messages` and +/// `/v1/messages/count_tokens` endpoints reject as `Extra inputs are not +/// permitted`. The `betas` opt-in is communicated via the `anthropic-beta` +/// HTTP header on these endpoints, never as a JSON body field. +fn strip_unsupported_beta_body_fields(body: &mut Value) { + if let Some(object) = body.as_object_mut() { + object.remove("betas"); + // These fields are OpenAI-compatible only; Anthropic rejects them. + object.remove("frequency_penalty"); + object.remove("presence_penalty"); + // Anthropic uses "stop_sequences" not "stop". Convert if present. + if let Some(stop_val) = object.remove("stop") { + if stop_val.as_array().map_or(false, |a| !a.is_empty()) { + object.insert("stop_sequences".to_string(), stop_val); + } + } + } } #[derive(Debug, Deserialize)] -struct ApiErrorBody { - #[serde(alias = "code", rename = "type")] +struct AnthropicErrorEnvelope { + error: AnthropicErrorBody, +} + +#[derive(Debug, Deserialize)] +struct AnthropicErrorBody { + #[serde(rename = "type")] error_type: String, message: String, } @@ -652,7 +1048,7 @@ mod tests { use super::{ now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, - resolve_startup_auth_source, AuthSource, ClawApiClient, OAuthTokenSet, + resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, }; use crate::types::{ContentBlockDelta, MessageRequest}; @@ -953,6 +1349,7 @@ mod tests { tools: None, tool_choice: None, stream: false, + ..Default::default() }; assert!(request.with_streaming().stream); @@ -960,7 +1357,7 @@ mod tests { #[test] fn backoff_doubles_until_maximum() { - let client = ClawApiClient::new("test-key").with_retry_policy( + let client = AnthropicClient::new("test-key").with_retry_policy( 3, Duration::from_millis(10), Duration::from_millis(25), @@ -979,6 +1376,58 @@ mod tests { ); } + #[test] + fn jittered_backoff_stays_within_additive_bounds_and_varies() { + let client = AnthropicClient::new("test-key").with_retry_policy( + 8, + Duration::from_secs(1), + Duration::from_secs(128), + ); + let mut samples = Vec::with_capacity(64); + for _ in 0..64 { + let base = client.backoff_for_attempt(3).expect("base attempt 3"); + let jittered = client + .jittered_backoff_for_attempt(3) + .expect("jittered attempt 3"); + assert!( + jittered >= base, + "jittered delay {jittered:?} must be at least the base {base:?}" + ); + assert!( + jittered <= base * 2, + "jittered delay {jittered:?} must not exceed base*2 {:?}", + base * 2 + ); + samples.push(jittered); + } + let distinct: std::collections::HashSet<_> = samples.iter().collect(); + assert!( + distinct.len() > 1, + "jitter should produce varied delays across samples, got {samples:?}" + ); + } + + #[test] + fn default_retry_policy_matches_exponential_schedule() { + let client = AnthropicClient::new("test-key"); + assert_eq!( + client.backoff_for_attempt(1).expect("attempt 1"), + Duration::from_secs(1) + ); + assert_eq!( + client.backoff_for_attempt(2).expect("attempt 2"), + Duration::from_secs(2) + ); + assert_eq!( + client.backoff_for_attempt(3).expect("attempt 3"), + Duration::from_secs(4) + ); + assert_eq!( + client.backoff_for_attempt(8).expect("attempt 8"), + Duration::from_secs(128) + ); + } + #[test] fn retryable_statuses_are_detected() { assert!(super::is_retryable_status( @@ -1043,4 +1492,279 @@ mod tests { Some("Bearer proxy-token") ); } + + #[test] + fn strip_unsupported_beta_body_fields_removes_betas_array() { + let mut body = serde_json::json!({ + "model": "claude-sonnet-4-6", + "max_tokens": 1024, + "betas": ["claude-code-20250219", "prompt-caching-scope-2026-01-05"], + "metadata": {"source": "test"}, + }); + + super::strip_unsupported_beta_body_fields(&mut body); + + assert!( + body.get("betas").is_none(), + "betas body field must be stripped before sending to /v1/messages" + ); + assert_eq!( + body.get("model").and_then(serde_json::Value::as_str), + Some("claude-sonnet-4-6") + ); + assert_eq!(body["max_tokens"], serde_json::json!(1024)); + assert_eq!(body["metadata"]["source"], serde_json::json!("test")); + } + + #[test] + fn strip_unsupported_beta_body_fields_is_a_noop_when_betas_absent() { + let mut body = serde_json::json!({ + "model": "claude-sonnet-4-6", + "max_tokens": 1024, + }); + let original = body.clone(); + + super::strip_unsupported_beta_body_fields(&mut body); + + assert_eq!(body, original); + } + + #[test] + fn strip_removes_openai_only_fields_and_converts_stop() { + let mut body = serde_json::json!({ + "model": "claude-sonnet-4-6", + "max_tokens": 1024, + "temperature": 0.7, + "frequency_penalty": 0.5, + "presence_penalty": 0.3, + "stop": ["\n"], + }); + + super::strip_unsupported_beta_body_fields(&mut body); + + // temperature is kept (Anthropic supports it) + assert_eq!(body["temperature"], serde_json::json!(0.7)); + // frequency_penalty and presence_penalty are removed + assert!( + body.get("frequency_penalty").is_none(), + "frequency_penalty must be stripped for Anthropic" + ); + assert!( + body.get("presence_penalty").is_none(), + "presence_penalty must be stripped for Anthropic" + ); + // stop is renamed to stop_sequences + assert!(body.get("stop").is_none(), "stop must be renamed"); + assert_eq!(body["stop_sequences"], serde_json::json!(["\n"])); + } + + #[test] + fn strip_does_not_add_empty_stop_sequences() { + let mut body = serde_json::json!({ + "model": "claude-sonnet-4-6", + "max_tokens": 1024, + "stop": [], + }); + + super::strip_unsupported_beta_body_fields(&mut body); + + assert!(body.get("stop").is_none()); + assert!( + body.get("stop_sequences").is_none(), + "empty stop should not produce stop_sequences" + ); + } + + #[test] + fn rendered_request_body_strips_betas_for_standard_messages_endpoint() { + let client = AnthropicClient::new("test-key").with_beta("tools-2026-04-01"); + let request = MessageRequest { + model: "claude-sonnet-4-6".to_string(), + max_tokens: 64, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + ..Default::default() + }; + + let mut rendered = client + .request_profile() + .render_json_body(&request) + .expect("body should render"); + assert!( + rendered.get("betas").is_some(), + "render_json_body still emits betas; the strip helper guards the wire format", + ); + super::strip_unsupported_beta_body_fields(&mut rendered); + + assert!( + rendered.get("betas").is_none(), + "betas must not appear in /v1/messages request bodies" + ); + assert_eq!( + rendered.get("model").and_then(serde_json::Value::as_str), + Some("claude-sonnet-4-6") + ); + } + + #[test] + fn enrich_bearer_auth_error_appends_sk_ant_hint_on_401_with_pure_bearer_token() { + // given + let auth = AuthSource::BearerToken("sk-ant-api03-deadbeef".to_string()); + let error = crate::error::ApiError::Api { + status: reqwest::StatusCode::UNAUTHORIZED, + error_type: Some("authentication_error".to_string()), + message: Some("Invalid bearer token".to_string()), + request_id: Some("req_varleg_001".to_string()), + body: String::new(), + retryable: false, + }; + + // when + let enriched = super::enrich_bearer_auth_error(error, &auth); + + // then + let rendered = enriched.to_string(); + assert!( + rendered.contains("Invalid bearer token"), + "existing provider message should be preserved: {rendered}" + ); + assert!( + rendered.contains( + "sk-ant-* keys go in ANTHROPIC_API_KEY (x-api-key header), not ANTHROPIC_AUTH_TOKEN (Bearer header). Move your key to ANTHROPIC_API_KEY." + ), + "rendered error should include the sk-ant-* hint: {rendered}" + ); + assert!( + rendered.contains("[trace req_varleg_001]"), + "request id should still flow through the enriched error: {rendered}" + ); + match enriched { + crate::error::ApiError::Api { status, .. } => { + assert_eq!(status, reqwest::StatusCode::UNAUTHORIZED); + } + other => panic!("expected Api variant, got {other:?}"), + } + } + + #[test] + fn enrich_bearer_auth_error_leaves_non_401_errors_unchanged() { + // given + let auth = AuthSource::BearerToken("sk-ant-api03-deadbeef".to_string()); + let error = crate::error::ApiError::Api { + status: reqwest::StatusCode::INTERNAL_SERVER_ERROR, + error_type: Some("api_error".to_string()), + message: Some("internal server error".to_string()), + request_id: None, + body: String::new(), + retryable: true, + }; + + // when + let enriched = super::enrich_bearer_auth_error(error, &auth); + + // then + let rendered = enriched.to_string(); + assert!( + !rendered.contains("sk-ant-*"), + "non-401 errors must not be annotated with the bearer hint: {rendered}" + ); + assert!( + rendered.contains("internal server error"), + "original message must be preserved verbatim: {rendered}" + ); + } + + #[test] + fn enrich_bearer_auth_error_ignores_401_when_bearer_token_is_not_sk_ant() { + // given + let auth = AuthSource::BearerToken("oauth-access-token-opaque".to_string()); + let error = crate::error::ApiError::Api { + status: reqwest::StatusCode::UNAUTHORIZED, + error_type: Some("authentication_error".to_string()), + message: Some("Invalid bearer token".to_string()), + request_id: None, + body: String::new(), + retryable: false, + }; + + // when + let enriched = super::enrich_bearer_auth_error(error, &auth); + + // then + let rendered = enriched.to_string(); + assert!( + !rendered.contains("sk-ant-*"), + "oauth-style bearer tokens must not trigger the sk-ant-* hint: {rendered}" + ); + } + + #[test] + fn enrich_bearer_auth_error_skips_hint_when_api_key_header_is_also_present() { + // given + let auth = AuthSource::ApiKeyAndBearer { + api_key: "sk-ant-api03-legitimate".to_string(), + bearer_token: "sk-ant-api03-deadbeef".to_string(), + }; + let error = crate::error::ApiError::Api { + status: reqwest::StatusCode::UNAUTHORIZED, + error_type: Some("authentication_error".to_string()), + message: Some("Invalid bearer token".to_string()), + request_id: None, + body: String::new(), + retryable: false, + }; + + // when + let enriched = super::enrich_bearer_auth_error(error, &auth); + + // then + let rendered = enriched.to_string(); + assert!( + !rendered.contains("sk-ant-*"), + "hint should be suppressed when x-api-key header is already being sent: {rendered}" + ); + } + + #[test] + fn enrich_bearer_auth_error_ignores_401_when_auth_source_has_no_bearer() { + // given + let auth = AuthSource::ApiKey("sk-ant-api03-legitimate".to_string()); + let error = crate::error::ApiError::Api { + status: reqwest::StatusCode::UNAUTHORIZED, + error_type: Some("authentication_error".to_string()), + message: Some("Invalid x-api-key".to_string()), + request_id: None, + body: String::new(), + retryable: false, + }; + + // when + let enriched = super::enrich_bearer_auth_error(error, &auth); + + // then + let rendered = enriched.to_string(); + assert!( + !rendered.contains("sk-ant-*"), + "bearer hint must not apply when AuthSource is ApiKey-only: {rendered}" + ); + } + + #[test] + fn enrich_bearer_auth_error_passes_non_api_errors_through_unchanged() { + // given + let auth = AuthSource::BearerToken("sk-ant-api03-deadbeef".to_string()); + let error = crate::error::ApiError::InvalidSseFrame("unterminated event"); + + // when + let enriched = super::enrich_bearer_auth_error(error, &auth); + + // then + assert!(matches!( + enriched, + crate::error::ApiError::InvalidSseFrame(_) + )); + } } diff --git a/crates/api/src/providers/mod.rs b/crates/api/src/providers/mod.rs index 80e60db..24d5630 100644 --- a/crates/api/src/providers/mod.rs +++ b/crates/api/src/providers/mod.rs @@ -1,14 +1,19 @@ +#![allow(clippy::cast_possible_truncation)] use std::future::Future; use std::pin::Pin; +use serde::Serialize; + use crate::error::ApiError; use crate::types::{MessageRequest, MessageResponse}; -pub mod claw_provider; +pub mod anthropic; pub mod openai_compat; +#[allow(dead_code)] pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; +#[allow(dead_code)] pub trait Provider { type Stream; @@ -25,7 +30,7 @@ pub trait Provider { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ProviderKind { - ClawApi, + Anthropic, Xai, OpenAi, } @@ -38,59 +43,38 @@ pub struct ProviderMetadata { pub default_base_url: &'static str, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ModelTokenLimit { + pub max_output_tokens: u32, + pub context_window_tokens: u32, +} + const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ ( "opus", ProviderMetadata { - provider: ProviderKind::ClawApi, + provider: ProviderKind::Anthropic, auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, + default_base_url: anthropic::DEFAULT_BASE_URL, }, ), ( "sonnet", ProviderMetadata { - provider: ProviderKind::ClawApi, + provider: ProviderKind::Anthropic, auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, + default_base_url: anthropic::DEFAULT_BASE_URL, }, ), ( "haiku", ProviderMetadata { - provider: ProviderKind::ClawApi, + provider: ProviderKind::Anthropic, auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "claude-opus-4-6", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "claude-sonnet-4-6", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "claude-haiku-4-5-20251213", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, + default_base_url: anthropic::DEFAULT_BASE_URL, }, ), ( @@ -138,69 +122,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }, ), - ( - "glm-4-plus", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "glm-4-0520", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "glm-4", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "glm-4-air", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "glm-4-flash", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "glm-5", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), - ( - "glm-5.1", - ProviderMetadata { - provider: ProviderKind::ClawApi, - auth_env: "ANTHROPIC_API_KEY", - base_url_env: "ANTHROPIC_BASE_URL", - default_base_url: claw_provider::DEFAULT_BASE_URL, - }, - ), ]; #[must_use] @@ -211,7 +132,7 @@ pub fn resolve_model_alias(model: &str) -> String { .iter() .find_map(|(alias, metadata)| { (*alias == lower).then_some(match metadata.provider { - ProviderKind::ClawApi => match *alias { + ProviderKind::Anthropic => match *alias { "opus" => "claude-opus-4-6", "sonnet" => "claude-sonnet-4-6", "haiku" => "claude-haiku-4-5-20251213", @@ -232,11 +153,15 @@ pub fn resolve_model_alias(model: &str) -> String { #[must_use] pub fn metadata_for_model(model: &str) -> Option { let canonical = resolve_model_alias(model); - let lower = canonical.to_ascii_lowercase(); - if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) { - return Some(*metadata); + if canonical.starts_with("claude") { + return Some(ProviderMetadata { + provider: ProviderKind::Anthropic, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }); } - if lower.starts_with("grok") { + if canonical.starts_with("grok") { return Some(ProviderMetadata { provider: ProviderKind::Xai, auth_env: "XAI_API_KEY", @@ -244,6 +169,31 @@ pub fn metadata_for_model(model: &str) -> Option { default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, }); } + // Explicit provider-namespaced models (e.g. "openai/gpt-4.1-mini") must + // route to the correct provider regardless of which auth env vars are set. + // Without this, detect_provider_kind falls through to the auth-sniffer + // order and misroutes to Anthropic if ANTHROPIC_API_KEY is present. + if canonical.starts_with("openai/") || canonical.starts_with("gpt-") { + return Some(ProviderMetadata { + provider: ProviderKind::OpenAi, + auth_env: "OPENAI_API_KEY", + base_url_env: "OPENAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_OPENAI_BASE_URL, + }); + } + // Alibaba DashScope compatible-mode endpoint. Routes qwen/* and bare + // qwen-* model names (qwen-max, qwen-plus, qwen-turbo, qwen-qwq, etc.) + // to the OpenAI-compat client pointed at DashScope's /compatible-mode/v1. + // Uses the OpenAi provider kind because DashScope speaks the OpenAI REST + // shape — only the base URL and auth env var differ. + if canonical.starts_with("qwen/") || canonical.starts_with("qwen-") { + return Some(ProviderMetadata { + provider: ProviderKind::OpenAi, + auth_env: "DASHSCOPE_API_KEY", + base_url_env: "DASHSCOPE_BASE_URL", + default_base_url: openai_compat::DEFAULT_DASHSCOPE_BASE_URL, + }); + } None } @@ -252,8 +202,17 @@ pub fn detect_provider_kind(model: &str) -> ProviderKind { if let Some(metadata) = metadata_for_model(model) { return metadata.provider; } - if claw_provider::has_auth_from_env_or_saved().unwrap_or(false) { - return ProviderKind::ClawApi; + // When OPENAI_BASE_URL is set, the user explicitly configured an + // OpenAI-compatible endpoint. Prefer it over the Anthropic fallback + // even when the model name has no recognized prefix — this is the + // common case for local providers (Ollama, LM Studio, vLLM, etc.) + // where model names like "qwen2.5-coder:7b" don't match any prefix. + if std::env::var_os("OPENAI_BASE_URL").is_some() && openai_compat::has_api_key("OPENAI_API_KEY") + { + return ProviderKind::OpenAi; + } + if anthropic::has_auth_from_env_or_saved().unwrap_or(false) { + return ProviderKind::Anthropic; } if openai_compat::has_api_key("OPENAI_API_KEY") { return ProviderKind::OpenAi; @@ -261,22 +220,271 @@ pub fn detect_provider_kind(model: &str) -> ProviderKind { if openai_compat::has_api_key("XAI_API_KEY") { return ProviderKind::Xai; } - ProviderKind::ClawApi + // Last resort: if OPENAI_BASE_URL is set without OPENAI_API_KEY (some + // local providers like Ollama don't require auth), still route there. + if std::env::var_os("OPENAI_BASE_URL").is_some() { + return ProviderKind::OpenAi; + } + ProviderKind::Anthropic } #[must_use] pub fn max_tokens_for_model(model: &str) -> u32 { + model_token_limit(model).map_or_else( + || { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } + }, + |limit| limit.max_output_tokens, + ) +} + +/// Returns the effective max output tokens for a model, preferring a plugin +/// override when present. Falls back to [`max_tokens_for_model`] when the +/// override is `None`. +#[must_use] +pub fn max_tokens_for_model_with_override(model: &str, plugin_override: Option) -> u32 { + plugin_override.unwrap_or_else(|| max_tokens_for_model(model)) +} + +#[must_use] +pub fn model_token_limit(model: &str) -> Option { let canonical = resolve_model_alias(model); - if canonical.contains("opus") { - 32_000 - } else { - 64_000 + match canonical.as_str() { + "claude-opus-4-6" => Some(ModelTokenLimit { + max_output_tokens: 32_000, + context_window_tokens: 200_000, + }), + "claude-sonnet-4-6" | "claude-haiku-4-5-20251213" => Some(ModelTokenLimit { + max_output_tokens: 64_000, + context_window_tokens: 200_000, + }), + "grok-3" | "grok-3-mini" => Some(ModelTokenLimit { + max_output_tokens: 64_000, + context_window_tokens: 131_072, + }), + _ => None, } } +pub fn preflight_message_request(request: &MessageRequest) -> Result<(), ApiError> { + let Some(limit) = model_token_limit(&request.model) else { + return Ok(()); + }; + + let estimated_input_tokens = estimate_message_request_input_tokens(request); + let estimated_total_tokens = estimated_input_tokens.saturating_add(request.max_tokens); + if estimated_total_tokens > limit.context_window_tokens { + return Err(ApiError::ContextWindowExceeded { + model: resolve_model_alias(&request.model), + estimated_input_tokens, + requested_output_tokens: request.max_tokens, + estimated_total_tokens, + context_window_tokens: limit.context_window_tokens, + }); + } + + Ok(()) +} + +fn estimate_message_request_input_tokens(request: &MessageRequest) -> u32 { + let mut estimate = estimate_serialized_tokens(&request.messages); + estimate = estimate.saturating_add(estimate_serialized_tokens(&request.system)); + estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tools)); + estimate = estimate.saturating_add(estimate_serialized_tokens(&request.tool_choice)); + estimate +} + +fn estimate_serialized_tokens(value: &T) -> u32 { + serde_json::to_vec(value) + .ok() + .map_or(0, |bytes| (bytes.len() / 4 + 1) as u32) +} + +/// Env var names used by other provider backends. When Anthropic auth +/// resolution fails we sniff these so we can hint the user that their +/// credentials probably belong to a different provider and suggest the +/// model-prefix routing fix that would select it. +const FOREIGN_PROVIDER_ENV_VARS: &[(&str, &str, &str)] = &[ + ( + "OPENAI_API_KEY", + "OpenAI-compat", + "prefix your model name with `openai/` (e.g. `--model openai/gpt-4.1-mini`) so prefix routing selects the OpenAI-compatible provider, and set `OPENAI_BASE_URL` if you are pointing at OpenRouter/Ollama/a local server", + ), + ( + "XAI_API_KEY", + "xAI", + "use an xAI model alias (e.g. `--model grok` or `--model grok-mini`) so the prefix router selects the xAI backend", + ), + ( + "DASHSCOPE_API_KEY", + "Alibaba DashScope", + "prefix your model name with `qwen/` or `qwen-` (e.g. `--model qwen-plus`) so prefix routing selects the DashScope backend", + ), +]; + +/// Check whether an env var is set to a non-empty value either in the real +/// process environment or in the working-directory `.env` file. Mirrors the +/// credential discovery path used by `read_env_non_empty` so the hint text +/// stays truthful when users rely on `.env` instead of a real export. +fn env_or_dotenv_present(key: &str) -> bool { + match std::env::var(key) { + Ok(value) if !value.is_empty() => true, + Ok(_) | Err(std::env::VarError::NotPresent) => { + dotenv_value(key).is_some_and(|value| !value.is_empty()) + } + Err(_) => false, + } +} + +/// Produce a hint string describing the first foreign provider credential +/// that is present in the environment when Anthropic auth resolution has +/// just failed. Returns `None` when no foreign credential is set, in which +/// case the caller should fall back to the plain `missing_credentials` +/// error without a hint. +pub(crate) fn anthropic_missing_credentials_hint() -> Option { + for (env_var, provider_label, fix_hint) in FOREIGN_PROVIDER_ENV_VARS { + if env_or_dotenv_present(env_var) { + return Some(format!( + "I see {env_var} is set — if you meant to use the {provider_label} provider, {fix_hint}." + )); + } + } + None +} + +/// Build an Anthropic-specific `MissingCredentials` error, attaching a +/// hint suggesting the probable fix whenever a different provider's +/// credentials are already present in the environment. Anthropic call +/// sites should prefer this helper over `ApiError::missing_credentials` +/// so users who mistyped a model name or forgot the prefix get a useful +/// signal instead of a generic "missing Anthropic credentials" wall. +pub(crate) fn anthropic_missing_credentials() -> ApiError { + const PROVIDER: &str = "Anthropic"; + const ENV_VARS: &[&str] = &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]; + match anthropic_missing_credentials_hint() { + Some(hint) => ApiError::missing_credentials_with_hint(PROVIDER, ENV_VARS, hint), + None => ApiError::missing_credentials(PROVIDER, ENV_VARS), + } +} + +/// Parse a `.env` file body into key/value pairs using a minimal `KEY=VALUE` +/// grammar. Lines that are blank, start with `#`, or do not contain `=` are +/// ignored. Surrounding double or single quotes are stripped from the value. +/// An optional leading `export ` prefix on the key is also stripped so files +/// shared with shell `source` workflows still parse cleanly. +pub(crate) fn parse_dotenv(content: &str) -> std::collections::HashMap { + let mut values = std::collections::HashMap::new(); + for raw_line in content.lines() { + let line = raw_line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + let Some((raw_key, raw_value)) = line.split_once('=') else { + continue; + }; + let trimmed_key = raw_key.trim(); + let key = trimmed_key + .strip_prefix("export ") + .map_or(trimmed_key, str::trim) + .to_string(); + if key.is_empty() { + continue; + } + let trimmed_value = raw_value.trim(); + let unquoted = if (trimmed_value.starts_with('"') && trimmed_value.ends_with('"') + || trimmed_value.starts_with('\'') && trimmed_value.ends_with('\'')) + && trimmed_value.len() >= 2 + { + &trimmed_value[1..trimmed_value.len() - 1] + } else { + trimmed_value + }; + values.insert(key, unquoted.to_string()); + } + values +} + +/// Load and parse a `.env` file from the given path. Missing files yield +/// `None` instead of an error so callers can use this as a soft fallback. +pub(crate) fn load_dotenv_file( + path: &std::path::Path, +) -> Option> { + let content = std::fs::read_to_string(path).ok()?; + Some(parse_dotenv(&content)) +} + +/// Look up `key` in a `.env` file located in the current working directory. +/// Returns `None` when the file is missing, the key is absent, or the value +/// is empty. +pub(crate) fn dotenv_value(key: &str) -> Option { + let cwd = std::env::current_dir().ok()?; + let values = load_dotenv_file(&cwd.join(".env"))?; + values.get(key).filter(|value| !value.is_empty()).cloned() +} + #[cfg(test)] mod tests { - use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + use std::ffi::OsString; + use std::sync::{Mutex, OnceLock}; + + use serde_json::json; + + use crate::error::ApiError; + use crate::types::{ + InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, + }; + + use super::{ + anthropic_missing_credentials, anthropic_missing_credentials_hint, detect_provider_kind, + load_dotenv_file, max_tokens_for_model, max_tokens_for_model_with_override, + model_token_limit, parse_dotenv, preflight_message_request, resolve_model_alias, + ProviderKind, + }; + + /// Serializes every test in this module that mutates process-wide + /// environment variables so concurrent test threads cannot observe + /// each other's partially-applied state while probing the foreign + /// provider credential sniffer. + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + /// Snapshot-restore guard for a single environment variable. Captures + /// the original value on construction, applies the requested override + /// (set or remove), and restores the original on drop so tests leave + /// the process env untouched even when they panic mid-assertion. + struct EnvVarGuard { + key: &'static str, + original: Option, + } + + impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + match self.original.take() { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } + } #[test] fn resolves_grok_aliases() { @@ -290,7 +498,59 @@ mod tests { assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); assert_eq!( detect_provider_kind("claude-sonnet-4-6"), - ProviderKind::ClawApi + ProviderKind::Anthropic + ); + } + + #[test] + fn openai_namespaced_model_routes_to_openai_not_anthropic() { + // Regression: "openai/gpt-4.1-mini" was misrouted to Anthropic when + // ANTHROPIC_API_KEY was set because metadata_for_model returned None + // and detect_provider_kind fell through to auth-sniffer order. + // The model prefix must win over env-var presence. + let kind = super::metadata_for_model("openai/gpt-4.1-mini") + .map(|m| m.provider) + .unwrap_or_else(|| detect_provider_kind("openai/gpt-4.1-mini")); + assert_eq!( + kind, + ProviderKind::OpenAi, + "openai/ prefix must route to OpenAi regardless of ANTHROPIC_API_KEY" + ); + + // Also cover bare gpt- prefix + let kind2 = super::metadata_for_model("gpt-4o") + .map(|m| m.provider) + .unwrap_or_else(|| detect_provider_kind("gpt-4o")); + assert_eq!(kind2, ProviderKind::OpenAi); + } + + #[test] + fn qwen_prefix_routes_to_dashscope_not_anthropic() { + // User request from Discord #clawcode-get-help: web3g wants to use + // Qwen 3.6 Plus via native Alibaba DashScope API (not OpenRouter, + // which has lower rate limits). metadata_for_model must route + // qwen/* and bare qwen-* to the OpenAi provider kind pointed at + // the DashScope compatible-mode endpoint, regardless of whether + // ANTHROPIC_API_KEY is present in the environment. + let meta = super::metadata_for_model("qwen/qwen-max") + .expect("qwen/ prefix must resolve to DashScope metadata"); + assert_eq!(meta.provider, ProviderKind::OpenAi); + assert_eq!(meta.auth_env, "DASHSCOPE_API_KEY"); + assert_eq!(meta.base_url_env, "DASHSCOPE_BASE_URL"); + assert!(meta.default_base_url.contains("dashscope.aliyuncs.com")); + + // Bare qwen- prefix also routes + let meta2 = super::metadata_for_model("qwen-plus") + .expect("qwen- prefix must resolve to DashScope metadata"); + assert_eq!(meta2.provider, ProviderKind::OpenAi); + assert_eq!(meta2.auth_env, "DASHSCOPE_API_KEY"); + + // detect_provider_kind must agree even if ANTHROPIC_API_KEY is set + let kind = detect_provider_kind("qwen/qwen3-coder"); + assert_eq!( + kind, + ProviderKind::OpenAi, + "qwen/ prefix must win over auth-sniffer order" ); } @@ -299,4 +559,467 @@ mod tests { assert_eq!(max_tokens_for_model("opus"), 32_000); assert_eq!(max_tokens_for_model("grok-3"), 64_000); } + + #[test] + fn plugin_config_max_output_tokens_overrides_model_default() { + // given + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + let root = std::env::temp_dir().join(format!("api-plugin-max-tokens-{nanos}")); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + std::fs::create_dir_all(&home).expect("home config dir"); + std::fs::write( + home.join("settings.json"), + r#"{ + "plugins": { + "maxOutputTokens": 12345 + } + }"#, + ) + .expect("write plugin settings"); + + // when + let loaded = runtime::ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + let plugin_override = loaded.plugins().max_output_tokens(); + let effective = max_tokens_for_model_with_override("claude-opus-4-6", plugin_override); + + // then + assert_eq!(plugin_override, Some(12345)); + assert_eq!(effective, 12345); + assert_ne!(effective, max_tokens_for_model("claude-opus-4-6")); + + std::fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn max_tokens_for_model_with_override_falls_back_when_plugin_unset() { + // given + let plugin_override: Option = None; + + // when + let effective = max_tokens_for_model_with_override("claude-opus-4-6", plugin_override); + + // then + assert_eq!(effective, max_tokens_for_model("claude-opus-4-6")); + assert_eq!(effective, 32_000); + } + + #[test] + fn returns_context_window_metadata_for_supported_models() { + assert_eq!( + model_token_limit("claude-sonnet-4-6") + .expect("claude-sonnet-4-6 should be registered") + .context_window_tokens, + 200_000 + ); + assert_eq!( + model_token_limit("grok-mini") + .expect("grok-mini should resolve to a registered model") + .context_window_tokens, + 131_072 + ); + } + + #[test] + fn preflight_blocks_requests_that_exceed_the_model_context_window() { + let request = MessageRequest { + model: "claude-sonnet-4-6".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(600_000), + }], + }], + system: Some("Keep the answer short.".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Fetches weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": { "city": { "type": "string" } }, + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: true, + ..Default::default() + }; + + let error = preflight_message_request(&request) + .expect_err("oversized request should be rejected before the provider call"); + + match error { + ApiError::ContextWindowExceeded { + model, + estimated_input_tokens, + requested_output_tokens, + estimated_total_tokens, + context_window_tokens, + } => { + assert_eq!(model, "claude-sonnet-4-6"); + assert!(estimated_input_tokens > 136_000); + assert_eq!(requested_output_tokens, 64_000); + assert!(estimated_total_tokens > context_window_tokens); + assert_eq!(context_window_tokens, 200_000); + } + other => panic!("expected context-window preflight failure, got {other:?}"), + } + } + + #[test] + fn preflight_skips_unknown_models() { + let request = MessageRequest { + model: "unknown-model".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(600_000), + }], + }], + system: None, + tools: None, + tool_choice: None, + stream: false, + ..Default::default() + }; + + preflight_message_request(&request) + .expect("models without context metadata should skip the guarded preflight"); + } + + #[test] + fn parse_dotenv_extracts_keys_handles_comments_quotes_and_export_prefix() { + // given + let body = "\ +# this is a comment + +ANTHROPIC_API_KEY=plain-value +XAI_API_KEY=\"quoted-value\" +OPENAI_API_KEY='single-quoted' +export GROK_API_KEY=exported-value + PADDED_KEY = padded-value +EMPTY_VALUE= +NO_EQUALS_LINE +"; + + // when + let values = parse_dotenv(body); + + // then + assert_eq!( + values.get("ANTHROPIC_API_KEY").map(String::as_str), + Some("plain-value") + ); + assert_eq!( + values.get("XAI_API_KEY").map(String::as_str), + Some("quoted-value") + ); + assert_eq!( + values.get("OPENAI_API_KEY").map(String::as_str), + Some("single-quoted") + ); + assert_eq!( + values.get("GROK_API_KEY").map(String::as_str), + Some("exported-value") + ); + assert_eq!( + values.get("PADDED_KEY").map(String::as_str), + Some("padded-value") + ); + assert_eq!(values.get("EMPTY_VALUE").map(String::as_str), Some("")); + assert!(!values.contains_key("NO_EQUALS_LINE")); + assert!(!values.contains_key("# this is a comment")); + } + + #[test] + fn load_dotenv_file_reads_keys_from_disk_and_returns_none_when_missing() { + // given + let temp_root = std::env::temp_dir().join(format!( + "api-dotenv-test-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |duration| duration.as_nanos()) + )); + std::fs::create_dir_all(&temp_root).expect("create temp dir"); + let env_path = temp_root.join(".env"); + std::fs::write( + &env_path, + "ANTHROPIC_API_KEY=secret-from-file\n# comment\nXAI_API_KEY=\"xai-secret\"\n", + ) + .expect("write .env"); + let missing_path = temp_root.join("does-not-exist.env"); + + // when + let loaded = load_dotenv_file(&env_path).expect("file should load"); + let missing = load_dotenv_file(&missing_path); + + // then + assert_eq!( + loaded.get("ANTHROPIC_API_KEY").map(String::as_str), + Some("secret-from-file") + ); + assert_eq!( + loaded.get("XAI_API_KEY").map(String::as_str), + Some("xai-secret") + ); + assert!(missing.is_none()); + + let _ = std::fs::remove_dir_all(&temp_root); + } + + #[test] + fn anthropic_missing_credentials_hint_is_none_when_no_foreign_creds_present() { + // given + let _lock = env_lock(); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", None); + let _xai = EnvVarGuard::set("XAI_API_KEY", None); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", None); + + // when + let hint = anthropic_missing_credentials_hint(); + + // then + assert!( + hint.is_none(), + "no hint should be produced when every foreign provider env var is absent, got {hint:?}" + ); + } + + #[test] + fn anthropic_missing_credentials_hint_detects_openai_api_key_and_recommends_openai_prefix() { + // given + let _lock = env_lock(); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", Some("sk-openrouter-varleg")); + let _xai = EnvVarGuard::set("XAI_API_KEY", None); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", None); + + // when + let hint = anthropic_missing_credentials_hint() + .expect("OPENAI_API_KEY presence should produce a hint"); + + // then + assert!( + hint.contains("OPENAI_API_KEY is set"), + "hint should name the detected env var so users recognize it: {hint}" + ); + assert!( + hint.contains("OpenAI-compat"), + "hint should identify the target provider: {hint}" + ); + assert!( + hint.contains("openai/"), + "hint should mention the `openai/` prefix routing fix: {hint}" + ); + assert!( + hint.contains("OPENAI_BASE_URL"), + "hint should mention OPENAI_BASE_URL so OpenRouter users see the full picture: {hint}" + ); + } + + #[test] + fn anthropic_missing_credentials_hint_detects_xai_api_key() { + // given + let _lock = env_lock(); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", None); + let _xai = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", None); + + // when + let hint = anthropic_missing_credentials_hint() + .expect("XAI_API_KEY presence should produce a hint"); + + // then + assert!( + hint.contains("XAI_API_KEY is set"), + "hint should name XAI_API_KEY: {hint}" + ); + assert!( + hint.contains("xAI"), + "hint should identify the xAI provider: {hint}" + ); + assert!( + hint.contains("grok"), + "hint should suggest a grok-prefixed model alias: {hint}" + ); + } + + #[test] + fn anthropic_missing_credentials_hint_detects_dashscope_api_key() { + // given + let _lock = env_lock(); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", None); + let _xai = EnvVarGuard::set("XAI_API_KEY", None); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", Some("sk-dashscope-test")); + + // when + let hint = anthropic_missing_credentials_hint() + .expect("DASHSCOPE_API_KEY presence should produce a hint"); + + // then + assert!( + hint.contains("DASHSCOPE_API_KEY is set"), + "hint should name DASHSCOPE_API_KEY: {hint}" + ); + assert!( + hint.contains("DashScope"), + "hint should identify the DashScope provider: {hint}" + ); + assert!( + hint.contains("qwen"), + "hint should suggest a qwen-prefixed model alias: {hint}" + ); + } + + #[test] + fn anthropic_missing_credentials_hint_prefers_openai_when_multiple_foreign_creds_set() { + // given + let _lock = env_lock(); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", Some("sk-openrouter-varleg")); + let _xai = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", Some("sk-dashscope-test")); + + // when + let hint = anthropic_missing_credentials_hint() + .expect("multiple foreign creds should still produce a hint"); + + // then + assert!( + hint.contains("OPENAI_API_KEY"), + "OpenAI should be prioritized because it is the most common misrouting pattern (OpenRouter users), got: {hint}" + ); + assert!( + !hint.contains("XAI_API_KEY"), + "only the first detected provider should be named to keep the hint focused, got: {hint}" + ); + } + + #[test] + fn anthropic_missing_credentials_builds_error_with_canonical_env_vars_and_no_hint_when_clean() { + // given + let _lock = env_lock(); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", None); + let _xai = EnvVarGuard::set("XAI_API_KEY", None); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", None); + + // when + let error = anthropic_missing_credentials(); + + // then + match &error { + ApiError::MissingCredentials { + provider, + env_vars, + hint, + } => { + assert_eq!(*provider, "Anthropic"); + assert_eq!(*env_vars, &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]); + assert!( + hint.is_none(), + "clean environment should not generate a hint, got {hint:?}" + ); + } + other => panic!("expected MissingCredentials variant, got {other:?}"), + } + let rendered = error.to_string(); + assert!( + !rendered.contains(" — hint: "), + "rendered error should be a plain missing-creds message: {rendered}" + ); + } + + #[test] + fn anthropic_missing_credentials_builds_error_with_hint_when_openai_key_is_set() { + // given + let _lock = env_lock(); + let _openai = EnvVarGuard::set("OPENAI_API_KEY", Some("sk-openrouter-varleg")); + let _xai = EnvVarGuard::set("XAI_API_KEY", None); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", None); + + // when + let error = anthropic_missing_credentials(); + + // then + match &error { + ApiError::MissingCredentials { + provider, + env_vars, + hint, + } => { + assert_eq!(*provider, "Anthropic"); + assert_eq!(*env_vars, &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"]); + let hint_value = hint.as_deref().expect("hint should be populated"); + assert!( + hint_value.contains("OPENAI_API_KEY is set"), + "hint should name the detected env var: {hint_value}" + ); + } + other => panic!("expected MissingCredentials variant, got {other:?}"), + } + let rendered = error.to_string(); + assert!( + rendered.starts_with("missing Anthropic credentials;"), + "canonical base message should still lead the rendered error: {rendered}" + ); + assert!( + rendered.contains(" — hint: I see OPENAI_API_KEY is set"), + "rendered error should carry the env-driven hint: {rendered}" + ); + } + + #[test] + fn anthropic_missing_credentials_hint_ignores_empty_string_values() { + // given + let _lock = env_lock(); + // An empty value is semantically equivalent to "not set" for the + // credential discovery path, so the sniffer must treat it that way + // to avoid false-positive hints for users who intentionally cleared + // a stale export with `OPENAI_API_KEY=`. + let _openai = EnvVarGuard::set("OPENAI_API_KEY", Some("")); + let _xai = EnvVarGuard::set("XAI_API_KEY", None); + let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", None); + + // when + let hint = anthropic_missing_credentials_hint(); + + // then + assert!( + hint.is_none(), + "empty env var should not trigger the hint sniffer, got {hint:?}" + ); + } + + #[test] + fn openai_base_url_overrides_anthropic_fallback_for_unknown_model() { + // given — user has OPENAI_BASE_URL + OPENAI_API_KEY but no Anthropic + // creds, and a model name with no recognized prefix. + let _lock = env_lock(); + let _base_url = EnvVarGuard::set("OPENAI_BASE_URL", Some("http://127.0.0.1:11434/v1")); + let _api_key = EnvVarGuard::set("OPENAI_API_KEY", Some("dummy")); + let _anthropic_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _anthropic_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + + // when + let provider = detect_provider_kind("qwen2.5-coder:7b"); + + // then — should route to OpenAI, not Anthropic + assert_eq!( + provider, + ProviderKind::OpenAi, + "OPENAI_BASE_URL should win over Anthropic fallback for unknown models" + ); + } + + // NOTE: a "OPENAI_BASE_URL without OPENAI_API_KEY" test is omitted + // because workspace-parallel test binaries can race on process env + // (env_lock only protects within a single binary). The detection logic + // is covered: OPENAI_BASE_URL alone routes to OpenAi as a last-resort + // fallback in detect_provider_kind(). } diff --git a/crates/api/src/providers/openai_compat.rs b/crates/api/src/providers/openai_compat.rs index 52f3695..1d46ee6 100644 --- a/crates/api/src/providers/openai_compat.rs +++ b/crates/api/src/providers/openai_compat.rs @@ -1,10 +1,12 @@ use std::collections::{BTreeMap, VecDeque}; -use std::time::Duration; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; use serde::Deserialize; use serde_json::{json, Value}; use crate::error::ApiError; +use crate::http_client::build_http_client_or_default; use crate::types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, @@ -12,15 +14,16 @@ use crate::types::{ ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; -use super::{Provider, ProviderFuture}; +use super::{preflight_message_request, Provider, ProviderFuture}; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; +pub const DEFAULT_DASHSCOPE_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; -const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); -const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); -const DEFAULT_MAX_RETRIES: u32 = 2; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_secs(1); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(128); +const DEFAULT_MAX_RETRIES: u32 = 8; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct OpenAiCompatConfig { @@ -32,6 +35,7 @@ pub struct OpenAiCompatConfig { const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; +const DASHSCOPE_ENV_VARS: &[&str] = &["DASHSCOPE_API_KEY"]; impl OpenAiCompatConfig { #[must_use] @@ -53,11 +57,27 @@ impl OpenAiCompatConfig { default_base_url: DEFAULT_OPENAI_BASE_URL, } } + + /// Alibaba DashScope compatible-mode endpoint (Qwen family models). + /// Uses the OpenAI-compatible REST shape at /compatible-mode/v1. + /// Requested via Discord #clawcode-get-help: native Alibaba API for + /// higher rate limits than going through OpenRouter. + #[must_use] + pub const fn dashscope() -> Self { + Self { + provider_name: "DashScope", + api_key_env: "DASHSCOPE_API_KEY", + base_url_env: "DASHSCOPE_BASE_URL", + default_base_url: DEFAULT_DASHSCOPE_BASE_URL, + } + } + #[must_use] pub fn credential_env_vars(self) -> &'static [&'static str] { match self.provider_name { "xAI" => XAI_ENV_VARS, "OpenAI" => OPENAI_ENV_VARS, + "DashScope" => DASHSCOPE_ENV_VARS, _ => &[], } } @@ -67,6 +87,7 @@ impl OpenAiCompatConfig { pub struct OpenAiCompatClient { http: reqwest::Client, api_key: String, + config: OpenAiCompatConfig, base_url: String, max_retries: u32, initial_backoff: Duration, @@ -74,11 +95,20 @@ pub struct OpenAiCompatClient { } impl OpenAiCompatClient { + const fn config(&self) -> OpenAiCompatConfig { + self.config + } + + #[must_use] + pub fn base_url(&self) -> &str { + &self.base_url + } #[must_use] pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { Self { - http: reqwest::Client::new(), + http: build_http_client_or_default(), api_key: api_key.into(), + config, base_url: read_base_url(config), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, @@ -123,9 +153,42 @@ impl OpenAiCompatClient { stream: false, ..request.clone() }; + preflight_message_request(&request)?; let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); - let payload = response.json::().await?; + let body = response.text().await.map_err(ApiError::from)?; + // Some backends return {"error":{"message":"...","type":"...","code":...}} + // instead of a valid completion object. Check for this before attempting + // full deserialization so the user sees the actual error, not a cryptic + // "missing field 'id'" parse failure. + if let Ok(raw) = serde_json::from_str::(&body) { + if let Some(err_obj) = raw.get("error") { + let msg = err_obj + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("provider returned an error") + .to_string(); + let code = err_obj + .get("code") + .and_then(|c| c.as_u64()) + .map(|c| c as u16); + return Err(ApiError::Api { + status: reqwest::StatusCode::from_u16(code.unwrap_or(400)) + .unwrap_or(reqwest::StatusCode::BAD_REQUEST), + error_type: err_obj + .get("type") + .and_then(|t| t.as_str()) + .map(str::to_owned), + message: Some(msg), + request_id, + body, + retryable: false, + }); + } + } + let payload = serde_json::from_str::(&body).map_err(|error| { + ApiError::json_deserialize(self.config.provider_name, &request.model, &body, error) + })?; let mut normalized = normalize_response(&request.model, payload)?; if normalized.request_id.is_none() { normalized.request_id = request_id; @@ -137,13 +200,14 @@ impl OpenAiCompatClient { &self, request: &MessageRequest, ) -> Result { + preflight_message_request(request)?; let response = self .send_with_retry(&request.clone().with_streaming()) .await?; Ok(MessageStream { request_id: request_id_from_headers(response.headers()), response, - parser: OpenAiSseParser::new(), + parser: OpenAiSseParser::with_context(self.config.provider_name, request.model.clone()), pending: VecDeque::new(), done: false, state: StreamState::new(request.model.clone()), @@ -172,7 +236,7 @@ impl OpenAiCompatClient { break retryable_error; } - tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + tokio::time::sleep(self.jittered_backoff_for_attempt(attempts)?).await; }; Err(ApiError::RetriesExhausted { @@ -190,7 +254,7 @@ impl OpenAiCompatClient { .post(&request_url) .header("content-type", "application/json") .bearer_auth(&self.api_key) - .json(&build_chat_completion_request(request)) + .json(&build_chat_completion_request(request, self.config())) .send() .await .map_err(ApiError::from) @@ -208,6 +272,52 @@ impl OpenAiCompatClient { .checked_mul(multiplier) .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) } + + fn jittered_backoff_for_attempt(&self, attempt: u32) -> Result { + let base = self.backoff_for_attempt(attempt)?; + Ok(base + jitter_for_base(base)) + } +} + +/// Process-wide counter that guarantees distinct jitter samples even when +/// the system clock resolution is coarser than consecutive retry sleeps. +static JITTER_COUNTER: AtomicU64 = AtomicU64::new(0); + +/// Returns a random additive jitter in `[0, base]` to decorrelate retries +/// Deserialize a JSON field as a `Vec`, treating an explicit `null` value +/// the same as a missing field (i.e. as an empty vector). +/// Some OpenAI-compatible providers emit `"tool_calls": null` instead of +/// omitting the field or using `[]`, which serde's `#[serde(default)]` alone +/// does not tolerate — `default` only handles absent keys, not null values. +fn deserialize_null_as_empty_vec<'de, D, T>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, + T: serde::Deserialize<'de>, +{ + Ok(Option::>::deserialize(deserializer)?.unwrap_or_default()) +} + +/// from multiple concurrent clients. Entropy is drawn from the nanosecond +/// wall clock mixed with a monotonic counter and run through a splitmix64 +/// finalizer; adequate for retry jitter (no cryptographic requirement). +fn jitter_for_base(base: Duration) -> Duration { + let base_nanos = u64::try_from(base.as_nanos()).unwrap_or(u64::MAX); + if base_nanos == 0 { + return Duration::ZERO; + } + let raw_nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|elapsed| u64::try_from(elapsed.as_nanos()).unwrap_or(u64::MAX)) + .unwrap_or(0); + let tick = JITTER_COUNTER.fetch_add(1, Ordering::Relaxed); + let mut mixed = raw_nanos + .wrapping_add(tick) + .wrapping_add(0x9E37_79B9_7F4A_7C15); + mixed = (mixed ^ (mixed >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + mixed = (mixed ^ (mixed >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + mixed ^= mixed >> 31; + let jitter_nanos = mixed % base_nanos.saturating_add(1); + Duration::from_nanos(jitter_nanos) } impl Provider for OpenAiCompatClient { @@ -251,7 +361,7 @@ impl MessageStream { } if self.done { - self.pending.extend(self.state.finish()); + self.pending.extend(self.state.finish()?); if let Some(event) = self.pending.pop_front() { return Ok(Some(event)); } @@ -261,7 +371,7 @@ impl MessageStream { match self.response.chunk().await? { Some(chunk) => { for parsed in self.parser.push(&chunk)? { - self.pending.extend(self.state.ingest_chunk(parsed)); + self.pending.extend(self.state.ingest_chunk(parsed)?); } } None => { @@ -275,11 +385,17 @@ impl MessageStream { #[derive(Debug, Default)] struct OpenAiSseParser { buffer: Vec, + provider: String, + model: String, } impl OpenAiSseParser { - fn new() -> Self { - Self::default() + fn with_context(provider: impl Into, model: impl Into) -> Self { + Self { + buffer: Vec::new(), + provider: provider.into(), + model: model.into(), + } } fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { @@ -287,7 +403,7 @@ impl OpenAiSseParser { let mut events = Vec::new(); while let Some(frame) = next_sse_frame(&mut self.buffer) { - if let Some(event) = parse_sse_frame(&frame)? { + if let Some(event) = parse_sse_frame(&frame, &self.provider, &self.model)? { events.push(event); } } @@ -296,8 +412,8 @@ impl OpenAiSseParser { } } -#[derive(Debug)] #[allow(clippy::struct_excessive_bools)] +#[derive(Debug)] struct StreamState { model: String, message_started: bool, @@ -323,7 +439,7 @@ impl StreamState { } } - fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Vec { + fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { let mut events = Vec::new(); if !self.message_started { self.message_started = true; @@ -378,7 +494,7 @@ impl StreamState { state.apply(tool_call); let block_index = state.block_index(); if !state.started { - if let Some(start_event) = state.start_event() { + if let Some(start_event) = state.start_event()? { state.started = true; events.push(StreamEvent::ContentBlockStart(start_event)); } else { @@ -411,12 +527,12 @@ impl StreamState { } } - events + Ok(events) } - fn finish(&mut self) -> Vec { + fn finish(&mut self) -> Result, ApiError> { if self.finished { - return Vec::new(); + return Ok(Vec::new()); } self.finished = true; @@ -430,7 +546,7 @@ impl StreamState { for state in self.tool_calls.values_mut() { if !state.started { - if let Some(start_event) = state.start_event() { + if let Some(start_event) = state.start_event()? { state.started = true; events.push(StreamEvent::ContentBlockStart(start_event)); if let Some(delta_event) = state.delta_event() { @@ -465,7 +581,7 @@ impl StreamState { })); events.push(StreamEvent::MessageStop(MessageStopEvent {})); } - events + Ok(events) } } @@ -498,20 +614,23 @@ impl ToolCallState { self.openai_index + 1 } - fn start_event(&self) -> Option { - let name = self.name.clone()?; + #[allow(clippy::unnecessary_wraps)] + fn start_event(&self) -> Result, ApiError> { + let Some(name) = self.name.clone() else { + return Ok(None); + }; let id = self .id .clone() .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); - Some(ContentBlockStartEvent { + Ok(Some(ContentBlockStartEvent { index: self.block_index(), content_block: OutputContentBlock::ToolUse { id, name, input: json!({}), }, - }) + })) } fn delta_event(&mut self) -> Option { @@ -596,7 +715,7 @@ struct ChunkChoice { struct ChunkDelta { #[serde(default)] content: Option, - #[serde(default)] + #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")] tool_calls: Vec, } @@ -630,7 +749,44 @@ struct ErrorBody { message: Option, } -fn build_chat_completion_request(request: &MessageRequest) -> Value { +/// Returns true for models known to reject tuning parameters like temperature, +/// top_p, frequency_penalty, and presence_penalty. These are typically +/// reasoning/chain-of-thought models with fixed sampling. +fn is_reasoning_model(model: &str) -> bool { + let lowered = model.to_ascii_lowercase(); + // Strip any provider/ prefix for the check (e.g. qwen/qwen-qwq -> qwen-qwq) + let canonical = lowered.rsplit('/').next().unwrap_or(lowered.as_str()); + // OpenAI reasoning models + canonical.starts_with("o1") + || canonical.starts_with("o3") + || canonical.starts_with("o4") + // xAI reasoning: grok-3-mini always uses reasoning mode + || canonical == "grok-3-mini" + // Alibaba DashScope reasoning variants (QwQ + Qwen3-Thinking family) + || canonical.starts_with("qwen-qwq") + || canonical.starts_with("qwq") + || canonical.contains("thinking") +} + +/// Strip routing prefix (e.g., "openai/gpt-4" → "gpt-4") for the wire. +/// The prefix is used only to select transport; the backend expects the +/// bare model id. +fn strip_routing_prefix(model: &str) -> &str { + if let Some(pos) = model.find('/') { + let prefix = &model[..pos]; + // Only strip if the prefix before "/" is a known routing prefix, + // not if "/" appears in the middle of the model name for other reasons. + if matches!(prefix, "openai" | "xai" | "grok" | "qwen") { + &model[pos + 1..] + } else { + model + } + } else { + model + } +} + +fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value { let mut messages = Vec::new(); if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { messages.push(json!({ @@ -641,14 +797,38 @@ fn build_chat_completion_request(request: &MessageRequest) -> Value { for message in &request.messages { messages.extend(translate_message(message)); } + // Sanitize: drop any `role:"tool"` message that does not have a valid + // paired `role:"assistant"` with a `tool_calls` entry carrying the same + // `id` immediately before it (directly or as part of a run of tool + // results). OpenAI-compatible backends return 400 for orphaned tool + // messages regardless of how they were produced (compaction, session + // editing, resume, etc.). We drop rather than error so the request can + // still proceed with the remaining history intact. + messages = sanitize_tool_message_pairing(messages); + + // Strip routing prefix (e.g., "openai/gpt-4" → "gpt-4") for the wire. + let wire_model = strip_routing_prefix(&request.model); + + // gpt-5* requires `max_completion_tokens`; older OpenAI models accept both. + // We send the correct field based on the wire model name so gpt-5.x requests + // don't fail with "unknown field max_tokens". + let max_tokens_key = if wire_model.starts_with("gpt-5") { + "max_completion_tokens" + } else { + "max_tokens" + }; let mut payload = json!({ - "model": request.model, - "max_tokens": request.max_tokens, + "model": wire_model, + max_tokens_key: request.max_tokens, "messages": messages, "stream": request.stream, }); + if request.stream && should_request_stream_usage(config) { + payload["stream_options"] = json!({ "include_usage": true }); + } + if let Some(tools) = &request.tools { payload["tools"] = Value::Array(tools.iter().map(openai_tool_definition).collect::>()); @@ -657,6 +837,34 @@ fn build_chat_completion_request(request: &MessageRequest) -> Value { payload["tool_choice"] = openai_tool_choice(tool_choice); } + // OpenAI-compatible tuning parameters — only included when explicitly set. + // Reasoning models (o1/o3/o4/grok-3-mini) reject these params with 400; + // silently strip them to avoid cryptic provider errors. + if !is_reasoning_model(&request.model) { + if let Some(temperature) = request.temperature { + payload["temperature"] = json!(temperature); + } + if let Some(top_p) = request.top_p { + payload["top_p"] = json!(top_p); + } + if let Some(frequency_penalty) = request.frequency_penalty { + payload["frequency_penalty"] = json!(frequency_penalty); + } + if let Some(presence_penalty) = request.presence_penalty { + payload["presence_penalty"] = json!(presence_penalty); + } + } + // stop is generally safe for all providers + if let Some(stop) = &request.stop { + if !stop.is_empty() { + payload["stop"] = json!(stop); + } + } + // reasoning_effort for OpenAI-compatible reasoning models (o4-mini, o3, etc.) + if let Some(effort) = &request.reasoning_effort { + payload["reasoning_effort"] = json!(effort); + } + payload } @@ -677,24 +885,21 @@ fn translate_message(message: &InputMessage) -> Vec { } })), InputContentBlock::ToolResult { .. } => {} - InputContentBlock::Thinking { thinking, .. } => { - text.push_str("\n"); - text.push_str(thinking); - text.push_str("\n\n"); - } - InputContentBlock::RedactedThinking { .. } => { - text.push_str("\n\n\n"); - } } } if text.is_empty() && tool_calls.is_empty() { Vec::new() } else { - vec![json!({ + let mut msg = serde_json::json!({ "role": "assistant", "content": (!text.is_empty()).then_some(text), - "tool_calls": tool_calls, - })] + }); + // Only include tool_calls when non-empty: some providers reject + // assistant messages with an explicit empty tool_calls array. + if !tool_calls.is_empty() { + msg["tool_calls"] = json!(tool_calls); + } + vec![msg] } } _ => message @@ -715,14 +920,81 @@ fn translate_message(message: &InputMessage) -> Vec { "content": flatten_tool_result_content(content), "is_error": is_error, })), - InputContentBlock::ToolUse { .. } - | InputContentBlock::Thinking { .. } - | InputContentBlock::RedactedThinking { .. } => None, + InputContentBlock::ToolUse { .. } => None, }) .collect(), } } +/// Remove `role:"tool"` messages from `messages` that have no valid paired +/// `role:"assistant"` message with a matching `tool_calls[].id` immediately +/// preceding them. This is a last-resort safety net at the request-building +/// layer — the compaction boundary fix (6e301c8) prevents the most common +/// producer path, but resume, session editing, or future compaction variants +/// could still create orphaned tool messages. +/// +/// Algorithm: scan left-to-right. For each `role:"tool"` message, check the +/// immediately preceding non-tool message. If it's `role:"assistant"` with a +/// `tool_calls` array containing an entry whose `id` matches the tool +/// message's `tool_call_id`, the pair is valid and both are kept. Otherwise +/// the tool message is dropped. +fn sanitize_tool_message_pairing(messages: Vec) -> Vec { + // Collect indices of tool messages that are orphaned. + let mut drop_indices = std::collections::HashSet::new(); + for (i, msg) in messages.iter().enumerate() { + if msg.get("role").and_then(|v| v.as_str()) != Some("tool") { + continue; + } + let tool_call_id = msg + .get("tool_call_id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + // Find the nearest preceding non-tool message. + let preceding = messages[..i] + .iter() + .rev() + .find(|m| m.get("role").and_then(|v| v.as_str()) != Some("tool")); + // A tool message is considered paired when: + // (a) the nearest preceding non-tool message is an assistant message + // whose `tool_calls` array contains an entry with the matching id, OR + // (b) there's no clear preceding context (e.g. the message comes right + // after a user turn — this can happen with translated mixed-content + // user messages). In case (b) we allow the message through rather + // than silently dropping potentially valid history. + let preceding_role = preceding + .and_then(|m| m.get("role")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + // Only apply sanitization when the preceding message is an assistant + // turn (the invariant is: assistant-with-tool_calls must precede tool). + // If the preceding is something else (user, system) don't drop — it + // may be a valid translation artifact or a path we don't understand. + if preceding_role != "assistant" { + continue; + } + let paired = preceding + .and_then(|m| m.get("tool_calls").and_then(|tc| tc.as_array())) + .map(|tool_calls| { + tool_calls + .iter() + .any(|tc| tc.get("id").and_then(|v| v.as_str()) == Some(tool_call_id)) + }) + .unwrap_or(false); + if !paired { + drop_indices.insert(i); + } + } + if drop_indices.is_empty() { + return messages; + } + messages + .into_iter() + .enumerate() + .filter(|(i, _)| !drop_indices.contains(i)) + .map(|(_, m)| m) + .collect() +} + fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { content .iter() @@ -734,13 +1006,45 @@ fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { .join("\n") } +/// Recursively ensure every object-type node in a JSON Schema has +/// `"properties"` (at least `{}`) and `"additionalProperties": false`. +/// The OpenAI `/responses` endpoint validates schemas strictly and rejects +/// objects that omit these fields; `/chat/completions` is lenient but also +/// accepts them, so we normalise unconditionally. +fn normalize_object_schema(schema: &mut Value) { + if let Some(obj) = schema.as_object_mut() { + if obj.get("type").and_then(Value::as_str) == Some("object") { + obj.entry("properties").or_insert_with(|| json!({})); + obj.entry("additionalProperties") + .or_insert(Value::Bool(false)); + } + // Recurse into properties values + if let Some(props) = obj.get_mut("properties") { + if let Some(props_obj) = props.as_object_mut() { + let keys: Vec = props_obj.keys().cloned().collect(); + for k in keys { + if let Some(v) = props_obj.get_mut(&k) { + normalize_object_schema(v); + } + } + } + } + // Recurse into items (arrays) + if let Some(items) = obj.get_mut("items") { + normalize_object_schema(items); + } + } +} + fn openai_tool_definition(tool: &ToolDefinition) -> Value { + let mut parameters = tool.input_schema.clone(); + normalize_object_schema(&mut parameters); json!({ "type": "function", "function": { "name": tool.name, "description": tool.description, - "parameters": tool.input_schema, + "parameters": parameters, } }) } @@ -756,6 +1060,10 @@ fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { } } +fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool { + matches!(config.provider_name, "OpenAI") +} + fn normalize_response( model: &str, response: ChatCompletionResponse, @@ -827,7 +1135,11 @@ fn next_sse_frame(buffer: &mut Vec) -> Option { Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) } -fn parse_sse_frame(frame: &str) -> Result, ApiError> { +fn parse_sse_frame( + frame: &str, + provider: &str, + model: &str, +) -> Result, ApiError> { let trimmed = frame.trim(); if trimmed.is_empty() { return Ok(None); @@ -849,15 +1161,44 @@ fn parse_sse_frame(frame: &str) -> Result, ApiError> if payload == "[DONE]" { return Ok(None); } - serde_json::from_str(&payload) + // Some backends embed an error object in a data: frame instead of using an + // HTTP error status. Surface the error message directly rather than letting + // ChatCompletionChunk deserialization fail with a cryptic 'missing field' error. + if let Ok(raw) = serde_json::from_str::(&payload) { + if let Some(err_obj) = raw.get("error") { + let msg = err_obj + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("provider returned an error in stream") + .to_string(); + let code = err_obj + .get("code") + .and_then(|c| c.as_u64()) + .map(|c| c as u16); + let status = reqwest::StatusCode::from_u16(code.unwrap_or(400)) + .unwrap_or(reqwest::StatusCode::BAD_REQUEST); + return Err(ApiError::Api { + status, + error_type: err_obj + .get("type") + .and_then(|t| t.as_str()) + .map(str::to_owned), + message: Some(msg), + request_id: None, + body: payload.to_string(), + retryable: false, + }); + } + } + serde_json::from_str::(&payload) .map(Some) - .map_err(ApiError::from) + .map_err(|error| ApiError::json_deserialize(provider, model, &payload, error)) } fn read_env_non_empty(key: &str) -> Result, ApiError> { match std::env::var(key) { Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(super::dotenv_value(key)), Err(error) => Err(ApiError::from(error)), } } @@ -898,6 +1239,7 @@ async fn expect_success(response: reqwest::Response) -> Result(&body).ok(); let retryable = is_retryable_status(status); @@ -910,6 +1252,7 @@ async fn expect_success(response: reqwest::Response) -> Result, + #[serde(default, deserialize_with = "deserialize_null_as_empty_vec")] + tool_calls: Vec, + } + let delta: Delta = serde_json::from_str(json) + .expect("delta with tool_calls:null must deserialize without error"); + assert!( + delta.tool_calls.is_empty(), + "tool_calls:null must produce an empty vec, not an error" + ); + } + + /// Regression: when building a multi-turn request where a prior assistant + /// turn has no tool calls, the serialized assistant message must NOT include + /// `tool_calls: []`. Some providers reject requests that carry an empty + /// tool_calls array on assistant turns (gaebal-gajae repro 2026-04-09). + #[test] + fn assistant_message_without_tool_calls_omits_tool_calls_field() { + use crate::types::{InputContentBlock, InputMessage}; + + let request = MessageRequest { + model: "gpt-4o".to_string(), + max_tokens: 100, + messages: vec![InputMessage { + role: "assistant".to_string(), + content: vec![InputContentBlock::Text { + text: "Hello".to_string(), + }], + }], + stream: false, + ..Default::default() + }; + let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); + let messages = payload["messages"].as_array().unwrap(); + let assistant_msg = messages + .iter() + .find(|m| m["role"] == "assistant") + .expect("assistant message must be present"); + assert!( + assistant_msg.get("tool_calls").is_none(), + "assistant message without tool calls must omit tool_calls field: {:?}", + assistant_msg + ); + } + + /// Regression: assistant messages WITH tool calls must still include + /// the tool_calls array (normal multi-turn tool-use flow). + #[test] + fn assistant_message_with_tool_calls_includes_tool_calls_field() { + use crate::types::{InputContentBlock, InputMessage}; + + let request = MessageRequest { + model: "gpt-4o".to_string(), + max_tokens: 100, + messages: vec![InputMessage { + role: "assistant".to_string(), + content: vec![InputContentBlock::ToolUse { + id: "call_1".to_string(), + name: "read_file".to_string(), + input: serde_json::json!({"path": "/tmp/test"}), + }], + }], + stream: false, + ..Default::default() + }; + let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); + let messages = payload["messages"].as_array().unwrap(); + let assistant_msg = messages + .iter() + .find(|m| m["role"] == "assistant") + .expect("assistant message must be present"); + let tool_calls = assistant_msg + .get("tool_calls") + .expect("assistant message with tool calls must include tool_calls field"); + assert!(tool_calls.is_array()); + assert_eq!(tool_calls.as_array().unwrap().len(), 1); + } + + /// Orphaned tool messages (no preceding assistant tool_calls) must be + /// dropped by the request-builder sanitizer. Regression for the second + /// layer of the tool-pairing invariant fix (gaebal-gajae 2026-04-10). + #[test] + fn sanitize_drops_orphaned_tool_messages() { + use super::sanitize_tool_message_pairing; + + // Valid pair: assistant with tool_calls → tool result + let valid = vec![ + json!({"role": "assistant", "content": null, "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "search", "arguments": "{}"}}]}), + json!({"role": "tool", "tool_call_id": "call_1", "content": "result"}), + ]; + let out = sanitize_tool_message_pairing(valid); + assert_eq!(out.len(), 2, "valid pair must be preserved"); + + // Orphaned tool message: no preceding assistant tool_calls + let orphaned = vec![ + json!({"role": "assistant", "content": "hi"}), + json!({"role": "tool", "tool_call_id": "call_2", "content": "orphaned"}), + ]; + let out = sanitize_tool_message_pairing(orphaned); + assert_eq!(out.len(), 1, "orphaned tool message must be dropped"); + assert_eq!(out[0]["role"], json!("assistant")); + + // Mismatched tool_call_id + let mismatched = vec![ + json!({"role": "assistant", "content": null, "tool_calls": [{"id": "call_3", "type": "function", "function": {"name": "f", "arguments": "{}"}}]}), + json!({"role": "tool", "tool_call_id": "call_WRONG", "content": "bad"}), + ]; + let out = sanitize_tool_message_pairing(mismatched); + assert_eq!(out.len(), 1, "tool message with wrong id must be dropped"); + + // Two tool results both valid (same preceding assistant) + let two_results = vec![ + json!({"role": "assistant", "content": null, "tool_calls": [ + {"id": "call_a", "type": "function", "function": {"name": "fa", "arguments": "{}"}}, + {"id": "call_b", "type": "function", "function": {"name": "fb", "arguments": "{}"}} + ]}), + json!({"role": "tool", "tool_call_id": "call_a", "content": "ra"}), + json!({"role": "tool", "tool_call_id": "call_b", "content": "rb"}), + ]; + let out = sanitize_tool_message_pairing(two_results); + assert_eq!(out.len(), 3, "both valid tool results must be preserved"); + } + + #[test] + fn non_gpt5_uses_max_tokens() { + // Older OpenAI models expect `max_tokens`; verify gpt-4o is unaffected. + let request = MessageRequest { + model: "gpt-4o".to_string(), + max_tokens: 512, + messages: vec![], + stream: false, + ..Default::default() + }; + let payload = build_chat_completion_request(&request, OpenAiCompatConfig::openai()); + assert_eq!(payload["max_tokens"], json!(512)); + assert!( + payload.get("max_completion_tokens").is_none(), + "gpt-4o must not emit max_completion_tokens" + ); + } } diff --git a/crates/api/src/sse.rs b/crates/api/src/sse.rs index 44ac73d..551dfd6 100644 --- a/crates/api/src/sse.rs +++ b/crates/api/src/sse.rs @@ -1,11 +1,11 @@ use crate::error::ApiError; use crate::types::StreamEvent; -use serde_json::Value; -use reqwest::StatusCode; #[derive(Debug, Default)] pub struct SseParser { buffer: Vec, + provider: Option, + model: Option, } impl SseParser { @@ -14,12 +14,23 @@ impl SseParser { Self::default() } + /// Attach the provider name and model to this parser so that JSON + /// deserialization failures within streamed frames carry enough context + /// for callers to understand which upstream produced the unparseable + /// payload. + #[must_use] + pub fn with_context(mut self, provider: impl Into, model: impl Into) -> Self { + self.provider = Some(provider.into()); + self.model = Some(model.into()); + self + } + pub fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { self.buffer.extend_from_slice(chunk); let mut events = Vec::new(); while let Some(frame) = self.next_frame() { - if let Some(event) = parse_frame(&frame)? { + if let Some(event) = self.parse_frame_with_context(&frame)? { events.push(event); } } @@ -33,12 +44,18 @@ impl SseParser { } let trailing = std::mem::take(&mut self.buffer); - match parse_frame(&String::from_utf8_lossy(&trailing))? { + match self.parse_frame_with_context(&String::from_utf8_lossy(&trailing))? { Some(event) => Ok(vec![event]), None => Ok(Vec::new()), } } + fn parse_frame_with_context(&self, frame: &str) -> Result, ApiError> { + let provider = self.provider.as_deref().unwrap_or("unknown"); + let model = self.model.as_deref().unwrap_or("unknown"); + parse_frame_with_provider(frame, provider, model) + } + fn next_frame(&mut self) -> Option { let separator = self .buffer @@ -63,6 +80,14 @@ impl SseParser { } pub fn parse_frame(frame: &str) -> Result, ApiError> { + parse_frame_with_provider(frame, "unknown", "unknown") +} + +pub(crate) fn parse_frame_with_provider( + frame: &str, + provider: &str, + model: &str, +) -> Result, ApiError> { let trimmed = frame.trim(); if trimmed.is_empty() { return Ok(None); @@ -97,75 +122,9 @@ pub fn parse_frame(frame: &str) -> Result, ApiError> { return Ok(None); } - if matches!(event_name, Some("error")) { - return Err(parse_error_event(&payload)); - } - - // Some "Anthropic-compatible" gateways put the event type in the SSE `event:` field, - // and omit the `{ "type": ... }` discriminator from the JSON `data:` payload. - // Our Rust enums are tagged with `#[serde(tag = "type")]`, so we synthesize it here. - match serde_json::from_str::(&payload) { - Ok(event) => Ok(Some(event)), - Err(error) => { - // Best-effort: if we have an SSE event name and the payload is a JSON object - // without a `type` field, inject it and retry. - let Some(event_name) = event_name else { - return Err(ApiError::from(error)); - }; - let Ok(Value::Object(mut object)) = serde_json::from_str::(&payload) else { - return Err(ApiError::from(error)); - }; - if object - .get("type") - .and_then(Value::as_str) - .is_some_and(|value| value == "error") - { - return Err(parse_error_object(&object, payload)); - } - if object.contains_key("type") { - return Err(ApiError::from(error)); - } - object.insert("type".to_string(), Value::String(event_name.to_string())); - serde_json::from_value::(Value::Object(object)) - .map(Some) - .map_err(ApiError::from) - } - } -} - -fn parse_error_event(payload: &str) -> ApiError { - match serde_json::from_str::(payload) { - Ok(Value::Object(object)) => parse_error_object(&object, payload.to_string()), - _ => ApiError::Api { - status: StatusCode::BAD_GATEWAY, - error_type: Some("stream_error".to_string()), - message: Some(payload.to_string()), - body: payload.to_string(), - retryable: false, - }, - } -} - -fn parse_error_object(object: &serde_json::Map, body: String) -> ApiError { - let nested = object.get("error").and_then(Value::as_object); - let error_type = nested - .and_then(|error| error.get("type")) - .or_else(|| object.get("type")) - .and_then(Value::as_str) - .map(ToOwned::to_owned); - let message = nested - .and_then(|error| error.get("message")) - .or_else(|| object.get("message")) - .and_then(Value::as_str) - .map(ToOwned::to_owned); - - ApiError::Api { - status: StatusCode::BAD_GATEWAY, - error_type, - message, - body, - retryable: false, - } + serde_json::from_str::(&payload) + .map(Some) + .map_err(|error| ApiError::json_deserialize(provider, model, &payload, error)) } #[cfg(test)] @@ -263,26 +222,6 @@ mod tests { assert_eq!(event, None); } - #[test] - fn parses_event_name_when_payload_omits_type() { - let frame = concat!("event: message_stop\n", "data: {}\n\n"); - let event = parse_frame(frame).expect("frame should parse"); - assert_eq!(event, Some(StreamEvent::MessageStop(crate::types::MessageStopEvent {}))); - } - - #[test] - fn surfaces_stream_error_events() { - let frame = concat!( - "event: error\n", - "data: {\"error\":{\"type\":\"invalid_request_error\",\"message\":\"bad input\"}}\n\n" - ); - let error = parse_frame(frame).expect_err("error frame should surface"); - assert_eq!( - error.to_string(), - "api returned 502 Bad Gateway (invalid_request_error): bad input" - ); - } - #[test] fn parses_split_json_across_data_lines() { let frame = concat!( @@ -364,4 +303,28 @@ mod tests { )) ); } + + #[test] + fn given_message_delta_frame_with_empty_usage_when_parsed_then_usage_defaults_to_zero() { + // given + let frame = concat!( + "event: message_delta\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{}}\n\n" + ); + + // when + let event = parse_frame(frame).expect("frame should parse"); + + // then + assert_eq!( + event, + Some(StreamEvent::MessageDelta(crate::types::MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + }, + usage: Usage::default(), + })) + ); + } } diff --git a/crates/api/src/types.rs b/crates/api/src/types.rs index 80c6ead..e136a76 100644 --- a/crates/api/src/types.rs +++ b/crates/api/src/types.rs @@ -1,7 +1,8 @@ +use runtime::{pricing_for_model, TokenUsage, UsageCostEstimate}; use serde::{Deserialize, Serialize}; use serde_json::Value; -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] pub struct MessageRequest { pub model: String, pub max_tokens: u32, @@ -14,6 +15,22 @@ pub struct MessageRequest { pub tool_choice: Option, #[serde(default, skip_serializing_if = "std::ops::Not::not")] pub stream: bool, + /// OpenAI-compatible tuning parameters. Optional — omitted from payload when None. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + /// Reasoning effort level for OpenAI-compatible reasoning models (e.g. `o4-mini`). + /// Accepted values: `"low"`, `"medium"`, `"high"`. Omitted when `None`. + /// Silently ignored by backends that do not support it. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, } impl MessageRequest { @@ -75,14 +92,6 @@ pub enum InputContentBlock { #[serde(default, skip_serializing_if = "std::ops::Not::not")] is_error: bool, }, - Thinking { - thinking: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - signature: Option, - }, - RedactedThinking { - data: Value, - }, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -120,6 +129,7 @@ pub struct MessageResponse { pub stop_reason: Option, #[serde(default)] pub stop_sequence: Option, + #[serde(default)] pub usage: Usage, #[serde(default)] pub request_id: Option, @@ -154,20 +164,44 @@ pub enum OutputContentBlock { }, } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct Usage { + #[serde(default)] pub input_tokens: u32, #[serde(default)] pub cache_creation_input_tokens: u32, #[serde(default)] pub cache_read_input_tokens: u32, + #[serde(default)] pub output_tokens: u32, } impl Usage { #[must_use] pub const fn total_tokens(&self) -> u32 { - self.input_tokens + self.output_tokens + self.input_tokens + + self.output_tokens + + self.cache_creation_input_tokens + + self.cache_read_input_tokens + } + + #[must_use] + pub const fn token_usage(&self) -> TokenUsage { + TokenUsage { + input_tokens: self.input_tokens, + output_tokens: self.output_tokens, + cache_creation_input_tokens: self.cache_creation_input_tokens, + cache_read_input_tokens: self.cache_read_input_tokens, + } + } + + #[must_use] + pub fn estimated_cost_usd(&self, model: &str) -> UsageCostEstimate { + let usage = self.token_usage(); + pricing_for_model(model).map_or_else( + || usage.estimate_cost_usd(), + |pricing| usage.estimate_cost_usd_with_pricing(pricing), + ) } } @@ -179,6 +213,7 @@ pub struct MessageStartEvent { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct MessageDeltaEvent { pub delta: MessageDelta, + #[serde(default)] pub usage: Usage, } @@ -229,3 +264,47 @@ pub enum StreamEvent { ContentBlockStop(ContentBlockStopEvent), MessageStop(MessageStopEvent), } + +#[cfg(test)] +mod tests { + use runtime::format_usd; + + use super::{MessageResponse, Usage}; + + #[test] + fn usage_total_tokens_includes_cache_tokens() { + let usage = Usage { + input_tokens: 10, + cache_creation_input_tokens: 2, + cache_read_input_tokens: 3, + output_tokens: 4, + }; + + assert_eq!(usage.total_tokens(), 19); + assert_eq!(usage.token_usage().total_tokens(), 19); + } + + #[test] + fn message_response_estimates_cost_from_model_usage() { + let response = MessageResponse { + id: "msg_cost".to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: "claude-sonnet-4-20250514".to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1_000_000, + cache_creation_input_tokens: 100_000, + cache_read_input_tokens: 200_000, + output_tokens: 500_000, + }, + request_id: None, + }; + + let cost = response.usage.estimated_cost_usd(&response.model); + assert_eq!(format_usd(cost.total_cost_usd()), "$54.6750"); + assert_eq!(response.total_tokens(), 1_800_000); + } +} diff --git a/crates/api/tests/client_integration.rs b/crates/api/tests/client_integration.rs index 3b6a3c3..512e346 100644 --- a/crates/api/tests/client_integration.rs +++ b/crates/api/tests/client_integration.rs @@ -1,17 +1,27 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; use std::time::Duration; use api::{ - ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, + AnthropicClient, ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, - OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, + OutputContentBlock, PromptCache, PromptCacheConfig, ProviderClient, StreamEvent, ToolChoice, + ToolDefinition, }; use serde_json::json; +use telemetry::{ClientIdentity, MemoryTelemetrySink, SessionTracer, TelemetryEvent}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::sync::Mutex; +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) +} + #[tokio::test] async fn send_message_posts_json_and_parses_response() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -20,8 +30,8 @@ async fn send_message_posts_json_and_parses_response() { "\"id\":\"msg_test\",", "\"type\":\"message\",", "\"role\":\"assistant\",", - "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claw\"}],", - "\"model\":\"claude-sonnet-4-6\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],", + "\"model\":\"claude-3-7-sonnet-latest\",", "\"stop_reason\":\"end_turn\",", "\"stop_sequence\":null,", "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},", @@ -45,10 +55,12 @@ async fn send_message_posts_json_and_parses_response() { assert_eq!(response.id, "msg_test"); assert_eq!(response.total_tokens(), 16); assert_eq!(response.request_id.as_deref(), Some("req_body_123")); + assert_eq!(response.usage.cache_creation_input_tokens, 0); + assert_eq!(response.usage.cache_read_input_tokens, 0); assert_eq!( response.content, vec![OutputContentBlock::Text { - text: "Hello from Claw".to_string(), + text: "Hello from Claude".to_string(), }] ); @@ -64,23 +76,258 @@ async fn send_message_posts_json_and_parses_response() { request.headers.get("authorization").map(String::as_str), Some("Bearer proxy-token") ); + assert_eq!( + request.headers.get("anthropic-version").map(String::as_str), + Some("2023-06-01") + ); + assert_eq!( + request.headers.get("user-agent").map(String::as_str), + Some("claude-code/0.1.0") + ); + assert_eq!( + request.headers.get("anthropic-beta").map(String::as_str), + Some("claude-code-20250219,prompt-caching-scope-2026-01-05") + ); let body: serde_json::Value = serde_json::from_str(&request.body).expect("request body should be json"); assert_eq!( body.get("model").and_then(serde_json::Value::as_str), - Some("claude-sonnet-4-6") + Some("claude-3-7-sonnet-latest") ); assert!(body.get("stream").is_none()); assert_eq!(body["tools"][0]["name"], json!("get_weather")); assert_eq!(body["tool_choice"]["type"], json!("auto")); + assert!( + body.get("betas").is_none(), + "betas must travel via the anthropic-beta header, not the request body" + ); } #[tokio::test] +async fn send_message_blocks_oversized_requests_before_the_http_call() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", "{}")], + ) + .await; + + let client = AnthropicClient::new("test-key").with_base_url(server.base_url()); + let error = client + .send_message(&MessageRequest { + model: "claude-sonnet-4-6".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(600_000), + }], + }], + system: Some("Keep the answer short.".to_string()), + tools: None, + tool_choice: None, + stream: false, + ..Default::default() + }) + .await + .expect_err("oversized request should fail local context-window preflight"); + + assert!(matches!(error, ApiError::ContextWindowExceeded { .. })); + assert!( + state.lock().await.is_empty(), + "preflight failure should avoid any upstream HTTP request" + ); +} + +#[tokio::test] +async fn send_message_applies_request_profile_and_records_telemetry() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "application/json", + concat!( + "{", + "\"id\":\"msg_profile\",", + "\"type\":\"message\",", + "\"role\":\"assistant\",", + "\"content\":[{\"type\":\"text\",\"text\":\"ok\"}],", + "\"model\":\"claude-3-7-sonnet-latest\",", + "\"stop_reason\":\"end_turn\",", + "\"stop_sequence\":null,", + "\"usage\":{\"input_tokens\":1,\"cache_creation_input_tokens\":2,\"cache_read_input_tokens\":3,\"output_tokens\":1}", + "}" + ), + &[("request-id", "req_profile_123")], + )], + ) + .await; + let sink = Arc::new(MemoryTelemetrySink::default()); + + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_client_identity(ClientIdentity::new("claude-code", "9.9.9").with_runtime("rust-cli")) + .with_beta("tools-2026-04-01") + .with_extra_body_param("metadata", json!({"source": "clawd-code"})) + .with_session_tracer(SessionTracer::new("session-telemetry", sink.clone())); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.request_id.as_deref(), Some("req_profile_123")); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!( + request.headers.get("anthropic-beta").map(String::as_str), + Some("claude-code-20250219,prompt-caching-scope-2026-01-05,tools-2026-04-01") + ); + assert_eq!( + request.headers.get("user-agent").map(String::as_str), + Some("claude-code/9.9.9") + ); + let body: serde_json::Value = + serde_json::from_str(&request.body).expect("request body should be json"); + assert_eq!(body["metadata"]["source"], json!("clawd-code")); + assert!( + body.get("betas").is_none(), + "betas must travel via the anthropic-beta header, not the request body" + ); + + let events = sink.events(); + assert_eq!(events.len(), 6); + assert!(matches!( + &events[0], + TelemetryEvent::HttpRequestStarted { + session_id, + attempt: 1, + method, + path, + .. + } if session_id == "session-telemetry" && method == "POST" && path == "/v1/messages" + )); + assert!(matches!( + &events[1], + TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_started" + )); + assert!(matches!( + &events[2], + TelemetryEvent::HttpRequestSucceeded { + request_id, + status: 200, + .. + } if request_id.as_deref() == Some("req_profile_123") + )); + assert!(matches!( + &events[3], + TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_succeeded" + )); + assert!(matches!( + &events[4], + TelemetryEvent::Analytics(event) + if event.namespace == "api" + && event.action == "message_usage" + && event.properties.get("request_id") == Some(&json!("req_profile_123")) + && event.properties.get("total_tokens") == Some(&json!(7)) + && event.properties.get("estimated_cost_usd") == Some(&json!("$0.0001")) + )); + assert!(matches!( + &events[5], + TelemetryEvent::SessionTrace(trace) if trace.name == "analytics" + )); +} + +#[tokio::test] +async fn send_message_parses_prompt_cache_token_usage_from_response() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"msg_cache_tokens\",", + "\"type\":\"message\",", + "\"role\":\"assistant\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],", + "\"model\":\"claude-3-7-sonnet-latest\",", + "\"stop_reason\":\"end_turn\",", + "\"stop_sequence\":null,", + "\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}", + "}" + ); + let server = spawn_server( + state, + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = AnthropicClient::new("test-key").with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.usage.input_tokens, 12); + assert_eq!(response.usage.cache_creation_input_tokens, 321); + assert_eq!(response.usage.cache_read_input_tokens, 654); + assert_eq!(response.usage.output_tokens, 4); +} + +#[tokio::test] +async fn given_empty_usage_object_when_send_message_parses_response_then_usage_defaults_to_zero() { + // given + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"msg_empty_usage\",", + "\"type\":\"message\",", + "\"role\":\"assistant\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],", + "\"model\":\"claude-3-7-sonnet-latest\",", + "\"stop_reason\":\"end_turn\",", + "\"stop_sequence\":null,", + "\"usage\":{}", + "}" + ); + let server = spawn_server( + state, + vec![http_response("200 OK", "application/json", body)], + ) + .await; + let client = AnthropicClient::new("test-key").with_base_url(server.base_url()); + + // when + let response = client + .send_message(&sample_request(false)) + .await + .expect("response with empty usage object should still parse"); + + // then + assert_eq!(response.id, "msg_empty_usage"); + assert_eq!(response.total_tokens(), 0); + assert_eq!(response.usage.input_tokens, 0); + assert_eq!(response.usage.cache_creation_input_tokens, 0); + assert_eq!(response.usage.cache_read_input_tokens, 0); + assert_eq!(response.usage.output_tokens, 0); +} + +#[tokio::test] +#[allow(clippy::await_holding_lock)] async fn stream_message_parses_sse_events_with_tool_use() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-stream-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); let state = Arc::new(Mutex::new(Vec::::new())); let sse = concat!( "event: message_start\n", - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n", "event: content_block_start\n", "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", "event: content_block_delta\n", @@ -88,7 +335,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { "event: content_block_stop\n", "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", "event: message_delta\n", - "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n", "event: message_stop\n", "data: {\"type\":\"message_stop\"}\n\n", "data: [DONE]\n\n" @@ -106,7 +353,8 @@ async fn stream_message_parses_sse_events_with_tool_use() { let client = ApiClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) - .with_base_url(server.base_url()); + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("stream-session")); let mut stream = client .stream_message(&sample_request(false)) .await @@ -160,6 +408,20 @@ async fn stream_message_parses_sse_events_with_tool_use() { let captured = state.lock().await; let request = captured.first().expect("server should capture request"); assert!(request.body.contains("\"stream\":true")); + + let cache_stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(cache_stats.tracked_requests, 1); + assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34)); + assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55)); + assert_eq!( + cache_stats.last_cache_source.as_deref(), + Some("api-response") + ); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[tokio::test] @@ -176,7 +438,7 @@ async fn retries_retryable_failures_before_succeeding() { http_response( "200 OK", "application/json", - "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", ), ], ) @@ -196,28 +458,28 @@ async fn retries_retryable_failures_before_succeeding() { } #[tokio::test] -async fn provider_client_dispatches_api_requests() { +async fn provider_client_dispatches_anthropic_requests() { let state = Arc::new(Mutex::new(Vec::::new())); let server = spawn_server( state.clone(), vec![http_response( "200 OK", "application/json", - "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", )], ) .await; - let client = ProviderClient::from_model_with_default_auth( + let client = ProviderClient::from_model_with_anthropic_auth( "claude-sonnet-4-6", Some(AuthSource::ApiKey("test-key".to_string())), ) - .expect("api provider client should be constructed"); + .expect("anthropic provider client should be constructed"); let client = match client { - ProviderClient::ClawApi(client) => { - ProviderClient::ClawApi(client.with_base_url(server.base_url())) + ProviderClient::Anthropic(client) => { + ProviderClient::Anthropic(client.with_base_url(server.base_url())) } - other => panic!("expected default provider, got {other:?}"), + other => panic!("expected anthropic provider, got {other:?}"), }; let response = client @@ -284,13 +546,194 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { } } +#[tokio::test] +async fn retries_multiple_retryable_failures_with_exponential_backoff_and_jitter() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![ + http_response( + "429 Too Many Requests", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}", + ), + http_response( + "500 Internal Server Error", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"api_error\",\"message\":\"boom\"}}", + ), + http_response( + "503 Service Unavailable", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}", + ), + http_response( + "429 Too Many Requests", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down again\"}}", + ), + http_response( + "503 Service Unavailable", + "application/json", + "{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}", + ), + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_exp_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered after 5\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + ), + ], + ) + .await; + + let client = ApiClient::new("test-key") + .with_base_url(server.base_url()) + .with_retry_policy(8, Duration::from_millis(1), Duration::from_millis(4)); + let started_at = std::time::Instant::now(); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("8-retry policy should absorb 5 retryable failures"); + + let elapsed = started_at.elapsed(); + assert_eq!(response.total_tokens(), 5); + assert_eq!( + state.lock().await.len(), + 6, + "client should issue 1 original + 5 retry requests before the 200" + ); + // Jittered sleeps are bounded by 2 * max_backoff per retry (base + jitter), + // so 5 sleeps fit comfortably below this upper bound with generous slack. + assert!( + elapsed < Duration::from_secs(5), + "retries should complete promptly, took {elapsed:?}" + ); +} + +#[tokio::test] +#[allow(clippy::await_holding_lock)] +async fn send_message_reuses_recent_completion_cache_entries() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}", + )], + ) + .await; + + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("integration-session")); + + let first = client + .send_message(&sample_request(false)) + .await + .expect("first request should succeed"); + let second = client + .send_message(&sample_request(false)) + .await + .expect("second request should reuse cache"); + + assert_eq!(first.content, second.content); + assert_eq!(state.lock().await.len(), 1); + + let cache_stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(cache_stats.completion_cache_hits, 1); + assert_eq!(cache_stats.completion_cache_misses, 1); + assert_eq!(cache_stats.completion_cache_writes, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + +#[tokio::test] +#[allow(clippy::await_holding_lock)] +async fn send_message_tracks_unexpected_prompt_cache_breaks() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-break-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state, + vec![ + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}", + ), + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}", + ), + ], + ) + .await; + + let request = sample_request(false); + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::with_config(PromptCacheConfig { + session_id: "break-session".to_string(), + completion_ttl: Duration::from_secs(0), + ..PromptCacheConfig::default() + })); + + client + .send_message(&request) + .await + .expect("first response should succeed"); + client + .send_message(&request) + .await + .expect("second response should succeed"); + + let cache_stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(cache_stats.unexpected_cache_breaks, 1); + assert_eq!( + cache_stats.last_break_reason.as_deref(), + Some("cache read tokens dropped while prompt fingerprint remained stable") + ); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY and network access"] async fn live_stream_smoke_test() { let client = ApiClient::from_env().expect("ANTHROPIC_API_KEY must be set"); let mut stream = client .stream_message(&MessageRequest { - model: std::env::var("CLAW_MODEL").unwrap_or_else(|_| "claude-sonnet-4-6".to_string()), + model: std::env::var("ANTHROPIC_MODEL") + .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()), max_tokens: 32, messages: vec![InputMessage::user_text( "Reply with exactly: hello from rust", @@ -299,6 +742,7 @@ async fn live_stream_smoke_test() { tools: None, tool_choice: None, stream: false, + ..Default::default() }) .await .expect("live stream should start"); @@ -450,7 +894,7 @@ fn http_response_with_headers( fn sample_request(stream: bool) -> MessageRequest { MessageRequest { - model: "claude-sonnet-4-6".to_string(), + model: "claude-3-7-sonnet-latest".to_string(), max_tokens: 64, messages: vec![InputMessage { role: "user".to_string(), @@ -479,5 +923,6 @@ fn sample_request(stream: bool) -> MessageRequest { }]), tool_choice: Some(ToolChoice::Auto), stream, + ..Default::default() } } diff --git a/crates/api/tests/openai_compat_integration.rs b/crates/api/tests/openai_compat_integration.rs index b345b1f..d5596bb 100644 --- a/crates/api/tests/openai_compat_integration.rs +++ b/crates/api/tests/openai_compat_integration.rs @@ -4,9 +4,10 @@ use std::sync::Arc; use std::sync::{Mutex as StdMutex, OnceLock}; use api::{ - ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, - InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, - OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, + ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, + ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OpenAiCompatClient, OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, + ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -62,6 +63,43 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() { assert_eq!(body["tools"][0]["type"], json!("function")); } +#[tokio::test] +async fn send_message_blocks_oversized_xai_requests_before_the_http_call() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", "{}")], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let error = client + .send_message(&MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64_000, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "x".repeat(300_000), + }], + }], + system: Some("Keep the answer short.".to_string()), + tools: None, + tool_choice: None, + stream: false, + ..Default::default() + }) + .await + .expect_err("oversized request should fail local context-window preflight"); + + assert!(matches!(error, ApiError::ContextWindowExceeded { .. })); + assert!( + state.lock().await.is_empty(), + "preflight failure should avoid any upstream HTTP request" + ); +} + #[tokio::test] async fn send_message_accepts_full_chat_completions_endpoint_override() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -195,6 +233,83 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() { assert!(request.body.contains("\"stream\":true")); } +#[allow(clippy::await_holding_lock)] +#[tokio::test] +async fn openai_streaming_requests_opt_into_usage_chunks() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n", + "data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_openai_stream")], + )], + ) + .await; + + let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai()) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_openai_stream")); + + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::Text { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { .. }, + .. + }) + )); + assert!(matches!( + events[3], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 }) + )); + assert!(matches!( + events[4], + StreamEvent::MessageDelta(MessageDeltaEvent { .. }) + )); + assert!(matches!(events[5], StreamEvent::MessageStop(_))); + + match &events[4] { + StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => { + assert_eq!(usage.input_tokens, 9); + assert_eq!(usage.output_tokens, 4); + } + other => panic!("expected message delta, got {other:?}"), + } + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["stream"], json!(true)); + assert_eq!(body["stream_options"], json!({"include_usage": true})); +} + +#[allow(clippy::await_holding_lock)] #[tokio::test] async fn provider_client_dispatches_xai_requests_from_env() { let _lock = env_lock(); @@ -382,6 +497,7 @@ fn sample_request(stream: bool) -> MessageRequest { }]), tool_choice: Some(ToolChoice::Auto), stream, + ..Default::default() } } @@ -389,7 +505,7 @@ fn env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| StdMutex::new(())) .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) } struct ScopedEnvVar { diff --git a/crates/api/tests/provider_client_integration.rs b/crates/api/tests/provider_client_integration.rs index abeebdd..3d8236e 100644 --- a/crates/api/tests/provider_client_integration.rs +++ b/crates/api/tests/provider_client_integration.rs @@ -22,7 +22,9 @@ fn provider_client_reports_missing_xai_credentials_for_grok_models() { .expect_err("grok requests without XAI_API_KEY should fail fast"); match error { - ApiError::MissingCredentials { provider, env_vars } => { + ApiError::MissingCredentials { + provider, env_vars, .. + } => { assert_eq!(provider, "xAI"); assert_eq!(env_vars, &["XAI_API_KEY"]); } @@ -31,18 +33,18 @@ fn provider_client_reports_missing_xai_credentials_for_grok_models() { } #[test] -fn provider_client_uses_explicit_auth_without_env_lookup() { +fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() { let _lock = env_lock(); - let _api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); - let _auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); - let client = ProviderClient::from_model_with_default_auth( + let client = ProviderClient::from_model_with_anthropic_auth( "claude-sonnet-4-6", - Some(AuthSource::ApiKey("claw-test-key".to_string())), + Some(AuthSource::ApiKey("anthropic-test-key".to_string())), ) - .expect("explicit auth should avoid env lookup"); + .expect("explicit anthropic auth should avoid env lookup"); - assert_eq!(client.provider_kind(), ProviderKind::ClawApi); + assert_eq!(client.provider_kind(), ProviderKind::Anthropic); } #[test] @@ -57,7 +59,7 @@ fn env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) } struct EnvVarGuard { diff --git a/crates/api/tests/proxy_integration.rs b/crates/api/tests/proxy_integration.rs new file mode 100644 index 0000000..7e39069 --- /dev/null +++ b/crates/api/tests/proxy_integration.rs @@ -0,0 +1,173 @@ +use std::ffi::OsString; +use std::sync::{Mutex, OnceLock}; + +use api::{build_http_client_with, ProxyConfig}; + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) +} + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} + +#[test] +fn proxy_config_from_env_reads_uppercase_proxy_vars() { + // given + let _lock = env_lock(); + let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128")); + let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129")); + let _no = EnvVarGuard::set("NO_PROXY", Some("localhost,127.0.0.1")); + let _http_lower = EnvVarGuard::set("http_proxy", None); + let _https_lower = EnvVarGuard::set("https_proxy", None); + let _no_lower = EnvVarGuard::set("no_proxy", None); + + // when + let config = ProxyConfig::from_env(); + + // then + assert_eq!(config.http_proxy.as_deref(), Some("http://proxy.corp:3128")); + assert_eq!( + config.https_proxy.as_deref(), + Some("http://secure.corp:3129") + ); + assert_eq!(config.no_proxy.as_deref(), Some("localhost,127.0.0.1")); + assert!(config.proxy_url.is_none()); + assert!(!config.is_empty()); +} + +#[test] +fn proxy_config_from_env_reads_lowercase_proxy_vars() { + // given + let _lock = env_lock(); + let _http = EnvVarGuard::set("HTTP_PROXY", None); + let _https = EnvVarGuard::set("HTTPS_PROXY", None); + let _no = EnvVarGuard::set("NO_PROXY", None); + let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128")); + let _https_lower = EnvVarGuard::set("https_proxy", Some("http://lower-secure.corp:3129")); + let _no_lower = EnvVarGuard::set("no_proxy", Some(".internal")); + + // when + let config = ProxyConfig::from_env(); + + // then + assert_eq!(config.http_proxy.as_deref(), Some("http://lower.corp:3128")); + assert_eq!( + config.https_proxy.as_deref(), + Some("http://lower-secure.corp:3129") + ); + assert_eq!(config.no_proxy.as_deref(), Some(".internal")); + assert!(!config.is_empty()); +} + +#[test] +fn proxy_config_from_env_is_empty_when_no_vars_set() { + // given + let _lock = env_lock(); + let _http = EnvVarGuard::set("HTTP_PROXY", None); + let _https = EnvVarGuard::set("HTTPS_PROXY", None); + let _no = EnvVarGuard::set("NO_PROXY", None); + let _http_lower = EnvVarGuard::set("http_proxy", None); + let _https_lower = EnvVarGuard::set("https_proxy", None); + let _no_lower = EnvVarGuard::set("no_proxy", None); + + // when + let config = ProxyConfig::from_env(); + + // then + assert!(config.is_empty()); + assert!(config.http_proxy.is_none()); + assert!(config.https_proxy.is_none()); + assert!(config.no_proxy.is_none()); +} + +#[test] +fn proxy_config_from_env_treats_empty_values_as_unset() { + // given + let _lock = env_lock(); + let _http = EnvVarGuard::set("HTTP_PROXY", Some("")); + let _https = EnvVarGuard::set("HTTPS_PROXY", Some("")); + let _http_lower = EnvVarGuard::set("http_proxy", Some("")); + let _https_lower = EnvVarGuard::set("https_proxy", Some("")); + let _no = EnvVarGuard::set("NO_PROXY", Some("")); + let _no_lower = EnvVarGuard::set("no_proxy", Some("")); + + // when + let config = ProxyConfig::from_env(); + + // then + assert!(config.is_empty()); +} + +#[test] +fn build_client_with_env_proxy_config_succeeds() { + // given + let _lock = env_lock(); + let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128")); + let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129")); + let _no = EnvVarGuard::set("NO_PROXY", Some("localhost")); + let _http_lower = EnvVarGuard::set("http_proxy", None); + let _https_lower = EnvVarGuard::set("https_proxy", None); + let _no_lower = EnvVarGuard::set("no_proxy", None); + let config = ProxyConfig::from_env(); + + // when + let result = build_http_client_with(&config); + + // then + assert!(result.is_ok()); +} + +#[test] +fn build_client_with_proxy_url_config_succeeds() { + // given + let config = ProxyConfig::from_proxy_url("http://unified.corp:3128"); + + // when + let result = build_http_client_with(&config); + + // then + assert!(result.is_ok()); +} + +#[test] +fn proxy_config_from_env_prefers_uppercase_over_lowercase() { + // given + let _lock = env_lock(); + let _http_upper = EnvVarGuard::set("HTTP_PROXY", Some("http://upper.corp:3128")); + let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128")); + let _https = EnvVarGuard::set("HTTPS_PROXY", None); + let _https_lower = EnvVarGuard::set("https_proxy", None); + let _no = EnvVarGuard::set("NO_PROXY", None); + let _no_lower = EnvVarGuard::set("no_proxy", None); + + // when + let config = ProxyConfig::from_env(); + + // then + assert_eq!(config.http_proxy.as_deref(), Some("http://upper.corp:3128")); +} diff --git a/crates/claw-cli/src/main.rs b/crates/claw-cli/src/main.rs index 3bde1cc..1998fe1 100644 --- a/crates/claw-cli/src/main.rs +++ b/crates/claw-cli/src/main.rs @@ -16,7 +16,7 @@ use std::thread; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use api::{ - resolve_startup_auth_source, AuthSource, ClawApiClient, ContentBlockDelta, InputContentBlock, + resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; @@ -329,7 +329,7 @@ fn join_optional_args(args: &[String]) -> Option { fn parse_direct_slash_cli_action(rest: &[String]) -> Result { let raw = rest.join(" "); - match SlashCommand::parse(&raw) { + match SlashCommand::parse(&raw).map_err(|e| e.to_string())? { Some(SlashCommand::Help) => Ok(CliAction::Help), Some(SlashCommand::Agents { args }) => Ok(CliAction::Agents { args }), Some(SlashCommand::Skills { args }) => Ok(CliAction::Skills { args }), @@ -484,7 +484,7 @@ fn dump_manifests() { } fn print_bootstrap_plan() { - for phase in runtime::BootstrapPlan::claw_default().phases() { + for phase in runtime::BootstrapPlan::claude_code_default().phases() { println!("- {phase:?}"); } } @@ -541,7 +541,7 @@ fn run_login() -> Result<(), Box> { return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into()); } - let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); let exchange_request = OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri); let runtime = tokio::runtime::Runtime::new()?; @@ -650,7 +650,7 @@ fn resume_session(session_path: &Path, commands: &[String]) { let mut session = session; for raw_command in commands { - let Some(command) = SlashCommand::parse(raw_command) else { + let Ok(Some(command)) = SlashCommand::parse(raw_command) else { eprintln!("unsupported resumed command: {raw_command}"); std::process::exit(2); }; @@ -987,8 +987,6 @@ fn run_resume_command( } SlashCommand::Bughunter { .. } | SlashCommand::Branch { .. } - | SlashCommand::Worktree { .. } - | SlashCommand::CommitPushPr { .. } | SlashCommand::Commit | SlashCommand::Pr { .. } | SlashCommand::Issue { .. } @@ -1000,7 +998,7 @@ fn run_resume_command( | SlashCommand::Permissions { .. } | SlashCommand::Session { .. } | SlashCommand::Plugins { .. } - | SlashCommand::Unknown(_) => Err("unsupported resumed slash command".into()), + | _ => Err("unsupported resumed slash command".into()), } } @@ -1024,7 +1022,7 @@ fn run_repl( cli.persist_session()?; break; } - if let Some(command) = SlashCommand::parse(trimmed) { + if let Ok(Some(command)) = SlashCommand::parse(trimmed) { if cli.handle_repl_command(command)? { cli.persist_session()?; } @@ -1336,24 +1334,14 @@ impl LiveCli { ); false } - SlashCommand::Worktree { .. } => { - eprintln!( - "{}", - render_mode_unavailable("worktree", "git worktree commands") - ); - false - } - SlashCommand::CommitPushPr { .. } => { - eprintln!( - "{}", - render_mode_unavailable("commit-push-pr", "commit + push + PR automation") - ); - false - } SlashCommand::Unknown(name) => { eprintln!("{}", render_unknown_repl_command(&name)); false } + _ => { + eprintln!("command not available in this mode"); + false + } }) } @@ -2505,12 +2493,6 @@ fn render_export_text(session: &Session) -> String { for block in &message.blocks { match block { ContentBlock::Text { text } => lines.push(text.clone()), - ContentBlock::Thinking { thinking, .. } => { - lines.push(format!("[thinking] {thinking}")); - } - ContentBlock::RedactedThinking { .. } => { - lines.push("[thinking] ".to_string()); - } ContentBlock::ToolUse { id, name, input } => { lines.push(format!("[tool_use id={id} name={name}] {input}")); } @@ -2995,7 +2977,7 @@ fn build_runtime( CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()), permission_policy(permission_mode, &tool_registry), system_prompt, - feature_config, + &feature_config, )) } @@ -3047,7 +3029,7 @@ impl runtime::PermissionPrompter for CliPermissionPrompter { struct DefaultRuntimeClient { runtime: tokio::runtime::Runtime, - client: ClawApiClient, + client: AnthropicClient, model: String, enable_tools: bool, emit_output: bool, @@ -3067,7 +3049,7 @@ impl DefaultRuntimeClient { ) -> Result> { Ok(Self { runtime: tokio::runtime::Runtime::new()?, - client: ClawApiClient::from_auth(resolve_cli_auth_source()?) + client: AnthropicClient::from_auth(resolve_cli_auth_source()?) .with_base_url(api::read_base_url()), model, enable_tools, @@ -3105,6 +3087,12 @@ impl ApiClient for DefaultRuntimeClient { .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), tool_choice: self.enable_tools.then_some(ToolChoice::Auto), stream: true, + temperature: None, + top_p: None, + frequency_penalty: None, + presence_penalty: None, + stop: None, + reasoning_effort: None, }; self.runtime.block_on(async { @@ -3173,7 +3161,6 @@ impl ApiClient for DefaultRuntimeClient { .and_then(|()| out.flush()) .map_err(|error| RuntimeError::new(error.to_string()))?; } - events.push(AssistantEvent::ThinkingDelta(thinking)); } } ContentBlockDelta::SignatureDelta { .. } => {} @@ -3254,9 +3241,7 @@ fn final_assistant_text(summary: &runtime::TurnSummary) -> String { .iter() .filter_map(|block| match block { ContentBlock::Text { text } => Some(text.as_str()), - ContentBlock::Thinking { thinking, .. } => Some(thinking.as_str()), - ContentBlock::RedactedThinking { .. } - | ContentBlock::ToolUse { .. } + ContentBlock::ToolUse { .. } | ContentBlock::ToolResult { .. } => None, }) .collect::>() @@ -3276,9 +3261,7 @@ fn collect_tool_uses(summary: &runtime::TurnSummary) -> Vec { "name": name, "input": input, })), - ContentBlock::Thinking { .. } - | ContentBlock::RedactedThinking { .. } - | ContentBlock::Text { .. } + ContentBlock::Text { .. } | ContentBlock::ToolResult { .. } => None, }) .collect() @@ -3301,9 +3284,7 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec None, }) .collect() @@ -3851,7 +3832,6 @@ fn push_output_block( write!(out, "\x1b[2m{thinking}\x1b[0m") .and_then(|()| out.flush()) .map_err(|error| RuntimeError::new(error.to_string()))?; - events.push(AssistantEvent::ThinkingDelta(thinking)); } } OutputContentBlock::RedactedThinking { .. } => {} @@ -3942,7 +3922,7 @@ impl ToolExecutor for CliToolExecutor { } fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy { - tool_registry.permission_specs(None).into_iter().fold( + tool_registry.permission_specs(None).unwrap_or_default().into_iter().fold( PermissionPolicy::new(mode), |policy, (name, required_permission)| { policy.with_tool_requirement(name, required_permission) @@ -3963,16 +3943,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { .iter() .map(|block| match block { ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() }, - ContentBlock::Thinking { - thinking, - signature, - } => InputContentBlock::Thinking { - thinking: thinking.clone(), - signature: signature.clone(), - }, - ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking { - data: serde_json::from_str(&data.render()).unwrap_or(serde_json::Value::Null), - }, ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse { id: id.clone(), name: name.clone(), @@ -4735,39 +4705,39 @@ mod tests { #[test] fn clear_command_requires_explicit_confirmation_flag() { assert_eq!( - SlashCommand::parse("/clear"), - Some(SlashCommand::Clear { confirm: false }) + SlashCommand::parse("/clear").map_err(|e| e.to_string()), + Ok(Some(SlashCommand::Clear { confirm: false })) ); assert_eq!( - SlashCommand::parse("/clear --confirm"), - Some(SlashCommand::Clear { confirm: true }) + SlashCommand::parse("/clear --confirm").map_err(|e| e.to_string()), + Ok(Some(SlashCommand::Clear { confirm: true })) ); } #[test] fn parses_resume_and_config_slash_commands() { assert_eq!( - SlashCommand::parse("/resume saved-session.json"), - Some(SlashCommand::Resume { + SlashCommand::parse("/resume saved-session.json").map_err(|e| e.to_string()), + Ok(Some(SlashCommand::Resume { session_path: Some("saved-session.json".to_string()) - }) + })) ); assert_eq!( - SlashCommand::parse("/clear --confirm"), - Some(SlashCommand::Clear { confirm: true }) + SlashCommand::parse("/clear --confirm").map_err(|e| e.to_string()), + Ok(Some(SlashCommand::Clear { confirm: true })) ); assert_eq!( - SlashCommand::parse("/config"), - Some(SlashCommand::Config { section: None }) + SlashCommand::parse("/config").map_err(|e| e.to_string()), + Ok(Some(SlashCommand::Config { section: None })) ); assert_eq!( - SlashCommand::parse("/config env"), - Some(SlashCommand::Config { + SlashCommand::parse("/config env").map_err(|e| e.to_string()), + Ok(Some(SlashCommand::Config { section: Some("env".to_string()) - }) + })) ); - assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory)); - assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init)); + assert_eq!(SlashCommand::parse("/memory").map_err(|e| e.to_string()), Ok(Some(SlashCommand::Memory))); + assert_eq!(SlashCommand::parse("/init").map_err(|e| e.to_string()), Ok(Some(SlashCommand::Init))); } #[test] diff --git a/crates/commands/src/lib.rs b/crates/commands/src/lib.rs index 5537028..4904f9f 100644 --- a/crates/commands/src/lib.rs +++ b/crates/commands/src/lib.rs @@ -1,13 +1,15 @@ use std::collections::BTreeMap; use std::env; +use std::fmt; use std::fs; -use std::io; use std::path::{Path, PathBuf}; -use std::process::Command; -use std::time::{SystemTime, UNIX_EPOCH}; use plugins::{PluginError, PluginManager, PluginSummary}; -use runtime::{compact_session, CompactionConfig, Session}; +use runtime::{ + compact_session, CompactionConfig, ConfigLoader, ConfigSource, McpOAuthConfig, McpServerConfig, + ScopedMcpServerConfig, Session, +}; +use serde_json::{json, Value}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CommandManifestEntry { @@ -39,27 +41,6 @@ impl CommandRegistry { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SlashCommandCategory { - Core, - Workspace, - Session, - Git, - Automation, -} - -impl SlashCommandCategory { - const fn title(self) -> &'static str { - match self { - Self::Core => "Core flow", - Self::Workspace => "Workspace & memory", - Self::Session => "Sessions & output", - Self::Git => "Git & GitHub", - Self::Automation => "Automation & discovery", - } - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SlashCommandSpec { pub name: &'static str, @@ -67,7 +48,12 @@ pub struct SlashCommandSpec { pub summary: &'static str, pub argument_hint: Option<&'static str>, pub resume_supported: bool, - pub category: SlashCommandCategory, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SkillSlashDispatch { + Local, + Invoke(String), } const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ @@ -77,7 +63,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Show available slash commands", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "status", @@ -85,7 +70,13 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Show current session status", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Core, + }, + SlashCommandSpec { + name: "sandbox", + aliases: &[], + summary: "Show sandbox isolation status", + argument_hint: None, + resume_supported: true, }, SlashCommandSpec { name: "compact", @@ -93,7 +84,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Compact local session history", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "model", @@ -101,7 +91,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Show or switch the active model", argument_hint: Some("[model]"), resume_supported: false, - category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "permissions", @@ -109,7 +98,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Show or switch the active permission mode", argument_hint: Some("[read-only|workspace-write|danger-full-access]"), resume_supported: false, - category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "clear", @@ -117,7 +105,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Start a fresh local session", argument_hint: Some("[--confirm]"), resume_supported: true, - category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "cost", @@ -125,7 +112,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Show cumulative token usage for this session", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "resume", @@ -133,31 +119,34 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Load a saved session into the REPL", argument_hint: Some(""), resume_supported: false, - category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "config", aliases: &[], - summary: "Inspect Claw config files or merged sections", + summary: "Inspect Claude config files or merged sections", argument_hint: Some("[env|hooks|model|plugins]"), resume_supported: true, - category: SlashCommandCategory::Workspace, + }, + SlashCommandSpec { + name: "mcp", + aliases: &[], + summary: "Inspect configured MCP servers", + argument_hint: Some("[list|show |help]"), + resume_supported: true, }, SlashCommandSpec { name: "memory", aliases: &[], - summary: "Inspect loaded Claw instruction memory files", + summary: "Inspect loaded Claude instruction memory files", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "init", aliases: &[], - summary: "Create a starter CLAW.md for this repo", + summary: "Create a starter CLAUDE.md for this repo", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "diff", @@ -165,7 +154,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Show git diff for current workspace changes", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "version", @@ -173,7 +161,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Show CLI version and build information", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "bughunter", @@ -181,23 +168,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Inspect the codebase for likely bugs", argument_hint: Some("[scope]"), resume_supported: false, - category: SlashCommandCategory::Automation, - }, - SlashCommandSpec { - name: "branch", - aliases: &[], - summary: "List, create, or switch git branches", - argument_hint: Some("[list|create |switch ]"), - resume_supported: false, - category: SlashCommandCategory::Git, - }, - SlashCommandSpec { - name: "worktree", - aliases: &[], - summary: "List, add, remove, or prune git worktrees", - argument_hint: Some("[list|add [branch]|remove |prune]"), - resume_supported: false, - category: SlashCommandCategory::Git, }, SlashCommandSpec { name: "commit", @@ -205,15 +175,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Generate a commit message and create a git commit", argument_hint: None, resume_supported: false, - category: SlashCommandCategory::Git, - }, - SlashCommandSpec { - name: "commit-push-pr", - aliases: &[], - summary: "Commit workspace changes, push the branch, and open a PR", - argument_hint: Some("[context]"), - resume_supported: false, - category: SlashCommandCategory::Git, }, SlashCommandSpec { name: "pr", @@ -221,7 +182,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Draft or create a pull request from the conversation", argument_hint: Some("[context]"), resume_supported: false, - category: SlashCommandCategory::Git, }, SlashCommandSpec { name: "issue", @@ -229,7 +189,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Draft or create a GitHub issue from the conversation", argument_hint: Some("[context]"), resume_supported: false, - category: SlashCommandCategory::Git, }, SlashCommandSpec { name: "ultraplan", @@ -237,7 +196,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Run a deep planning prompt with multi-step reasoning", argument_hint: Some("[task]"), resume_supported: false, - category: SlashCommandCategory::Automation, }, SlashCommandSpec { name: "teleport", @@ -245,7 +203,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Jump to a file or symbol by searching the workspace", argument_hint: Some(""), resume_supported: false, - category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "debug-tool-call", @@ -253,7 +210,6 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Replay the last tool call with debug details", argument_hint: None, resume_supported: false, - category: SlashCommandCategory::Automation, }, SlashCommandSpec { name: "export", @@ -261,15 +217,15 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ summary: "Export the current conversation to a file", argument_hint: Some("[file]"), resume_supported: true, - category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "session", aliases: &[], - summary: "List or switch managed local sessions", - argument_hint: Some("[list|switch ]"), + summary: "List, switch, fork, or delete managed local sessions", + argument_hint: Some( + "[list|switch |fork [branch-name]|delete [--force]]", + ), resume_supported: false, - category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "plugin", @@ -279,23 +235,818 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ "[list|install |enable |disable |uninstall |update ]", ), resume_supported: false, - category: SlashCommandCategory::Automation, }, SlashCommandSpec { name: "agents", aliases: &[], summary: "List configured agents", - argument_hint: None, + argument_hint: Some("[list|help]"), resume_supported: true, - category: SlashCommandCategory::Automation, }, SlashCommandSpec { name: "skills", + aliases: &["skill"], + summary: "List, install, or invoke available skills", + argument_hint: Some("[list|install |help| [args]]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "doctor", aliases: &[], - summary: "List available skills", + summary: "Diagnose setup issues and environment health", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "login", + aliases: &[], + summary: "Log in to the service", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "logout", + aliases: &[], + summary: "Log out of the current session", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "plan", + aliases: &[], + summary: "Toggle or inspect planning mode", + argument_hint: Some("[on|off]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "review", + aliases: &[], + summary: "Run a code review on current changes", + argument_hint: Some("[scope]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "tasks", + aliases: &[], + summary: "List and manage background tasks", + argument_hint: Some("[list|get |stop ]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "theme", + aliases: &[], + summary: "Switch the terminal color theme", + argument_hint: Some("[theme-name]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "vim", + aliases: &[], + summary: "Toggle vim keybinding mode", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "voice", + aliases: &[], + summary: "Toggle voice input mode", + argument_hint: Some("[on|off]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "upgrade", + aliases: &[], + summary: "Check for and install CLI updates", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "usage", + aliases: &[], + summary: "Show detailed API usage statistics", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "stats", + aliases: &[], + summary: "Show workspace and session statistics", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "rename", + aliases: &[], + summary: "Rename the current session", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "copy", + aliases: &[], + summary: "Copy conversation or output to clipboard", + argument_hint: Some("[last|all]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "share", + aliases: &[], + summary: "Share the current conversation", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "feedback", + aliases: &[], + summary: "Submit feedback about the current session", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "hooks", + aliases: &[], + summary: "List and manage lifecycle hooks", + argument_hint: Some("[list|run ]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "files", + aliases: &[], + summary: "List files in the current context window", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "context", + aliases: &[], + summary: "Inspect or manage the conversation context", + argument_hint: Some("[show|clear]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "color", + aliases: &[], + summary: "Configure terminal color settings", + argument_hint: Some("[scheme]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "effort", + aliases: &[], + summary: "Set the effort level for responses", + argument_hint: Some("[low|medium|high]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "fast", + aliases: &[], + summary: "Toggle fast/concise response mode", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "exit", + aliases: &[], + summary: "Exit the REPL session", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "branch", + aliases: &[], + summary: "Create or switch git branches", + argument_hint: Some("[name]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "rewind", + aliases: &[], + summary: "Rewind the conversation to a previous state", + argument_hint: Some("[steps]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "summary", + aliases: &[], + summary: "Generate a summary of the conversation", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "desktop", + aliases: &[], + summary: "Open or manage the desktop app integration", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "ide", + aliases: &[], + summary: "Open or configure IDE integration", + argument_hint: Some("[vscode|cursor]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "tag", + aliases: &[], + summary: "Tag the current conversation point", + argument_hint: Some("[label]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "brief", + aliases: &[], + summary: "Toggle brief output mode", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "advisor", + aliases: &[], + summary: "Toggle advisor mode for guidance-only responses", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "stickers", + aliases: &[], + summary: "Browse and manage sticker packs", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "insights", + aliases: &[], + summary: "Show AI-generated insights about the session", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "thinkback", + aliases: &[], + summary: "Replay the thinking process of the last response", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "release-notes", + aliases: &[], + summary: "Generate release notes from recent changes", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "security-review", + aliases: &[], + summary: "Run a security review on the codebase", + argument_hint: Some("[scope]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "keybindings", + aliases: &[], + summary: "Show or configure keyboard shortcuts", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "privacy-settings", + aliases: &[], + summary: "View or modify privacy settings", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "output-style", + aliases: &[], + summary: "Switch output formatting style", + argument_hint: Some("[style]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "add-dir", + aliases: &[], + summary: "Add an additional directory to the context", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "allowed-tools", + aliases: &[], + summary: "Show or modify the allowed tools list", + argument_hint: Some("[add|remove|list] [tool]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "api-key", + aliases: &[], + summary: "Show or set the Anthropic API key", + argument_hint: Some("[key]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "approve", + aliases: &["yes", "y"], + summary: "Approve a pending tool execution", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "deny", + aliases: &["no", "n"], + summary: "Deny a pending tool execution", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "undo", + aliases: &[], + summary: "Undo the last file write or edit", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "stop", + aliases: &[], + summary: "Stop the current generation", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "retry", + aliases: &[], + summary: "Retry the last failed message", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "paste", + aliases: &[], + summary: "Paste clipboard content as input", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "screenshot", + aliases: &[], + summary: "Take a screenshot and add to conversation", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "image", + aliases: &[], + summary: "Add an image file to the conversation", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "terminal-setup", + aliases: &[], + summary: "Configure terminal integration settings", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "search", + aliases: &[], + summary: "Search files in the workspace", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "listen", + aliases: &[], + summary: "Listen for voice input", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "speak", + aliases: &[], + summary: "Read the last response aloud", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "language", + aliases: &[], + summary: "Set the interface language", + argument_hint: Some("[language]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "profile", + aliases: &[], + summary: "Show or switch user profile", + argument_hint: Some("[name]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "max-tokens", + aliases: &[], + summary: "Show or set the max output tokens", + argument_hint: Some("[count]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "temperature", + aliases: &[], + summary: "Show or set the sampling temperature", + argument_hint: Some("[value]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "system-prompt", + aliases: &[], + summary: "Show the active system prompt", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "tool-details", + aliases: &[], + summary: "Show detailed info about a specific tool", + argument_hint: Some(""), + resume_supported: true, + }, + SlashCommandSpec { + name: "format", + aliases: &[], + summary: "Format the last response in a different style", + argument_hint: Some("[markdown|plain|json]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "pin", + aliases: &[], + summary: "Pin a message to persist across compaction", + argument_hint: Some("[message-index]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "unpin", + aliases: &[], + summary: "Unpin a previously pinned message", + argument_hint: Some("[message-index]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "bookmarks", + aliases: &[], + summary: "List or manage conversation bookmarks", + argument_hint: Some("[add|remove|list]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "workspace", + aliases: &["cwd"], + summary: "Show or change the working directory", + argument_hint: Some("[path]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "history", + aliases: &[], + summary: "Show conversation history summary", + argument_hint: Some("[count]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "tokens", + aliases: &[], + summary: "Show token count for the current conversation", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "cache", + aliases: &[], + summary: "Show prompt cache statistics", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "providers", + aliases: &[], + summary: "List available model providers", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "notifications", + aliases: &[], + summary: "Show or configure notification settings", + argument_hint: Some("[on|off|status]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "changelog", + aliases: &[], + summary: "Show recent changes to the codebase", + argument_hint: Some("[count]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "test", + aliases: &[], + summary: "Run tests for the current project", + argument_hint: Some("[filter]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "lint", + aliases: &[], + summary: "Run linting for the current project", + argument_hint: Some("[filter]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "build", + aliases: &[], + summary: "Build the current project", + argument_hint: Some("[target]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "run", + aliases: &[], + summary: "Run a command in the project context", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "git", + aliases: &[], + summary: "Run a git command in the workspace", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "stash", + aliases: &[], + summary: "Stash or unstash workspace changes", + argument_hint: Some("[pop|list|apply]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "blame", + aliases: &[], + summary: "Show git blame for a file", + argument_hint: Some(" [line]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "log", + aliases: &[], + summary: "Show git log for the workspace", + argument_hint: Some("[count]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "cron", + aliases: &[], + summary: "Manage scheduled tasks", + argument_hint: Some("[list|add|remove]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "team", + aliases: &[], + summary: "Manage agent teams", + argument_hint: Some("[list|create|delete]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "benchmark", + aliases: &[], + summary: "Run performance benchmarks", + argument_hint: Some("[suite]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "migrate", + aliases: &[], + summary: "Run pending data migrations", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "reset", + aliases: &[], + summary: "Reset configuration to defaults", + argument_hint: Some("[section]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "telemetry", + aliases: &[], + summary: "Show or configure telemetry settings", + argument_hint: Some("[on|off|status]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "env", + aliases: &[], + summary: "Show environment variables visible to tools", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "project", + aliases: &[], + summary: "Show project detection info", + argument_hint: None, + resume_supported: true, + }, + SlashCommandSpec { + name: "templates", + aliases: &[], + summary: "List or apply prompt templates", + argument_hint: Some("[list|apply ]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "explain", + aliases: &[], + summary: "Explain a file or code snippet", + argument_hint: Some(" [line-range]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "refactor", + aliases: &[], + summary: "Suggest refactoring for a file or function", + argument_hint: Some(" [scope]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "docs", + aliases: &[], + summary: "Generate or show documentation", + argument_hint: Some("[path]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "fix", + aliases: &[], + summary: "Fix errors in a file or project", + argument_hint: Some("[path]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "perf", + aliases: &[], + summary: "Analyze performance of a function or file", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "chat", + aliases: &[], + summary: "Switch to free-form chat mode", + argument_hint: None, + resume_supported: false, + }, + SlashCommandSpec { + name: "focus", + aliases: &[], + summary: "Focus context on specific files or directories", + argument_hint: Some(" [path...]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "unfocus", + aliases: &[], + summary: "Remove focus from files or directories", + argument_hint: Some("[path...]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "web", + aliases: &[], + summary: "Fetch and summarize a web page", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "map", + aliases: &[], + summary: "Show a visual map of the codebase structure", + argument_hint: Some("[depth]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "symbols", + aliases: &[], + summary: "List symbols (functions, classes, etc.) in a file", + argument_hint: Some(""), + resume_supported: true, + }, + SlashCommandSpec { + name: "references", + aliases: &[], + summary: "Find all references to a symbol", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "definition", + aliases: &[], + summary: "Go to the definition of a symbol", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "hover", + aliases: &[], + summary: "Show hover information for a symbol", + argument_hint: Some(""), + resume_supported: true, + }, + SlashCommandSpec { + name: "diagnostics", + aliases: &[], + summary: "Show LSP diagnostics for a file", + argument_hint: Some("[path]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "autofix", + aliases: &[], + summary: "Auto-fix all fixable diagnostics", + argument_hint: Some("[path]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "multi", + aliases: &[], + summary: "Execute multiple slash commands in sequence", + argument_hint: Some(""), + resume_supported: false, + }, + SlashCommandSpec { + name: "macro", + aliases: &[], + summary: "Record or replay command macros", + argument_hint: Some("[record|stop|play ]"), + resume_supported: false, + }, + SlashCommandSpec { + name: "alias", + aliases: &[], + summary: "Create a command alias", + argument_hint: Some(" "), + resume_supported: true, + }, + SlashCommandSpec { + name: "parallel", + aliases: &[], + summary: "Run commands in parallel subagents", + argument_hint: Some(" "), + resume_supported: false, + }, + SlashCommandSpec { + name: "agent", + aliases: &[], + summary: "Manage sub-agents and spawned sessions", + argument_hint: Some("[list|spawn|kill]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "subagent", + aliases: &[], + summary: "Control active subagent execution", + argument_hint: Some("[list|steer |kill ]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "reasoning", + aliases: &[], + summary: "Toggle extended reasoning mode", + argument_hint: Some("[on|off|stream]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "budget", + aliases: &[], + summary: "Show or set token budget limits", + argument_hint: Some("[show|set ]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "rate-limit", + aliases: &[], + summary: "Configure API rate limiting", + argument_hint: Some("[status|set ]"), + resume_supported: true, + }, + SlashCommandSpec { + name: "metrics", + aliases: &[], + summary: "Show performance and usage metrics", argument_hint: None, resume_supported: true, - category: SlashCommandCategory::Automation, }, ]; @@ -303,23 +1054,12 @@ const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ pub enum SlashCommand { Help, Status, + Sandbox, Compact, - Branch { - action: Option, - target: Option, - }, Bughunter { scope: Option, }, - Worktree { - action: Option, - path: Option, - branch: Option, - }, Commit, - CommitPushPr { - context: Option, - }, Pr { context: Option, }, @@ -349,6 +1089,10 @@ pub enum SlashCommand { Config { section: Option, }, + Mcp { + action: Option, + target: Option, + }, Memory, Init, Diff, @@ -370,97 +1114,711 @@ pub enum SlashCommand { Skills { args: Option, }, + Doctor, + Login, + Logout, + Vim, + Upgrade, + Stats, + Share, + Feedback, + Files, + Fast, + Exit, + Summary, + Desktop, + Brief, + Advisor, + Stickers, + Insights, + Thinkback, + ReleaseNotes, + SecurityReview, + Keybindings, + PrivacySettings, + Plan { + mode: Option, + }, + Review { + scope: Option, + }, + Tasks { + args: Option, + }, + Theme { + name: Option, + }, + Voice { + mode: Option, + }, + Usage { + scope: Option, + }, + Rename { + name: Option, + }, + Copy { + target: Option, + }, + Hooks { + args: Option, + }, + Context { + action: Option, + }, + Color { + scheme: Option, + }, + Effort { + level: Option, + }, + Branch { + name: Option, + }, + Rewind { + steps: Option, + }, + Ide { + target: Option, + }, + Tag { + label: Option, + }, + OutputStyle { + style: Option, + }, + AddDir { + path: Option, + }, + History { + count: Option, + }, Unknown(String), } -impl SlashCommand { - #[must_use] - pub fn parse(input: &str) -> Option { - let trimmed = input.trim(); - if !trimmed.starts_with('/') { - return None; - } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SlashCommandParseError { + message: String, +} - let mut parts = trimmed.trim_start_matches('/').split_whitespace(); - let command = parts.next().unwrap_or_default(); - Some(match command { - "help" => Self::Help, - "status" => Self::Status, - "compact" => Self::Compact, - "branch" => Self::Branch { - action: parts.next().map(ToOwned::to_owned), - target: parts.next().map(ToOwned::to_owned), - }, - "bughunter" => Self::Bughunter { - scope: remainder_after_command(trimmed, command), - }, - "worktree" => Self::Worktree { - action: parts.next().map(ToOwned::to_owned), - path: parts.next().map(ToOwned::to_owned), - branch: parts.next().map(ToOwned::to_owned), - }, - "commit" => Self::Commit, - "commit-push-pr" => Self::CommitPushPr { - context: remainder_after_command(trimmed, command), - }, - "pr" => Self::Pr { - context: remainder_after_command(trimmed, command), - }, - "issue" => Self::Issue { - context: remainder_after_command(trimmed, command), - }, - "ultraplan" => Self::Ultraplan { - task: remainder_after_command(trimmed, command), - }, - "teleport" => Self::Teleport { - target: remainder_after_command(trimmed, command), - }, - "debug-tool-call" => Self::DebugToolCall, - "model" => Self::Model { - model: parts.next().map(ToOwned::to_owned), - }, - "permissions" => Self::Permissions { - mode: parts.next().map(ToOwned::to_owned), - }, - "clear" => Self::Clear { - confirm: parts.next() == Some("--confirm"), - }, - "cost" => Self::Cost, - "resume" => Self::Resume { - session_path: parts.next().map(ToOwned::to_owned), - }, - "config" => Self::Config { - section: parts.next().map(ToOwned::to_owned), - }, - "memory" => Self::Memory, - "init" => Self::Init, - "diff" => Self::Diff, - "version" => Self::Version, - "export" => Self::Export { - path: parts.next().map(ToOwned::to_owned), - }, - "session" => Self::Session { - action: parts.next().map(ToOwned::to_owned), - target: parts.next().map(ToOwned::to_owned), - }, - "plugin" | "plugins" | "marketplace" => Self::Plugins { - action: parts.next().map(ToOwned::to_owned), - target: { - let remainder = parts.collect::>().join(" "); - (!remainder.is_empty()).then_some(remainder) - }, - }, - "agents" => Self::Agents { - args: remainder_after_command(trimmed, command), - }, - "skills" => Self::Skills { - args: remainder_after_command(trimmed, command), - }, - other => Self::Unknown(other.to_string()), - }) +impl SlashCommandParseError { + fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } } } +impl fmt::Display for SlashCommandParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for SlashCommandParseError {} + +impl SlashCommand { + pub fn parse(input: &str) -> Result, SlashCommandParseError> { + validate_slash_command_input(input) + } + + /// Returns the canonical slash-command name (e.g. `"/branch"`) for use in + /// error messages and logging. Derived from the spec table so it always + /// matches what the user would have typed. + #[must_use] + pub fn slash_name(&self) -> &'static str { + match self { + Self::Help => "/help", + Self::Clear { .. } => "/clear", + Self::Compact { .. } => "/compact", + Self::Cost => "/cost", + Self::Doctor => "/doctor", + Self::Config { .. } => "/config", + Self::Memory { .. } => "/memory", + Self::History { .. } => "/history", + Self::Diff => "/diff", + Self::Status => "/status", + Self::Stats => "/stats", + Self::Version => "/version", + Self::Commit { .. } => "/commit", + Self::Pr { .. } => "/pr", + Self::Issue { .. } => "/issue", + Self::Init => "/init", + Self::Bughunter { .. } => "/bughunter", + Self::Ultraplan { .. } => "/ultraplan", + Self::Teleport { .. } => "/teleport", + Self::DebugToolCall { .. } => "/debug-tool-call", + Self::Resume { .. } => "/resume", + Self::Model { .. } => "/model", + Self::Permissions { .. } => "/permissions", + Self::Session { .. } => "/session", + Self::Plugins { .. } => "/plugins", + Self::Login => "/login", + Self::Logout => "/logout", + Self::Vim => "/vim", + Self::Upgrade => "/upgrade", + Self::Share => "/share", + Self::Feedback => "/feedback", + Self::Files => "/files", + Self::Fast => "/fast", + Self::Exit => "/exit", + Self::Summary => "/summary", + Self::Desktop => "/desktop", + Self::Brief => "/brief", + Self::Advisor => "/advisor", + Self::Stickers => "/stickers", + Self::Insights => "/insights", + Self::Thinkback => "/thinkback", + Self::ReleaseNotes => "/release-notes", + Self::SecurityReview => "/security-review", + Self::Keybindings => "/keybindings", + Self::PrivacySettings => "/privacy-settings", + Self::Plan { .. } => "/plan", + Self::Review { .. } => "/review", + Self::Tasks { .. } => "/tasks", + Self::Theme { .. } => "/theme", + Self::Voice { .. } => "/voice", + Self::Usage { .. } => "/usage", + Self::Rename { .. } => "/rename", + Self::Copy { .. } => "/copy", + Self::Hooks { .. } => "/hooks", + Self::Context { .. } => "/context", + Self::Color { .. } => "/color", + Self::Effort { .. } => "/effort", + Self::Branch { .. } => "/branch", + Self::Rewind { .. } => "/rewind", + Self::Ide { .. } => "/ide", + Self::Tag { .. } => "/tag", + Self::OutputStyle { .. } => "/output-style", + Self::AddDir { .. } => "/add-dir", + Self::Unknown(_) => "/unknown", + Self::Sandbox => "/sandbox", + Self::Mcp { .. } => "/mcp", + Self::Export { .. } => "/export", + #[allow(unreachable_patterns)] + _ => "/unknown", + } + } +} + +#[allow(clippy::too_many_lines)] +pub fn validate_slash_command_input( + input: &str, +) -> Result, SlashCommandParseError> { + let trimmed = input.trim(); + if !trimmed.starts_with('/') { + return Ok(None); + } + + let mut parts = trimmed.trim_start_matches('/').split_whitespace(); + let command = parts.next().unwrap_or_default(); + if command.is_empty() { + return Err(SlashCommandParseError::new( + "Slash command name is missing. Use /help to list available slash commands.", + )); + } + + let args = parts.collect::>(); + let remainder = remainder_after_command(trimmed, command); + + Ok(Some(match command { + "help" => { + validate_no_args(command, &args)?; + SlashCommand::Help + } + "status" => { + validate_no_args(command, &args)?; + SlashCommand::Status + } + "sandbox" => { + validate_no_args(command, &args)?; + SlashCommand::Sandbox + } + "compact" => { + validate_no_args(command, &args)?; + SlashCommand::Compact + } + "bughunter" => SlashCommand::Bughunter { scope: remainder }, + "commit" => { + validate_no_args(command, &args)?; + SlashCommand::Commit + } + "pr" => SlashCommand::Pr { context: remainder }, + "issue" => SlashCommand::Issue { context: remainder }, + "ultraplan" => SlashCommand::Ultraplan { task: remainder }, + "teleport" => SlashCommand::Teleport { + target: Some(require_remainder(command, remainder, "")?), + }, + "debug-tool-call" => { + validate_no_args(command, &args)?; + SlashCommand::DebugToolCall + } + "model" => SlashCommand::Model { + model: optional_single_arg(command, &args, "[model]")?, + }, + "permissions" => SlashCommand::Permissions { + mode: parse_permissions_mode(&args)?, + }, + "clear" => SlashCommand::Clear { + confirm: parse_clear_args(&args)?, + }, + "cost" => { + validate_no_args(command, &args)?; + SlashCommand::Cost + } + "resume" => SlashCommand::Resume { + session_path: Some(require_remainder(command, remainder, "")?), + }, + "config" => SlashCommand::Config { + section: parse_config_section(&args)?, + }, + "mcp" => parse_mcp_command(&args)?, + "memory" => { + validate_no_args(command, &args)?; + SlashCommand::Memory + } + "init" => { + validate_no_args(command, &args)?; + SlashCommand::Init + } + "diff" => { + validate_no_args(command, &args)?; + SlashCommand::Diff + } + "version" => { + validate_no_args(command, &args)?; + SlashCommand::Version + } + "export" => SlashCommand::Export { path: remainder }, + "session" => parse_session_command(&args)?, + "plugin" | "plugins" | "marketplace" => parse_plugin_command(&args)?, + "agents" => SlashCommand::Agents { + args: parse_list_or_help_args(command, remainder)?, + }, + "skills" | "skill" => SlashCommand::Skills { + args: parse_skills_args(remainder.as_deref())?, + }, + "doctor" | "providers" => { + validate_no_args(command, &args)?; + SlashCommand::Doctor + } + "login" => { + validate_no_args(command, &args)?; + SlashCommand::Login + } + "logout" => { + validate_no_args(command, &args)?; + SlashCommand::Logout + } + "vim" => { + validate_no_args(command, &args)?; + SlashCommand::Vim + } + "upgrade" => { + validate_no_args(command, &args)?; + SlashCommand::Upgrade + } + "stats" | "tokens" | "cache" => { + validate_no_args(command, &args)?; + SlashCommand::Stats + } + "share" => { + validate_no_args(command, &args)?; + SlashCommand::Share + } + "feedback" => { + validate_no_args(command, &args)?; + SlashCommand::Feedback + } + "files" => { + validate_no_args(command, &args)?; + SlashCommand::Files + } + "fast" => { + validate_no_args(command, &args)?; + SlashCommand::Fast + } + "exit" => { + validate_no_args(command, &args)?; + SlashCommand::Exit + } + "summary" => { + validate_no_args(command, &args)?; + SlashCommand::Summary + } + "desktop" => { + validate_no_args(command, &args)?; + SlashCommand::Desktop + } + "brief" => { + validate_no_args(command, &args)?; + SlashCommand::Brief + } + "advisor" => { + validate_no_args(command, &args)?; + SlashCommand::Advisor + } + "stickers" => { + validate_no_args(command, &args)?; + SlashCommand::Stickers + } + "insights" => { + validate_no_args(command, &args)?; + SlashCommand::Insights + } + "thinkback" => { + validate_no_args(command, &args)?; + SlashCommand::Thinkback + } + "release-notes" => { + validate_no_args(command, &args)?; + SlashCommand::ReleaseNotes + } + "security-review" => { + validate_no_args(command, &args)?; + SlashCommand::SecurityReview + } + "keybindings" => { + validate_no_args(command, &args)?; + SlashCommand::Keybindings + } + "privacy-settings" => { + validate_no_args(command, &args)?; + SlashCommand::PrivacySettings + } + "plan" => SlashCommand::Plan { mode: remainder }, + "review" => SlashCommand::Review { scope: remainder }, + "tasks" => SlashCommand::Tasks { args: remainder }, + "theme" => SlashCommand::Theme { name: remainder }, + "voice" => SlashCommand::Voice { mode: remainder }, + "usage" => SlashCommand::Usage { scope: remainder }, + "rename" => SlashCommand::Rename { name: remainder }, + "copy" => SlashCommand::Copy { target: remainder }, + "hooks" => SlashCommand::Hooks { args: remainder }, + "context" => SlashCommand::Context { action: remainder }, + "color" => SlashCommand::Color { scheme: remainder }, + "effort" => SlashCommand::Effort { level: remainder }, + "branch" => SlashCommand::Branch { name: remainder }, + "rewind" => SlashCommand::Rewind { steps: remainder }, + "ide" => SlashCommand::Ide { target: remainder }, + "tag" => SlashCommand::Tag { label: remainder }, + "output-style" => SlashCommand::OutputStyle { style: remainder }, + "add-dir" => SlashCommand::AddDir { path: remainder }, + "history" => SlashCommand::History { + count: optional_single_arg(command, &args, "[count]")?, + }, + other => SlashCommand::Unknown(other.to_string()), + })) +} +fn validate_no_args(command: &str, args: &[&str]) -> Result<(), SlashCommandParseError> { + if args.is_empty() { + return Ok(()); + } + + Err(command_error( + &format!("Unexpected arguments for /{command}."), + command, + &format!("/{command}"), + )) +} + +fn optional_single_arg( + command: &str, + args: &[&str], + argument_hint: &str, +) -> Result, SlashCommandParseError> { + match args { + [] => Ok(None), + [value] => Ok(Some((*value).to_string())), + _ => Err(usage_error(command, argument_hint)), + } +} + +fn require_remainder( + command: &str, + remainder: Option, + argument_hint: &str, +) -> Result { + remainder.ok_or_else(|| usage_error(command, argument_hint)) +} + +fn parse_permissions_mode(args: &[&str]) -> Result, SlashCommandParseError> { + let mode = optional_single_arg( + "permissions", + args, + "[read-only|workspace-write|danger-full-access]", + )?; + if let Some(mode) = mode { + if matches!( + mode.as_str(), + "read-only" | "workspace-write" | "danger-full-access" + ) { + return Ok(Some(mode)); + } + return Err(command_error( + &format!( + "Unsupported /permissions mode '{mode}'. Use read-only, workspace-write, or danger-full-access." + ), + "permissions", + "/permissions [read-only|workspace-write|danger-full-access]", + )); + } + + Ok(None) +} + +fn parse_clear_args(args: &[&str]) -> Result { + match args { + [] => Ok(false), + ["--confirm"] => Ok(true), + [unexpected] => Err(command_error( + &format!("Unsupported /clear argument '{unexpected}'. Use /clear or /clear --confirm."), + "clear", + "/clear [--confirm]", + )), + _ => Err(usage_error("clear", "[--confirm]")), + } +} + +fn parse_config_section(args: &[&str]) -> Result, SlashCommandParseError> { + let section = optional_single_arg("config", args, "[env|hooks|model|plugins]")?; + if let Some(section) = section { + if matches!(section.as_str(), "env" | "hooks" | "model" | "plugins") { + return Ok(Some(section)); + } + return Err(command_error( + &format!("Unsupported /config section '{section}'. Use env, hooks, model, or plugins."), + "config", + "/config [env|hooks|model|plugins]", + )); + } + + Ok(None) +} + +fn parse_session_command(args: &[&str]) -> Result { + match args { + [] => Ok(SlashCommand::Session { + action: None, + target: None, + }), + ["list"] => Ok(SlashCommand::Session { + action: Some("list".to_string()), + target: None, + }), + ["list", ..] => Err(usage_error("session", "[list|switch |fork [branch-name]|delete [--force]]")), + ["switch"] => Err(usage_error("session switch", "")), + ["switch", target] => Ok(SlashCommand::Session { + action: Some("switch".to_string()), + target: Some((*target).to_string()), + }), + ["switch", ..] => Err(command_error( + "Unexpected arguments for /session switch.", + "session", + "/session switch ", + )), + ["fork"] => Ok(SlashCommand::Session { + action: Some("fork".to_string()), + target: None, + }), + ["fork", target] => Ok(SlashCommand::Session { + action: Some("fork".to_string()), + target: Some((*target).to_string()), + }), + ["fork", ..] => Err(command_error( + "Unexpected arguments for /session fork.", + "session", + "/session fork [branch-name]", + )), + ["delete"] => Err(usage_error("session delete", " [--force]")), + ["delete", target] => Ok(SlashCommand::Session { + action: Some("delete".to_string()), + target: Some((*target).to_string()), + }), + ["delete", target, "--force"] => Ok(SlashCommand::Session { + action: Some("delete-force".to_string()), + target: Some((*target).to_string()), + }), + ["delete", _target, unexpected] => Err(command_error( + &format!( + "Unsupported /session delete flag '{unexpected}'. Use --force to skip confirmation." + ), + "session", + "/session delete [--force]", + )), + ["delete", ..] => Err(command_error( + "Unexpected arguments for /session delete.", + "session", + "/session delete [--force]", + )), + [action, ..] => Err(command_error( + &format!( + "Unknown /session action '{action}'. Use list, switch , fork [branch-name], or delete [--force]." + ), + "session", + "/session [list|switch |fork [branch-name]|delete [--force]]", + )), + } +} + +fn parse_mcp_command(args: &[&str]) -> Result { + match args { + [] => Ok(SlashCommand::Mcp { + action: None, + target: None, + }), + ["list"] => Ok(SlashCommand::Mcp { + action: Some("list".to_string()), + target: None, + }), + ["list", ..] => Err(usage_error("mcp list", "")), + ["show"] => Err(usage_error("mcp show", "")), + ["show", target] => Ok(SlashCommand::Mcp { + action: Some("show".to_string()), + target: Some((*target).to_string()), + }), + ["show", ..] => Err(command_error( + "Unexpected arguments for /mcp show.", + "mcp", + "/mcp show ", + )), + ["help" | "-h" | "--help"] => Ok(SlashCommand::Mcp { + action: Some("help".to_string()), + target: None, + }), + [action, ..] => Err(command_error( + &format!("Unknown /mcp action '{action}'. Use list, show , or help."), + "mcp", + "/mcp [list|show |help]", + )), + } +} + +fn parse_plugin_command(args: &[&str]) -> Result { + match args { + [] => Ok(SlashCommand::Plugins { + action: None, + target: None, + }), + ["list"] => Ok(SlashCommand::Plugins { + action: Some("list".to_string()), + target: None, + }), + ["list", ..] => Err(usage_error("plugin list", "")), + ["install"] => Err(usage_error("plugin install", "")), + ["install", target @ ..] => Ok(SlashCommand::Plugins { + action: Some("install".to_string()), + target: Some(target.join(" ")), + }), + ["enable"] => Err(usage_error("plugin enable", "")), + ["enable", target] => Ok(SlashCommand::Plugins { + action: Some("enable".to_string()), + target: Some((*target).to_string()), + }), + ["enable", ..] => Err(command_error( + "Unexpected arguments for /plugin enable.", + "plugin", + "/plugin enable ", + )), + ["disable"] => Err(usage_error("plugin disable", "")), + ["disable", target] => Ok(SlashCommand::Plugins { + action: Some("disable".to_string()), + target: Some((*target).to_string()), + }), + ["disable", ..] => Err(command_error( + "Unexpected arguments for /plugin disable.", + "plugin", + "/plugin disable ", + )), + ["uninstall"] => Err(usage_error("plugin uninstall", "")), + ["uninstall", target] => Ok(SlashCommand::Plugins { + action: Some("uninstall".to_string()), + target: Some((*target).to_string()), + }), + ["uninstall", ..] => Err(command_error( + "Unexpected arguments for /plugin uninstall.", + "plugin", + "/plugin uninstall ", + )), + ["update"] => Err(usage_error("plugin update", "")), + ["update", target] => Ok(SlashCommand::Plugins { + action: Some("update".to_string()), + target: Some((*target).to_string()), + }), + ["update", ..] => Err(command_error( + "Unexpected arguments for /plugin update.", + "plugin", + "/plugin update ", + )), + [action, ..] => Err(command_error( + &format!( + "Unknown /plugin action '{action}'. Use list, install , enable , disable , uninstall , or update ." + ), + "plugin", + "/plugin [list|install |enable |disable |uninstall |update ]", + )), + } +} + +fn parse_list_or_help_args( + command: &str, + args: Option, +) -> Result, SlashCommandParseError> { + match normalize_optional_args(args.as_deref()) { + None | Some("list" | "help" | "-h" | "--help") => Ok(args), + Some(unexpected) => Err(command_error( + &format!( + "Unexpected arguments for /{command}: {unexpected}. Use /{command}, /{command} list, or /{command} help." + ), + command, + &format!("/{command} [list|help]"), + )), + } +} + +fn parse_skills_args(args: Option<&str>) -> Result, SlashCommandParseError> { + let Some(args) = normalize_optional_args(args) else { + return Ok(None); + }; + + if matches!(args, "list" | "help" | "-h" | "--help") { + return Ok(Some(args.to_string())); + } + + if args == "install" { + return Err(command_error( + "Usage: /skills install ", + "skills", + "/skills install ", + )); + } + + if let Some(target) = args.strip_prefix("install").map(str::trim) { + if !target.is_empty() { + return Ok(Some(format!("install {target}"))); + } + } + + Ok(Some(args.to_string())) +} + +fn usage_error(command: &str, argument_hint: &str) -> SlashCommandParseError { + let usage = format!("/{command} {argument_hint}"); + let usage = usage.trim_end().to_string(); + command_error( + &format!("Usage: {usage}"), + command_root_name(command), + &usage, + ) +} + +fn command_error(message: &str, command: &str, usage: &str) -> SlashCommandParseError { + let detail = render_slash_command_help_detail(command) + .map(|detail| format!("\n\n{detail}")) + .unwrap_or_default(); + SlashCommandParseError::new(format!("{message}\n Usage {usage}{detail}")) +} + fn remainder_after_command(input: &str, command: &str) -> Option { input .trim() @@ -470,6 +1828,56 @@ fn remainder_after_command(input: &str, command: &str) -> Option { .map(ToOwned::to_owned) } +fn find_slash_command_spec(name: &str) -> Option<&'static SlashCommandSpec> { + slash_command_specs().iter().find(|spec| { + spec.name.eq_ignore_ascii_case(name) + || spec + .aliases + .iter() + .any(|alias| alias.eq_ignore_ascii_case(name)) + }) +} + +fn command_root_name(command: &str) -> &str { + command.split_whitespace().next().unwrap_or(command) +} + +fn slash_command_usage(spec: &SlashCommandSpec) -> String { + match spec.argument_hint { + Some(argument_hint) => format!("/{} {argument_hint}", spec.name), + None => format!("/{}", spec.name), + } +} + +fn slash_command_detail_lines(spec: &SlashCommandSpec) -> Vec { + let mut lines = vec![format!("/{}", spec.name)]; + lines.push(format!(" Summary {}", spec.summary)); + lines.push(format!(" Usage {}", slash_command_usage(spec))); + lines.push(format!( + " Category {}", + slash_command_category(spec.name) + )); + if !spec.aliases.is_empty() { + lines.push(format!( + " Aliases {}", + spec.aliases + .iter() + .map(|alias| format!("/{alias}")) + .collect::>() + .join(", ") + )); + } + if spec.resume_supported { + lines.push(" Resume Supported with --resume SESSION.jsonl".to_string()); + } + lines +} + +#[must_use] +pub fn render_slash_command_help_detail(name: &str) -> Option { + find_slash_command_spec(name).map(|spec| slash_command_detail_lines(spec).join("\n")) +} + #[must_use] pub fn slash_command_specs() -> &'static [SlashCommandSpec] { SLASH_COMMAND_SPECS @@ -483,35 +1891,36 @@ pub fn resume_supported_slash_commands() -> Vec<&'static SlashCommandSpec> { .collect() } -#[must_use] -pub fn render_slash_command_help() -> String { - let mut lines = vec![ - "Slash commands".to_string(), - " Tab completes commands inside the REPL.".to_string(), - " [resume] = also available via claw --resume SESSION.json".to_string(), - ]; - - for category in [ - SlashCommandCategory::Core, - SlashCommandCategory::Workspace, - SlashCommandCategory::Session, - SlashCommandCategory::Git, - SlashCommandCategory::Automation, - ] { - lines.push(String::new()); - lines.push(category.title().to_string()); - lines.extend( - slash_command_specs() - .iter() - .filter(|spec| spec.category == category) - .map(render_slash_command_entry), - ); +fn slash_command_category(name: &str) -> &'static str { + match name { + "help" | "status" | "cost" | "resume" | "session" | "version" | "login" | "logout" + | "usage" | "stats" | "rename" | "clear" | "compact" | "history" | "tokens" | "cache" + | "exit" | "summary" | "tag" | "thinkback" | "copy" | "share" | "feedback" | "rewind" + | "pin" | "unpin" | "bookmarks" | "context" | "files" | "focus" | "unfocus" | "retry" + | "stop" | "undo" => "Session", + "diff" | "commit" | "pr" | "issue" | "branch" | "blame" | "log" | "git" | "stash" + | "init" | "export" | "plan" | "review" | "security-review" | "bughunter" | "ultraplan" + | "teleport" | "refactor" | "fix" | "autofix" | "explain" | "docs" | "perf" | "search" + | "references" | "definition" | "hover" | "symbols" | "map" | "web" | "image" + | "screenshot" | "paste" | "listen" | "speak" | "test" | "lint" | "build" | "run" + | "format" | "parallel" | "multi" | "macro" | "alias" | "templates" | "migrate" + | "benchmark" | "cron" | "agent" | "subagent" | "agents" | "skills" | "team" | "plugin" + | "mcp" | "hooks" | "tasks" | "advisor" | "insights" | "release-notes" | "chat" + | "approve" | "deny" | "allowed-tools" | "add-dir" => "Tools", + "model" | "permissions" | "config" | "memory" | "theme" | "vim" | "voice" | "color" + | "effort" | "fast" | "brief" | "output-style" | "keybindings" | "privacy-settings" + | "stickers" | "language" | "profile" | "max-tokens" | "temperature" | "system-prompt" + | "api-key" | "terminal-setup" | "notifications" | "telemetry" | "providers" | "env" + | "project" | "reasoning" | "budget" | "rate-limit" | "workspace" | "reset" | "ide" + | "desktop" | "upgrade" => "Config", + "debug-tool-call" | "doctor" | "sandbox" | "diagnostics" | "tool-details" | "changelog" + | "metrics" => "Debug", + _ => "Tools", } - - lines.join("\n") } -fn render_slash_command_entry(spec: &SlashCommandSpec) -> String { +fn format_slash_command_help_line(spec: &SlashCommandSpec) -> String { + let name = slash_command_usage(spec); let alias_suffix = if spec.aliases.is_empty() { String::new() } else { @@ -529,18 +1938,7 @@ fn render_slash_command_entry(spec: &SlashCommandSpec) -> String { } else { "" }; - format!( - " {name:<46} {}{alias_suffix}{resume}", - spec.summary, - name = render_slash_command_name(spec), - ) -} - -fn render_slash_command_name(spec: &SlashCommandSpec) -> String { - match spec.argument_hint { - Some(argument_hint) => format!("/{} {}", spec.name, argument_hint), - None => format!("/{}", spec.name), - } + format!(" {name:<66} {}{alias_suffix}{resume}", spec.summary) } fn levenshtein_distance(left: &str, right: &str) -> usize { @@ -561,12 +1959,12 @@ fn levenshtein_distance(left: &str, right: &str) -> usize { for (left_index, left_char) in left.chars().enumerate() { current[0] = left_index + 1; for (right_index, right_char) in right_chars.iter().enumerate() { - let cost = usize::from(left_char != *right_char); - current[right_index + 1] = (previous[right_index + 1] + 1) - .min(current[right_index] + 1) - .min(previous[right_index] + cost); + let substitution_cost = usize::from(left_char != *right_char); + current[right_index + 1] = (current[right_index] + 1) + .min(previous[right_index + 1] + 1) + .min(previous[right_index] + substitution_cost); } - std::mem::swap(&mut previous, &mut current); + previous.clone_from(¤t); } previous[right_chars.len()] @@ -574,44 +1972,124 @@ fn levenshtein_distance(left: &str, right: &str) -> usize { #[must_use] pub fn suggest_slash_commands(input: &str, limit: usize) -> Vec { - let normalized = input.trim().trim_start_matches('/').to_ascii_lowercase(); - if normalized.is_empty() || limit == 0 { + let query = input.trim().trim_start_matches('/').to_ascii_lowercase(); + if query.is_empty() || limit == 0 { return Vec::new(); } - let mut ranked = slash_command_specs() + let mut suggestions = slash_command_specs() .iter() .filter_map(|spec| { - let score = std::iter::once(spec.name) + let best = std::iter::once(spec.name) .chain(spec.aliases.iter().copied()) .map(str::to_ascii_lowercase) - .filter_map(|alias| { - if alias == normalized { - Some((0_usize, alias.len())) - } else if alias.starts_with(&normalized) { - Some((1, alias.len())) - } else if alias.contains(&normalized) { - Some((2, alias.len())) - } else { - let distance = levenshtein_distance(&alias, &normalized); - (distance <= 2).then_some((3 + distance, alias.len())) - } + .map(|candidate| { + let prefix_rank = + if candidate.starts_with(&query) || query.starts_with(&candidate) { + 0 + } else if candidate.contains(&query) || query.contains(&candidate) { + 1 + } else { + 2 + }; + let distance = levenshtein_distance(&candidate, &query); + (prefix_rank, distance) }) .min(); - score.map(|(bucket, len)| (bucket, len, render_slash_command_name(spec))) + best.and_then(|(prefix_rank, distance)| { + if prefix_rank <= 1 || distance <= 2 { + Some((prefix_rank, distance, spec.name.len(), spec.name)) + } else { + None + } + }) }) .collect::>(); - ranked.sort(); - ranked.dedup_by(|left, right| left.2 == right.2); - ranked + suggestions.sort_unstable(); + suggestions .into_iter() + .map(|(_, _, _, name)| format!("/{name}")) .take(limit) - .map(|(_, _, display)| display) .collect() } +#[must_use] +/// Render the slash-command help section, optionally excluding stub commands +/// (commands that are registered in the spec list but not yet implemented). +/// Pass an empty slice to include all commands. +pub fn render_slash_command_help_filtered(exclude: &[&str]) -> String { + let mut lines = vec![ + "Slash commands".to_string(), + " Start here /status, /diff, /agents, /skills, /commit".to_string(), + " [resume] also works with --resume SESSION.jsonl".to_string(), + String::new(), + ]; + + let categories = ["Session", "Tools", "Config", "Debug"]; + + for category in categories { + lines.push(category.to_string()); + for spec in slash_command_specs() + .iter() + .filter(|spec| slash_command_category(spec.name) == category) + .filter(|spec| !exclude.contains(&spec.name)) + { + lines.push(format_slash_command_help_line(spec)); + } + lines.push(String::new()); + } + + lines + .into_iter() + .rev() + .skip_while(String::is_empty) + .collect::>() + .into_iter() + .rev() + .collect::>() + .join("\n") +} + +pub fn render_slash_command_help() -> String { + let mut lines = vec![ + "Slash commands".to_string(), + " Start here /status, /diff, /agents, /skills, /commit".to_string(), + " [resume] also works with --resume SESSION.jsonl".to_string(), + String::new(), + ]; + + let categories = ["Session", "Tools", "Config", "Debug"]; + + for category in categories { + lines.push(category.to_string()); + for spec in slash_command_specs() + .iter() + .filter(|spec| slash_command_category(spec.name) == category) + { + lines.push(format_slash_command_help_line(spec)); + } + lines.push(String::new()); + } + + lines.push("Keyboard shortcuts".to_string()); + lines.push(" Up/Down Navigate prompt history".to_string()); + lines.push(" Tab Complete commands, modes, and recent sessions".to_string()); + lines.push(" Ctrl-C Clear input (or exit on empty prompt)".to_string()); + lines.push(" Shift+Enter/Ctrl+J Insert a newline".to_string()); + + lines + .into_iter() + .rev() + .skip_while(String::is_empty) + .collect::>() + .into_iter() + .rev() + .collect::>() + .join("\n") +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct SlashCommandResult { pub message: String, @@ -626,23 +2104,47 @@ pub struct PluginsCommandResult { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] enum DefinitionSource { - ProjectCodex, ProjectClaw, + ProjectCodex, + ProjectClaude, + UserClawConfigHome, UserCodexHome, - UserCodex, UserClaw, + UserCodex, + UserClaude, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum DefinitionScope { + Project, + UserConfigHome, + UserHome, +} + +impl DefinitionScope { + fn label(self) -> &'static str { + match self { + Self::Project => "Project roots", + Self::UserConfigHome => "User config roots", + Self::UserHome => "User home roots", + } + } } impl DefinitionSource { - fn label(self) -> &'static str { + fn report_scope(self) -> DefinitionScope { match self { - Self::ProjectCodex => "Project (.codex)", - Self::ProjectClaw => "Project (.claw)", - Self::UserCodexHome => "User ($CODEX_HOME)", - Self::UserCodex => "User (~/.codex)", - Self::UserClaw => "User (~/.claw)", + Self::ProjectClaw | Self::ProjectCodex | Self::ProjectClaude => { + DefinitionScope::Project + } + Self::UserClawConfigHome | Self::UserCodexHome => DefinitionScope::UserConfigHome, + Self::UserClaw | Self::UserCodex | Self::UserClaude => DefinitionScope::UserHome, } } + + fn label(self) -> &'static str { + self.report_scope().label() + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -686,6 +2188,21 @@ struct SkillRoot { origin: SkillOrigin, } +#[derive(Debug, Clone, PartialEq, Eq)] +struct InstalledSkill { + invocation_name: String, + display_name: Option, + source: PathBuf, + registry_root: PathBuf, + installed_path: PathBuf, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum SkillInstallSource { + Directory { root: PathBuf, prompt_path: PathBuf }, + MarkdownFile { path: PathBuf }, +} + #[allow(clippy::too_many_lines)] pub fn handle_plugins_slash_command( action: Option<&str>, @@ -799,413 +2316,333 @@ pub fn handle_plugins_slash_command( } pub fn handle_agents_slash_command(args: Option<&str>, cwd: &Path) -> std::io::Result { + if let Some(args) = normalize_optional_args(args) { + if let Some(help_path) = help_path_from_args(args) { + return Ok(match help_path.as_slice() { + [] => render_agents_usage(None), + _ => render_agents_usage(Some(&help_path.join(" "))), + }); + } + } + match normalize_optional_args(args) { None | Some("list") => { let roots = discover_definition_roots(cwd, "agents"); let agents = load_agents_from_roots(&roots)?; Ok(render_agents_report(&agents)) } - Some("-h" | "--help" | "help") => Ok(render_agents_usage(None)), + Some(args) if is_help_arg(args) => Ok(render_agents_usage(None)), Some(args) => Ok(render_agents_usage(Some(args))), } } +pub fn handle_agents_slash_command_json(args: Option<&str>, cwd: &Path) -> std::io::Result { + if let Some(args) = normalize_optional_args(args) { + if let Some(help_path) = help_path_from_args(args) { + return Ok(match help_path.as_slice() { + [] => render_agents_usage_json(None), + _ => render_agents_usage_json(Some(&help_path.join(" "))), + }); + } + } + + match normalize_optional_args(args) { + None | Some("list") => { + let roots = discover_definition_roots(cwd, "agents"); + let agents = load_agents_from_roots(&roots)?; + Ok(render_agents_report_json(cwd, &agents)) + } + Some(args) if is_help_arg(args) => Ok(render_agents_usage_json(None)), + Some(args) => Ok(render_agents_usage_json(Some(args))), + } +} + +pub fn handle_mcp_slash_command( + args: Option<&str>, + cwd: &Path, +) -> Result { + let loader = ConfigLoader::default_for(cwd); + render_mcp_report_for(&loader, cwd, args) +} + +pub fn handle_mcp_slash_command_json( + args: Option<&str>, + cwd: &Path, +) -> Result { + let loader = ConfigLoader::default_for(cwd); + render_mcp_report_json_for(&loader, cwd, args) +} + pub fn handle_skills_slash_command(args: Option<&str>, cwd: &Path) -> std::io::Result { + if let Some(args) = normalize_optional_args(args) { + if let Some(help_path) = help_path_from_args(args) { + return Ok(match help_path.as_slice() { + [] => render_skills_usage(None), + ["install", ..] => render_skills_usage(Some("install")), + _ => render_skills_usage(Some(&help_path.join(" "))), + }); + } + } + match normalize_optional_args(args) { None | Some("list") => { let roots = discover_skill_roots(cwd); let skills = load_skills_from_roots(&roots)?; Ok(render_skills_report(&skills)) } - Some("-h" | "--help" | "help") => Ok(render_skills_usage(None)), + Some("install") => Ok(render_skills_usage(Some("install"))), + Some(args) if args.starts_with("install ") => { + let target = args["install ".len()..].trim(); + if target.is_empty() { + return Ok(render_skills_usage(Some("install"))); + } + let install = install_skill(target, cwd)?; + Ok(render_skill_install_report(&install)) + } + Some(args) if is_help_arg(args) => Ok(render_skills_usage(None)), Some(args) => Ok(render_skills_usage(Some(args))), } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CommitPushPrRequest { - pub commit_message: Option, - pub pr_title: String, - pub pr_body: String, - pub branch_name_hint: String, -} +pub fn handle_skills_slash_command_json(args: Option<&str>, cwd: &Path) -> std::io::Result { + if let Some(args) = normalize_optional_args(args) { + if let Some(help_path) = help_path_from_args(args) { + return Ok(match help_path.as_slice() { + [] => render_skills_usage_json(None), + ["install", ..] => render_skills_usage_json(Some("install")), + _ => render_skills_usage_json(Some(&help_path.join(" "))), + }); + } + } -pub fn handle_branch_slash_command( - action: Option<&str>, - target: Option<&str>, - cwd: &Path, -) -> io::Result { - match normalize_optional_args(action) { + match normalize_optional_args(args) { None | Some("list") => { - let branches = git_stdout(cwd, &["branch", "--list", "--verbose"])?; - let trimmed = branches.trim(); - Ok(if trimmed.is_empty() { - "Branch\n Result no branches found".to_string() - } else { - format!("Branch\n Result listed\n\n{trimmed}") - }) + let roots = discover_skill_roots(cwd); + let skills = load_skills_from_roots(&roots)?; + Ok(render_skills_report_json(&skills)) } - Some("create") => { - let Some(target) = target.filter(|value| !value.trim().is_empty()) else { - return Ok("Usage: /branch create ".to_string()); - }; - git_status_ok(cwd, &["switch", "-c", target])?; - Ok(format!( - "Branch\n Result created and switched\n Branch {target}" - )) + Some("install") => Ok(render_skills_usage_json(Some("install"))), + Some(args) if args.starts_with("install ") => { + let target = args["install ".len()..].trim(); + if target.is_empty() { + return Ok(render_skills_usage_json(Some("install"))); + } + let install = install_skill(target, cwd)?; + Ok(render_skill_install_report_json(&install)) } - Some("switch") => { - let Some(target) = target.filter(|value| !value.trim().is_empty()) else { - return Ok("Usage: /branch switch ".to_string()); - }; - git_status_ok(cwd, &["switch", target])?; - Ok(format!( - "Branch\n Result switched\n Branch {target}" - )) - } - Some(other) => Ok(format!( - "Unknown /branch action '{other}'. Use /branch list, /branch create , or /branch switch ." - )), + Some(args) if is_help_arg(args) => Ok(render_skills_usage_json(None)), + Some(args) => Ok(render_skills_usage_json(Some(args))), } } -pub fn handle_worktree_slash_command( - action: Option<&str>, - path: Option<&str>, - branch: Option<&str>, - cwd: &Path, -) -> io::Result { - match normalize_optional_args(action) { - None | Some("list") => { - let worktrees = git_stdout(cwd, &["worktree", "list"])?; - let trimmed = worktrees.trim(); - Ok(if trimmed.is_empty() { - "Worktree\n Result no worktrees found".to_string() - } else { - format!("Worktree\n Result listed\n\n{trimmed}") - }) +#[must_use] +pub fn classify_skills_slash_command(args: Option<&str>) -> SkillSlashDispatch { + match normalize_optional_args(args) { + None | Some("list" | "help" | "-h" | "--help") => SkillSlashDispatch::Local, + Some(args) if args == "install" || args.starts_with("install ") => { + SkillSlashDispatch::Local } - Some("add") => { - let Some(path) = path.filter(|value| !value.trim().is_empty()) else { - return Ok("Usage: /worktree add [branch]".to_string()); - }; - if let Some(branch) = branch.filter(|value| !value.trim().is_empty()) { - if branch_exists(cwd, branch) { - git_status_ok(cwd, &["worktree", "add", path, branch])?; - } else { - git_status_ok(cwd, &["worktree", "add", path, "-b", branch])?; + Some(args) => SkillSlashDispatch::Invoke(format!("${}", args.trim_start_matches('/'))), + } +} + +/// Resolve a skill invocation by validating the skill exists on disk before +/// returning the dispatch. When the skill is not found, returns `Err` with a +/// human-readable message that lists nearby skill names. +pub fn resolve_skill_invocation( + cwd: &Path, + args: Option<&str>, +) -> Result { + let dispatch = classify_skills_slash_command(args); + if let SkillSlashDispatch::Invoke(ref prompt) = dispatch { + // Extract the skill name from the "$skill [args]" prompt. + let skill_token = prompt + .trim_start_matches('$') + .split_whitespace() + .next() + .unwrap_or_default(); + if !skill_token.is_empty() { + if let Err(error) = resolve_skill_path(cwd, skill_token) { + let mut message = format!("Unknown skill: {skill_token} ({error})"); + let roots = discover_skill_roots(cwd); + if let Ok(available) = load_skills_from_roots(&roots) { + let names: Vec = available + .iter() + .filter(|s| s.shadowed_by.is_none()) + .map(|s| s.name.clone()) + .collect(); + if !names.is_empty() { + message.push_str(&format!("\n Available skills: {}", names.join(", "))); + } } - Ok(format!( - "Worktree\n Result added\n Path {path}\n Branch {branch}" - )) - } else { - git_status_ok(cwd, &["worktree", "add", path])?; - Ok(format!( - "Worktree\n Result added\n Path {path}" - )) + message.push_str("\n Usage: /skills [list|install |help| [args]]"); + return Err(message); } } - Some("remove") => { - let Some(path) = path.filter(|value| !value.trim().is_empty()) else { - return Ok("Usage: /worktree remove ".to_string()); - }; - git_status_ok(cwd, &["worktree", "remove", path])?; - Ok(format!( - "Worktree\n Result removed\n Path {path}" - )) - } - Some("prune") => { - git_status_ok(cwd, &["worktree", "prune"])?; - Ok("Worktree\n Result pruned".to_string()) - } - Some(other) => Ok(format!( - "Unknown /worktree action '{other}'. Use /worktree list, /worktree add [branch], /worktree remove , or /worktree prune." - )), } + Ok(dispatch) } -pub fn handle_commit_slash_command(message: &str, cwd: &Path) -> io::Result { - let status = git_stdout(cwd, &["status", "--short"])?; - if status.trim().is_empty() { - return Ok( - "Commit\n Result skipped\n Reason no workspace changes" - .to_string(), - ); +pub fn resolve_skill_path(cwd: &Path, skill: &str) -> std::io::Result { + let requested = skill.trim().trim_start_matches('/').trim_start_matches('$'); + if requested.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "skill must not be empty", + )); } - let message = message.trim(); - if message.is_empty() { - return Err(io::Error::other("generated commit message was empty")); + let roots = discover_skill_roots(cwd); + for root in &roots { + let mut entries = Vec::new(); + for entry in fs::read_dir(&root.path)? { + let entry = entry?; + match root.origin { + SkillOrigin::SkillsDir => { + if !entry.path().is_dir() { + continue; + } + let skill_path = entry.path().join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + let contents = fs::read_to_string(&skill_path)?; + let (name, _) = parse_skill_frontmatter(&contents); + entries.push(( + name.unwrap_or_else(|| entry.file_name().to_string_lossy().to_string()), + skill_path, + )); + } + SkillOrigin::LegacyCommandsDir => { + let path = entry.path(); + let markdown_path = if path.is_dir() { + let skill_path = path.join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + skill_path + } else if path + .extension() + .is_some_and(|ext| ext.to_string_lossy().eq_ignore_ascii_case("md")) + { + path + } else { + continue; + }; + + let contents = fs::read_to_string(&markdown_path)?; + let fallback_name = markdown_path.file_stem().map_or_else( + || entry.file_name().to_string_lossy().to_string(), + |stem| stem.to_string_lossy().to_string(), + ); + let (name, _) = parse_skill_frontmatter(&contents); + entries.push((name.unwrap_or(fallback_name), markdown_path)); + } + } + } + entries.sort_by(|left, right| left.0.cmp(&right.0)); + if let Some((_, path)) = entries + .into_iter() + .find(|(name, _)| name.eq_ignore_ascii_case(requested)) + { + return Ok(path); + } } - git_status_ok(cwd, &["add", "-A"])?; - let path = write_temp_text_file("claw-commit-message", "txt", message)?; - let path_string = path.to_string_lossy().into_owned(); - git_status_ok(cwd, &["commit", "--file", path_string.as_str()])?; - - Ok(format!( - "Commit\n Result created\n Message file {}\n\n{}", - path.display(), - message + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("unknown skill: {requested}"), )) } -pub fn handle_commit_push_pr_slash_command( - request: &CommitPushPrRequest, +fn render_mcp_report_for( + loader: &ConfigLoader, cwd: &Path, -) -> io::Result { - if !command_exists("gh") { - return Err(io::Error::other("gh CLI is required for /commit-push-pr")); - } - - let default_branch = detect_default_branch(cwd)?; - let mut branch = current_branch(cwd)?; - let mut created_branch = false; - if branch == default_branch { - let hint = if request.branch_name_hint.trim().is_empty() { - request.pr_title.as_str() - } else { - request.branch_name_hint.as_str() - }; - let next_branch = build_branch_name(hint); - git_status_ok(cwd, &["switch", "-c", next_branch.as_str()])?; - branch = next_branch; - created_branch = true; - } - - let workspace_has_changes = !git_stdout(cwd, &["status", "--short"])?.trim().is_empty(); - let commit_report = if workspace_has_changes { - let Some(message) = request.commit_message.as_deref() else { - return Err(io::Error::other( - "commit message is required when workspace changes are present", - )); - }; - Some(handle_commit_slash_command(message, cwd)?) - } else { - None - }; - - let branch_diff = git_stdout( - cwd, - &["diff", "--stat", &format!("{default_branch}...HEAD")], - )?; - if branch_diff.trim().is_empty() { - return Ok( - "Commit/Push/PR\n Result skipped\n Reason no branch changes to push or open as a pull request" - .to_string(), - ); - } - - git_status_ok(cwd, &["push", "--set-upstream", "origin", branch.as_str()])?; - - let body_path = write_temp_text_file("claw-pr-body", "md", request.pr_body.trim())?; - let body_path_string = body_path.to_string_lossy().into_owned(); - let create = Command::new("gh") - .args([ - "pr", - "create", - "--title", - request.pr_title.as_str(), - "--body-file", - body_path_string.as_str(), - "--base", - default_branch.as_str(), - ]) - .current_dir(cwd) - .output()?; - - let (result, url) = if create.status.success() { - ( - "created", - parse_pr_url(&String::from_utf8_lossy(&create.stdout)) - .unwrap_or_else(|| "".to_string()), - ) - } else { - let view = Command::new("gh") - .args(["pr", "view", "--json", "url"]) - .current_dir(cwd) - .output()?; - if !view.status.success() { - return Err(io::Error::other(command_failure( - "gh", - &["pr", "create"], - &create, - ))); - } - ( - "existing", - parse_pr_json_url(&String::from_utf8_lossy(&view.stdout)) - .unwrap_or_else(|| "".to_string()), - ) - }; - - let mut lines = vec![ - "Commit/Push/PR".to_string(), - format!(" Result {result}"), - format!(" Branch {branch}"), - format!(" Base {default_branch}"), - format!(" Body file {}", body_path.display()), - format!(" URL {url}"), - ]; - if created_branch { - lines.insert(2, " Branch action created and switched".to_string()); - } - if let Some(report) = commit_report { - lines.push(String::new()); - lines.push(report); - } - Ok(lines.join("\n")) -} - -pub fn detect_default_branch(cwd: &Path) -> io::Result { - if let Ok(reference) = git_stdout(cwd, &["symbolic-ref", "refs/remotes/origin/HEAD"]) { - if let Some(branch) = reference - .trim() - .rsplit('/') - .next() - .filter(|value| !value.is_empty()) - { - return Ok(branch.to_string()); + args: Option<&str>, +) -> Result { + if let Some(args) = normalize_optional_args(args) { + if let Some(help_path) = help_path_from_args(args) { + return Ok(match help_path.as_slice() { + [] => render_mcp_usage(None), + ["show", ..] => render_mcp_usage(Some("show")), + _ => render_mcp_usage(Some(&help_path.join(" "))), + }); } } - for branch in ["main", "master"] { - if branch_exists(cwd, branch) { - return Ok(branch.to_string()); + match normalize_optional_args(args) { + None | Some("list") => { + let runtime_config = loader.load()?; + Ok(render_mcp_summary_report( + cwd, + runtime_config.mcp().servers(), + )) + } + Some(args) if is_help_arg(args) => Ok(render_mcp_usage(None)), + Some("show") => Ok(render_mcp_usage(Some("show"))), + Some(args) if args.split_whitespace().next() == Some("show") => { + let mut parts = args.split_whitespace(); + let _ = parts.next(); + let Some(server_name) = parts.next() else { + return Ok(render_mcp_usage(Some("show"))); + }; + if parts.next().is_some() { + return Ok(render_mcp_usage(Some(args))); + } + let runtime_config = loader.load()?; + Ok(render_mcp_server_report( + cwd, + server_name, + runtime_config.mcp().get(server_name), + )) + } + Some(args) => Ok(render_mcp_usage(Some(args))), + } +} + +fn render_mcp_report_json_for( + loader: &ConfigLoader, + cwd: &Path, + args: Option<&str>, +) -> Result { + if let Some(args) = normalize_optional_args(args) { + if let Some(help_path) = help_path_from_args(args) { + return Ok(match help_path.as_slice() { + [] => render_mcp_usage_json(None), + ["show", ..] => render_mcp_usage_json(Some("show")), + _ => render_mcp_usage_json(Some(&help_path.join(" "))), + }); } } - current_branch(cwd) -} - -fn git_stdout(cwd: &Path, args: &[&str]) -> io::Result { - run_command_stdout("git", args, cwd) -} - -fn git_status_ok(cwd: &Path, args: &[&str]) -> io::Result<()> { - run_command_success("git", args, cwd) -} - -fn run_command_stdout(program: &str, args: &[&str], cwd: &Path) -> io::Result { - let output = Command::new(program).args(args).current_dir(cwd).output()?; - if !output.status.success() { - return Err(io::Error::other(command_failure(program, args, &output))); - } - String::from_utf8(output.stdout) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)) -} - -fn run_command_success(program: &str, args: &[&str], cwd: &Path) -> io::Result<()> { - let output = Command::new(program).args(args).current_dir(cwd).output()?; - if !output.status.success() { - return Err(io::Error::other(command_failure(program, args, &output))); - } - Ok(()) -} - -fn command_failure(program: &str, args: &[&str], output: &std::process::Output) -> String { - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); - let detail = if stderr.is_empty() { stdout } else { stderr }; - if detail.is_empty() { - format!("{program} {} failed", args.join(" ")) - } else { - format!("{program} {} failed: {detail}", args.join(" ")) - } -} - -fn branch_exists(cwd: &Path, branch: &str) -> bool { - Command::new("git") - .args([ - "show-ref", - "--verify", - "--quiet", - &format!("refs/heads/{branch}"), - ]) - .current_dir(cwd) - .output() - .map(|output| output.status.success()) - .unwrap_or(false) -} - -fn current_branch(cwd: &Path) -> io::Result { - let branch = git_stdout(cwd, &["branch", "--show-current"])?; - let branch = branch.trim(); - if branch.is_empty() { - Err(io::Error::other("unable to determine current git branch")) - } else { - Ok(branch.to_string()) - } -} - -fn command_exists(name: &str) -> bool { - Command::new(name) - .arg("--version") - .output() - .map(|output| output.status.success()) - .unwrap_or(false) -} - -fn write_temp_text_file(prefix: &str, extension: &str, contents: &str) -> io::Result { - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_nanos()) - .unwrap_or_default(); - let path = env::temp_dir().join(format!("{prefix}-{nanos}.{extension}")); - fs::write(&path, contents)?; - Ok(path) -} - -fn build_branch_name(hint: &str) -> String { - let slug = slugify(hint); - let owner = env::var("SAFEUSER") - .ok() - .filter(|value| !value.trim().is_empty()) - .or_else(|| { - env::var("USER") - .ok() - .filter(|value| !value.trim().is_empty()) - }); - match owner { - Some(owner) => format!("{owner}/{slug}"), - None => slug, - } -} - -fn slugify(value: &str) -> String { - let mut slug = String::new(); - let mut last_was_dash = false; - for ch in value.chars() { - if ch.is_ascii_alphanumeric() { - slug.push(ch.to_ascii_lowercase()); - last_was_dash = false; - } else if !last_was_dash { - slug.push('-'); - last_was_dash = true; + match normalize_optional_args(args) { + None | Some("list") => { + let runtime_config = loader.load()?; + Ok(render_mcp_summary_report_json( + cwd, + runtime_config.mcp().servers(), + )) } + Some(args) if is_help_arg(args) => Ok(render_mcp_usage_json(None)), + Some("show") => Ok(render_mcp_usage_json(Some("show"))), + Some(args) if args.split_whitespace().next() == Some("show") => { + let mut parts = args.split_whitespace(); + let _ = parts.next(); + let Some(server_name) = parts.next() else { + return Ok(render_mcp_usage_json(Some("show"))); + }; + if parts.next().is_some() { + return Ok(render_mcp_usage_json(Some(args))); + } + let runtime_config = loader.load()?; + Ok(render_mcp_server_report_json( + cwd, + server_name, + runtime_config.mcp().get(server_name), + )) + } + Some(args) => Ok(render_mcp_usage_json(Some(args))), } - let slug = slug.trim_matches('-').to_string(); - if slug.is_empty() { - "change".to_string() - } else { - slug - } -} - -fn parse_pr_url(stdout: &str) -> Option { - stdout - .lines() - .map(str::trim) - .find(|line| line.starts_with("http://") || line.starts_with("https://")) - .map(ToOwned::to_owned) -} - -fn parse_pr_json_url(stdout: &str) -> Option { - serde_json::from_str::(stdout) - .ok()? - .get("url")? - .as_str() - .map(ToOwned::to_owned) } #[must_use] @@ -1264,6 +2701,11 @@ fn discover_definition_roots(cwd: &Path, leaf: &str) -> Vec<(DefinitionSource, P let mut roots = Vec::new(); for ancestor in cwd.ancestors() { + push_unique_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join(leaf), + ); push_unique_root( &mut roots, DefinitionSource::ProjectCodex, @@ -1271,8 +2713,16 @@ fn discover_definition_roots(cwd: &Path, leaf: &str) -> Vec<(DefinitionSource, P ); push_unique_root( &mut roots, - DefinitionSource::ProjectClaw, - ancestor.join(".claw").join(leaf), + DefinitionSource::ProjectClaude, + ancestor.join(".claude").join(leaf), + ); + } + + if let Ok(claw_config_home) = env::var("CLAW_CONFIG_HOME") { + push_unique_root( + &mut roots, + DefinitionSource::UserClawConfigHome, + PathBuf::from(claw_config_home).join(leaf), ); } @@ -1284,8 +2734,21 @@ fn discover_definition_roots(cwd: &Path, leaf: &str) -> Vec<(DefinitionSource, P ); } + if let Ok(claude_config_dir) = env::var("CLAUDE_CONFIG_DIR") { + push_unique_root( + &mut roots, + DefinitionSource::UserClaude, + PathBuf::from(claude_config_dir).join(leaf), + ); + } + if let Some(home) = env::var_os("HOME") { let home = PathBuf::from(home); + push_unique_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join(leaf), + ); push_unique_root( &mut roots, DefinitionSource::UserCodex, @@ -1293,18 +2756,37 @@ fn discover_definition_roots(cwd: &Path, leaf: &str) -> Vec<(DefinitionSource, P ); push_unique_root( &mut roots, - DefinitionSource::UserClaw, - home.join(".claw").join(leaf), + DefinitionSource::UserClaude, + home.join(".claude").join(leaf), ); } roots } +#[allow(clippy::too_many_lines)] fn discover_skill_roots(cwd: &Path) -> Vec { let mut roots = Vec::new(); for ancestor in cwd.ancestors() { + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".omc").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".agents").join("skills"), + SkillOrigin::SkillsDir, + ); push_unique_skill_root( &mut roots, DefinitionSource::ProjectCodex, @@ -1313,10 +2795,16 @@ fn discover_skill_roots(cwd: &Path) -> Vec { ); push_unique_skill_root( &mut roots, - DefinitionSource::ProjectClaw, - ancestor.join(".claw").join("skills"), + DefinitionSource::ProjectClaude, + ancestor.join(".claude").join("skills"), SkillOrigin::SkillsDir, ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); push_unique_skill_root( &mut roots, DefinitionSource::ProjectCodex, @@ -1325,8 +2813,24 @@ fn discover_skill_roots(cwd: &Path) -> Vec { ); push_unique_skill_root( &mut roots, - DefinitionSource::ProjectClaw, - ancestor.join(".claw").join("commands"), + DefinitionSource::ProjectClaude, + ancestor.join(".claude").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + if let Ok(claw_config_home) = env::var("CLAW_CONFIG_HOME") { + let claw_config_home = PathBuf::from(claw_config_home); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClawConfigHome, + claw_config_home.join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClawConfigHome, + claw_config_home.join("commands"), SkillOrigin::LegacyCommandsDir, ); } @@ -1349,6 +2853,24 @@ fn discover_skill_roots(cwd: &Path) -> Vec { if let Some(home) = env::var_os("HOME") { let home = PathBuf::from(home); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".omc").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); push_unique_skill_root( &mut roots, DefinitionSource::UserCodex, @@ -1363,14 +2885,43 @@ fn discover_skill_roots(cwd: &Path) -> Vec { ); push_unique_skill_root( &mut roots, - DefinitionSource::UserClaw, - home.join(".claw").join("skills"), + DefinitionSource::UserClaude, + home.join(".claude").join("skills"), SkillOrigin::SkillsDir, ); push_unique_skill_root( &mut roots, - DefinitionSource::UserClaw, - home.join(".claw").join("commands"), + DefinitionSource::UserClaude, + home.join(".claude").join("skills").join("omc-learned"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaude, + home.join(".claude").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + if let Ok(claude_config_dir) = env::var("CLAUDE_CONFIG_DIR") { + let claude_config_dir = PathBuf::from(claude_config_dir); + let skills_dir = claude_config_dir.join("skills"); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaude, + skills_dir.clone(), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaude, + skills_dir.join("omc-learned"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaude, + claude_config_dir.join("commands"), SkillOrigin::LegacyCommandsDir, ); } @@ -1378,6 +2929,205 @@ fn discover_skill_roots(cwd: &Path) -> Vec { roots } +fn install_skill(source: &str, cwd: &Path) -> std::io::Result { + let registry_root = default_skill_install_root()?; + install_skill_into(source, cwd, ®istry_root) +} + +fn install_skill_into( + source: &str, + cwd: &Path, + registry_root: &Path, +) -> std::io::Result { + let source = resolve_skill_install_source(source, cwd)?; + let prompt_path = source.prompt_path(); + let contents = fs::read_to_string(prompt_path)?; + let display_name = parse_skill_frontmatter(&contents).0; + let invocation_name = derive_skill_install_name(&source, display_name.as_deref())?; + let installed_path = registry_root.join(&invocation_name); + + if installed_path.exists() { + return Err(std::io::Error::new( + std::io::ErrorKind::AlreadyExists, + format!( + "skill '{invocation_name}' is already installed at {}", + installed_path.display() + ), + )); + } + + fs::create_dir_all(&installed_path)?; + let install_result = match &source { + SkillInstallSource::Directory { root, .. } => { + copy_directory_contents(root, &installed_path) + } + SkillInstallSource::MarkdownFile { path } => { + fs::copy(path, installed_path.join("SKILL.md")).map(|_| ()) + } + }; + if let Err(error) = install_result { + let _ = fs::remove_dir_all(&installed_path); + return Err(error); + } + + Ok(InstalledSkill { + invocation_name, + display_name, + source: source.report_path().to_path_buf(), + registry_root: registry_root.to_path_buf(), + installed_path, + }) +} + +fn default_skill_install_root() -> std::io::Result { + if let Ok(claw_config_home) = env::var("CLAW_CONFIG_HOME") { + return Ok(PathBuf::from(claw_config_home).join("skills")); + } + if let Ok(codex_home) = env::var("CODEX_HOME") { + return Ok(PathBuf::from(codex_home).join("skills")); + } + if let Some(home) = env::var_os("HOME") { + return Ok(PathBuf::from(home).join(".claw").join("skills")); + } + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "unable to resolve a skills install root; set CLAW_CONFIG_HOME or HOME", + )) +} + +fn resolve_skill_install_source(source: &str, cwd: &Path) -> std::io::Result { + let candidate = PathBuf::from(source); + let source = if candidate.is_absolute() { + candidate + } else { + cwd.join(candidate) + }; + let source = fs::canonicalize(&source)?; + + if source.is_dir() { + let prompt_path = source.join("SKILL.md"); + if !prompt_path.is_file() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "skill directory '{}' must contain SKILL.md", + source.display() + ), + )); + } + return Ok(SkillInstallSource::Directory { + root: source, + prompt_path, + }); + } + + if source + .extension() + .is_some_and(|ext| ext.to_string_lossy().eq_ignore_ascii_case("md")) + { + return Ok(SkillInstallSource::MarkdownFile { path: source }); + } + + Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "skill source '{}' must be a directory with SKILL.md or a markdown file", + source.display() + ), + )) +} + +fn derive_skill_install_name( + source: &SkillInstallSource, + declared_name: Option<&str>, +) -> std::io::Result { + for candidate in [declared_name, source.fallback_name().as_deref()] { + if let Some(candidate) = candidate.and_then(sanitize_skill_invocation_name) { + return Ok(candidate); + } + } + + Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "unable to derive an installable invocation name from '{}'", + source.report_path().display() + ), + )) +} + +fn sanitize_skill_invocation_name(candidate: &str) -> Option { + let trimmed = candidate + .trim() + .trim_start_matches('/') + .trim_start_matches('$'); + if trimmed.is_empty() { + return None; + } + + let mut sanitized = String::new(); + let mut last_was_separator = false; + for ch in trimmed.chars() { + if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.') { + sanitized.push(ch.to_ascii_lowercase()); + last_was_separator = false; + } else if (ch.is_whitespace() || matches!(ch, '/' | '\\')) + && !last_was_separator + && !sanitized.is_empty() + { + sanitized.push('-'); + last_was_separator = true; + } + } + + let sanitized = sanitized + .trim_matches(|ch| matches!(ch, '-' | '_' | '.')) + .to_string(); + (!sanitized.is_empty()).then_some(sanitized) +} + +fn copy_directory_contents(source: &Path, destination: &Path) -> std::io::Result<()> { + for entry in fs::read_dir(source)? { + let entry = entry?; + let entry_type = entry.file_type()?; + let destination_path = destination.join(entry.file_name()); + if entry_type.is_dir() { + fs::create_dir_all(&destination_path)?; + copy_directory_contents(&entry.path(), &destination_path)?; + } else { + fs::copy(entry.path(), destination_path)?; + } + } + Ok(()) +} + +impl SkillInstallSource { + fn prompt_path(&self) -> &Path { + match self { + Self::Directory { prompt_path, .. } => prompt_path, + Self::MarkdownFile { path } => path, + } + } + + fn fallback_name(&self) -> Option { + match self { + Self::Directory { root, .. } => root + .file_name() + .map(|name| name.to_string_lossy().to_string()), + Self::MarkdownFile { path } => path + .file_stem() + .map(|name| name.to_string_lossy().to_string()), + } + } + + fn report_path(&self) -> &Path { + match self { + Self::Directory { root, .. } => root, + Self::MarkdownFile { path } => path, + } + } +} + fn push_unique_root( roots: &mut Vec<(DefinitionSource, PathBuf)>, source: DefinitionSource, @@ -1607,22 +3357,20 @@ fn render_agents_report(agents: &[AgentSummary]) -> String { String::new(), ]; - for source in [ - DefinitionSource::ProjectCodex, - DefinitionSource::ProjectClaw, - DefinitionSource::UserCodexHome, - DefinitionSource::UserCodex, - DefinitionSource::UserClaw, + for scope in [ + DefinitionScope::Project, + DefinitionScope::UserConfigHome, + DefinitionScope::UserHome, ] { let group = agents .iter() - .filter(|agent| agent.source == source) + .filter(|agent| agent.source.report_scope() == scope) .collect::>(); if group.is_empty() { continue; } - lines.push(format!("{}:", source.label())); + lines.push(format!("{}:", scope.label())); for agent in group { let detail = agent_detail(agent); match agent.shadowed_by { @@ -1636,6 +3384,25 @@ fn render_agents_report(agents: &[AgentSummary]) -> String { lines.join("\n").trim_end().to_string() } +fn render_agents_report_json(cwd: &Path, agents: &[AgentSummary]) -> Value { + let active = agents + .iter() + .filter(|agent| agent.shadowed_by.is_none()) + .count(); + json!({ + "kind": "agents", + "action": "list", + "working_directory": cwd.display().to_string(), + "count": agents.len(), + "summary": { + "total": agents.len(), + "active": active, + "shadowed": agents.len().saturating_sub(active), + }, + "agents": agents.iter().map(agent_summary_json).collect::>(), + }) +} + fn agent_detail(agent: &AgentSummary) -> String { let mut parts = vec![agent.name.clone()]; if let Some(description) = &agent.description { @@ -1665,22 +3432,20 @@ fn render_skills_report(skills: &[SkillSummary]) -> String { String::new(), ]; - for source in [ - DefinitionSource::ProjectCodex, - DefinitionSource::ProjectClaw, - DefinitionSource::UserCodexHome, - DefinitionSource::UserCodex, - DefinitionSource::UserClaw, + for scope in [ + DefinitionScope::Project, + DefinitionScope::UserConfigHome, + DefinitionScope::UserHome, ] { let group = skills .iter() - .filter(|skill| skill.source == source) + .filter(|skill| skill.source.report_scope() == scope) .collect::>(); if group.is_empty() { continue; } - lines.push(format!("{}:", source.label())); + lines.push(format!("{}:", scope.label())); for skill in group { let mut parts = vec![skill.name.clone()]; if let Some(description) = &skill.description { @@ -1701,16 +3466,224 @@ fn render_skills_report(skills: &[SkillSummary]) -> String { lines.join("\n").trim_end().to_string() } +fn render_skills_report_json(skills: &[SkillSummary]) -> Value { + let active = skills + .iter() + .filter(|skill| skill.shadowed_by.is_none()) + .count(); + json!({ + "kind": "skills", + "action": "list", + "summary": { + "total": skills.len(), + "active": active, + "shadowed": skills.len().saturating_sub(active), + }, + "skills": skills.iter().map(skill_summary_json).collect::>(), + }) +} + +fn render_skill_install_report(skill: &InstalledSkill) -> String { + let mut lines = vec![ + "Skills".to_string(), + format!(" Result installed {}", skill.invocation_name), + format!(" Invoke as ${}", skill.invocation_name), + ]; + if let Some(display_name) = &skill.display_name { + lines.push(format!(" Display name {display_name}")); + } + lines.push(format!(" Source {}", skill.source.display())); + lines.push(format!( + " Registry {}", + skill.registry_root.display() + )); + lines.push(format!( + " Installed path {}", + skill.installed_path.display() + )); + lines.join("\n") +} + +fn render_skill_install_report_json(skill: &InstalledSkill) -> Value { + json!({ + "kind": "skills", + "action": "install", + "result": "installed", + "invocation_name": &skill.invocation_name, + "invoke_as": format!("${}", skill.invocation_name), + "display_name": &skill.display_name, + "source": skill.source.display().to_string(), + "registry_root": skill.registry_root.display().to_string(), + "installed_path": skill.installed_path.display().to_string(), + }) +} + +fn render_mcp_summary_report( + cwd: &Path, + servers: &BTreeMap, +) -> String { + let mut lines = vec![ + "MCP".to_string(), + format!(" Working directory {}", cwd.display()), + format!(" Configured servers {}", servers.len()), + ]; + if servers.is_empty() { + lines.push(" No MCP servers configured.".to_string()); + return lines.join("\n"); + } + + lines.push(String::new()); + for (name, server) in servers { + lines.push(format!( + " {name:<16} {transport:<13} {scope:<7} {summary}", + transport = mcp_transport_label(&server.config), + scope = config_source_label(server.scope), + summary = mcp_server_summary(&server.config) + )); + } + + lines.join("\n") +} + +fn render_mcp_summary_report_json( + cwd: &Path, + servers: &BTreeMap, +) -> Value { + json!({ + "kind": "mcp", + "action": "list", + "working_directory": cwd.display().to_string(), + "configured_servers": servers.len(), + "servers": servers + .iter() + .map(|(name, server)| mcp_server_json(name, server)) + .collect::>(), + }) +} + +fn render_mcp_server_report( + cwd: &Path, + server_name: &str, + server: Option<&ScopedMcpServerConfig>, +) -> String { + let Some(server) = server else { + return format!( + "MCP\n Working directory {}\n Result server `{server_name}` is not configured", + cwd.display() + ); + }; + + let mut lines = vec![ + "MCP".to_string(), + format!(" Working directory {}", cwd.display()), + format!(" Name {server_name}"), + format!(" Scope {}", config_source_label(server.scope)), + format!( + " Transport {}", + mcp_transport_label(&server.config) + ), + ]; + + match &server.config { + McpServerConfig::Stdio(config) => { + lines.push(format!(" Command {}", config.command)); + lines.push(format!( + " Args {}", + format_optional_list(&config.args) + )); + lines.push(format!( + " Env keys {}", + format_optional_keys(config.env.keys().cloned().collect()) + )); + lines.push(format!( + " Tool timeout {}", + config + .tool_call_timeout_ms + .map_or_else(|| "".to_string(), |value| format!("{value} ms")) + )); + } + McpServerConfig::Sse(config) | McpServerConfig::Http(config) => { + lines.push(format!(" URL {}", config.url)); + lines.push(format!( + " Header keys {}", + format_optional_keys(config.headers.keys().cloned().collect()) + )); + lines.push(format!( + " Header helper {}", + config.headers_helper.as_deref().unwrap_or("") + )); + lines.push(format!( + " OAuth {}", + format_mcp_oauth(config.oauth.as_ref()) + )); + } + McpServerConfig::Ws(config) => { + lines.push(format!(" URL {}", config.url)); + lines.push(format!( + " Header keys {}", + format_optional_keys(config.headers.keys().cloned().collect()) + )); + lines.push(format!( + " Header helper {}", + config.headers_helper.as_deref().unwrap_or("") + )); + } + McpServerConfig::Sdk(config) => { + lines.push(format!(" SDK name {}", config.name)); + } + McpServerConfig::ManagedProxy(config) => { + lines.push(format!(" URL {}", config.url)); + lines.push(format!(" Proxy id {}", config.id)); + } + } + + lines.join("\n") +} + +fn render_mcp_server_report_json( + cwd: &Path, + server_name: &str, + server: Option<&ScopedMcpServerConfig>, +) -> Value { + match server { + Some(server) => json!({ + "kind": "mcp", + "action": "show", + "working_directory": cwd.display().to_string(), + "found": true, + "server": mcp_server_json(server_name, server), + }), + None => json!({ + "kind": "mcp", + "action": "show", + "working_directory": cwd.display().to_string(), + "found": false, + "server_name": server_name, + "message": format!("server `{server_name}` is not configured"), + }), + } +} + fn normalize_optional_args(args: Option<&str>) -> Option<&str> { args.map(str::trim).filter(|value| !value.is_empty()) } +fn is_help_arg(arg: &str) -> bool { + matches!(arg, "help" | "-h" | "--help") +} + +fn help_path_from_args(args: &str) -> Option> { + let parts = args.split_whitespace().collect::>(); + let help_index = parts.iter().position(|part| is_help_arg(part))?; + Some(parts[..help_index].to_vec()) +} + fn render_agents_usage(unexpected: Option<&str>) -> String { let mut lines = vec![ "Agents".to_string(), - " Usage /agents".to_string(), + " Usage /agents [list|help]".to_string(), " Direct CLI claw agents".to_string(), - " Sources .codex/agents, .claw/agents, $CODEX_HOME/agents".to_string(), + " Sources .claw/agents, ~/.claw/agents, $CLAW_CONFIG_HOME/agents".to_string(), ]; if let Some(args) = unexpected { lines.push(format!(" Unexpected {args}")); @@ -1718,12 +3691,28 @@ fn render_agents_usage(unexpected: Option<&str>) -> String { lines.join("\n") } +fn render_agents_usage_json(unexpected: Option<&str>) -> Value { + json!({ + "kind": "agents", + "action": "help", + "usage": { + "slash_command": "/agents [list|help]", + "direct_cli": "claw agents [list|help]", + "sources": [".claw/agents", "~/.claw/agents", "$CLAW_CONFIG_HOME/agents"], + }, + "unexpected": unexpected, + }) +} + fn render_skills_usage(unexpected: Option<&str>) -> String { let mut lines = vec![ "Skills".to_string(), - " Usage /skills".to_string(), - " Direct CLI claw skills".to_string(), - " Sources .codex/skills, .claw/skills, legacy /commands".to_string(), + " Usage /skills [list|install |help| [args]]".to_string(), + " Alias /skill".to_string(), + " Direct CLI claw skills [list|install |help| [args]]".to_string(), + " Invoke /skills help overview -> $help overview".to_string(), + " Install root $CLAW_CONFIG_HOME/skills or ~/.claw/skills".to_string(), + " Sources .claw/skills, .omc/skills, .agents/skills, .codex/skills, .claude/skills, ~/.claw/skills, ~/.omc/skills, ~/.claude/skills/omc-learned, ~/.codex/skills, ~/.claude/skills, legacy /commands".to_string(), ]; if let Some(args) = unexpected { lines.push(format!(" Unexpected {args}")); @@ -1731,13 +3720,287 @@ fn render_skills_usage(unexpected: Option<&str>) -> String { lines.join("\n") } +fn render_skills_usage_json(unexpected: Option<&str>) -> Value { + json!({ + "kind": "skills", + "action": "help", + "usage": { + "slash_command": "/skills [list|install |help| [args]]", + "aliases": ["/skill"], + "direct_cli": "claw skills [list|install |help| [args]]", + "invoke": "/skills help overview -> $help overview", + "install_root": "$CLAW_CONFIG_HOME/skills or ~/.claw/skills", + "sources": [ + ".claw/skills", + ".omc/skills", + ".agents/skills", + ".codex/skills", + ".claude/skills", + "~/.claw/skills", + "~/.omc/skills", + "~/.claude/skills/omc-learned", + "~/.codex/skills", + "~/.claude/skills", + "legacy /commands", + "legacy fallback dirs still load automatically" + ], + }, + "unexpected": unexpected, + }) +} + +fn render_mcp_usage(unexpected: Option<&str>) -> String { + let mut lines = vec![ + "MCP".to_string(), + " Usage /mcp [list|show |help]".to_string(), + " Direct CLI claw mcp [list|show |help]".to_string(), + " Sources .claw/settings.json, .claw/settings.local.json".to_string(), + ]; + if let Some(args) = unexpected { + lines.push(format!(" Unexpected {args}")); + } + lines.join("\n") +} + +fn render_mcp_usage_json(unexpected: Option<&str>) -> Value { + json!({ + "kind": "mcp", + "action": "help", + "usage": { + "slash_command": "/mcp [list|show |help]", + "direct_cli": "claw mcp [list|show |help]", + "sources": [".claw/settings.json", ".claw/settings.local.json"], + }, + "unexpected": unexpected, + }) +} + +fn config_source_label(source: ConfigSource) -> &'static str { + match source { + ConfigSource::User => "user", + ConfigSource::Project => "project", + ConfigSource::Local => "local", + } +} + +fn mcp_transport_label(config: &McpServerConfig) -> &'static str { + match config { + McpServerConfig::Stdio(_) => "stdio", + McpServerConfig::Sse(_) => "sse", + McpServerConfig::Http(_) => "http", + McpServerConfig::Ws(_) => "ws", + McpServerConfig::Sdk(_) => "sdk", + McpServerConfig::ManagedProxy(_) => "managed-proxy", + } +} + +fn mcp_server_summary(config: &McpServerConfig) -> String { + match config { + McpServerConfig::Stdio(config) => { + if config.args.is_empty() { + config.command.clone() + } else { + format!("{} {}", config.command, config.args.join(" ")) + } + } + McpServerConfig::Sse(config) | McpServerConfig::Http(config) => config.url.clone(), + McpServerConfig::Ws(config) => config.url.clone(), + McpServerConfig::Sdk(config) => config.name.clone(), + McpServerConfig::ManagedProxy(config) => format!("{} ({})", config.id, config.url), + } +} + +fn format_optional_list(values: &[String]) -> String { + if values.is_empty() { + "".to_string() + } else { + values.join(" ") + } +} + +fn format_optional_keys(mut keys: Vec) -> String { + if keys.is_empty() { + return "".to_string(); + } + keys.sort(); + keys.join(", ") +} + +fn format_mcp_oauth(oauth: Option<&McpOAuthConfig>) -> String { + let Some(oauth) = oauth else { + return "".to_string(); + }; + + let mut parts = Vec::new(); + if let Some(client_id) = &oauth.client_id { + parts.push(format!("client_id={client_id}")); + } + if let Some(port) = oauth.callback_port { + parts.push(format!("callback_port={port}")); + } + if let Some(url) = &oauth.auth_server_metadata_url { + parts.push(format!("metadata_url={url}")); + } + if let Some(xaa) = oauth.xaa { + parts.push(format!("xaa={xaa}")); + } + if parts.is_empty() { + "enabled".to_string() + } else { + parts.join(", ") + } +} + +fn definition_source_id(source: DefinitionSource) -> &'static str { + match source { + DefinitionSource::ProjectClaw + | DefinitionSource::ProjectCodex + | DefinitionSource::ProjectClaude => "project_claw", + DefinitionSource::UserClawConfigHome | DefinitionSource::UserCodexHome => { + "user_claw_config_home" + } + DefinitionSource::UserClaw | DefinitionSource::UserCodex | DefinitionSource::UserClaude => { + "user_claw" + } + } +} + +fn definition_source_json(source: DefinitionSource) -> Value { + json!({ + "id": definition_source_id(source), + "label": source.label(), + }) +} + +fn agent_summary_json(agent: &AgentSummary) -> Value { + json!({ + "name": &agent.name, + "description": &agent.description, + "model": &agent.model, + "reasoning_effort": &agent.reasoning_effort, + "source": definition_source_json(agent.source), + "active": agent.shadowed_by.is_none(), + "shadowed_by": agent.shadowed_by.map(definition_source_json), + }) +} + +fn skill_origin_id(origin: SkillOrigin) -> &'static str { + match origin { + SkillOrigin::SkillsDir => "skills_dir", + SkillOrigin::LegacyCommandsDir => "legacy_commands_dir", + } +} + +fn skill_origin_json(origin: SkillOrigin) -> Value { + json!({ + "id": skill_origin_id(origin), + "detail_label": origin.detail_label(), + }) +} + +fn skill_summary_json(skill: &SkillSummary) -> Value { + json!({ + "name": &skill.name, + "description": &skill.description, + "source": definition_source_json(skill.source), + "origin": skill_origin_json(skill.origin), + "active": skill.shadowed_by.is_none(), + "shadowed_by": skill.shadowed_by.map(definition_source_json), + }) +} + +fn config_source_id(source: ConfigSource) -> &'static str { + match source { + ConfigSource::User => "user", + ConfigSource::Project => "project", + ConfigSource::Local => "local", + } +} + +fn config_source_json(source: ConfigSource) -> Value { + json!({ + "id": config_source_id(source), + "label": config_source_label(source), + }) +} + +fn mcp_transport_json(config: &McpServerConfig) -> Value { + let label = mcp_transport_label(config); + json!({ + "id": label, + "label": label, + }) +} + +fn mcp_oauth_json(oauth: Option<&McpOAuthConfig>) -> Value { + let Some(oauth) = oauth else { + return Value::Null; + }; + json!({ + "client_id": &oauth.client_id, + "callback_port": oauth.callback_port, + "auth_server_metadata_url": &oauth.auth_server_metadata_url, + "xaa": oauth.xaa, + }) +} + +fn mcp_server_details_json(config: &McpServerConfig) -> Value { + match config { + McpServerConfig::Stdio(config) => json!({ + "command": &config.command, + "args": &config.args, + "env_keys": config.env.keys().cloned().collect::>(), + "tool_call_timeout_ms": config.tool_call_timeout_ms, + }), + McpServerConfig::Sse(config) | McpServerConfig::Http(config) => json!({ + "url": &config.url, + "header_keys": config.headers.keys().cloned().collect::>(), + "headers_helper": &config.headers_helper, + "oauth": mcp_oauth_json(config.oauth.as_ref()), + }), + McpServerConfig::Ws(config) => json!({ + "url": &config.url, + "header_keys": config.headers.keys().cloned().collect::>(), + "headers_helper": &config.headers_helper, + }), + McpServerConfig::Sdk(config) => json!({ + "name": &config.name, + }), + McpServerConfig::ManagedProxy(config) => json!({ + "url": &config.url, + "id": &config.id, + }), + } +} + +fn mcp_server_json(name: &str, server: &ScopedMcpServerConfig) -> Value { + json!({ + "name": name, + "scope": config_source_json(server.scope), + "transport": mcp_transport_json(&server.config), + "summary": mcp_server_summary(&server.config), + "details": mcp_server_details_json(&server.config), + }) +} + #[must_use] pub fn handle_slash_command( input: &str, session: &Session, compaction: CompactionConfig, ) -> Option { - match SlashCommand::parse(input)? { + let command = match SlashCommand::parse(input) { + Ok(Some(command)) => command, + Ok(None) => return None, + Err(error) => { + return Some(SlashCommandResult { + message: error.to_string(), + session: session.clone(), + }); + } + }; + + match command { SlashCommand::Compact => { let result = compact_session(session, compaction); let message = if result.removed_message_count == 0 { @@ -1758,22 +4021,21 @@ pub fn handle_slash_command( session: session.clone(), }), SlashCommand::Status - | SlashCommand::Branch { .. } | SlashCommand::Bughunter { .. } - | SlashCommand::Worktree { .. } | SlashCommand::Commit - | SlashCommand::CommitPushPr { .. } | SlashCommand::Pr { .. } | SlashCommand::Issue { .. } | SlashCommand::Ultraplan { .. } | SlashCommand::Teleport { .. } | SlashCommand::DebugToolCall + | SlashCommand::Sandbox | SlashCommand::Model { .. } | SlashCommand::Permissions { .. } | SlashCommand::Clear { .. } | SlashCommand::Cost | SlashCommand::Resume { .. } | SlashCommand::Config { .. } + | SlashCommand::Mcp { .. } | SlashCommand::Memory | SlashCommand::Init | SlashCommand::Diff @@ -1783,6 +4045,47 @@ pub fn handle_slash_command( | SlashCommand::Plugins { .. } | SlashCommand::Agents { .. } | SlashCommand::Skills { .. } + | SlashCommand::Doctor + | SlashCommand::Login + | SlashCommand::Logout + | SlashCommand::Vim + | SlashCommand::Upgrade + | SlashCommand::Stats + | SlashCommand::Share + | SlashCommand::Feedback + | SlashCommand::Files + | SlashCommand::Fast + | SlashCommand::Exit + | SlashCommand::Summary + | SlashCommand::Desktop + | SlashCommand::Brief + | SlashCommand::Advisor + | SlashCommand::Stickers + | SlashCommand::Insights + | SlashCommand::Thinkback + | SlashCommand::ReleaseNotes + | SlashCommand::SecurityReview + | SlashCommand::Keybindings + | SlashCommand::PrivacySettings + | SlashCommand::Plan { .. } + | SlashCommand::Review { .. } + | SlashCommand::Tasks { .. } + | SlashCommand::Theme { .. } + | SlashCommand::Voice { .. } + | SlashCommand::Usage { .. } + | SlashCommand::Rename { .. } + | SlashCommand::Copy { .. } + | SlashCommand::Hooks { .. } + | SlashCommand::Context { .. } + | SlashCommand::Color { .. } + | SlashCommand::Effort { .. } + | SlashCommand::Branch { .. } + | SlashCommand::Rewind { .. } + | SlashCommand::Ide { .. } + | SlashCommand::Tag { .. } + | SlashCommand::OutputStyle { .. } + | SlashCommand::AddDir { .. } + | SlashCommand::History { .. } | SlashCommand::Unknown(_) => None, } } @@ -1790,26 +4093,25 @@ pub fn handle_slash_command( #[cfg(test)] mod tests { use super::{ - handle_branch_slash_command, handle_commit_push_pr_slash_command, - handle_commit_slash_command, handle_plugins_slash_command, handle_slash_command, - handle_worktree_slash_command, load_agents_from_roots, load_skills_from_roots, - render_agents_report, render_plugins_report, render_skills_report, - render_slash_command_help, resume_supported_slash_commands, slash_command_specs, - suggest_slash_commands, CommitPushPrRequest, DefinitionSource, SkillOrigin, SkillRoot, - SlashCommand, + classify_skills_slash_command, handle_agents_slash_command_json, + handle_plugins_slash_command, handle_skills_slash_command_json, handle_slash_command, + load_agents_from_roots, load_skills_from_roots, render_agents_report, + render_agents_report_json, render_mcp_report_json_for, render_plugins_report, + render_skills_report, render_slash_command_help, render_slash_command_help_detail, + resolve_skill_path, resume_supported_slash_commands, slash_command_specs, + suggest_slash_commands, validate_slash_command_input, DefinitionSource, SkillOrigin, + SkillRoot, SkillSlashDispatch, SlashCommand, }; use plugins::{PluginKind, PluginManager, PluginManagerConfig, PluginMetadata, PluginSummary}; - use runtime::{CompactionConfig, ContentBlock, ConversationMessage, MessageRole, Session}; - use std::env; + use runtime::{ + CompactionConfig, ConfigLoader, ContentBlock, ConversationMessage, MessageRole, Session, + }; + use std::ffi::OsString; use std::fs; use std::path::{Path, PathBuf}; - use std::process::Command; use std::sync::{Mutex, OnceLock}; use std::time::{SystemTime, UNIX_EPOCH}; - #[cfg(unix)] - use std::os::unix::fs::PermissionsExt; - fn temp_dir(label: &str) -> PathBuf { let nanos = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -1818,95 +4120,22 @@ mod tests { std::env::temp_dir().join(format!("commands-plugin-{label}-{nanos}")) } - fn env_lock() -> std::sync::MutexGuard<'static, ()> { + fn env_lock() -> &'static Mutex<()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) - .lock() - .expect("env lock") } - fn run_command(cwd: &Path, program: &str, args: &[&str]) -> String { - let output = Command::new(program) - .args(args) - .current_dir(cwd) - .output() - .expect("command should run"); - assert!( - output.status.success(), - "{} {} failed: {}", - program, - args.join(" "), - String::from_utf8_lossy(&output.stderr) - ); - String::from_utf8(output.stdout).expect("stdout should be utf8") - } - - fn init_git_repo(label: &str) -> PathBuf { - let root = temp_dir(label); - fs::create_dir_all(&root).expect("repo root"); - - let init = Command::new("git") - .args(["init", "-b", "main"]) - .current_dir(&root) - .output() - .expect("git init should run"); - if !init.status.success() { - let fallback = Command::new("git") - .arg("init") - .current_dir(&root) - .output() - .expect("fallback git init should run"); - assert!( - fallback.status.success(), - "fallback git init should succeed" - ); - let rename = Command::new("git") - .args(["branch", "-m", "main"]) - .current_dir(&root) - .output() - .expect("git branch -m should run"); - assert!(rename.status.success(), "git branch -m main should succeed"); + fn restore_env_var(key: &str, original: Option) { + match original { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), } - - run_command(&root, "git", &["config", "user.name", "Claw Tests"]); - run_command(&root, "git", &["config", "user.email", "claw@example.com"]); - fs::write(root.join("README.md"), "seed\n").expect("seed file"); - run_command(&root, "git", &["add", "README.md"]); - run_command(&root, "git", &["commit", "-m", "chore: seed repo"]); - root - } - - fn init_bare_repo(label: &str) -> PathBuf { - let root = temp_dir(label); - let output = Command::new("git") - .args(["init", "--bare"]) - .arg(&root) - .output() - .expect("bare repo should initialize"); - assert!(output.status.success(), "git init --bare should succeed"); - root - } - - #[cfg(unix)] - fn write_fake_gh(bin_dir: &Path, log_path: &Path, url: &str) { - fs::create_dir_all(bin_dir).expect("bin dir"); - let script = format!( - "#!/bin/sh\nif [ \"$1\" = \"--version\" ]; then\n echo 'gh 1.0.0'\n exit 0\nfi\nprintf '%s\\n' \"$*\" >> \"{}\"\nif [ \"$1\" = \"pr\" ] && [ \"$2\" = \"create\" ]; then\n echo '{}'\n exit 0\nfi\nif [ \"$1\" = \"pr\" ] && [ \"$2\" = \"view\" ]; then\n echo '{{\"url\":\"{}\"}}'\n exit 0\nfi\nexit 0\n", - log_path.display(), - url, - url, - ); - let path = bin_dir.join("gh"); - fs::write(&path, script).expect("gh stub"); - let mut permissions = fs::metadata(&path).expect("metadata").permissions(); - permissions.set_mode(0o755); - fs::set_permissions(&path, permissions).expect("chmod"); } fn write_external_plugin(root: &Path, name: &str, version: &str) { - fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); fs::write( - root.join(".claw-plugin").join("plugin.json"), + root.join(".claude-plugin").join("plugin.json"), format!( "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"commands plugin\"\n}}" ), @@ -1915,9 +4144,9 @@ mod tests { } fn write_bundled_plugin(root: &Path, name: &str, version: &str, default_enabled: bool) { - fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); fs::write( - root.join(".claw-plugin").join("plugin.json"), + root.join(".claude-plugin").join("plugin.json"), format!( "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"bundled commands plugin\",\n \"defaultEnabled\": {}\n}}", if default_enabled { "true" } else { "false" } @@ -1956,172 +4185,387 @@ mod tests { .expect("write command"); } + fn parse_error_message(input: &str) -> String { + SlashCommand::parse(input) + .expect_err("slash command should be rejected") + .to_string() + } + #[allow(clippy::too_many_lines)] #[test] fn parses_supported_slash_commands() { - assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help)); - assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status)); + assert_eq!(SlashCommand::parse("/help"), Ok(Some(SlashCommand::Help))); + assert_eq!( + SlashCommand::parse(" /status "), + Ok(Some(SlashCommand::Status)) + ); + assert_eq!( + SlashCommand::parse("/sandbox"), + Ok(Some(SlashCommand::Sandbox)) + ); assert_eq!( SlashCommand::parse("/bughunter runtime"), - Some(SlashCommand::Bughunter { + Ok(Some(SlashCommand::Bughunter { scope: Some("runtime".to_string()) - }) + })) ); assert_eq!( - SlashCommand::parse("/branch create feature/demo"), - Some(SlashCommand::Branch { - action: Some("create".to_string()), - target: Some("feature/demo".to_string()), - }) - ); - assert_eq!( - SlashCommand::parse("/worktree add ../demo wt-demo"), - Some(SlashCommand::Worktree { - action: Some("add".to_string()), - path: Some("../demo".to_string()), - branch: Some("wt-demo".to_string()), - }) - ); - assert_eq!(SlashCommand::parse("/commit"), Some(SlashCommand::Commit)); - assert_eq!( - SlashCommand::parse("/commit-push-pr ready for review"), - Some(SlashCommand::CommitPushPr { - context: Some("ready for review".to_string()) - }) + SlashCommand::parse("/commit"), + Ok(Some(SlashCommand::Commit)) ); assert_eq!( SlashCommand::parse("/pr ready for review"), - Some(SlashCommand::Pr { + Ok(Some(SlashCommand::Pr { context: Some("ready for review".to_string()) - }) + })) ); assert_eq!( SlashCommand::parse("/issue flaky test"), - Some(SlashCommand::Issue { + Ok(Some(SlashCommand::Issue { context: Some("flaky test".to_string()) - }) + })) ); assert_eq!( SlashCommand::parse("/ultraplan ship both features"), - Some(SlashCommand::Ultraplan { + Ok(Some(SlashCommand::Ultraplan { task: Some("ship both features".to_string()) - }) + })) ); assert_eq!( SlashCommand::parse("/teleport conversation.rs"), - Some(SlashCommand::Teleport { + Ok(Some(SlashCommand::Teleport { target: Some("conversation.rs".to_string()) - }) + })) ); assert_eq!( SlashCommand::parse("/debug-tool-call"), - Some(SlashCommand::DebugToolCall) + Ok(Some(SlashCommand::DebugToolCall)) ); assert_eq!( - SlashCommand::parse("/model opus"), - Some(SlashCommand::Model { - model: Some("opus".to_string()), - }) + SlashCommand::parse("/bughunter runtime"), + Ok(Some(SlashCommand::Bughunter { + scope: Some("runtime".to_string()) + })) + ); + assert_eq!( + SlashCommand::parse("/commit"), + Ok(Some(SlashCommand::Commit)) + ); + assert_eq!( + SlashCommand::parse("/pr ready for review"), + Ok(Some(SlashCommand::Pr { + context: Some("ready for review".to_string()) + })) + ); + assert_eq!( + SlashCommand::parse("/issue flaky test"), + Ok(Some(SlashCommand::Issue { + context: Some("flaky test".to_string()) + })) + ); + assert_eq!( + SlashCommand::parse("/ultraplan ship both features"), + Ok(Some(SlashCommand::Ultraplan { + task: Some("ship both features".to_string()) + })) + ); + assert_eq!( + SlashCommand::parse("/teleport conversation.rs"), + Ok(Some(SlashCommand::Teleport { + target: Some("conversation.rs".to_string()) + })) + ); + assert_eq!( + SlashCommand::parse("/debug-tool-call"), + Ok(Some(SlashCommand::DebugToolCall)) + ); + assert_eq!( + SlashCommand::parse("/model claude-opus"), + Ok(Some(SlashCommand::Model { + model: Some("claude-opus".to_string()), + })) ); assert_eq!( SlashCommand::parse("/model"), - Some(SlashCommand::Model { model: None }) + Ok(Some(SlashCommand::Model { model: None })) ); assert_eq!( SlashCommand::parse("/permissions read-only"), - Some(SlashCommand::Permissions { + Ok(Some(SlashCommand::Permissions { mode: Some("read-only".to_string()), - }) + })) ); assert_eq!( SlashCommand::parse("/clear"), - Some(SlashCommand::Clear { confirm: false }) + Ok(Some(SlashCommand::Clear { confirm: false })) ); assert_eq!( SlashCommand::parse("/clear --confirm"), - Some(SlashCommand::Clear { confirm: true }) + Ok(Some(SlashCommand::Clear { confirm: true })) ); - assert_eq!(SlashCommand::parse("/cost"), Some(SlashCommand::Cost)); + assert_eq!(SlashCommand::parse("/cost"), Ok(Some(SlashCommand::Cost))); assert_eq!( SlashCommand::parse("/resume session.json"), - Some(SlashCommand::Resume { + Ok(Some(SlashCommand::Resume { session_path: Some("session.json".to_string()), - }) + })) ); assert_eq!( SlashCommand::parse("/config"), - Some(SlashCommand::Config { section: None }) + Ok(Some(SlashCommand::Config { section: None })) ); assert_eq!( SlashCommand::parse("/config env"), - Some(SlashCommand::Config { + Ok(Some(SlashCommand::Config { section: Some("env".to_string()) - }) + })) + ); + assert_eq!( + SlashCommand::parse("/mcp"), + Ok(Some(SlashCommand::Mcp { + action: None, + target: None + })) + ); + assert_eq!( + SlashCommand::parse("/mcp show remote"), + Ok(Some(SlashCommand::Mcp { + action: Some("show".to_string()), + target: Some("remote".to_string()) + })) + ); + assert_eq!( + SlashCommand::parse("/memory"), + Ok(Some(SlashCommand::Memory)) + ); + assert_eq!(SlashCommand::parse("/init"), Ok(Some(SlashCommand::Init))); + assert_eq!(SlashCommand::parse("/diff"), Ok(Some(SlashCommand::Diff))); + assert_eq!( + SlashCommand::parse("/version"), + Ok(Some(SlashCommand::Version)) ); - assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory)); - assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init)); - assert_eq!(SlashCommand::parse("/diff"), Some(SlashCommand::Diff)); - assert_eq!(SlashCommand::parse("/version"), Some(SlashCommand::Version)); assert_eq!( SlashCommand::parse("/export notes.txt"), - Some(SlashCommand::Export { + Ok(Some(SlashCommand::Export { path: Some("notes.txt".to_string()) - }) + })) ); assert_eq!( SlashCommand::parse("/session switch abc123"), - Some(SlashCommand::Session { + Ok(Some(SlashCommand::Session { action: Some("switch".to_string()), target: Some("abc123".to_string()) - }) + })) ); assert_eq!( SlashCommand::parse("/plugins install demo"), - Some(SlashCommand::Plugins { + Ok(Some(SlashCommand::Plugins { action: Some("install".to_string()), target: Some("demo".to_string()) - }) + })) ); assert_eq!( SlashCommand::parse("/plugins list"), - Some(SlashCommand::Plugins { + Ok(Some(SlashCommand::Plugins { action: Some("list".to_string()), target: None - }) + })) ); assert_eq!( SlashCommand::parse("/plugins enable demo"), - Some(SlashCommand::Plugins { + Ok(Some(SlashCommand::Plugins { action: Some("enable".to_string()), target: Some("demo".to_string()) - }) + })) + ); + assert_eq!( + SlashCommand::parse("/skills install ./fixtures/help-skill"), + Ok(Some(SlashCommand::Skills { + args: Some("install ./fixtures/help-skill".to_string()) + })) ); assert_eq!( SlashCommand::parse("/plugins disable demo"), - Some(SlashCommand::Plugins { + Ok(Some(SlashCommand::Plugins { action: Some("disable".to_string()), target: Some("demo".to_string()) - }) + })) ); + assert_eq!( + SlashCommand::parse("/session fork incident-review"), + Ok(Some(SlashCommand::Session { + action: Some("fork".to_string()), + target: Some("incident-review".to_string()) + })) + ); + } + + #[test] + fn parses_history_command_without_count() { + // given + let input = "/history"; + + // when + let parsed = SlashCommand::parse(input); + + // then + assert_eq!(parsed, Ok(Some(SlashCommand::History { count: None }))); + } + + #[test] + fn parses_history_command_with_numeric_count() { + // given + let input = "/history 25"; + + // when + let parsed = SlashCommand::parse(input); + + // then + assert_eq!( + parsed, + Ok(Some(SlashCommand::History { + count: Some("25".to_string()) + })) + ); + } + + #[test] + fn rejects_history_with_extra_arguments() { + // given + let input = "/history 25 extra"; + + // when + let error = parse_error_message(input); + + // then + assert!(error.contains("Usage: /history [count]")); + } + + #[test] + fn rejects_unexpected_arguments_for_no_arg_commands() { + // given + let input = "/compact now"; + + // when + let error = parse_error_message(input); + + // then + assert!(error.contains("Unexpected arguments for /compact.")); + assert!(error.contains(" Usage /compact")); + assert!(error.contains(" Summary Compact local session history")); + } + + #[test] + fn rejects_invalid_argument_values() { + // given + let input = "/permissions admin"; + + // when + let error = parse_error_message(input); + + // then + assert!(error.contains( + "Unsupported /permissions mode 'admin'. Use read-only, workspace-write, or danger-full-access." + )); + assert!(error.contains( + " Usage /permissions [read-only|workspace-write|danger-full-access]" + )); + } + + #[test] + fn rejects_missing_required_arguments() { + // given + let input = "/teleport"; + + // when + let error = parse_error_message(input); + + // then + assert!(error.contains("Usage: /teleport ")); + assert!(error.contains(" Category Tools")); + } + + #[test] + fn rejects_invalid_session_and_plugin_shapes() { + // given + let session_input = "/session switch"; + let plugin_input = "/plugins list extra"; + + // when + let session_error = parse_error_message(session_input); + let plugin_error = parse_error_message(plugin_input); + + // then + assert!(session_error.contains("Usage: /session switch ")); + assert!(session_error.contains("/session")); + assert!(plugin_error.contains("Usage: /plugin list")); + assert!(plugin_error.contains("Aliases /plugins, /marketplace")); + } + + #[test] + fn rejects_invalid_agents_arguments() { + // given + let agents_input = "/agents show planner"; + + // when + let agents_error = parse_error_message(agents_input); + + // then + assert!(agents_error.contains( + "Unexpected arguments for /agents: show planner. Use /agents, /agents list, or /agents help." + )); + assert!(agents_error.contains(" Usage /agents [list|help]")); + } + + #[test] + fn accepts_skills_invocation_arguments_for_prompt_dispatch() { + assert_eq!( + SlashCommand::parse("/skills help overview"), + Ok(Some(SlashCommand::Skills { + args: Some("help overview".to_string()), + })) + ); + assert_eq!( + classify_skills_slash_command(Some("help overview")), + SkillSlashDispatch::Invoke("$help overview".to_string()) + ); + assert_eq!( + classify_skills_slash_command(Some("/test")), + SkillSlashDispatch::Invoke("$test".to_string()) + ); + assert_eq!( + classify_skills_slash_command(Some("install ./skill-pack")), + SkillSlashDispatch::Local + ); + } + + #[test] + fn rejects_invalid_mcp_arguments() { + let show_error = parse_error_message("/mcp show alpha beta"); + assert!(show_error.contains("Unexpected arguments for /mcp show.")); + assert!(show_error.contains(" Usage /mcp show ")); + + let action_error = parse_error_message("/mcp inspect alpha"); + assert!(action_error + .contains("Unknown /mcp action 'inspect'. Use list, show , or help.")); + assert!(action_error.contains(" Usage /mcp [list|show |help]")); } #[test] fn renders_help_from_shared_specs() { let help = render_slash_command_help(); - assert!(help.contains("available via claw --resume SESSION.json")); - assert!(help.contains("Core flow")); - assert!(help.contains("Workspace & memory")); - assert!(help.contains("Sessions & output")); - assert!(help.contains("Git & GitHub")); - assert!(help.contains("Automation & discovery")); + assert!(help.contains("Start here /status, /diff, /agents, /skills, /commit")); + assert!(help.contains("[resume] also works with --resume SESSION.jsonl")); + assert!(help.contains("Session")); + assert!(help.contains("Tools")); + assert!(help.contains("Config")); + assert!(help.contains("Debug")); assert!(help.contains("/help")); assert!(help.contains("/status")); + assert!(help.contains("/sandbox")); assert!(help.contains("/compact")); assert!(help.contains("/bughunter [scope]")); - assert!(help.contains("/branch [list|create |switch ]")); - assert!(help.contains("/worktree [list|add [branch]|remove |prune]")); assert!(help.contains("/commit")); - assert!(help.contains("/commit-push-pr [context]")); assert!(help.contains("/pr [context]")); assert!(help.contains("/issue [context]")); assert!(help.contains("/ultraplan [task]")); @@ -2133,44 +4577,141 @@ mod tests { assert!(help.contains("/cost")); assert!(help.contains("/resume ")); assert!(help.contains("/config [env|hooks|model|plugins]")); + assert!(help.contains("/mcp [list|show |help]")); assert!(help.contains("/memory")); assert!(help.contains("/init")); assert!(help.contains("/diff")); assert!(help.contains("/version")); assert!(help.contains("/export [file]")); - assert!(help.contains("/session [list|switch ]")); + assert!(help.contains("/session"), "help must mention /session"); + assert!(help.contains("/sandbox")); assert!(help.contains( "/plugin [list|install |enable |disable |uninstall |update ]" )); assert!(help.contains("aliases: /plugins, /marketplace")); - assert!(help.contains("/agents")); - assert!(help.contains("/skills")); - assert_eq!(slash_command_specs().len(), 28); - assert_eq!(resume_supported_slash_commands().len(), 13); + assert!(help.contains("/agents [list|help]")); + assert!(help.contains("/skills [list|install |help| [args]]")); + assert!(help.contains("aliases: /skill")); + assert_eq!(slash_command_specs().len(), 141); + assert!(resume_supported_slash_commands().len() >= 39); } #[test] - fn suggests_close_slash_commands() { + fn renders_help_with_grouped_categories_and_keyboard_shortcuts() { + // given + let categories = ["Session", "Tools", "Config", "Debug"]; + + // when + let help = render_slash_command_help(); + + // then + for category in categories { + assert!( + help.contains(category), + "expected help to contain category {category}" + ); + } + let session_index = help.find("Session").expect("Session header should exist"); + let tools_index = help.find("Tools").expect("Tools header should exist"); + let config_index = help.find("Config").expect("Config header should exist"); + let debug_index = help.find("Debug").expect("Debug header should exist"); + assert!(session_index < tools_index); + assert!(tools_index < config_index); + assert!(config_index < debug_index); + + assert!(help.contains("Keyboard shortcuts")); + assert!(help.contains("Up/Down Navigate prompt history")); + assert!(help.contains("Tab Complete commands, modes, and recent sessions")); + assert!(help.contains("Ctrl-C Clear input (or exit on empty prompt)")); + assert!(help.contains("Shift+Enter/Ctrl+J Insert a newline")); + + // every command should still render with a summary line + for spec in slash_command_specs() { + let usage = match spec.argument_hint { + Some(hint) => format!("/{} {hint}", spec.name), + None => format!("/{}", spec.name), + }; + assert!( + help.contains(&usage), + "expected help to contain command {usage}" + ); + assert!( + help.contains(spec.summary), + "expected help to contain summary for /{}", + spec.name + ); + } + } + + #[test] + fn renders_per_command_help_detail() { + // given + let command = "plugins"; + + // when + let help = render_slash_command_help_detail(command).expect("detail help should exist"); + + // then + assert!(help.contains("/plugin")); + assert!(help.contains("Summary Manage Claw Code plugins")); + assert!(help.contains("Aliases /plugins, /marketplace")); + assert!(help.contains("Category Tools")); + } + + #[test] + fn renders_per_command_help_detail_for_mcp() { + let help = render_slash_command_help_detail("mcp").expect("detail help should exist"); + assert!(help.contains("/mcp")); + assert!(help.contains("Summary Inspect configured MCP servers")); + assert!(help.contains("Category Tools")); + assert!(help.contains("Resume Supported with --resume SESSION.jsonl")); + } + + #[test] + fn validate_slash_command_input_rejects_extra_single_value_arguments() { + // given + let session_input = "/session switch current next"; + let plugin_input = "/plugin enable demo extra"; + + // when + let session_error = validate_slash_command_input(session_input) + .expect_err("session input should be rejected") + .to_string(); + let plugin_error = validate_slash_command_input(plugin_input) + .expect_err("plugin input should be rejected") + .to_string(); + + // then + assert!(session_error.contains("Unexpected arguments for /session switch.")); + assert!(session_error.contains(" Usage /session switch ")); + assert!(plugin_error.contains("Unexpected arguments for /plugin enable.")); + assert!(plugin_error.contains(" Usage /plugin enable ")); + } + + #[test] + fn suggests_closest_slash_commands_for_typos_and_aliases() { let suggestions = suggest_slash_commands("stats", 3); - assert!(!suggestions.is_empty()); - assert_eq!(suggestions[0], "/status"); + assert!(suggestions.contains(&"/stats".to_string())); + assert!(suggestions.contains(&"/status".to_string())); + assert!(suggestions.len() <= 3); + let plugin_suggestions = suggest_slash_commands("/plugns", 3); + assert!(plugin_suggestions.contains(&"/plugin".to_string())); + assert_eq!(suggest_slash_commands("zzz", 3), Vec::::new()); } #[test] fn compacts_sessions_via_slash_command() { - let session = Session { - version: 1, - messages: vec![ - ConversationMessage::user_text("a ".repeat(200)), - ConversationMessage::assistant(vec![ContentBlock::Text { - text: "b ".repeat(200), - }]), - ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false), - ConversationMessage::assistant(vec![ContentBlock::Text { - text: "recent".to_string(), - }]), - ], - }; + let mut session = Session::new(); + session.messages = vec![ + ConversationMessage::user_text("a ".repeat(200)), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "b ".repeat(200), + }]), + ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "recent".to_string(), + }]), + ]; let result = handle_slash_command( "/compact", @@ -2182,7 +4723,14 @@ mod tests { ) .expect("slash command should be handled"); - assert!(result.message.contains("Compacted 2 messages")); + // With the tool-use/tool-result boundary guard the compaction may + // preserve one extra message, so 1 or 2 messages may be removed. + assert!( + result.message.contains("Compacted 1 messages") + || result.message.contains("Compacted 2 messages"), + "unexpected compaction message: {}", + result.message + ); assert_eq!(result.session.messages[0].role, MessageRole::System); } @@ -2200,22 +4748,11 @@ mod tests { let session = Session::new(); assert!(handle_slash_command("/unknown", &session, CompactionConfig::default()).is_none()); assert!(handle_slash_command("/status", &session, CompactionConfig::default()).is_none()); - assert!( - handle_slash_command("/branch list", &session, CompactionConfig::default()).is_none() - ); + assert!(handle_slash_command("/sandbox", &session, CompactionConfig::default()).is_none()); assert!( handle_slash_command("/bughunter", &session, CompactionConfig::default()).is_none() ); - assert!( - handle_slash_command("/worktree list", &session, CompactionConfig::default()).is_none() - ); assert!(handle_slash_command("/commit", &session, CompactionConfig::default()).is_none()); - assert!(handle_slash_command( - "/commit-push-pr review notes", - &session, - CompactionConfig::default() - ) - .is_none()); assert!(handle_slash_command("/pr", &session, CompactionConfig::default()).is_none()); assert!(handle_slash_command("/issue", &session, CompactionConfig::default()).is_none()); assert!( @@ -2229,7 +4766,7 @@ mod tests { .is_none() ); assert!( - handle_slash_command("/model sonnet", &session, CompactionConfig::default()).is_none() + handle_slash_command("/model claude", &session, CompactionConfig::default()).is_none() ); assert!(handle_slash_command( "/permissions read-only", @@ -2249,10 +4786,17 @@ mod tests { CompactionConfig::default() ) .is_none()); + assert!(handle_slash_command( + "/resume session.jsonl", + &session, + CompactionConfig::default() + ) + .is_none()); assert!(handle_slash_command("/config", &session, CompactionConfig::default()).is_none()); assert!( handle_slash_command("/config env", &session, CompactionConfig::default()).is_none() ); + assert!(handle_slash_command("/mcp list", &session, CompactionConfig::default()).is_none()); assert!(handle_slash_command("/diff", &session, CompactionConfig::default()).is_none()); assert!(handle_slash_command("/version", &session, CompactionConfig::default()).is_none()); assert!( @@ -2311,7 +4855,7 @@ mod tests { let workspace = temp_dir("agents-workspace"); let project_agents = workspace.join(".codex").join("agents"); let user_home = temp_dir("agents-home"); - let user_agents = user_home.join(".codex").join("agents"); + let user_agents = user_home.join(".claude").join("agents"); write_agent( &project_agents, @@ -2344,21 +4888,87 @@ mod tests { assert!(report.contains("Agents")); assert!(report.contains("2 active agents")); - assert!(report.contains("Project (.codex):")); + assert!(report.contains("Project roots:")); assert!(report.contains("planner · Project planner · gpt-5.4 · medium")); - assert!(report.contains("User (~/.codex):")); - assert!(report.contains("(shadowed by Project (.codex)) planner · User planner")); + assert!(report.contains("User home roots:")); + assert!(report.contains("(shadowed by Project roots) planner · User planner")); assert!(report.contains("verifier · Verification agent · gpt-5.4-mini · high")); let _ = fs::remove_dir_all(workspace); let _ = fs::remove_dir_all(user_home); } + #[test] + fn renders_agents_reports_as_json() { + let workspace = temp_dir("agents-json-workspace"); + let project_agents = workspace.join(".codex").join("agents"); + let user_home = temp_dir("agents-json-home"); + let user_agents = user_home.join(".codex").join("agents"); + + write_agent( + &project_agents, + "planner", + "Project planner", + "gpt-5.4", + "medium", + ); + write_agent( + &project_agents, + "verifier", + "Verification agent", + "gpt-5.4-mini", + "high", + ); + write_agent( + &user_agents, + "planner", + "User planner", + "gpt-5.4-mini", + "high", + ); + + let roots = vec![ + (DefinitionSource::ProjectCodex, project_agents), + (DefinitionSource::UserCodex, user_agents), + ]; + let report = render_agents_report_json( + &workspace, + &load_agents_from_roots(&roots).expect("agent roots should load"), + ); + + assert_eq!(report["kind"], "agents"); + assert_eq!(report["action"], "list"); + assert_eq!(report["working_directory"], workspace.display().to_string()); + assert_eq!(report["count"], 3); + assert_eq!(report["summary"]["active"], 2); + assert_eq!(report["summary"]["shadowed"], 1); + assert_eq!(report["agents"][0]["name"], "planner"); + assert_eq!(report["agents"][0]["model"], "gpt-5.4"); + assert_eq!(report["agents"][0]["active"], true); + assert_eq!(report["agents"][1]["name"], "verifier"); + assert_eq!(report["agents"][2]["name"], "planner"); + assert_eq!(report["agents"][2]["active"], false); + assert_eq!(report["agents"][2]["shadowed_by"]["id"], "project_claw"); + + let help = handle_agents_slash_command_json(Some("help"), &workspace).expect("agents help"); + assert_eq!(help["kind"], "agents"); + assert_eq!(help["action"], "help"); + assert_eq!(help["usage"]["direct_cli"], "claw agents [list|help]"); + + let unexpected = handle_agents_slash_command_json(Some("show planner"), &workspace) + .expect("agents usage"); + assert_eq!(unexpected["action"], "help"); + assert_eq!(unexpected["unexpected"], "show planner"); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(user_home); + } + #[test] fn lists_skills_from_project_and_user_roots() { let workspace = temp_dir("skills-workspace"); let project_skills = workspace.join(".codex").join("skills"); - let project_commands = workspace.join(".claw").join("commands"); + let project_commands = workspace.join(".claude").join("commands"); let user_home = temp_dir("skills-home"); let user_skills = user_home.join(".codex").join("skills"); @@ -2374,7 +4984,7 @@ mod tests { origin: SkillOrigin::SkillsDir, }, SkillRoot { - source: DefinitionSource::ProjectClaw, + source: DefinitionSource::ProjectClaude, path: project_commands, origin: SkillOrigin::LegacyCommandsDir, }, @@ -2389,26 +4999,102 @@ mod tests { assert!(report.contains("Skills")); assert!(report.contains("3 available skills")); - assert!(report.contains("Project (.codex):")); + assert!(report.contains("Project roots:")); assert!(report.contains("plan · Project planning guidance")); - assert!(report.contains("Project (.claw):")); assert!(report.contains("deploy · Legacy deployment guidance · legacy /commands")); - assert!(report.contains("User (~/.codex):")); - assert!(report.contains("(shadowed by Project (.codex)) plan · User planning guidance")); + assert!(report.contains("User home roots:")); + assert!(report.contains("(shadowed by Project roots) plan · User planning guidance")); assert!(report.contains("help · Help guidance")); let _ = fs::remove_dir_all(workspace); let _ = fs::remove_dir_all(user_home); } + #[test] + fn resolves_project_skills_and_legacy_commands_from_shared_registry() { + let workspace = temp_dir("resolve-project-skills"); + let project_skills = workspace.join(".claw").join("skills"); + let legacy_commands = workspace.join(".claw").join("commands"); + + write_skill(&project_skills, "plan", "Project planning guidance"); + write_legacy_command(&legacy_commands, "handoff", "Legacy handoff guidance"); + + assert_eq!( + resolve_skill_path(&workspace, "$plan").expect("project skill should resolve"), + project_skills.join("plan").join("SKILL.md") + ); + assert_eq!( + resolve_skill_path(&workspace, "/handoff").expect("legacy command should resolve"), + legacy_commands.join("handoff.md") + ); + } + + #[test] + fn renders_skills_reports_as_json() { + let workspace = temp_dir("skills-json-workspace"); + let project_skills = workspace.join(".codex").join("skills"); + let project_commands = workspace.join(".claude").join("commands"); + let user_home = temp_dir("skills-json-home"); + let user_skills = user_home.join(".codex").join("skills"); + + write_skill(&project_skills, "plan", "Project planning guidance"); + write_legacy_command(&project_commands, "deploy", "Legacy deployment guidance"); + write_skill(&user_skills, "plan", "User planning guidance"); + write_skill(&user_skills, "help", "Help guidance"); + + let roots = vec![ + SkillRoot { + source: DefinitionSource::ProjectCodex, + path: project_skills, + origin: SkillOrigin::SkillsDir, + }, + SkillRoot { + source: DefinitionSource::ProjectClaude, + path: project_commands, + origin: SkillOrigin::LegacyCommandsDir, + }, + SkillRoot { + source: DefinitionSource::UserCodex, + path: user_skills, + origin: SkillOrigin::SkillsDir, + }, + ]; + let report = super::render_skills_report_json( + &load_skills_from_roots(&roots).expect("skills should load"), + ); + assert_eq!(report["kind"], "skills"); + assert_eq!(report["action"], "list"); + assert_eq!(report["summary"]["active"], 3); + assert_eq!(report["summary"]["shadowed"], 1); + assert_eq!(report["skills"][0]["name"], "plan"); + assert_eq!(report["skills"][0]["source"]["id"], "project_claw"); + assert_eq!(report["skills"][1]["name"], "deploy"); + assert_eq!(report["skills"][1]["origin"]["id"], "legacy_commands_dir"); + assert_eq!(report["skills"][3]["shadowed_by"]["id"], "project_claw"); + + let help = handle_skills_slash_command_json(Some("help"), &workspace).expect("skills help"); + assert_eq!(help["kind"], "skills"); + assert_eq!(help["action"], "help"); + assert_eq!(help["usage"]["aliases"][0], "/skill"); + assert_eq!( + help["usage"]["direct_cli"], + "claw skills [list|install |help| [args]]" + ); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(user_home); + } + #[test] fn agents_and_skills_usage_support_help_and_unexpected_args() { let cwd = temp_dir("slash-usage"); let agents_help = super::handle_agents_slash_command(Some("help"), &cwd).expect("agents help"); - assert!(agents_help.contains("Usage /agents")); + assert!(agents_help.contains("Usage /agents [list|help]")); assert!(agents_help.contains("Direct CLI claw agents")); + assert!(agents_help + .contains("Sources .claw/agents, ~/.claw/agents, $CLAW_CONFIG_HOME/agents")); let agents_unexpected = super::handle_agents_slash_command(Some("show planner"), &cwd).expect("agents usage"); @@ -2416,16 +5102,298 @@ mod tests { let skills_help = super::handle_skills_slash_command(Some("--help"), &cwd).expect("skills help"); - assert!(skills_help.contains("Usage /skills")); + assert!(skills_help + .contains("Usage /skills [list|install |help| [args]]")); + assert!(skills_help.contains("Alias /skill")); + assert!(skills_help.contains("Invoke /skills help overview -> $help overview")); + assert!(skills_help.contains("Install root $CLAW_CONFIG_HOME/skills or ~/.claw/skills")); + assert!(skills_help.contains(".omc/skills")); + assert!(skills_help.contains(".agents/skills")); + assert!(skills_help.contains("~/.claude/skills/omc-learned")); assert!(skills_help.contains("legacy /commands")); let skills_unexpected = super::handle_skills_slash_command(Some("show help"), &cwd).expect("skills usage"); - assert!(skills_unexpected.contains("Unexpected show help")); + assert!(skills_unexpected.contains("Unexpected show")); + + let skills_install_help = super::handle_skills_slash_command(Some("install --help"), &cwd) + .expect("nested skills help"); + assert!(skills_install_help + .contains("Usage /skills [list|install |help| [args]]")); + assert!(skills_install_help.contains("Alias /skill")); + assert!(skills_install_help.contains("Unexpected install")); + + let skills_unknown_help = + super::handle_skills_slash_command(Some("show --help"), &cwd).expect("skills help"); + assert!(skills_unknown_help + .contains("Usage /skills [list|install |help| [args]]")); + assert!(skills_unknown_help.contains("Unexpected show")); + + let skills_help_json = + super::handle_skills_slash_command_json(Some("help"), &cwd).expect("skills help json"); + let sources = skills_help_json["usage"]["sources"] + .as_array() + .expect("skills help sources"); + assert_eq!(skills_help_json["usage"]["aliases"][0], "/skill"); + assert!(sources.iter().any(|value| value == ".omc/skills")); + assert!(sources.iter().any(|value| value == ".agents/skills")); + assert!(sources.iter().any(|value| value == "~/.omc/skills")); + assert!(sources + .iter() + .any(|value| value == "~/.claude/skills/omc-learned")); let _ = fs::remove_dir_all(cwd); } + #[test] + fn discovers_omc_skills_from_project_and_user_compatibility_roots() { + let _guard = env_lock().lock().expect("env lock"); + let workspace = temp_dir("skills-omc-workspace"); + let user_home = temp_dir("skills-omc-home"); + let claude_config_dir = temp_dir("skills-omc-claude-config"); + let project_omc_skills = workspace.join(".omc").join("skills"); + let project_agents_skills = workspace.join(".agents").join("skills"); + let user_omc_skills = user_home.join(".omc").join("skills"); + let claude_config_skills = claude_config_dir.join("skills"); + let claude_config_commands = claude_config_dir.join("commands"); + let learned_skills = claude_config_dir.join("skills").join("omc-learned"); + let original_home = std::env::var_os("HOME"); + let original_claude_config_dir = std::env::var_os("CLAUDE_CONFIG_DIR"); + + write_skill(&project_omc_skills, "hud", "OMC HUD guidance"); + write_skill( + &project_agents_skills, + "trace", + "Compatibility skill guidance", + ); + write_skill(&user_omc_skills, "cancel", "OMC cancel guidance"); + write_skill( + &claude_config_skills, + "statusline", + "Claude config skill guidance", + ); + write_legacy_command( + &claude_config_commands, + "doctor-check", + "Claude config command guidance", + ); + write_skill(&learned_skills, "learned", "Learned skill guidance"); + std::env::set_var("HOME", &user_home); + std::env::set_var("CLAUDE_CONFIG_DIR", &claude_config_dir); + + let report = super::handle_skills_slash_command(None, &workspace).expect("skills list"); + assert!(report.contains("available skills")); + assert!(report.contains("hud · OMC HUD guidance")); + assert!(report.contains("trace · Compatibility skill guidance")); + assert!(report.contains("cancel · OMC cancel guidance")); + assert!(report.contains("statusline · Claude config skill guidance")); + assert!(report.contains("doctor-check · Claude config command guidance · legacy /commands")); + assert!(report.contains("learned · Learned skill guidance")); + + let help = + super::handle_skills_slash_command_json(Some("help"), &workspace).expect("skills help"); + let sources = help["usage"]["sources"] + .as_array() + .expect("skills help sources"); + assert_eq!(help["usage"]["aliases"][0], "/skill"); + assert!(sources.iter().any(|value| value == ".omc/skills")); + assert!(sources.iter().any(|value| value == ".agents/skills")); + assert!(sources.iter().any(|value| value == "~/.omc/skills")); + assert!(sources + .iter() + .any(|value| value == "~/.claude/skills/omc-learned")); + + restore_env_var("HOME", original_home); + restore_env_var("CLAUDE_CONFIG_DIR", original_claude_config_dir); + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(user_home); + let _ = fs::remove_dir_all(claude_config_dir); + } + + #[test] + fn mcp_usage_supports_help_and_unexpected_args() { + let cwd = temp_dir("mcp-usage"); + + let help = super::handle_mcp_slash_command(Some("help"), &cwd).expect("mcp help"); + assert!(help.contains("Usage /mcp [list|show |help]")); + assert!(help.contains("Direct CLI claw mcp [list|show |help]")); + + let unexpected = + super::handle_mcp_slash_command(Some("show alpha beta"), &cwd).expect("mcp usage"); + assert!(unexpected.contains("Unexpected show alpha beta")); + + let nested_help = + super::handle_mcp_slash_command(Some("show --help"), &cwd).expect("mcp help"); + assert!(nested_help.contains("Usage /mcp [list|show |help]")); + assert!(nested_help.contains("Unexpected show")); + + let unknown_help = + super::handle_mcp_slash_command(Some("inspect --help"), &cwd).expect("mcp usage"); + assert!(unknown_help.contains("Usage /mcp [list|show |help]")); + assert!(unknown_help.contains("Unexpected inspect")); + + let _ = fs::remove_dir_all(cwd); + } + + #[test] + fn renders_mcp_reports_from_loaded_config() { + let workspace = temp_dir("mcp-config-workspace"); + let config_home = temp_dir("mcp-config-home"); + fs::create_dir_all(workspace.join(".claw")).expect("workspace config dir"); + fs::create_dir_all(&config_home).expect("config home"); + fs::write( + workspace.join(".claw").join("settings.json"), + r#"{ + "mcpServers": { + "alpha": { + "command": "uvx", + "args": ["alpha-server"], + "env": {"ALPHA_TOKEN": "secret"}, + "toolCallTimeoutMs": 1200 + }, + "remote": { + "type": "http", + "url": "https://remote.example/mcp", + "headers": {"Authorization": "Bearer secret"}, + "headersHelper": "./bin/headers", + "oauth": { + "clientId": "remote-client", + "callbackPort": 7878 + } + } + } + }"#, + ) + .expect("write settings"); + fs::write( + workspace.join(".claw").join("settings.local.json"), + r#"{ + "mcpServers": { + "remote": { + "type": "ws", + "url": "wss://remote.example/mcp" + } + } + }"#, + ) + .expect("write local settings"); + + let loader = ConfigLoader::new(&workspace, &config_home); + let list = super::render_mcp_report_for(&loader, &workspace, None) + .expect("mcp list report should render"); + assert!(list.contains("Configured servers 2")); + assert!(list.contains("alpha")); + assert!(list.contains("stdio")); + assert!(list.contains("project")); + assert!(list.contains("uvx alpha-server")); + assert!(list.contains("remote")); + assert!(list.contains("ws")); + assert!(list.contains("local")); + assert!(list.contains("wss://remote.example/mcp")); + + let show = super::render_mcp_report_for(&loader, &workspace, Some("show alpha")) + .expect("mcp show report should render"); + assert!(show.contains("Name alpha")); + assert!(show.contains("Command uvx")); + assert!(show.contains("Args alpha-server")); + assert!(show.contains("Env keys ALPHA_TOKEN")); + assert!(show.contains("Tool timeout 1200 ms")); + + let remote = super::render_mcp_report_for(&loader, &workspace, Some("show remote")) + .expect("mcp show remote report should render"); + assert!(remote.contains("Transport ws")); + assert!(remote.contains("URL wss://remote.example/mcp")); + + let missing = super::render_mcp_report_for(&loader, &workspace, Some("show missing")) + .expect("missing report should render"); + assert!(missing.contains("server `missing` is not configured")); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(config_home); + } + + #[test] + fn renders_mcp_reports_as_json() { + let workspace = temp_dir("mcp-json-workspace"); + let config_home = temp_dir("mcp-json-home"); + fs::create_dir_all(workspace.join(".claw")).expect("workspace config dir"); + fs::create_dir_all(&config_home).expect("config home"); + fs::write( + workspace.join(".claw").join("settings.json"), + r#"{ + "mcpServers": { + "alpha": { + "command": "uvx", + "args": ["alpha-server"], + "env": {"ALPHA_TOKEN": "secret"}, + "toolCallTimeoutMs": 1200 + }, + "remote": { + "type": "http", + "url": "https://remote.example/mcp", + "headers": {"Authorization": "Bearer secret"}, + "headersHelper": "./bin/headers", + "oauth": { + "clientId": "remote-client", + "callbackPort": 7878 + } + } + } + }"#, + ) + .expect("write settings"); + fs::write( + workspace.join(".claw").join("settings.local.json"), + r#"{ + "mcpServers": { + "remote": { + "type": "ws", + "url": "wss://remote.example/mcp" + } + } + }"#, + ) + .expect("write local settings"); + + let loader = ConfigLoader::new(&workspace, &config_home); + let list = + render_mcp_report_json_for(&loader, &workspace, None).expect("mcp list json render"); + assert_eq!(list["kind"], "mcp"); + assert_eq!(list["action"], "list"); + assert_eq!(list["configured_servers"], 2); + assert_eq!(list["servers"][0]["name"], "alpha"); + assert_eq!(list["servers"][0]["transport"]["id"], "stdio"); + assert_eq!(list["servers"][0]["details"]["command"], "uvx"); + assert_eq!(list["servers"][1]["name"], "remote"); + assert_eq!(list["servers"][1]["scope"]["id"], "local"); + assert_eq!(list["servers"][1]["transport"]["id"], "ws"); + assert_eq!( + list["servers"][1]["details"]["url"], + "wss://remote.example/mcp" + ); + + let show = render_mcp_report_json_for(&loader, &workspace, Some("show alpha")) + .expect("mcp show json render"); + assert_eq!(show["action"], "show"); + assert_eq!(show["found"], true); + assert_eq!(show["server"]["name"], "alpha"); + assert_eq!(show["server"]["details"]["env_keys"][0], "ALPHA_TOKEN"); + assert_eq!(show["server"]["details"]["tool_call_timeout_ms"], 1200); + + let missing = render_mcp_report_json_for(&loader, &workspace, Some("show missing")) + .expect("mcp missing json render"); + assert_eq!(missing["found"], false); + assert_eq!(missing["server_name"], "missing"); + + let help = + render_mcp_report_json_for(&loader, &workspace, Some("help")).expect("mcp help json"); + assert_eq!(help["action"], "help"); + assert_eq!(help["usage"]["sources"][0], ".claw/settings.json"); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(config_home); + } + #[test] fn parses_quoted_skill_frontmatter_values() { let contents = "---\nname: \"hud\"\ndescription: 'Quoted description'\n---\n"; @@ -2434,6 +5402,57 @@ mod tests { assert_eq!(description.as_deref(), Some("Quoted description")); } + #[test] + fn installs_skill_into_user_registry_and_preserves_nested_files() { + let workspace = temp_dir("skills-install-workspace"); + let source_root = workspace.join("source").join("help"); + let install_root = temp_dir("skills-install-root"); + write_skill( + source_root.parent().expect("parent"), + "help", + "Helpful skill", + ); + let script_dir = source_root.join("scripts"); + fs::create_dir_all(&script_dir).expect("script dir"); + fs::write(script_dir.join("run.sh"), "#!/bin/sh\necho help\n").expect("write script"); + + let installed = super::install_skill_into( + source_root.to_str().expect("utf8 skill path"), + &workspace, + &install_root, + ) + .expect("skill should install"); + + assert_eq!(installed.invocation_name, "help"); + assert_eq!(installed.display_name.as_deref(), Some("help")); + assert!(installed.installed_path.ends_with(Path::new("help"))); + assert!(installed.installed_path.join("SKILL.md").is_file()); + assert!(installed + .installed_path + .join("scripts") + .join("run.sh") + .is_file()); + + let report = super::render_skill_install_report(&installed); + assert!(report.contains("Result installed help")); + assert!(report.contains("Invoke as $help")); + assert!(report.contains(&install_root.display().to_string())); + + let roots = vec![SkillRoot { + source: DefinitionSource::UserCodexHome, + path: install_root.clone(), + origin: SkillOrigin::SkillsDir, + }]; + let listed = render_skills_report( + &load_skills_from_roots(&roots).expect("installed skills should load"), + ); + assert!(listed.contains("User config roots:")); + assert!(listed.contains("help · Helpful skill")); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(install_root); + } + #[test] fn installs_plugin_from_path_and_lists_it() { let config_home = temp_dir("home"); @@ -2527,141 +5546,4 @@ mod tests { let _ = fs::remove_dir_all(config_home); let _ = fs::remove_dir_all(bundled_root); } - - #[test] - fn branch_and_worktree_commands_manage_git_state() { - // given - let repo = init_git_repo("branch-worktree"); - let worktree_path = repo - .parent() - .expect("repo should have parent") - .join("branch-worktree-linked"); - - // when - let branch_list = - handle_branch_slash_command(Some("list"), None, &repo).expect("branch list succeeds"); - let created = handle_branch_slash_command(Some("create"), Some("feature/demo"), &repo) - .expect("branch create succeeds"); - let switched = handle_branch_slash_command(Some("switch"), Some("main"), &repo) - .expect("branch switch succeeds"); - let added = handle_worktree_slash_command( - Some("add"), - Some(worktree_path.to_str().expect("utf8 path")), - Some("wt-demo"), - &repo, - ) - .expect("worktree add succeeds"); - let listed_worktrees = - handle_worktree_slash_command(Some("list"), None, None, &repo).expect("list succeeds"); - let removed = handle_worktree_slash_command( - Some("remove"), - Some(worktree_path.to_str().expect("utf8 path")), - None, - &repo, - ) - .expect("remove succeeds"); - - // then - assert!(branch_list.contains("main")); - assert!(created.contains("feature/demo")); - assert!(switched.contains("main")); - assert!(added.contains("wt-demo")); - assert!(listed_worktrees.contains(worktree_path.to_str().expect("utf8 path"))); - assert!(removed.contains("Result removed")); - - let _ = fs::remove_dir_all(repo); - let _ = fs::remove_dir_all(worktree_path); - } - - #[test] - fn commit_command_stages_and_commits_changes() { - // given - let repo = init_git_repo("commit-command"); - fs::write(repo.join("notes.txt"), "hello\n").expect("write notes"); - - // when - let report = - handle_commit_slash_command("feat: add notes", &repo).expect("commit succeeds"); - let status = run_command(&repo, "git", &["status", "--short"]); - let message = run_command(&repo, "git", &["log", "-1", "--pretty=%B"]); - - // then - assert!(report.contains("Result created")); - assert!(status.trim().is_empty()); - assert_eq!(message.trim(), "feat: add notes"); - - let _ = fs::remove_dir_all(repo); - } - - #[cfg(unix)] - #[test] - fn commit_push_pr_command_commits_pushes_and_creates_pr() { - // given - let _guard = env_lock(); - let repo = init_git_repo("commit-push-pr"); - let remote = init_bare_repo("commit-push-pr-remote"); - run_command( - &repo, - "git", - &[ - "remote", - "add", - "origin", - remote.to_str().expect("utf8 remote"), - ], - ); - run_command(&repo, "git", &["push", "-u", "origin", "main"]); - fs::write(repo.join("feature.txt"), "feature\n").expect("write feature file"); - - let fake_bin = temp_dir("fake-gh-bin"); - let gh_log = fake_bin.join("gh.log"); - write_fake_gh(&fake_bin, &gh_log, "https://example.com/pr/123"); - - let previous_path = env::var_os("PATH"); - let mut new_path = fake_bin.display().to_string(); - if let Some(path) = &previous_path { - new_path.push(':'); - new_path.push_str(&path.to_string_lossy()); - } - env::set_var("PATH", &new_path); - let previous_safeuser = env::var_os("SAFEUSER"); - env::set_var("SAFEUSER", "tester"); - - let request = CommitPushPrRequest { - commit_message: Some("feat: add feature file".to_string()), - pr_title: "Add feature file".to_string(), - pr_body: "## Summary\n- add feature file".to_string(), - branch_name_hint: "Add feature file".to_string(), - }; - - // when - let report = - handle_commit_push_pr_slash_command(&request, &repo).expect("commit-push-pr succeeds"); - let branch = run_command(&repo, "git", &["branch", "--show-current"]); - let message = run_command(&repo, "git", &["log", "-1", "--pretty=%B"]); - let gh_invocations = fs::read_to_string(&gh_log).expect("gh log should exist"); - - // then - assert!(report.contains("Result created")); - assert!(report.contains("URL https://example.com/pr/123")); - assert_eq!(branch.trim(), "tester/add-feature-file"); - assert_eq!(message.trim(), "feat: add feature file"); - assert!(gh_invocations.contains("pr create")); - assert!(gh_invocations.contains("--base main")); - - if let Some(path) = previous_path { - env::set_var("PATH", path); - } else { - env::remove_var("PATH"); - } - if let Some(safeuser) = previous_safeuser { - env::set_var("SAFEUSER", safeuser); - } else { - env::remove_var("SAFEUSER"); - } - - let _ = fs::remove_dir_all(repo); - let _ = fs::remove_dir_all(remote); - let _ = fs::remove_dir_all(fake_bin); - } } diff --git a/crates/compat-harness/src/lib.rs b/crates/compat-harness/src/lib.rs index e4e5a82..1acfec9 100644 --- a/crates/compat-harness/src/lib.rs +++ b/crates/compat-harness/src/lib.rs @@ -70,16 +70,12 @@ fn upstream_repo_candidates(primary_repo_root: &Path) -> Vec { } for ancestor in primary_repo_root.ancestors().take(4) { - candidates.push(ancestor.join("claude-code")); + candidates.push(ancestor.join("claw-code")); candidates.push(ancestor.join("clawd-code")); } - candidates.push( - primary_repo_root - .join("reference-source") - .join("claude-code"), - ); - candidates.push(primary_repo_root.join("vendor").join("claude-code")); + candidates.push(primary_repo_root.join("reference-source").join("claw-code")); + candidates.push(primary_repo_root.join("vendor").join("claw-code")); let mut deduped = Vec::new(); for candidate in candidates { diff --git a/crates/lsp/src/lib.rs b/crates/lsp/src/lib.rs index 9b1b099..fa25e33 100644 --- a/crates/lsp/src/lib.rs +++ b/crates/lsp/src/lib.rs @@ -41,6 +41,7 @@ mod tests { }) } + #[allow(clippy::too_many_lines)] fn write_mock_server_script(root: &std::path::Path) -> PathBuf { let script_path = root.join("mock_lsp_server.py"); fs::write( diff --git a/crates/mock-anthropic-service/Cargo.toml b/crates/mock-anthropic-service/Cargo.toml new file mode 100644 index 0000000..daced90 --- /dev/null +++ b/crates/mock-anthropic-service/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "mock-anthropic-service" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[[bin]] +name = "mock-anthropic-service" +path = "src/main.rs" + +[dependencies] +api = { path = "../api" } +serde_json.workspace = true +tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "signal", "sync"] } + +[lints] +workspace = true diff --git a/crates/mock-anthropic-service/src/lib.rs b/crates/mock-anthropic-service/src/lib.rs new file mode 100644 index 0000000..68968ee --- /dev/null +++ b/crates/mock-anthropic-service/src/lib.rs @@ -0,0 +1,1123 @@ +use std::collections::HashMap; +use std::io; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use api::{InputContentBlock, MessageRequest, MessageResponse, OutputContentBlock, Usage}; +use serde_json::{json, Value}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::{oneshot, Mutex}; +use tokio::task::JoinHandle; + +pub const SCENARIO_PREFIX: &str = "PARITY_SCENARIO:"; +pub const DEFAULT_MODEL: &str = "claude-sonnet-4-6"; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CapturedRequest { + pub method: String, + pub path: String, + pub headers: HashMap, + pub scenario: String, + pub stream: bool, + pub raw_body: String, +} + +pub struct MockAnthropicService { + base_url: String, + requests: Arc>>, + shutdown: Option>, + join_handle: JoinHandle<()>, +} + +impl MockAnthropicService { + pub async fn spawn() -> io::Result { + Self::spawn_on("127.0.0.1:0").await + } + + pub async fn spawn_on(bind_addr: &str) -> io::Result { + let listener = TcpListener::bind(bind_addr).await?; + let address = listener.local_addr()?; + let requests = Arc::new(Mutex::new(Vec::new())); + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let request_state = Arc::clone(&requests); + + let join_handle = tokio::spawn(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => break, + accepted = listener.accept() => { + let Ok((socket, _)) = accepted else { + break; + }; + let request_state = Arc::clone(&request_state); + tokio::spawn(async move { + let _ = handle_connection(socket, request_state).await; + }); + } + } + } + }); + + Ok(Self { + base_url: format!("http://{address}"), + requests, + shutdown: Some(shutdown_tx), + join_handle, + }) + } + + #[must_use] + pub fn base_url(&self) -> String { + self.base_url.clone() + } + + pub async fn captured_requests(&self) -> Vec { + self.requests.lock().await.clone() + } +} + +impl Drop for MockAnthropicService { + fn drop(&mut self) { + if let Some(shutdown) = self.shutdown.take() { + let _ = shutdown.send(()); + } + self.join_handle.abort(); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Scenario { + StreamingText, + ReadFileRoundtrip, + GrepChunkAssembly, + WriteFileAllowed, + WriteFileDenied, + MultiToolTurnRoundtrip, + BashStdoutRoundtrip, + BashPermissionPromptApproved, + BashPermissionPromptDenied, + PluginToolRoundtrip, + AutoCompactTriggered, + TokenCostReporting, +} + +impl Scenario { + fn parse(value: &str) -> Option { + match value.trim() { + "streaming_text" => Some(Self::StreamingText), + "read_file_roundtrip" => Some(Self::ReadFileRoundtrip), + "grep_chunk_assembly" => Some(Self::GrepChunkAssembly), + "write_file_allowed" => Some(Self::WriteFileAllowed), + "write_file_denied" => Some(Self::WriteFileDenied), + "multi_tool_turn_roundtrip" => Some(Self::MultiToolTurnRoundtrip), + "bash_stdout_roundtrip" => Some(Self::BashStdoutRoundtrip), + "bash_permission_prompt_approved" => Some(Self::BashPermissionPromptApproved), + "bash_permission_prompt_denied" => Some(Self::BashPermissionPromptDenied), + "plugin_tool_roundtrip" => Some(Self::PluginToolRoundtrip), + "auto_compact_triggered" => Some(Self::AutoCompactTriggered), + "token_cost_reporting" => Some(Self::TokenCostReporting), + _ => None, + } + } + + fn name(self) -> &'static str { + match self { + Self::StreamingText => "streaming_text", + Self::ReadFileRoundtrip => "read_file_roundtrip", + Self::GrepChunkAssembly => "grep_chunk_assembly", + Self::WriteFileAllowed => "write_file_allowed", + Self::WriteFileDenied => "write_file_denied", + Self::MultiToolTurnRoundtrip => "multi_tool_turn_roundtrip", + Self::BashStdoutRoundtrip => "bash_stdout_roundtrip", + Self::BashPermissionPromptApproved => "bash_permission_prompt_approved", + Self::BashPermissionPromptDenied => "bash_permission_prompt_denied", + Self::PluginToolRoundtrip => "plugin_tool_roundtrip", + Self::AutoCompactTriggered => "auto_compact_triggered", + Self::TokenCostReporting => "token_cost_reporting", + } + } +} + +async fn handle_connection( + mut socket: tokio::net::TcpStream, + requests: Arc>>, +) -> io::Result<()> { + let (method, path, headers, raw_body) = read_http_request(&mut socket).await?; + let request: MessageRequest = serde_json::from_str(&raw_body) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?; + let scenario = detect_scenario(&request) + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "missing parity scenario"))?; + + requests.lock().await.push(CapturedRequest { + method, + path, + headers, + scenario: scenario.name().to_string(), + stream: request.stream, + raw_body, + }); + + let response = build_http_response(&request, scenario); + socket.write_all(response.as_bytes()).await?; + Ok(()) +} + +async fn read_http_request( + socket: &mut tokio::net::TcpStream, +) -> io::Result<(String, String, HashMap, String)> { + let mut buffer = Vec::new(); + let mut header_end = None; + + loop { + let mut chunk = [0_u8; 1024]; + let read = socket.read(&mut chunk).await?; + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if let Some(position) = find_header_end(&buffer) { + header_end = Some(position); + break; + } + } + + let header_end = header_end + .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "missing http headers"))?; + let (header_bytes, remaining) = buffer.split_at(header_end); + let header_text = String::from_utf8(header_bytes.to_vec()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?; + let mut lines = header_text.split("\r\n"); + let request_line = lines + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing request line"))?; + let mut request_parts = request_line.split_whitespace(); + let method = request_parts + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing method"))? + .to_string(); + let path = request_parts + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing path"))? + .to_string(); + + let mut headers = HashMap::new(); + let mut content_length = 0_usize; + for line in lines { + if line.is_empty() { + continue; + } + let (name, value) = line.split_once(':').ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "malformed http header line") + })?; + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse().map_err(|error| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("invalid content-length: {error}"), + ) + })?; + } + headers.insert(name.to_ascii_lowercase(), value); + } + + let mut body = remaining[4..].to_vec(); + while body.len() < content_length { + let mut chunk = vec![0_u8; content_length - body.len()]; + let read = socket.read(&mut chunk).await?; + if read == 0 { + break; + } + body.extend_from_slice(&chunk[..read]); + } + + let body = String::from_utf8(body) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string()))?; + Ok((method, path, headers, body)) +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn detect_scenario(request: &MessageRequest) -> Option { + request.messages.iter().rev().find_map(|message| { + message.content.iter().rev().find_map(|block| match block { + InputContentBlock::Text { text } => text + .split_whitespace() + .find_map(|token| token.strip_prefix(SCENARIO_PREFIX)) + .and_then(Scenario::parse), + _ => None, + }) + }) +} + +fn latest_tool_result(request: &MessageRequest) -> Option<(String, bool)> { + request.messages.iter().rev().find_map(|message| { + message.content.iter().rev().find_map(|block| match block { + InputContentBlock::ToolResult { + content, is_error, .. + } => Some((flatten_tool_result_content(content), *is_error)), + _ => None, + }) + }) +} + +fn tool_results_by_name(request: &MessageRequest) -> HashMap { + let mut tool_names_by_id = HashMap::new(); + for message in &request.messages { + for block in &message.content { + if let InputContentBlock::ToolUse { id, name, .. } = block { + tool_names_by_id.insert(id.clone(), name.clone()); + } + } + } + + let mut results = HashMap::new(); + for message in request.messages.iter().rev() { + for block in message.content.iter().rev() { + if let InputContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } = block + { + let tool_name = tool_names_by_id + .get(tool_use_id) + .cloned() + .unwrap_or_else(|| tool_use_id.clone()); + results + .entry(tool_name) + .or_insert_with(|| (flatten_tool_result_content(content), *is_error)); + } + } + } + results +} + +fn flatten_tool_result_content(content: &[api::ToolResultContentBlock]) -> String { + content + .iter() + .map(|block| match block { + api::ToolResultContentBlock::Text { text } => text.clone(), + api::ToolResultContentBlock::Json { value } => value.to_string(), + }) + .collect::>() + .join("\n") +} + +#[allow(clippy::too_many_lines)] +fn build_http_response(request: &MessageRequest, scenario: Scenario) -> String { + let response = if request.stream { + let body = build_stream_body(request, scenario); + return http_response( + "200 OK", + "text/event-stream", + &body, + &[("x-request-id", request_id_for(scenario))], + ); + } else { + build_message_response(request, scenario) + }; + + http_response( + "200 OK", + "application/json", + &serde_json::to_string(&response).expect("message response should serialize"), + &[("request-id", request_id_for(scenario))], + ) +} + +#[allow(clippy::too_many_lines)] +fn build_stream_body(request: &MessageRequest, scenario: Scenario) -> String { + match scenario { + Scenario::StreamingText => streaming_text_sse(), + Scenario::ReadFileRoundtrip => match latest_tool_result(request) { + Some((tool_output, _)) => final_text_sse(&format!( + "read_file roundtrip complete: {}", + extract_read_content(&tool_output) + )), + None => tool_use_sse( + "toolu_read_fixture", + "read_file", + &[r#"{"path":"fixture.txt"}"#], + ), + }, + Scenario::GrepChunkAssembly => match latest_tool_result(request) { + Some((tool_output, _)) => final_text_sse(&format!( + "grep_search matched {} occurrences", + extract_num_matches(&tool_output) + )), + None => tool_use_sse( + "toolu_grep_fixture", + "grep_search", + &[ + "{\"pattern\":\"par", + "ity\",\"path\":\"fixture.txt\"", + ",\"output_mode\":\"count\"}", + ], + ), + }, + Scenario::WriteFileAllowed => match latest_tool_result(request) { + Some((tool_output, _)) => final_text_sse(&format!( + "write_file succeeded: {}", + extract_file_path(&tool_output) + )), + None => tool_use_sse( + "toolu_write_allowed", + "write_file", + &[r#"{"path":"generated/output.txt","content":"created by mock service\n"}"#], + ), + }, + Scenario::WriteFileDenied => match latest_tool_result(request) { + Some((tool_output, _)) => { + final_text_sse(&format!("write_file denied as expected: {tool_output}")) + } + None => tool_use_sse( + "toolu_write_denied", + "write_file", + &[r#"{"path":"generated/denied.txt","content":"should not exist\n"}"#], + ), + }, + Scenario::MultiToolTurnRoundtrip => { + let tool_results = tool_results_by_name(request); + match ( + tool_results.get("read_file"), + tool_results.get("grep_search"), + ) { + (Some((read_output, _)), Some((grep_output, _))) => final_text_sse(&format!( + "multi-tool roundtrip complete: {} / {} occurrences", + extract_read_content(read_output), + extract_num_matches(grep_output) + )), + _ => tool_uses_sse(&[ + ToolUseSse { + tool_id: "toolu_multi_read", + tool_name: "read_file", + partial_json_chunks: &[r#"{"path":"fixture.txt"}"#], + }, + ToolUseSse { + tool_id: "toolu_multi_grep", + tool_name: "grep_search", + partial_json_chunks: &[ + "{\"pattern\":\"par", + "ity\",\"path\":\"fixture.txt\"", + ",\"output_mode\":\"count\"}", + ], + }, + ]), + } + } + Scenario::BashStdoutRoundtrip => match latest_tool_result(request) { + Some((tool_output, _)) => final_text_sse(&format!( + "bash completed: {}", + extract_bash_stdout(&tool_output) + )), + None => tool_use_sse( + "toolu_bash_stdout", + "bash", + &[r#"{"command":"printf 'alpha from bash'","timeout":1000}"#], + ), + }, + Scenario::BashPermissionPromptApproved => match latest_tool_result(request) { + Some((tool_output, is_error)) => { + if is_error { + final_text_sse(&format!("bash approval unexpectedly failed: {tool_output}")) + } else { + final_text_sse(&format!( + "bash approved and executed: {}", + extract_bash_stdout(&tool_output) + )) + } + } + None => tool_use_sse( + "toolu_bash_prompt_allow", + "bash", + &[r#"{"command":"printf 'approved via prompt'","timeout":1000}"#], + ), + }, + Scenario::BashPermissionPromptDenied => match latest_tool_result(request) { + Some((tool_output, _)) => { + final_text_sse(&format!("bash denied as expected: {tool_output}")) + } + None => tool_use_sse( + "toolu_bash_prompt_deny", + "bash", + &[r#"{"command":"printf 'should not run'","timeout":1000}"#], + ), + }, + Scenario::PluginToolRoundtrip => match latest_tool_result(request) { + Some((tool_output, _)) => final_text_sse(&format!( + "plugin tool completed: {}", + extract_plugin_message(&tool_output) + )), + None => tool_use_sse( + "toolu_plugin_echo", + "plugin_echo", + &[r#"{"message":"hello from plugin parity"}"#], + ), + }, + Scenario::AutoCompactTriggered => { + final_text_sse_with_usage("auto compact parity complete.", 50_000, 200) + } + Scenario::TokenCostReporting => { + final_text_sse_with_usage("token cost reporting parity complete.", 1_000, 500) + } + } +} + +#[allow(clippy::too_many_lines)] +fn build_message_response(request: &MessageRequest, scenario: Scenario) -> MessageResponse { + match scenario { + Scenario::StreamingText => text_message_response( + "msg_streaming_text", + "Mock streaming says hello from the parity harness.", + ), + Scenario::ReadFileRoundtrip => match latest_tool_result(request) { + Some((tool_output, _)) => text_message_response( + "msg_read_file_final", + &format!( + "read_file roundtrip complete: {}", + extract_read_content(&tool_output) + ), + ), + None => tool_message_response( + "msg_read_file_tool", + "toolu_read_fixture", + "read_file", + json!({"path": "fixture.txt"}), + ), + }, + Scenario::GrepChunkAssembly => match latest_tool_result(request) { + Some((tool_output, _)) => text_message_response( + "msg_grep_final", + &format!( + "grep_search matched {} occurrences", + extract_num_matches(&tool_output) + ), + ), + None => tool_message_response( + "msg_grep_tool", + "toolu_grep_fixture", + "grep_search", + json!({"pattern": "parity", "path": "fixture.txt", "output_mode": "count"}), + ), + }, + Scenario::WriteFileAllowed => match latest_tool_result(request) { + Some((tool_output, _)) => text_message_response( + "msg_write_allowed_final", + &format!("write_file succeeded: {}", extract_file_path(&tool_output)), + ), + None => tool_message_response( + "msg_write_allowed_tool", + "toolu_write_allowed", + "write_file", + json!({"path": "generated/output.txt", "content": "created by mock service\n"}), + ), + }, + Scenario::WriteFileDenied => match latest_tool_result(request) { + Some((tool_output, _)) => text_message_response( + "msg_write_denied_final", + &format!("write_file denied as expected: {tool_output}"), + ), + None => tool_message_response( + "msg_write_denied_tool", + "toolu_write_denied", + "write_file", + json!({"path": "generated/denied.txt", "content": "should not exist\n"}), + ), + }, + Scenario::MultiToolTurnRoundtrip => { + let tool_results = tool_results_by_name(request); + match ( + tool_results.get("read_file"), + tool_results.get("grep_search"), + ) { + (Some((read_output, _)), Some((grep_output, _))) => text_message_response( + "msg_multi_tool_final", + &format!( + "multi-tool roundtrip complete: {} / {} occurrences", + extract_read_content(read_output), + extract_num_matches(grep_output) + ), + ), + _ => tool_message_response_many( + "msg_multi_tool_start", + &[ + ToolUseMessage { + tool_id: "toolu_multi_read", + tool_name: "read_file", + input: json!({"path": "fixture.txt"}), + }, + ToolUseMessage { + tool_id: "toolu_multi_grep", + tool_name: "grep_search", + input: json!({"pattern": "parity", "path": "fixture.txt", "output_mode": "count"}), + }, + ], + ), + } + } + Scenario::BashStdoutRoundtrip => match latest_tool_result(request) { + Some((tool_output, _)) => text_message_response( + "msg_bash_stdout_final", + &format!("bash completed: {}", extract_bash_stdout(&tool_output)), + ), + None => tool_message_response( + "msg_bash_stdout_tool", + "toolu_bash_stdout", + "bash", + json!({"command": "printf 'alpha from bash'", "timeout": 1000}), + ), + }, + Scenario::BashPermissionPromptApproved => match latest_tool_result(request) { + Some((tool_output, is_error)) => { + if is_error { + text_message_response( + "msg_bash_prompt_allow_error", + &format!("bash approval unexpectedly failed: {tool_output}"), + ) + } else { + text_message_response( + "msg_bash_prompt_allow_final", + &format!( + "bash approved and executed: {}", + extract_bash_stdout(&tool_output) + ), + ) + } + } + None => tool_message_response( + "msg_bash_prompt_allow_tool", + "toolu_bash_prompt_allow", + "bash", + json!({"command": "printf 'approved via prompt'", "timeout": 1000}), + ), + }, + Scenario::BashPermissionPromptDenied => match latest_tool_result(request) { + Some((tool_output, _)) => text_message_response( + "msg_bash_prompt_deny_final", + &format!("bash denied as expected: {tool_output}"), + ), + None => tool_message_response( + "msg_bash_prompt_deny_tool", + "toolu_bash_prompt_deny", + "bash", + json!({"command": "printf 'should not run'", "timeout": 1000}), + ), + }, + Scenario::PluginToolRoundtrip => match latest_tool_result(request) { + Some((tool_output, _)) => text_message_response( + "msg_plugin_tool_final", + &format!( + "plugin tool completed: {}", + extract_plugin_message(&tool_output) + ), + ), + None => tool_message_response( + "msg_plugin_tool_start", + "toolu_plugin_echo", + "plugin_echo", + json!({"message": "hello from plugin parity"}), + ), + }, + Scenario::AutoCompactTriggered => text_message_response_with_usage( + "msg_auto_compact_triggered", + "auto compact parity complete.", + 50_000, + 200, + ), + Scenario::TokenCostReporting => text_message_response_with_usage( + "msg_token_cost_reporting", + "token cost reporting parity complete.", + 1_000, + 500, + ), + } +} + +fn request_id_for(scenario: Scenario) -> &'static str { + match scenario { + Scenario::StreamingText => "req_streaming_text", + Scenario::ReadFileRoundtrip => "req_read_file_roundtrip", + Scenario::GrepChunkAssembly => "req_grep_chunk_assembly", + Scenario::WriteFileAllowed => "req_write_file_allowed", + Scenario::WriteFileDenied => "req_write_file_denied", + Scenario::MultiToolTurnRoundtrip => "req_multi_tool_turn_roundtrip", + Scenario::BashStdoutRoundtrip => "req_bash_stdout_roundtrip", + Scenario::BashPermissionPromptApproved => "req_bash_permission_prompt_approved", + Scenario::BashPermissionPromptDenied => "req_bash_permission_prompt_denied", + Scenario::PluginToolRoundtrip => "req_plugin_tool_roundtrip", + Scenario::AutoCompactTriggered => "req_auto_compact_triggered", + Scenario::TokenCostReporting => "req_token_cost_reporting", + } +} + +fn http_response(status: &str, content_type: &str, body: &str, headers: &[(&str, &str)]) -> String { + let mut extra_headers = String::new(); + for (name, value) in headers { + use std::fmt::Write as _; + write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed"); + } + format!( + "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) +} + +fn text_message_response(id: &str, text: &str) -> MessageResponse { + MessageResponse { + id: id.to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::Text { + text: text.to_string(), + }], + model: DEFAULT_MODEL.to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 6, + }, + request_id: None, + } +} + +fn text_message_response_with_usage( + id: &str, + text: &str, + input_tokens: u32, + output_tokens: u32, +) -> MessageResponse { + MessageResponse { + id: id.to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::Text { + text: text.to_string(), + }], + model: DEFAULT_MODEL.to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens, + }, + request_id: None, + } +} + +fn tool_message_response( + id: &str, + tool_id: &str, + tool_name: &str, + input: Value, +) -> MessageResponse { + tool_message_response_many( + id, + &[ToolUseMessage { + tool_id, + tool_name, + input, + }], + ) +} + +struct ToolUseMessage<'a> { + tool_id: &'a str, + tool_name: &'a str, + input: Value, +} + +fn tool_message_response_many(id: &str, tool_uses: &[ToolUseMessage<'_>]) -> MessageResponse { + MessageResponse { + id: id.to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: tool_uses + .iter() + .map(|tool_use| OutputContentBlock::ToolUse { + id: tool_use.tool_id.to_string(), + name: tool_use.tool_name.to_string(), + input: tool_use.input.clone(), + }) + .collect(), + model: DEFAULT_MODEL.to_string(), + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 3, + }, + request_id: None, + } +} + +fn streaming_text_sse() -> String { + let mut body = String::new(); + append_sse( + &mut body, + "message_start", + json!({ + "type": "message_start", + "message": { + "id": "msg_streaming_text", + "type": "message", + "role": "assistant", + "content": [], + "model": DEFAULT_MODEL, + "stop_reason": null, + "stop_sequence": null, + "usage": usage_json(11, 0) + } + }), + ); + append_sse( + &mut body, + "content_block_start", + json!({ + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""} + }), + ); + append_sse( + &mut body, + "content_block_delta", + json!({ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "Mock streaming "} + }), + ); + append_sse( + &mut body, + "content_block_delta", + json!({ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": "says hello from the parity harness."} + }), + ); + append_sse( + &mut body, + "content_block_stop", + json!({ + "type": "content_block_stop", + "index": 0 + }), + ); + append_sse( + &mut body, + "message_delta", + json!({ + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": null}, + "usage": usage_json(11, 8) + }), + ); + append_sse(&mut body, "message_stop", json!({"type": "message_stop"})); + body +} + +fn tool_use_sse(tool_id: &str, tool_name: &str, partial_json_chunks: &[&str]) -> String { + tool_uses_sse(&[ToolUseSse { + tool_id, + tool_name, + partial_json_chunks, + }]) +} + +struct ToolUseSse<'a> { + tool_id: &'a str, + tool_name: &'a str, + partial_json_chunks: &'a [&'a str], +} + +fn tool_uses_sse(tool_uses: &[ToolUseSse<'_>]) -> String { + let mut body = String::new(); + let message_id = tool_uses.first().map_or_else( + || "msg_tool_use".to_string(), + |tool_use| format!("msg_{}", tool_use.tool_id), + ); + append_sse( + &mut body, + "message_start", + json!({ + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "content": [], + "model": DEFAULT_MODEL, + "stop_reason": null, + "stop_sequence": null, + "usage": usage_json(12, 0) + } + }), + ); + for (index, tool_use) in tool_uses.iter().enumerate() { + append_sse( + &mut body, + "content_block_start", + json!({ + "type": "content_block_start", + "index": index, + "content_block": { + "type": "tool_use", + "id": tool_use.tool_id, + "name": tool_use.tool_name, + "input": {} + } + }), + ); + for chunk in tool_use.partial_json_chunks { + append_sse( + &mut body, + "content_block_delta", + json!({ + "type": "content_block_delta", + "index": index, + "delta": {"type": "input_json_delta", "partial_json": chunk} + }), + ); + } + append_sse( + &mut body, + "content_block_stop", + json!({ + "type": "content_block_stop", + "index": index + }), + ); + } + append_sse( + &mut body, + "message_delta", + json!({ + "type": "message_delta", + "delta": {"stop_reason": "tool_use", "stop_sequence": null}, + "usage": usage_json(12, 4) + }), + ); + append_sse(&mut body, "message_stop", json!({"type": "message_stop"})); + body +} + +fn final_text_sse(text: &str) -> String { + let mut body = String::new(); + append_sse( + &mut body, + "message_start", + json!({ + "type": "message_start", + "message": { + "id": unique_message_id(), + "type": "message", + "role": "assistant", + "content": [], + "model": DEFAULT_MODEL, + "stop_reason": null, + "stop_sequence": null, + "usage": usage_json(14, 0) + } + }), + ); + append_sse( + &mut body, + "content_block_start", + json!({ + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""} + }), + ); + append_sse( + &mut body, + "content_block_delta", + json!({ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": text} + }), + ); + append_sse( + &mut body, + "content_block_stop", + json!({ + "type": "content_block_stop", + "index": 0 + }), + ); + append_sse( + &mut body, + "message_delta", + json!({ + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": null}, + "usage": usage_json(14, 7) + }), + ); + append_sse(&mut body, "message_stop", json!({"type": "message_stop"})); + body +} + +fn final_text_sse_with_usage(text: &str, input_tokens: u32, output_tokens: u32) -> String { + let mut body = String::new(); + append_sse( + &mut body, + "message_start", + json!({ + "type": "message_start", + "message": { + "id": unique_message_id(), + "type": "message", + "role": "assistant", + "content": [], + "model": DEFAULT_MODEL, + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": input_tokens, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "output_tokens": 0 + } + } + }), + ); + append_sse( + &mut body, + "content_block_start", + json!({ + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""} + }), + ); + append_sse( + &mut body, + "content_block_delta", + json!({ + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": text} + }), + ); + append_sse( + &mut body, + "content_block_stop", + json!({ + "type": "content_block_stop", + "index": 0 + }), + ); + append_sse( + &mut body, + "message_delta", + json!({ + "type": "message_delta", + "delta": {"stop_reason": "end_turn", "stop_sequence": null}, + "usage": { + "input_tokens": input_tokens, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "output_tokens": output_tokens + } + }), + ); + append_sse(&mut body, "message_stop", json!({"type": "message_stop"})); + body +} + +#[allow(clippy::needless_pass_by_value)] +fn append_sse(buffer: &mut String, event: &str, payload: Value) { + use std::fmt::Write as _; + writeln!(buffer, "event: {event}").expect("event write should succeed"); + writeln!(buffer, "data: {payload}").expect("payload write should succeed"); + buffer.push('\n'); +} + +fn usage_json(input_tokens: u32, output_tokens: u32) -> Value { + json!({ + "input_tokens": input_tokens, + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 0, + "output_tokens": output_tokens + }) +} + +fn unique_message_id() -> String { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_nanos(); + format!("msg_{nanos}") +} + +fn extract_read_content(tool_output: &str) -> String { + serde_json::from_str::(tool_output) + .ok() + .and_then(|value| { + value + .get("file") + .and_then(|file| file.get("content")) + .and_then(Value::as_str) + .map(ToOwned::to_owned) + }) + .unwrap_or_else(|| tool_output.trim().to_string()) +} + +#[allow(clippy::cast_possible_truncation)] +fn extract_num_matches(tool_output: &str) -> usize { + serde_json::from_str::(tool_output) + .ok() + .and_then(|value| value.get("numMatches").and_then(Value::as_u64)) + .unwrap_or(0) as usize +} + +fn extract_file_path(tool_output: &str) -> String { + serde_json::from_str::(tool_output) + .ok() + .and_then(|value| { + value + .get("filePath") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + }) + .unwrap_or_else(|| tool_output.trim().to_string()) +} + +fn extract_bash_stdout(tool_output: &str) -> String { + serde_json::from_str::(tool_output) + .ok() + .and_then(|value| { + value + .get("stdout") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + }) + .unwrap_or_else(|| tool_output.trim().to_string()) +} + +fn extract_plugin_message(tool_output: &str) -> String { + serde_json::from_str::(tool_output) + .ok() + .and_then(|value| { + value + .get("input") + .and_then(|input| input.get("message")) + .and_then(Value::as_str) + .map(ToOwned::to_owned) + }) + .unwrap_or_else(|| tool_output.trim().to_string()) +} diff --git a/crates/mock-anthropic-service/src/main.rs b/crates/mock-anthropic-service/src/main.rs new file mode 100644 index 0000000..e81fdb1 --- /dev/null +++ b/crates/mock-anthropic-service/src/main.rs @@ -0,0 +1,34 @@ +use std::env; + +use mock_anthropic_service::MockAnthropicService; + +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<(), Box> { + let mut bind_addr = String::from("127.0.0.1:0"); + let mut args = env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--bind" => { + bind_addr = args + .next() + .ok_or_else(|| "missing value for --bind".to_string())?; + } + flag if flag.starts_with("--bind=") => { + bind_addr = flag[7..].to_string(); + } + "--help" | "-h" => { + println!("Usage: mock-anthropic-service [--bind HOST:PORT]"); + return Ok(()); + } + other => { + return Err(format!("unsupported argument: {other}").into()); + } + } + } + + let server = MockAnthropicService::spawn_on(&bind_addr).await?; + println!("MOCK_ANTHROPIC_BASE_URL={}", server.base_url()); + tokio::signal::ctrl_c().await?; + drop(server); + Ok(()) +} diff --git a/crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json b/crates/plugins/bundled/example-bundled/.claude-plugin/plugin.json similarity index 100% rename from crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json rename to crates/plugins/bundled/example-bundled/.claude-plugin/plugin.json diff --git a/crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json b/crates/plugins/bundled/sample-hooks/.claude-plugin/plugin.json similarity index 100% rename from crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json rename to crates/plugins/bundled/sample-hooks/.claude-plugin/plugin.json diff --git a/crates/plugins/src/hooks.rs b/crates/plugins/src/hooks.rs index 165efc2..b8ee8a5 100644 --- a/crates/plugins/src/hooks.rs +++ b/crates/plugins/src/hooks.rs @@ -1,6 +1,4 @@ use std::ffi::OsStr; -#[cfg(not(windows))] -use std::path::Path; use std::process::Command; use serde_json::json; @@ -11,6 +9,7 @@ use crate::{PluginError, PluginHooks, PluginRegistry}; pub enum HookEvent { PreToolUse, PostToolUse, + PostToolUseFailure, } impl HookEvent { @@ -18,6 +17,7 @@ impl HookEvent { match self { Self::PreToolUse => "PreToolUse", Self::PostToolUse => "PostToolUse", + Self::PostToolUseFailure => "PostToolUseFailure", } } } @@ -25,6 +25,7 @@ impl HookEvent { #[derive(Debug, Clone, PartialEq, Eq)] pub struct HookRunResult { denied: bool, + failed: bool, messages: Vec, } @@ -33,6 +34,7 @@ impl HookRunResult { pub fn allow(messages: Vec) -> Self { Self { denied: false, + failed: false, messages, } } @@ -42,6 +44,11 @@ impl HookRunResult { self.denied } + #[must_use] + pub fn is_failed(&self) -> bool { + self.failed + } + #[must_use] pub fn messages(&self) -> &[String] { &self.messages @@ -65,7 +72,7 @@ impl HookRunner { #[must_use] pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PreToolUse, &self.hooks.pre_tool_use, tool_name, @@ -83,7 +90,7 @@ impl HookRunner { tool_output: &str, is_error: bool, ) -> HookRunResult { - self.run_commands( + Self::run_commands( HookEvent::PostToolUse, &self.hooks.post_tool_use, tool_name, @@ -93,8 +100,24 @@ impl HookRunner { ) } - fn run_commands( + #[must_use] + pub fn run_post_tool_use_failure( &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + ) -> HookRunResult { + Self::run_commands( + HookEvent::PostToolUseFailure, + &self.hooks.post_tool_use_failure, + tool_name, + tool_input, + Some(tool_error), + true, + ) + } + + fn run_commands( event: HookEvent, commands: &[String], tool_name: &str, @@ -106,20 +129,12 @@ impl HookRunner { return HookRunResult::allow(Vec::new()); } - let payload = json!({ - "hook_event_name": event.as_str(), - "tool_name": tool_name, - "tool_input": parse_tool_input(tool_input), - "tool_input_json": tool_input, - "tool_output": tool_output, - "tool_result_is_error": is_error, - }) - .to_string(); + let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string(); let mut messages = Vec::new(); for command in commands { - match self.run_command( + match Self::run_command( command, event, tool_name, @@ -139,19 +154,26 @@ impl HookRunner { })); return HookRunResult { denied: true, + failed: false, + messages, + }; + } + HookCommandOutcome::Failed { message } => { + messages.push(message); + return HookRunResult { + denied: false, + failed: true, messages, }; } - HookCommandOutcome::Warn { message } => messages.push(message), } } HookRunResult::allow(messages) } - #[allow(clippy::too_many_arguments, clippy::unused_self)] + #[allow(clippy::too_many_arguments)] fn run_command( - &self, command: &str, event: HookEvent, tool_name: &str, @@ -180,7 +202,7 @@ impl HookRunner { match output.status.code() { Some(0) => HookCommandOutcome::Allow { message }, Some(2) => HookCommandOutcome::Deny { message }, - Some(code) => HookCommandOutcome::Warn { + Some(code) => HookCommandOutcome::Failed { message: format_hook_warning( command, code, @@ -188,7 +210,7 @@ impl HookRunner { stderr.as_str(), ), }, - None => HookCommandOutcome::Warn { + None => HookCommandOutcome::Failed { message: format!( "{} hook `{command}` terminated by signal while handling `{tool_name}`", event.as_str() @@ -196,7 +218,7 @@ impl HookRunner { }, } } - Err(error) => HookCommandOutcome::Warn { + Err(error) => HookCommandOutcome::Failed { message: format!( "{} hook `{command}` failed to start for `{tool_name}`: {error}", event.as_str() @@ -209,7 +231,34 @@ impl HookRunner { enum HookCommandOutcome { Allow { message: Option }, Deny { message: Option }, - Warn { message: String }, + Failed { message: String }, +} + +fn hook_payload( + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, +) -> serde_json::Value { + match event { + HookEvent::PostToolUseFailure => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_error": tool_output, + "tool_result_is_error": true, + }), + _ => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }), + } } fn parse_tool_input(tool_input: &str) -> serde_json::Value { @@ -217,8 +266,7 @@ fn parse_tool_input(tool_input: &str) -> serde_json::Value { } fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { - let mut message = - format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); + let mut message = format!("Hook `{command}` exited with status {code}"); if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { message.push_str(": "); message.push_str(stdout); @@ -288,7 +336,28 @@ impl CommandWithStdin { let mut child = self.command.spawn()?; if let Some(mut child_stdin) = child.stdin.take() { use std::io::Write as _; - child_stdin.write_all(stdin)?; + // Tolerate BrokenPipe: a hook script that runs to completion + // (or exits early without reading stdin) closes its stdin + // before the parent finishes writing the JSON payload, and + // the kernel raises EPIPE on the parent's write_all. That is + // not a hook failure — the child still exited cleanly and we + // still need to wait_with_output() to capture stdout/stderr + // and the real exit code. Other write errors (e.g. EIO, + // permission, OOM) still propagate. + // + // This was the root cause of the Linux CI flake on + // hooks::tests::collects_and_runs_hooks_from_enabled_plugins + // (ROADMAP #25, runs 24120271422 / 24120538408 / 24121392171 + // / 24121776826): the test hook scripts run in microseconds + // and the parent's stdin write races against child exit. + // macOS pipes happen to buffer the small payload before the + // child exits; Linux pipes do not, so the race shows up + // deterministically on ubuntu runners. + match child_stdin.write_all(stdin) { + Ok(()) => {} + Err(error) if error.kind() == std::io::ErrorKind::BrokenPipe => {} + Err(error) => return Err(error), + } } child.wait_with_output() } @@ -310,23 +379,55 @@ mod tests { std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}")) } - fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) { - fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fn make_executable(path: &Path) { + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = fs::Permissions::from_mode(0o755); + fs::set_permissions(path, perms) + .unwrap_or_else(|e| panic!("chmod +x {}: {e}", path.display())); + } + #[cfg(not(unix))] + let _ = path; + } + + fn write_hook_plugin( + root: &Path, + name: &str, + pre_message: &str, + post_message: &str, + failure_message: &str, + ) { + fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); fs::create_dir_all(root.join("hooks")).expect("hooks dir"); + + let pre_path = root.join("hooks").join("pre.sh"); fs::write( - root.join("hooks").join("pre.sh"), + &pre_path, format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"), ) .expect("write pre hook"); + make_executable(&pre_path); + + let post_path = root.join("hooks").join("post.sh"); fs::write( - root.join("hooks").join("post.sh"), + &post_path, format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"), ) .expect("write post hook"); + make_executable(&post_path); + + let failure_path = root.join("hooks").join("failure.sh"); fs::write( - root.join(".claw-plugin").join("plugin.json"), + &failure_path, + format!("#!/bin/sh\nprintf '%s\\n' '{failure_message}'\n"), + ) + .expect("write failure hook"); + make_executable(&failure_path); + fs::write( + root.join(".claude-plugin").join("plugin.json"), format!( - "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"],\n \"PostToolUseFailure\": [\"./hooks/failure.sh\"]\n }}\n}}" ), ) .expect("write plugin manifest"); @@ -334,6 +435,7 @@ mod tests { #[test] fn collects_and_runs_hooks_from_enabled_plugins() { + // given let config_home = temp_dir("config"); let first_source_root = temp_dir("source-a"); let second_source_root = temp_dir("source-b"); @@ -342,12 +444,14 @@ mod tests { "first", "plugin pre one", "plugin post one", + "plugin failure one", ); write_hook_plugin( &second_source_root, "second", "plugin pre two", "plugin post two", + "plugin failure two", ); let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); @@ -359,8 +463,10 @@ mod tests { .expect("second plugin install should succeed"); let registry = manager.plugin_registry().expect("registry should build"); + // when let runner = HookRunner::from_registry(®istry).expect("plugin hooks should load"); + // then assert_eq!( runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#), HookRunResult::allow(vec![ @@ -375,6 +481,13 @@ mod tests { "plugin post two".to_string(), ]) ); + assert_eq!( + runner.run_post_tool_use_failure("Read", r#"{"path":"README.md"}"#, "tool failed",), + HookRunResult::allow(vec![ + "plugin failure one".to_string(), + "plugin failure two".to_string(), + ]) + ); let _ = fs::remove_dir_all(config_home); let _ = fs::remove_dir_all(first_source_root); @@ -383,14 +496,68 @@ mod tests { #[test] fn pre_tool_use_denies_when_plugin_hook_exits_two() { + // given let runner = HookRunner::new(crate::PluginHooks { pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()], post_tool_use: Vec::new(), + post_tool_use_failure: Vec::new(), }); + // when let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + // then assert!(result.is_denied()); assert_eq!(result.messages(), &["blocked by plugin".to_string()]); } + + #[test] + fn propagates_plugin_hook_failures() { + // given + let runner = HookRunner::new(crate::PluginHooks { + pre_tool_use: vec![ + "printf 'broken plugin hook'; exit 1".to_string(), + "printf 'later plugin hook'".to_string(), + ], + post_tool_use: Vec::new(), + post_tool_use_failure: Vec::new(), + }); + + // when + let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + + // then + assert!(result.is_failed()); + assert!(result + .messages() + .iter() + .any(|message| message.contains("broken plugin hook"))); + assert!(!result + .messages() + .iter() + .any(|message| message == "later plugin hook")); + } + + #[test] + #[cfg(unix)] + fn generated_hook_scripts_are_executable() { + use std::os::unix::fs::PermissionsExt; + + // given + let root = temp_dir("exec-guard"); + write_hook_plugin(&root, "exec-check", "pre", "post", "fail"); + + // then + for script in ["pre.sh", "post.sh", "failure.sh"] { + let path = root.join("hooks").join(script); + let mode = fs::metadata(&path) + .unwrap_or_else(|e| panic!("{script} metadata: {e}")) + .permissions() + .mode(); + assert!( + mode & 0o111 != 0, + "{script} must have at least one execute bit set, got mode {mode:#o}" + ); + } + } } diff --git a/crates/plugins/src/lib.rs b/crates/plugins/src/lib.rs index 6105ad9..052d1ba 100644 --- a/crates/plugins/src/lib.rs +++ b/crates/plugins/src/lib.rs @@ -18,7 +18,7 @@ const BUNDLED_MARKETPLACE: &str = "bundled"; const SETTINGS_FILE_NAME: &str = "settings.json"; const REGISTRY_FILE_NAME: &str = "installed.json"; const MANIFEST_FILE_NAME: &str = "plugin.json"; -const MANIFEST_RELATIVE_PATH: &str = ".claw-plugin/plugin.json"; +const MANIFEST_RELATIVE_PATH: &str = ".claude-plugin/plugin.json"; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] @@ -67,12 +67,16 @@ pub struct PluginHooks { pub pre_tool_use: Vec, #[serde(rename = "PostToolUse", default)] pub post_tool_use: Vec, + #[serde(rename = "PostToolUseFailure", default)] + pub post_tool_use_failure: Vec, } impl PluginHooks { #[must_use] pub fn is_empty(&self) -> bool { - self.pre_tool_use.is_empty() && self.post_tool_use.is_empty() + self.pre_tool_use.is_empty() + && self.post_tool_use.is_empty() + && self.post_tool_use_failure.is_empty() } #[must_use] @@ -85,6 +89,9 @@ impl PluginHooks { .post_tool_use .extend(other.post_tool_use.iter().cloned()); merged + .post_tool_use_failure + .extend(other.post_tool_use_failure.iter().cloned()); + merged } } @@ -302,14 +309,14 @@ impl PluginTool { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) - .env("CLAW_PLUGIN_ID", &self.plugin_id) - .env("CLAW_PLUGIN_NAME", &self.plugin_name) - .env("CLAW_TOOL_NAME", &self.definition.name) - .env("CLAW_TOOL_INPUT", &input_json); + .env("CLAWD_PLUGIN_ID", &self.plugin_id) + .env("CLAWD_PLUGIN_NAME", &self.plugin_name) + .env("CLAWD_TOOL_NAME", &self.definition.name) + .env("CLAWD_TOOL_INPUT", &input_json); if let Some(root) = &self.root { process .current_dir(root) - .env("CLAW_PLUGIN_ROOT", root.display().to_string()); + .env("CLAWD_PLUGIN_ROOT", root.display().to_string()); } let mut child = process.spawn()?; @@ -648,6 +655,106 @@ pub struct PluginSummary { pub enabled: bool, } +#[derive(Debug)] +pub struct PluginLoadFailure { + pub plugin_root: PathBuf, + pub kind: PluginKind, + pub source: String, + error: Box, +} + +impl PluginLoadFailure { + #[must_use] + pub fn new(plugin_root: PathBuf, kind: PluginKind, source: String, error: PluginError) -> Self { + Self { + plugin_root, + kind, + source, + error: Box::new(error), + } + } + + #[must_use] + pub fn error(&self) -> &PluginError { + self.error.as_ref() + } +} + +impl Display for PluginLoadFailure { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "failed to load {} plugin from `{}` (source: {}): {}", + self.kind, + self.plugin_root.display(), + self.source, + self.error() + ) + } +} + +#[derive(Debug)] +pub struct PluginRegistryReport { + registry: PluginRegistry, + failures: Vec, +} + +impl PluginRegistryReport { + #[must_use] + pub fn new(registry: PluginRegistry, failures: Vec) -> Self { + Self { registry, failures } + } + + #[must_use] + pub fn registry(&self) -> &PluginRegistry { + &self.registry + } + + #[must_use] + pub fn failures(&self) -> &[PluginLoadFailure] { + &self.failures + } + + #[must_use] + pub fn has_failures(&self) -> bool { + !self.failures.is_empty() + } + + #[must_use] + pub fn summaries(&self) -> Vec { + self.registry.summaries() + } + + pub fn into_registry(self) -> Result { + if self.failures.is_empty() { + Ok(self.registry) + } else { + Err(PluginError::LoadFailures(self.failures)) + } + } +} + +#[derive(Debug, Default)] +struct PluginDiscovery { + plugins: Vec, + failures: Vec, +} + +impl PluginDiscovery { + fn push_plugin(&mut self, plugin: PluginDefinition) { + self.plugins.push(plugin); + } + + fn push_failure(&mut self, failure: PluginLoadFailure) { + self.failures.push(failure); + } + + fn extend(&mut self, other: Self) { + self.plugins.extend(other.plugins); + self.failures.extend(other.failures); + } +} + #[derive(Debug, Clone, Default, PartialEq)] pub struct PluginRegistry { plugins: Vec, @@ -802,6 +909,10 @@ pub enum PluginManifestValidationError { kind: &'static str, path: PathBuf, }, + PathIsDirectory { + kind: &'static str, + path: PathBuf, + }, InvalidToolInputSchema { tool_name: String, }, @@ -809,6 +920,9 @@ pub enum PluginManifestValidationError { tool_name: String, permission: String, }, + UnsupportedManifestContract { + detail: String, + }, } impl Display for PluginManifestValidationError { @@ -838,6 +952,9 @@ impl Display for PluginManifestValidationError { Self::MissingPath { kind, path } => { write!(f, "{kind} path `{}` does not exist", path.display()) } + Self::PathIsDirectory { kind, path } => { + write!(f, "{kind} path `{}` must point to a file", path.display()) + } Self::InvalidToolInputSchema { tool_name } => { write!( f, @@ -851,6 +968,7 @@ impl Display for PluginManifestValidationError { f, "plugin tool `{tool_name}` requiredPermission `{permission}` must be read-only, workspace-write, or danger-full-access" ), + Self::UnsupportedManifestContract { detail } => f.write_str(detail), } } } @@ -860,6 +978,7 @@ pub enum PluginError { Io(std::io::Error), Json(serde_json::Error), ManifestValidation(Vec), + LoadFailures(Vec), InvalidManifest(String), NotFound(String), CommandFailed(String), @@ -879,6 +998,15 @@ impl Display for PluginError { } Ok(()) } + Self::LoadFailures(failures) => { + for (index, failure) in failures.iter().enumerate() { + if index > 0 { + write!(f, "; ")?; + } + write!(f, "{failure}")?; + } + Ok(()) + } Self::InvalidManifest(message) | Self::NotFound(message) | Self::CommandFailed(message) => write!(f, "{message}"), @@ -935,15 +1063,23 @@ impl PluginManager { } pub fn plugin_registry(&self) -> Result { - Ok(PluginRegistry::new( - self.discover_plugins()? - .into_iter() - .map(|plugin| { - let enabled = self.is_enabled(plugin.metadata()); - RegisteredPlugin::new(plugin, enabled) - }) - .collect(), - )) + self.plugin_registry_report()?.into_registry() + } + + pub fn plugin_registry_report(&self) -> Result { + self.sync_bundled_plugins()?; + + let mut discovery = PluginDiscovery::default(); + discovery.plugins.extend(builtin_plugins()); + + let installed = self.discover_installed_plugins_with_failures()?; + discovery.extend(installed); + + let external = + self.discover_external_directory_plugins_with_failures(&discovery.plugins)?; + discovery.extend(external); + + Ok(self.build_registry_report(discovery)) } pub fn list_plugins(&self) -> Result, PluginError> { @@ -955,11 +1091,12 @@ impl PluginManager { } pub fn discover_plugins(&self) -> Result, PluginError> { - self.sync_bundled_plugins()?; - let mut plugins = builtin_plugins(); - plugins.extend(self.discover_installed_plugins()?); - plugins.extend(self.discover_external_directory_plugins(&plugins)?); - Ok(plugins) + Ok(self + .plugin_registry()? + .plugins + .into_iter() + .map(|plugin| plugin.definition) + .collect()) } pub fn aggregated_hooks(&self) -> Result { @@ -1094,9 +1231,9 @@ impl PluginManager { }) } - fn discover_installed_plugins(&self) -> Result, PluginError> { + fn discover_installed_plugins_with_failures(&self) -> Result { let mut registry = self.load_registry()?; - let mut plugins = Vec::new(); + let mut discovery = PluginDiscovery::default(); let mut seen_ids = BTreeSet::::new(); let mut seen_paths = BTreeSet::::new(); let mut stale_registry_ids = Vec::new(); @@ -1111,10 +1248,21 @@ impl PluginManager { || install_path.display().to_string(), |record| describe_install_source(&record.source), ); - let plugin = load_plugin_definition(&install_path, kind, source, kind.marketplace())?; - if seen_ids.insert(plugin.metadata().id.clone()) { - seen_paths.insert(install_path); - plugins.push(plugin); + match load_plugin_definition(&install_path, kind, source.clone(), kind.marketplace()) { + Ok(plugin) => { + if seen_ids.insert(plugin.metadata().id.clone()) { + seen_paths.insert(install_path); + discovery.push_plugin(plugin); + } + } + Err(error) => { + discovery.push_failure(PluginLoadFailure::new( + install_path, + kind, + source, + error, + )); + } } } @@ -1127,15 +1275,27 @@ impl PluginManager { stale_registry_ids.push(record.id.clone()); continue; } - let plugin = load_plugin_definition( + let source = describe_install_source(&record.source); + match load_plugin_definition( &record.install_path, record.kind, - describe_install_source(&record.source), + source.clone(), record.kind.marketplace(), - )?; - if seen_ids.insert(plugin.metadata().id.clone()) { - seen_paths.insert(record.install_path.clone()); - plugins.push(plugin); + ) { + Ok(plugin) => { + if seen_ids.insert(plugin.metadata().id.clone()) { + seen_paths.insert(record.install_path.clone()); + discovery.push_plugin(plugin); + } + } + Err(error) => { + discovery.push_failure(PluginLoadFailure::new( + record.install_path.clone(), + record.kind, + source, + error, + )); + } } } @@ -1146,47 +1306,51 @@ impl PluginManager { self.store_registry(®istry)?; } - Ok(plugins) + Ok(discovery) } - fn discover_external_directory_plugins( + fn discover_external_directory_plugins_with_failures( &self, existing_plugins: &[PluginDefinition], - ) -> Result, PluginError> { - let mut plugins = Vec::new(); + ) -> Result { + let mut discovery = PluginDiscovery::default(); for directory in &self.config.external_dirs { for root in discover_plugin_dirs(directory)? { - let plugin = load_plugin_definition( + let source = root.display().to_string(); + match load_plugin_definition( &root, PluginKind::External, - root.display().to_string(), + source.clone(), EXTERNAL_MARKETPLACE, - )?; - if existing_plugins - .iter() - .chain(plugins.iter()) - .all(|existing| existing.metadata().id != plugin.metadata().id) - { - plugins.push(plugin); + ) { + Ok(plugin) => { + if existing_plugins + .iter() + .chain(discovery.plugins.iter()) + .all(|existing| existing.metadata().id != plugin.metadata().id) + { + discovery.push_plugin(plugin); + } + } + Err(error) => { + discovery.push_failure(PluginLoadFailure::new( + root, + PluginKind::External, + source, + error, + )); + } } } } - Ok(plugins) + Ok(discovery) } - fn installed_plugin_registry(&self) -> Result { + pub fn installed_plugin_registry_report(&self) -> Result { self.sync_bundled_plugins()?; - Ok(PluginRegistry::new( - self.discover_installed_plugins()? - .into_iter() - .map(|plugin| { - let enabled = self.is_enabled(plugin.metadata()); - RegisteredPlugin::new(plugin, enabled) - }) - .collect(), - )) + Ok(self.build_registry_report(self.discover_installed_plugins_with_failures()?)) } fn sync_bundled_plugins(&self) -> Result<(), PluginError> { @@ -1332,6 +1496,26 @@ impl PluginManager { } }) } + + fn installed_plugin_registry(&self) -> Result { + self.installed_plugin_registry_report()?.into_registry() + } + + fn build_registry_report(&self, discovery: PluginDiscovery) -> PluginRegistryReport { + PluginRegistryReport::new( + PluginRegistry::new( + discovery + .plugins + .into_iter() + .map(|plugin| { + let enabled = self.is_enabled(plugin.metadata()); + RegisteredPlugin::new(plugin, enabled) + }) + .collect(), + ), + discovery.failures, + ) + } } #[must_use] @@ -1414,10 +1598,73 @@ fn load_manifest_from_path( manifest_path.display() )) })?; - let raw_manifest: RawPluginManifest = serde_json::from_str(&contents)?; + let raw_json: Value = serde_json::from_str(&contents)?; + let compatibility_errors = detect_claude_code_manifest_contract_gaps(&raw_json); + if !compatibility_errors.is_empty() { + return Err(PluginError::ManifestValidation(compatibility_errors)); + } + let raw_manifest: RawPluginManifest = serde_json::from_value(raw_json)?; build_plugin_manifest(root, raw_manifest) } +fn detect_claude_code_manifest_contract_gaps( + raw_manifest: &Value, +) -> Vec { + let Some(root) = raw_manifest.as_object() else { + return Vec::new(); + }; + + let mut errors = Vec::new(); + + for (field, detail) in [ + ( + "skills", + "plugin manifest field `skills` uses the Claude Code plugin contract; `claw` does not load plugin-managed skills and instead discovers skills from local roots such as `.claw/skills`, `.omc/skills`, `.agents/skills`, `~/.omc/skills`, and `~/.claude/skills/omc-learned`.", + ), + ( + "mcpServers", + "plugin manifest field `mcpServers` uses the Claude Code plugin contract; `claw` does not import MCP servers from plugin manifests.", + ), + ( + "agents", + "plugin manifest field `agents` uses the Claude Code plugin contract; `claw` does not load plugin-managed agent markdown catalogs from plugin manifests.", + ), + ] { + if root.contains_key(field) { + errors.push(PluginManifestValidationError::UnsupportedManifestContract { + detail: detail.to_string(), + }); + } + } + + if root + .get("commands") + .and_then(Value::as_array) + .is_some_and(|commands| commands.iter().any(Value::is_string)) + { + errors.push(PluginManifestValidationError::UnsupportedManifestContract { + detail: "plugin manifest field `commands` uses Claude Code-style directory globs; `claw` slash dispatch is still built-in and does not load plugin slash command markdown files.".to_string(), + }); + } + + if let Some(hooks) = root.get("hooks").and_then(Value::as_object) { + for hook_name in hooks.keys() { + if !matches!( + hook_name.as_str(), + "PreToolUse" | "PostToolUse" | "PostToolUseFailure" + ) { + errors.push(PluginManifestValidationError::UnsupportedManifestContract { + detail: format!( + "plugin hook `{hook_name}` uses the Claude Code lifecycle contract; `claw` plugins currently support only PreToolUse, PostToolUse, and PostToolUseFailure." + ), + }); + } + } + } + + errors +} + fn plugin_manifest_path(root: &Path) -> Result { let direct_path = root.join(MANIFEST_FILE_NAME); if direct_path.exists() { @@ -1449,6 +1696,12 @@ fn build_plugin_manifest( let permissions = build_manifest_permissions(&raw.permissions, &mut errors); validate_command_entries(root, raw.hooks.pre_tool_use.iter(), "hook", &mut errors); validate_command_entries(root, raw.hooks.post_tool_use.iter(), "hook", &mut errors); + validate_command_entries( + root, + raw.hooks.post_tool_use_failure.iter(), + "hook", + &mut errors, + ); validate_command_entries( root, raw.lifecycle.init.iter(), @@ -1676,6 +1929,8 @@ fn validate_command_entry( }; if !path.exists() { errors.push(PluginManifestValidationError::MissingPath { kind, path }); + } else if !path.is_file() { + errors.push(PluginManifestValidationError::PathIsDirectory { kind, path }); } } @@ -1691,6 +1946,11 @@ fn resolve_hooks(root: &Path, hooks: &PluginHooks) -> PluginHooks { .iter() .map(|entry| resolve_hook_entry(root, entry)) .collect(), + post_tool_use_failure: hooks + .post_tool_use_failure + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), } } @@ -1739,7 +1999,12 @@ fn validate_hook_paths(root: Option<&Path>, hooks: &PluginHooks) -> Result<(), P let Some(root) = root else { return Ok(()); }; - for entry in hooks.pre_tool_use.iter().chain(hooks.post_tool_use.iter()) { + for entry in hooks + .pre_tool_use + .iter() + .chain(hooks.post_tool_use.iter()) + .chain(hooks.post_tool_use_failure.iter()) + { validate_command_path(root, entry, "hook")?; } Ok(()) @@ -1783,6 +2048,12 @@ fn validate_command_path(root: &Path, entry: &str, kind: &str) -> Result<(), Plu path.display() ))); } + if !path.is_file() { + return Err(PluginError::InvalidManifest(format!( + "{kind} path `{}` must point to a file", + path.display() + ))); + } Ok(()) } @@ -2094,6 +2365,30 @@ mod tests { ); } + fn write_directory_path_plugin(root: &Path, name: &str) { + fs::create_dir_all(root.join("hooks").join("pre-dir")).expect("hook dir"); + fs::create_dir_all(root.join("tools").join("tool-dir")).expect("tool dir"); + fs::create_dir_all(root.join("commands").join("sync-dir")).expect("command dir"); + fs::create_dir_all(root.join("lifecycle").join("init-dir")).expect("lifecycle dir"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"directory path plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre-dir\"]\n }},\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init-dir\"]\n }},\n \"tools\": [\n {{\n \"name\": \"dir_tool\",\n \"description\": \"Directory tool\",\n \"inputSchema\": {{\"type\": \"object\"}},\n \"command\": \"./tools/tool-dir\"\n }}\n ],\n \"commands\": [\n {{\n \"name\": \"sync\",\n \"description\": \"Directory command\",\n \"command\": \"./commands/sync-dir\"\n }}\n ]\n}}" + ) + .as_str(), + ); + } + + fn write_broken_failure_hook_plugin(root: &Path, name: &str) { + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"broken plugin\",\n \"hooks\": {{\n \"PostToolUseFailure\": [\"./hooks/missing-failure.sh\"]\n }}\n}}" + ) + .as_str(), + ); + } + fn write_lifecycle_plugin(root: &Path, name: &str, version: &str) -> PathBuf { let log_path = root.join("lifecycle.log"); write_file( @@ -2122,7 +2417,7 @@ mod tests { let script_path = root.join("tools").join("echo-json.sh"); write_file( &script_path, - "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAW_PLUGIN_ID\" \"$CLAW_TOOL_NAME\" \"$INPUT\"\n", + "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_PLUGIN_ID\" \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n", ); #[cfg(unix)] { @@ -2289,6 +2584,37 @@ mod tests { let _ = fs::remove_dir_all(root); } + #[test] + fn load_plugin_from_directory_rejects_claude_code_manifest_contracts_with_guidance() { + let root = temp_dir("manifest-claude-code-contract"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "oh-my-claudecode", + "version": "4.10.2", + "description": "Claude Code plugin manifest", + "hooks": { + "SessionStart": ["scripts/session-start.mjs"] + }, + "agents": ["agents/*.md"], + "commands": ["commands/**/*.md"], + "skills": "./skills/", + "mcpServers": "./.mcp.json" +}"#, + ); + + let error = load_plugin_from_directory(&root) + .expect_err("Claude Code plugin manifest should fail with guidance"); + let rendered = error.to_string(); + assert!(rendered.contains("field `skills` uses the Claude Code plugin contract")); + assert!(rendered.contains("field `mcpServers` uses the Claude Code plugin contract")); + assert!(rendered.contains("field `agents` uses the Claude Code plugin contract")); + assert!(rendered.contains("field `commands` uses Claude Code-style directory globs")); + assert!(rendered.contains("hook `SessionStart` uses the Claude Code lifecycle contract")); + + let _ = fs::remove_dir_all(root); + } + #[test] fn load_plugin_from_directory_rejects_missing_tool_or_command_paths() { let root = temp_dir("manifest-paths"); @@ -2315,6 +2641,90 @@ mod tests { let _ = fs::remove_dir_all(root); } + #[test] + fn load_plugin_from_directory_rejects_missing_lifecycle_paths() { + // given + let root = temp_dir("manifest-lifecycle-paths"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "missing-lifecycle-paths", + "version": "1.0.0", + "description": "Missing lifecycle path validation", + "lifecycle": { + "Init": ["./lifecycle/init.sh"], + "Shutdown": ["./lifecycle/shutdown.sh"] + } +}"#, + ); + + // when + let error = + load_plugin_from_directory(&root).expect_err("missing lifecycle paths should fail"); + + // then + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::MissingPath { kind, path } + if *kind == "lifecycle command" + && path.ends_with(Path::new("lifecycle/init.sh")) + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::MissingPath { kind, path } + if *kind == "lifecycle command" + && path.ends_with(Path::new("lifecycle/shutdown.sh")) + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_directory_command_paths() { + // given + let root = temp_dir("manifest-directory-paths"); + write_directory_path_plugin(&root, "directory-paths"); + + // when + let error = + load_plugin_from_directory(&root).expect_err("directory command paths should fail"); + + // then + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::PathIsDirectory { kind, path } + if *kind == "hook" && path.ends_with(Path::new("hooks/pre-dir")) + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::PathIsDirectory { kind, path } + if *kind == "lifecycle command" + && path.ends_with(Path::new("lifecycle/init-dir")) + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::PathIsDirectory { kind, path } + if *kind == "tool" && path.ends_with(Path::new("tools/tool-dir")) + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::PathIsDirectory { kind, path } + if *kind == "command" && path.ends_with(Path::new("commands/sync-dir")) + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + #[test] fn load_plugin_from_directory_rejects_invalid_permissions() { let root = temp_dir("manifest-invalid-permissions"); @@ -2806,16 +3216,95 @@ mod tests { let _ = fs::remove_dir_all(source_root); } + #[test] + fn plugin_registry_report_collects_load_failures_without_dropping_valid_plugins() { + // given + let config_home = temp_dir("report-home"); + let external_root = temp_dir("report-external"); + write_external_plugin(&external_root.join("valid"), "valid-report", "1.0.0"); + write_broken_plugin(&external_root.join("broken"), "broken-report"); + + let mut config = PluginManagerConfig::new(&config_home); + config.external_dirs = vec![external_root.clone()]; + let manager = PluginManager::new(config); + + // when + let report = manager + .plugin_registry_report() + .expect("report should tolerate invalid external plugins"); + + // then + assert!(report.registry().contains("valid-report@external")); + assert_eq!(report.failures().len(), 1); + assert_eq!(report.failures()[0].kind, PluginKind::External); + assert!(report.failures()[0] + .plugin_root + .ends_with(Path::new("broken"))); + assert!(report.failures()[0] + .error() + .to_string() + .contains("does not exist")); + + let error = manager + .plugin_registry() + .expect_err("strict registry should surface load failures"); + match error { + PluginError::LoadFailures(failures) => { + assert_eq!(failures.len(), 1); + assert!(failures[0].plugin_root.ends_with(Path::new("broken"))); + } + other => panic!("expected load failures, got {other}"), + } + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(external_root); + } + + #[test] + fn installed_plugin_registry_report_collects_load_failures_from_install_root() { + // given + let config_home = temp_dir("installed-report-home"); + let bundled_root = temp_dir("installed-report-bundled"); + let install_root = config_home.join("plugins").join("installed"); + write_external_plugin(&install_root.join("valid"), "installed-valid", "1.0.0"); + write_broken_plugin(&install_root.join("broken"), "installed-broken"); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root); + let manager = PluginManager::new(config); + + // when + let report = manager + .installed_plugin_registry_report() + .expect("installed report should tolerate invalid installed plugins"); + + // then + assert!(report.registry().contains("installed-valid@external")); + assert_eq!(report.failures().len(), 1); + assert!(report.failures()[0] + .plugin_root + .ends_with(Path::new("broken"))); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + #[test] fn rejects_plugin_sources_with_missing_hook_paths() { + // given let config_home = temp_dir("broken-home"); let source_root = temp_dir("broken-source"); write_broken_plugin(&source_root, "broken"); let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + + // when let error = manager .validate_plugin_source(source_root.to_str().expect("utf8 path")) .expect_err("missing hook file should fail validation"); + + // then assert!(error.to_string().contains("does not exist")); let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); @@ -2828,6 +3317,33 @@ mod tests { let _ = fs::remove_dir_all(source_root); } + #[test] + fn rejects_plugin_sources_with_missing_failure_hook_paths() { + // given + let config_home = temp_dir("broken-failure-home"); + let source_root = temp_dir("broken-failure-source"); + write_broken_failure_hook_plugin(&source_root, "broken-failure"); + + let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + + // when + let error = manager + .validate_plugin_source(source_root.to_str().expect("utf8 path")) + .expect_err("missing failure hook file should fail validation"); + + // then + assert!(error.to_string().contains("does not exist")); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install_error = manager + .install(source_root.to_str().expect("utf8 path")) + .expect_err("install should reject invalid failure hook paths"); + assert!(install_error.to_string().contains("does not exist")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + #[test] fn plugin_registry_runs_initialize_and_shutdown_for_enabled_plugins() { let config_home = temp_dir("lifecycle-home"); diff --git a/crates/runtime/Cargo.toml b/crates/runtime/Cargo.toml index 025cd03..c1e6a83 100644 --- a/crates/runtime/Cargo.toml +++ b/crates/runtime/Cargo.toml @@ -7,13 +7,14 @@ publish.workspace = true [dependencies] sha2 = "0.10" +which = "7" glob = "0.3" -lsp = { path = "../lsp" } plugins = { path = "../plugins" } regex = "1" serde = { version = "1", features = ["derive"] } serde_json.workspace = true -tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] } +telemetry = { path = "../telemetry" } +tokio = { version = "1", features = ["io-std", "io-util", "macros", "process", "rt", "rt-multi-thread", "time"] } walkdir = "2" [lints] diff --git a/crates/runtime/src/bash.rs b/crates/runtime/src/bash.rs index 7c2fcd2..489685b 100644 --- a/crates/runtime/src/bash.rs +++ b/crates/runtime/src/bash.rs @@ -14,6 +14,7 @@ use crate::sandbox::{ }; use crate::ConfigLoader; +/// Input schema for the built-in bash execution tool. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct BashCommandInput { pub command: String, @@ -33,6 +34,7 @@ pub struct BashCommandInput { pub allowed_mounts: Option>, } +/// Output returned from a bash tool invocation. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct BashCommandOutput { pub stdout: String, @@ -64,6 +66,7 @@ pub struct BashCommandOutput { pub sandbox_status: Option, } +/// Executes a shell command with the requested sandbox settings. pub fn execute_bash(input: BashCommandInput) -> io::Result { let cwd = env::current_dir()?; let sandbox_status = sandbox_status_for_input(&input, &cwd); @@ -134,8 +137,8 @@ async fn execute_bash_async( }; let (output, interrupted) = output_result; - let stdout = String::from_utf8_lossy(&output.stdout).into_owned(); - let stderr = String::from_utf8_lossy(&output.stderr).into_owned(); + let stdout = truncate_output(&String::from_utf8_lossy(&output.stdout)); + let stderr = truncate_output(&String::from_utf8_lossy(&output.stderr)); let no_output_expected = Some(stdout.trim().is_empty() && stderr.trim().is_empty()); let return_code_interpretation = output.status.code().and_then(|code| { if code == 0 { @@ -197,36 +200,31 @@ fn prepare_command( return prepared; } - let mut prepared = if cfg!(target_os = "windows") && !sh_exists() { + let prepared = if cfg!(target_os = "windows") && !sh_exists() { let mut p = Command::new("cmd"); - p.arg("/C").arg(command); + p.arg("/C").arg(command).current_dir(cwd); p } else { let mut p = Command::new("sh"); - p.arg("-lc").arg(command); + p.arg("-lc").arg(command).current_dir(cwd); + if sandbox_status.filesystem_active { + p.env("HOME", cwd.join(".sandbox-home")); + p.env("TMPDIR", cwd.join(".sandbox-tmp")); + } p }; - prepared.current_dir(cwd); - if sandbox_status.filesystem_active { - prepared.env("HOME", cwd.join(".sandbox-home")); - prepared.env("TMPDIR", cwd.join(".sandbox-tmp")); - } prepared } fn sh_exists() -> bool { - env::var_os("PATH").is_some_and(|paths| { - env::split_paths(&paths).any(|path| { - #[cfg(windows)] - { - path.join("sh.exe").exists() || path.join("sh.bat").exists() || path.join("sh").exists() - } - #[cfg(not(windows))] - { - path.join("sh").exists() - } - }) - }) + #[cfg(windows)] + { + which::which("sh").is_ok() + } + #[cfg(not(windows))] + { + true + } } fn prepare_tokio_command( @@ -247,20 +245,19 @@ fn prepare_tokio_command( return prepared; } - let mut prepared = if cfg!(target_os = "windows") && !sh_exists() { + let prepared = if cfg!(target_os = "windows") && !sh_exists() { let mut p = TokioCommand::new("cmd"); - p.arg("/C").arg(command); + p.arg("/C").arg(command).current_dir(cwd); p } else { let mut p = TokioCommand::new("sh"); - p.arg("-lc").arg(command); + p.arg("-lc").arg(command).current_dir(cwd); + if sandbox_status.filesystem_active { + p.env("HOME", cwd.join(".sandbox-home")); + p.env("TMPDIR", cwd.join(".sandbox-tmp")); + } p }; - prepared.current_dir(cwd); - if sandbox_status.filesystem_active { - prepared.env("HOME", cwd.join(".sandbox-home")); - prepared.env("TMPDIR", cwd.join(".sandbox-tmp")); - } prepared } @@ -312,3 +309,53 @@ mod tests { assert!(!output.sandbox_status.expect("sandbox status").enabled); } } + +/// Maximum output bytes before truncation (16 KiB, matching upstream). +const MAX_OUTPUT_BYTES: usize = 16_384; + +/// Truncate output to `MAX_OUTPUT_BYTES`, appending a marker when trimmed. +fn truncate_output(s: &str) -> String { + if s.len() <= MAX_OUTPUT_BYTES { + return s.to_string(); + } + // Find the last valid UTF-8 boundary at or before MAX_OUTPUT_BYTES + let mut end = MAX_OUTPUT_BYTES; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + let mut truncated = s[..end].to_string(); + truncated.push_str("\n\n[output truncated — exceeded 16384 bytes]"); + truncated +} + +#[cfg(test)] +mod truncation_tests { + use super::*; + + #[test] + fn short_output_unchanged() { + let s = "hello world"; + assert_eq!(truncate_output(s), s); + } + + #[test] + fn long_output_truncated() { + let s = "x".repeat(20_000); + let result = truncate_output(&s); + assert!(result.len() < 20_000); + assert!(result.ends_with("[output truncated — exceeded 16384 bytes]")); + } + + #[test] + fn exact_boundary_unchanged() { + let s = "a".repeat(MAX_OUTPUT_BYTES); + assert_eq!(truncate_output(&s), s); + } + + #[test] + fn one_over_boundary_truncated() { + let s = "a".repeat(MAX_OUTPUT_BYTES + 1); + let result = truncate_output(&s); + assert!(result.contains("[output truncated")); + } +} diff --git a/crates/runtime/src/bash_validation.rs b/crates/runtime/src/bash_validation.rs new file mode 100644 index 0000000..f00619e --- /dev/null +++ b/crates/runtime/src/bash_validation.rs @@ -0,0 +1,1004 @@ +//! Bash command validation submodules. +//! +//! Ports the upstream `BashTool` validation pipeline: +//! - `readOnlyValidation` — block write-like commands in read-only mode +//! - `destructiveCommandWarning` — flag dangerous destructive commands +//! - `modeValidation` — enforce permission mode constraints on commands +//! - `sedValidation` — validate sed expressions before execution +//! - `pathValidation` — detect suspicious path patterns +//! - `commandSemantics` — classify command intent + +use std::path::Path; + +use crate::permissions::PermissionMode; + +/// Result of validating a bash command before execution. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ValidationResult { + /// Command is safe to execute. + Allow, + /// Command should be blocked with the given reason. + Block { reason: String }, + /// Command requires user confirmation with the given warning. + Warn { message: String }, +} + +/// Semantic classification of a bash command's intent. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CommandIntent { + /// Read-only operations: ls, cat, grep, find, etc. + ReadOnly, + /// File system writes: cp, mv, mkdir, touch, tee, etc. + Write, + /// Destructive operations: rm, shred, truncate, etc. + Destructive, + /// Network operations: curl, wget, ssh, etc. + Network, + /// Process management: kill, pkill, etc. + ProcessManagement, + /// Package management: apt, brew, pip, npm, etc. + PackageManagement, + /// System administration: sudo, chmod, chown, mount, etc. + SystemAdmin, + /// Unknown or unclassifiable command. + Unknown, +} + +// --------------------------------------------------------------------------- +// readOnlyValidation +// --------------------------------------------------------------------------- + +/// Commands that perform write operations and should be blocked in read-only mode. +const WRITE_COMMANDS: &[&str] = &[ + "cp", "mv", "rm", "mkdir", "rmdir", "touch", "chmod", "chown", "chgrp", "ln", "install", "tee", + "truncate", "shred", "mkfifo", "mknod", "dd", +]; + +/// Commands that modify system state and should be blocked in read-only mode. +const STATE_MODIFYING_COMMANDS: &[&str] = &[ + "apt", + "apt-get", + "yum", + "dnf", + "pacman", + "brew", + "pip", + "pip3", + "npm", + "yarn", + "pnpm", + "bun", + "cargo", + "gem", + "go", + "rustup", + "docker", + "systemctl", + "service", + "mount", + "umount", + "kill", + "pkill", + "killall", + "reboot", + "shutdown", + "halt", + "poweroff", + "useradd", + "userdel", + "usermod", + "groupadd", + "groupdel", + "crontab", + "at", +]; + +/// Shell redirection operators that indicate writes. +const WRITE_REDIRECTIONS: &[&str] = &[">", ">>", ">&"]; + +/// Validate that a command is allowed under read-only mode. +/// +/// Corresponds to upstream `tools/BashTool/readOnlyValidation.ts`. +#[must_use] +pub fn validate_read_only(command: &str, mode: PermissionMode) -> ValidationResult { + if mode != PermissionMode::ReadOnly { + return ValidationResult::Allow; + } + + let first_command = extract_first_command(command); + + // Check for write commands. + for &write_cmd in WRITE_COMMANDS { + if first_command == write_cmd { + return ValidationResult::Block { + reason: format!( + "Command '{write_cmd}' modifies the filesystem and is not allowed in read-only mode" + ), + }; + } + } + + // Check for state-modifying commands. + for &state_cmd in STATE_MODIFYING_COMMANDS { + if first_command == state_cmd { + return ValidationResult::Block { + reason: format!( + "Command '{state_cmd}' modifies system state and is not allowed in read-only mode" + ), + }; + } + } + + // Check for sudo wrapping write commands. + if first_command == "sudo" { + let inner = extract_sudo_inner(command); + if !inner.is_empty() { + let inner_result = validate_read_only(inner, mode); + if inner_result != ValidationResult::Allow { + return inner_result; + } + } + } + + // Check for write redirections. + for &redir in WRITE_REDIRECTIONS { + if command.contains(redir) { + return ValidationResult::Block { + reason: format!( + "Command contains write redirection '{redir}' which is not allowed in read-only mode" + ), + }; + } + } + + // Check for git commands that modify state. + if first_command == "git" { + return validate_git_read_only(command); + } + + ValidationResult::Allow +} + +/// Git subcommands that are read-only safe. +const GIT_READ_ONLY_SUBCOMMANDS: &[&str] = &[ + "status", + "log", + "diff", + "show", + "branch", + "tag", + "stash", + "remote", + "fetch", + "ls-files", + "ls-tree", + "cat-file", + "rev-parse", + "describe", + "shortlog", + "blame", + "bisect", + "reflog", + "config", +]; + +fn validate_git_read_only(command: &str) -> ValidationResult { + let parts: Vec<&str> = command.split_whitespace().collect(); + // Skip past "git" and any flags (e.g., "git -C /path") + let subcommand = parts.iter().skip(1).find(|p| !p.starts_with('-')); + + match subcommand { + Some(&sub) if GIT_READ_ONLY_SUBCOMMANDS.contains(&sub) => ValidationResult::Allow, + Some(&sub) => ValidationResult::Block { + reason: format!( + "Git subcommand '{sub}' modifies repository state and is not allowed in read-only mode" + ), + }, + None => ValidationResult::Allow, // bare "git" is fine + } +} + +// --------------------------------------------------------------------------- +// destructiveCommandWarning +// --------------------------------------------------------------------------- + +/// Patterns that indicate potentially destructive commands. +const DESTRUCTIVE_PATTERNS: &[(&str, &str)] = &[ + ( + "rm -rf /", + "Recursive forced deletion at root — this will destroy the system", + ), + ("rm -rf ~", "Recursive forced deletion of home directory"), + ( + "rm -rf *", + "Recursive forced deletion of all files in current directory", + ), + ("rm -rf .", "Recursive forced deletion of current directory"), + ( + "mkfs", + "Filesystem creation will destroy existing data on the device", + ), + ( + "dd if=", + "Direct disk write — can overwrite partitions or devices", + ), + ("> /dev/sd", "Writing to raw disk device"), + ( + "chmod -R 777", + "Recursively setting world-writable permissions", + ), + ("chmod -R 000", "Recursively removing all permissions"), + (":(){ :|:& };:", "Fork bomb — will crash the system"), +]; + +/// Commands that are always destructive regardless of arguments. +const ALWAYS_DESTRUCTIVE_COMMANDS: &[&str] = &["shred", "wipefs"]; + +/// Warn if a command looks destructive. +/// +/// Corresponds to upstream `tools/BashTool/destructiveCommandWarning.ts`. +#[must_use] +pub fn check_destructive(command: &str) -> ValidationResult { + // Check known destructive patterns. + for &(pattern, warning) in DESTRUCTIVE_PATTERNS { + if command.contains(pattern) { + return ValidationResult::Warn { + message: format!("Destructive command detected: {warning}"), + }; + } + } + + // Check always-destructive commands. + let first = extract_first_command(command); + for &cmd in ALWAYS_DESTRUCTIVE_COMMANDS { + if first == cmd { + return ValidationResult::Warn { + message: format!( + "Command '{cmd}' is inherently destructive and may cause data loss" + ), + }; + } + } + + // Check for "rm -rf" with broad targets. + if command.contains("rm ") && command.contains("-r") && command.contains("-f") { + // Already handled the most dangerous patterns above. + // Flag any remaining "rm -rf" as a warning. + return ValidationResult::Warn { + message: "Recursive forced deletion detected — verify the target path is correct" + .to_string(), + }; + } + + ValidationResult::Allow +} + +// --------------------------------------------------------------------------- +// modeValidation +// --------------------------------------------------------------------------- + +/// Validate that a command is consistent with the given permission mode. +/// +/// Corresponds to upstream `tools/BashTool/modeValidation.ts`. +#[must_use] +pub fn validate_mode(command: &str, mode: PermissionMode) -> ValidationResult { + match mode { + PermissionMode::ReadOnly => validate_read_only(command, mode), + PermissionMode::WorkspaceWrite => { + // In workspace-write mode, check for system-level destructive + // operations that go beyond workspace scope. + if command_targets_outside_workspace(command) { + return ValidationResult::Warn { + message: + "Command appears to target files outside the workspace — requires elevated permission" + .to_string(), + }; + } + ValidationResult::Allow + } + PermissionMode::DangerFullAccess | PermissionMode::Allow | PermissionMode::Prompt => { + ValidationResult::Allow + } + } +} + +/// Heuristic: does the command reference absolute paths outside typical workspace dirs? +fn command_targets_outside_workspace(command: &str) -> bool { + let system_paths = [ + "/etc/", "/usr/", "/var/", "/boot/", "/sys/", "/proc/", "/dev/", "/sbin/", "/lib/", "/opt/", + ]; + + let first = extract_first_command(command); + let is_write_cmd = WRITE_COMMANDS.contains(&first.as_str()) + || STATE_MODIFYING_COMMANDS.contains(&first.as_str()); + + if !is_write_cmd { + return false; + } + + for sys_path in &system_paths { + if command.contains(sys_path) { + return true; + } + } + + false +} + +// --------------------------------------------------------------------------- +// sedValidation +// --------------------------------------------------------------------------- + +/// Validate sed expressions for safety. +/// +/// Corresponds to upstream `tools/BashTool/sedValidation.ts`. +#[must_use] +pub fn validate_sed(command: &str, mode: PermissionMode) -> ValidationResult { + let first = extract_first_command(command); + if first != "sed" { + return ValidationResult::Allow; + } + + // In read-only mode, block sed -i (in-place editing). + if mode == PermissionMode::ReadOnly && command.contains(" -i") { + return ValidationResult::Block { + reason: "sed -i (in-place editing) is not allowed in read-only mode".to_string(), + }; + } + + ValidationResult::Allow +} + +// --------------------------------------------------------------------------- +// pathValidation +// --------------------------------------------------------------------------- + +/// Validate that command paths don't include suspicious traversal patterns. +/// +/// Corresponds to upstream `tools/BashTool/pathValidation.ts`. +#[must_use] +pub fn validate_paths(command: &str, workspace: &Path) -> ValidationResult { + // Check for directory traversal attempts. + if command.contains("../") { + let workspace_str = workspace.to_string_lossy(); + // Allow traversal if it resolves within workspace (heuristic). + if !command.contains(&*workspace_str) { + return ValidationResult::Warn { + message: "Command contains directory traversal pattern '../' — verify the target path resolves within the workspace".to_string(), + }; + } + } + + // Check for home directory references that could escape workspace. + if command.contains("~/") || command.contains("$HOME") { + return ValidationResult::Warn { + message: + "Command references home directory — verify it stays within the workspace scope" + .to_string(), + }; + } + + ValidationResult::Allow +} + +// --------------------------------------------------------------------------- +// commandSemantics +// --------------------------------------------------------------------------- + +/// Commands that are read-only (no filesystem or state modification). +const SEMANTIC_READ_ONLY_COMMANDS: &[&str] = &[ + "ls", + "cat", + "head", + "tail", + "less", + "more", + "wc", + "sort", + "uniq", + "grep", + "egrep", + "fgrep", + "find", + "which", + "whereis", + "whatis", + "man", + "info", + "file", + "stat", + "du", + "df", + "free", + "uptime", + "uname", + "hostname", + "whoami", + "id", + "groups", + "env", + "printenv", + "echo", + "printf", + "date", + "cal", + "bc", + "expr", + "test", + "true", + "false", + "pwd", + "tree", + "diff", + "cmp", + "md5sum", + "sha256sum", + "sha1sum", + "xxd", + "od", + "hexdump", + "strings", + "readlink", + "realpath", + "basename", + "dirname", + "seq", + "yes", + "tput", + "column", + "jq", + "yq", + "xargs", + "tr", + "cut", + "paste", + "awk", + "sed", +]; + +/// Commands that perform network operations. +const NETWORK_COMMANDS: &[&str] = &[ + "curl", + "wget", + "ssh", + "scp", + "rsync", + "ftp", + "sftp", + "nc", + "ncat", + "telnet", + "ping", + "traceroute", + "dig", + "nslookup", + "host", + "whois", + "ifconfig", + "ip", + "netstat", + "ss", + "nmap", +]; + +/// Commands that manage processes. +const PROCESS_COMMANDS: &[&str] = &[ + "kill", "pkill", "killall", "ps", "top", "htop", "bg", "fg", "jobs", "nohup", "disown", "wait", + "nice", "renice", +]; + +/// Commands that manage packages. +const PACKAGE_COMMANDS: &[&str] = &[ + "apt", "apt-get", "yum", "dnf", "pacman", "brew", "pip", "pip3", "npm", "yarn", "pnpm", "bun", + "cargo", "gem", "go", "rustup", "snap", "flatpak", +]; + +/// Commands that require system administrator privileges. +const SYSTEM_ADMIN_COMMANDS: &[&str] = &[ + "sudo", + "su", + "chroot", + "mount", + "umount", + "fdisk", + "parted", + "lsblk", + "blkid", + "systemctl", + "service", + "journalctl", + "dmesg", + "modprobe", + "insmod", + "rmmod", + "iptables", + "ufw", + "firewall-cmd", + "sysctl", + "crontab", + "at", + "useradd", + "userdel", + "usermod", + "groupadd", + "groupdel", + "passwd", + "visudo", +]; + +/// Classify the semantic intent of a bash command. +/// +/// Corresponds to upstream `tools/BashTool/commandSemantics.ts`. +#[must_use] +pub fn classify_command(command: &str) -> CommandIntent { + let first = extract_first_command(command); + classify_by_first_command(&first, command) +} + +fn classify_by_first_command(first: &str, command: &str) -> CommandIntent { + if SEMANTIC_READ_ONLY_COMMANDS.contains(&first) { + if first == "sed" && command.contains(" -i") { + return CommandIntent::Write; + } + return CommandIntent::ReadOnly; + } + + if ALWAYS_DESTRUCTIVE_COMMANDS.contains(&first) || first == "rm" { + return CommandIntent::Destructive; + } + + if WRITE_COMMANDS.contains(&first) { + return CommandIntent::Write; + } + + if NETWORK_COMMANDS.contains(&first) { + return CommandIntent::Network; + } + + if PROCESS_COMMANDS.contains(&first) { + return CommandIntent::ProcessManagement; + } + + if PACKAGE_COMMANDS.contains(&first) { + return CommandIntent::PackageManagement; + } + + if SYSTEM_ADMIN_COMMANDS.contains(&first) { + return CommandIntent::SystemAdmin; + } + + if first == "git" { + return classify_git_command(command); + } + + CommandIntent::Unknown +} + +fn classify_git_command(command: &str) -> CommandIntent { + let parts: Vec<&str> = command.split_whitespace().collect(); + let subcommand = parts.iter().skip(1).find(|p| !p.starts_with('-')); + match subcommand { + Some(&sub) if GIT_READ_ONLY_SUBCOMMANDS.contains(&sub) => CommandIntent::ReadOnly, + _ => CommandIntent::Write, + } +} + +// --------------------------------------------------------------------------- +// Pipeline: run all validations +// --------------------------------------------------------------------------- + +/// Run the full validation pipeline on a bash command. +/// +/// Returns the first non-Allow result, or Allow if all validations pass. +#[must_use] +pub fn validate_command(command: &str, mode: PermissionMode, workspace: &Path) -> ValidationResult { + // 1. Mode-level validation (includes read-only checks). + let result = validate_mode(command, mode); + if result != ValidationResult::Allow { + return result; + } + + // 2. Sed-specific validation. + let result = validate_sed(command, mode); + if result != ValidationResult::Allow { + return result; + } + + // 3. Destructive command warnings. + let result = check_destructive(command); + if result != ValidationResult::Allow { + return result; + } + + // 4. Path validation. + validate_paths(command, workspace) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Extract the first bare command from a pipeline/chain, stripping env vars and sudo. +fn extract_first_command(command: &str) -> String { + let trimmed = command.trim(); + + // Skip leading environment variable assignments (KEY=val cmd ...). + let mut remaining = trimmed; + loop { + let next = remaining.trim_start(); + if let Some(eq_pos) = next.find('=') { + let before_eq = &next[..eq_pos]; + // Valid env var name: alphanumeric + underscore, no spaces. + if !before_eq.is_empty() + && before_eq + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_') + { + // Skip past the value (might be quoted). + let after_eq = &next[eq_pos + 1..]; + if let Some(space) = find_end_of_value(after_eq) { + remaining = &after_eq[space..]; + continue; + } + // No space found means value goes to end of string — no actual command. + return String::new(); + } + } + break; + } + + remaining + .split_whitespace() + .next() + .unwrap_or("") + .to_string() +} + +/// Extract the command following "sudo" (skip sudo flags). +fn extract_sudo_inner(command: &str) -> &str { + let parts: Vec<&str> = command.split_whitespace().collect(); + let sudo_idx = parts.iter().position(|&p| p == "sudo"); + match sudo_idx { + Some(idx) => { + // Skip flags after sudo. + let rest = &parts[idx + 1..]; + for &part in rest { + if !part.starts_with('-') { + // Found the inner command — return from here to end. + let offset = command.find(part).unwrap_or(0); + return &command[offset..]; + } + } + "" + } + None => "", + } +} + +/// Find the end of a value in `KEY=value rest` (handles basic quoting). +fn find_end_of_value(s: &str) -> Option { + let s = s.trim_start(); + if s.is_empty() { + return None; + } + + let first = s.as_bytes()[0]; + if first == b'"' || first == b'\'' { + let quote = first; + let mut i = 1; + while i < s.len() { + if s.as_bytes()[i] == quote && (i == 0 || s.as_bytes()[i - 1] != b'\\') { + // Skip past quote. + i += 1; + // Find next whitespace. + while i < s.len() && !s.as_bytes()[i].is_ascii_whitespace() { + i += 1; + } + return if i < s.len() { Some(i) } else { None }; + } + i += 1; + } + None + } else { + s.find(char::is_whitespace) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + // --- readOnlyValidation --- + + #[test] + fn blocks_rm_in_read_only() { + assert!(matches!( + validate_read_only("rm -rf /tmp/x", PermissionMode::ReadOnly), + ValidationResult::Block { reason } if reason.contains("rm") + )); + } + + #[test] + fn allows_rm_in_workspace_write() { + assert_eq!( + validate_read_only("rm -rf /tmp/x", PermissionMode::WorkspaceWrite), + ValidationResult::Allow + ); + } + + #[test] + fn blocks_write_redirections_in_read_only() { + assert!(matches!( + validate_read_only("echo hello > file.txt", PermissionMode::ReadOnly), + ValidationResult::Block { reason } if reason.contains("redirection") + )); + } + + #[test] + fn allows_read_commands_in_read_only() { + assert_eq!( + validate_read_only("ls -la", PermissionMode::ReadOnly), + ValidationResult::Allow + ); + assert_eq!( + validate_read_only("cat /etc/hosts", PermissionMode::ReadOnly), + ValidationResult::Allow + ); + assert_eq!( + validate_read_only("grep -r pattern .", PermissionMode::ReadOnly), + ValidationResult::Allow + ); + } + + #[test] + fn blocks_sudo_write_in_read_only() { + assert!(matches!( + validate_read_only("sudo rm -rf /tmp/x", PermissionMode::ReadOnly), + ValidationResult::Block { reason } if reason.contains("rm") + )); + } + + #[test] + fn blocks_git_push_in_read_only() { + assert!(matches!( + validate_read_only("git push origin main", PermissionMode::ReadOnly), + ValidationResult::Block { reason } if reason.contains("push") + )); + } + + #[test] + fn allows_git_status_in_read_only() { + assert_eq!( + validate_read_only("git status", PermissionMode::ReadOnly), + ValidationResult::Allow + ); + } + + #[test] + fn blocks_package_install_in_read_only() { + assert!(matches!( + validate_read_only("npm install express", PermissionMode::ReadOnly), + ValidationResult::Block { reason } if reason.contains("npm") + )); + } + + // --- destructiveCommandWarning --- + + #[test] + fn warns_rm_rf_root() { + assert!(matches!( + check_destructive("rm -rf /"), + ValidationResult::Warn { message } if message.contains("root") + )); + } + + #[test] + fn warns_rm_rf_home() { + assert!(matches!( + check_destructive("rm -rf ~"), + ValidationResult::Warn { message } if message.contains("home") + )); + } + + #[test] + fn warns_shred() { + assert!(matches!( + check_destructive("shred /dev/sda"), + ValidationResult::Warn { message } if message.contains("destructive") + )); + } + + #[test] + fn warns_fork_bomb() { + assert!(matches!( + check_destructive(":(){ :|:& };:"), + ValidationResult::Warn { message } if message.contains("Fork bomb") + )); + } + + #[test] + fn allows_safe_commands() { + assert_eq!(check_destructive("ls -la"), ValidationResult::Allow); + assert_eq!(check_destructive("echo hello"), ValidationResult::Allow); + } + + // --- modeValidation --- + + #[test] + fn workspace_write_warns_system_paths() { + assert!(matches!( + validate_mode("cp file.txt /etc/config", PermissionMode::WorkspaceWrite), + ValidationResult::Warn { message } if message.contains("outside the workspace") + )); + } + + #[test] + fn workspace_write_allows_local_writes() { + assert_eq!( + validate_mode("cp file.txt ./backup/", PermissionMode::WorkspaceWrite), + ValidationResult::Allow + ); + } + + // --- sedValidation --- + + #[test] + fn blocks_sed_inplace_in_read_only() { + assert!(matches!( + validate_sed("sed -i 's/old/new/' file.txt", PermissionMode::ReadOnly), + ValidationResult::Block { reason } if reason.contains("sed -i") + )); + } + + #[test] + fn allows_sed_stdout_in_read_only() { + assert_eq!( + validate_sed("sed 's/old/new/' file.txt", PermissionMode::ReadOnly), + ValidationResult::Allow + ); + } + + // --- pathValidation --- + + #[test] + fn warns_directory_traversal() { + let workspace = PathBuf::from("/workspace/project"); + assert!(matches!( + validate_paths("cat ../../../etc/passwd", &workspace), + ValidationResult::Warn { message } if message.contains("traversal") + )); + } + + #[test] + fn warns_home_directory_reference() { + let workspace = PathBuf::from("/workspace/project"); + assert!(matches!( + validate_paths("cat ~/.ssh/id_rsa", &workspace), + ValidationResult::Warn { message } if message.contains("home directory") + )); + } + + // --- commandSemantics --- + + #[test] + fn classifies_read_only_commands() { + assert_eq!(classify_command("ls -la"), CommandIntent::ReadOnly); + assert_eq!(classify_command("cat file.txt"), CommandIntent::ReadOnly); + assert_eq!( + classify_command("grep -r pattern ."), + CommandIntent::ReadOnly + ); + assert_eq!( + classify_command("find . -name '*.rs'"), + CommandIntent::ReadOnly + ); + } + + #[test] + fn classifies_write_commands() { + assert_eq!(classify_command("cp a.txt b.txt"), CommandIntent::Write); + assert_eq!(classify_command("mv old.txt new.txt"), CommandIntent::Write); + assert_eq!(classify_command("mkdir -p /tmp/dir"), CommandIntent::Write); + } + + #[test] + fn classifies_destructive_commands() { + assert_eq!( + classify_command("rm -rf /tmp/x"), + CommandIntent::Destructive + ); + assert_eq!( + classify_command("shred /dev/sda"), + CommandIntent::Destructive + ); + } + + #[test] + fn classifies_network_commands() { + assert_eq!( + classify_command("curl https://example.com"), + CommandIntent::Network + ); + assert_eq!(classify_command("wget file.zip"), CommandIntent::Network); + } + + #[test] + fn classifies_sed_inplace_as_write() { + assert_eq!( + classify_command("sed -i 's/old/new/' file.txt"), + CommandIntent::Write + ); + } + + #[test] + fn classifies_sed_stdout_as_read_only() { + assert_eq!( + classify_command("sed 's/old/new/' file.txt"), + CommandIntent::ReadOnly + ); + } + + #[test] + fn classifies_git_status_as_read_only() { + assert_eq!(classify_command("git status"), CommandIntent::ReadOnly); + assert_eq!( + classify_command("git log --oneline"), + CommandIntent::ReadOnly + ); + } + + #[test] + fn classifies_git_push_as_write() { + assert_eq!( + classify_command("git push origin main"), + CommandIntent::Write + ); + } + + // --- validate_command (full pipeline) --- + + #[test] + fn pipeline_blocks_write_in_read_only() { + let workspace = PathBuf::from("/workspace"); + assert!(matches!( + validate_command("rm -rf /tmp/x", PermissionMode::ReadOnly, &workspace), + ValidationResult::Block { .. } + )); + } + + #[test] + fn pipeline_warns_destructive_in_write_mode() { + let workspace = PathBuf::from("/workspace"); + assert!(matches!( + validate_command("rm -rf /", PermissionMode::WorkspaceWrite, &workspace), + ValidationResult::Warn { .. } + )); + } + + #[test] + fn pipeline_allows_safe_read_in_read_only() { + let workspace = PathBuf::from("/workspace"); + assert_eq!( + validate_command("ls -la", PermissionMode::ReadOnly, &workspace), + ValidationResult::Allow + ); + } + + // --- extract_first_command --- + + #[test] + fn extracts_command_from_env_prefix() { + assert_eq!(extract_first_command("FOO=bar ls -la"), "ls"); + assert_eq!(extract_first_command("A=1 B=2 echo hello"), "echo"); + } + + #[test] + fn extracts_plain_command() { + assert_eq!(extract_first_command("grep -r pattern ."), "grep"); + } +} diff --git a/crates/runtime/src/bootstrap.rs b/crates/runtime/src/bootstrap.rs index 760f27e..2faba2d 100644 --- a/crates/runtime/src/bootstrap.rs +++ b/crates/runtime/src/bootstrap.rs @@ -21,7 +21,7 @@ pub struct BootstrapPlan { impl BootstrapPlan { #[must_use] - pub fn claw_default() -> Self { + pub fn claude_code_default() -> Self { Self::from_phases(vec![ BootstrapPhase::CliEntry, BootstrapPhase::FastPathVersion, @@ -54,3 +54,58 @@ impl BootstrapPlan { &self.phases } } + +#[cfg(test)] +mod tests { + use super::{BootstrapPhase, BootstrapPlan}; + + #[test] + fn from_phases_deduplicates_while_preserving_order() { + // given + let phases = vec![ + BootstrapPhase::CliEntry, + BootstrapPhase::FastPathVersion, + BootstrapPhase::CliEntry, + BootstrapPhase::MainRuntime, + BootstrapPhase::FastPathVersion, + ]; + + // when + let plan = BootstrapPlan::from_phases(phases); + + // then + assert_eq!( + plan.phases(), + &[ + BootstrapPhase::CliEntry, + BootstrapPhase::FastPathVersion, + BootstrapPhase::MainRuntime, + ] + ); + } + + #[test] + fn claude_code_default_covers_each_phase_once() { + // given + let expected = [ + BootstrapPhase::CliEntry, + BootstrapPhase::FastPathVersion, + BootstrapPhase::StartupProfiler, + BootstrapPhase::SystemPromptFastPath, + BootstrapPhase::ChromeMcpFastPath, + BootstrapPhase::DaemonWorkerFastPath, + BootstrapPhase::BridgeFastPath, + BootstrapPhase::DaemonFastPath, + BootstrapPhase::BackgroundSessionFastPath, + BootstrapPhase::TemplateFastPath, + BootstrapPhase::EnvironmentRunnerFastPath, + BootstrapPhase::MainRuntime, + ]; + + // when + let plan = BootstrapPlan::claude_code_default(); + + // then + assert_eq!(plan.phases(), &expected); + } +} diff --git a/crates/runtime/src/branch_lock.rs b/crates/runtime/src/branch_lock.rs new file mode 100644 index 0000000..6fbf0d0 --- /dev/null +++ b/crates/runtime/src/branch_lock.rs @@ -0,0 +1,144 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BranchLockIntent { + #[serde(rename = "laneId")] + pub lane_id: String, + pub branch: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub worktree: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub modules: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BranchLockCollision { + pub branch: String, + pub module: String, + #[serde(rename = "laneIds")] + pub lane_ids: Vec, +} + +#[must_use] +pub fn detect_branch_lock_collisions(intents: &[BranchLockIntent]) -> Vec { + let mut collisions = Vec::new(); + + for (index, left) in intents.iter().enumerate() { + for right in &intents[index + 1..] { + if left.branch != right.branch { + continue; + } + for module in overlapping_modules(&left.modules, &right.modules) { + collisions.push(BranchLockCollision { + branch: left.branch.clone(), + module, + lane_ids: vec![left.lane_id.clone(), right.lane_id.clone()], + }); + } + } + } + + collisions.sort_by(|a, b| { + a.branch + .cmp(&b.branch) + .then(a.module.cmp(&b.module)) + .then(a.lane_ids.cmp(&b.lane_ids)) + }); + collisions.dedup(); + collisions +} + +fn overlapping_modules(left: &[String], right: &[String]) -> Vec { + let mut overlaps = Vec::new(); + for left_module in left { + for right_module in right { + if modules_overlap(left_module, right_module) { + overlaps.push(shared_scope(left_module, right_module)); + } + } + } + overlaps.sort(); + overlaps.dedup(); + overlaps +} + +fn modules_overlap(left: &str, right: &str) -> bool { + left == right + || left.starts_with(&format!("{right}/")) + || right.starts_with(&format!("{left}/")) +} + +fn shared_scope(left: &str, right: &str) -> String { + if left.starts_with(&format!("{right}/")) || left == right { + right.to_string() + } else { + left.to_string() + } +} + +#[cfg(test)] +mod tests { + use super::{detect_branch_lock_collisions, BranchLockIntent}; + + #[test] + fn detects_same_branch_same_module_collisions() { + let collisions = detect_branch_lock_collisions(&[ + BranchLockIntent { + lane_id: "lane-a".to_string(), + branch: "feature/lock".to_string(), + worktree: Some("wt-a".to_string()), + modules: vec!["runtime/mcp".to_string()], + }, + BranchLockIntent { + lane_id: "lane-b".to_string(), + branch: "feature/lock".to_string(), + worktree: Some("wt-b".to_string()), + modules: vec!["runtime/mcp".to_string()], + }, + ]); + + assert_eq!(collisions.len(), 1); + assert_eq!(collisions[0].branch, "feature/lock"); + assert_eq!(collisions[0].module, "runtime/mcp"); + } + + #[test] + fn detects_nested_module_scope_collisions() { + let collisions = detect_branch_lock_collisions(&[ + BranchLockIntent { + lane_id: "lane-a".to_string(), + branch: "feature/lock".to_string(), + worktree: None, + modules: vec!["runtime".to_string()], + }, + BranchLockIntent { + lane_id: "lane-b".to_string(), + branch: "feature/lock".to_string(), + worktree: None, + modules: vec!["runtime/mcp".to_string()], + }, + ]); + + assert_eq!(collisions[0].module, "runtime"); + } + + #[test] + fn ignores_different_branches() { + let collisions = detect_branch_lock_collisions(&[ + BranchLockIntent { + lane_id: "lane-a".to_string(), + branch: "feature/a".to_string(), + worktree: None, + modules: vec!["runtime/mcp".to_string()], + }, + BranchLockIntent { + lane_id: "lane-b".to_string(), + branch: "feature/b".to_string(), + worktree: None, + modules: vec!["runtime/mcp".to_string()], + }, + ]); + + assert!(collisions.is_empty()); + } +} diff --git a/crates/runtime/src/compact.rs b/crates/runtime/src/compact.rs index 438ee10..922e52a 100644 --- a/crates/runtime/src/compact.rs +++ b/crates/runtime/src/compact.rs @@ -5,6 +5,7 @@ const COMPACT_CONTINUATION_PREAMBLE: &str = const COMPACT_RECENT_MESSAGES_NOTE: &str = "Recent messages are preserved verbatim."; const COMPACT_DIRECT_RESUME_INSTRUCTION: &str = "Continue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text."; +/// Thresholds controlling when and how a session is compacted. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompactionConfig { pub preserve_recent_messages: usize, @@ -20,6 +21,7 @@ impl Default for CompactionConfig { } } +/// Result of compacting a session into a summary plus preserved tail messages. #[derive(Debug, Clone, PartialEq, Eq)] pub struct CompactionResult { pub summary: String, @@ -28,11 +30,13 @@ pub struct CompactionResult { pub removed_message_count: usize, } +/// Roughly estimates the token footprint of the current session transcript. #[must_use] pub fn estimate_session_tokens(session: &Session) -> usize { session.messages.iter().map(estimate_message_tokens).sum() } +/// Returns `true` when the session exceeds the configured compaction budget. #[must_use] pub fn should_compact(session: &Session, config: CompactionConfig) -> bool { let start = compacted_summary_prefix_len(session); @@ -46,6 +50,7 @@ pub fn should_compact(session: &Session, config: CompactionConfig) -> bool { >= config.max_estimated_tokens } +/// Normalizes a compaction summary into user-facing continuation text. #[must_use] pub fn format_compact_summary(summary: &str) -> String { let without_analysis = strip_tag_block(summary, "analysis"); @@ -61,6 +66,7 @@ pub fn format_compact_summary(summary: &str) -> String { collapse_blank_lines(&formatted).trim().to_string() } +/// Builds the synthetic system message used after session compaction. #[must_use] pub fn get_compact_continuation_message( summary: &str, @@ -85,6 +91,7 @@ pub fn get_compact_continuation_message( base } +/// Compacts a session by summarizing older messages and preserving the recent tail. #[must_use] pub fn compact_session(session: &Session, config: CompactionConfig) -> CompactionResult { if !should_compact(session, config) { @@ -101,10 +108,55 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio .first() .and_then(extract_existing_compacted_summary); let compacted_prefix_len = usize::from(existing_summary.is_some()); - let keep_from = session + let raw_keep_from = session .messages .len() .saturating_sub(config.preserve_recent_messages); + // Ensure we do not split a tool-use / tool-result pair at the compaction + // boundary. If the first preserved message is a user message whose first + // block is a ToolResult, the assistant message with the matching ToolUse + // was slated for removal — that produces an orphaned tool role message on + // the OpenAI-compat path (400: tool message must follow assistant with + // tool_calls). Walk the boundary back until we start at a safe point. + let keep_from = { + let mut k = raw_keep_from; + // If the first preserved message is a tool-result turn, ensure its + // paired assistant tool-use turn is preserved too. Without this fix, + // the OpenAI-compat adapter sends an orphaned 'tool' role message + // with no preceding assistant 'tool_calls', which providers reject + // with a 400. We walk back only if the immediately preceding message + // is NOT an assistant message that contains a ToolUse block (i.e. the + // pair is actually broken at the boundary). + loop { + if k == 0 || k <= compacted_prefix_len { + break; + } + let first_preserved = &session.messages[k]; + let starts_with_tool_result = first_preserved + .blocks + .first() + .map(|b| matches!(b, ContentBlock::ToolResult { .. })) + .unwrap_or(false); + if !starts_with_tool_result { + break; + } + // Check the message just before the current boundary. + let preceding = &session.messages[k - 1]; + let preceding_has_tool_use = preceding + .blocks + .iter() + .any(|b| matches!(b, ContentBlock::ToolUse { .. })); + if preceding_has_tool_use { + // Pair is intact — walk back one more to include the assistant turn. + k = k.saturating_sub(1); + break; + } + // Preceding message has no ToolUse but we have a ToolResult — + // this is already an orphaned pair; walk back to try to fix it. + k = k.saturating_sub(1); + } + k + }; let removed = &session.messages[compacted_prefix_len..keep_from]; let preserved = session.messages[keep_from..].to_vec(); let summary = @@ -119,13 +171,14 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio }]; compacted_messages.extend(preserved); + let mut compacted_session = session.clone(); + compacted_session.messages = compacted_messages; + compacted_session.record_compaction(summary.clone(), removed.len()); + CompactionResult { summary, formatted_summary, - compacted_session: Session { - version: session.version, - messages: compacted_messages, - }, + compacted_session, removed_message_count: removed.len(), } } @@ -160,9 +213,7 @@ fn summarize_messages(messages: &[ConversationMessage]) -> String { .filter_map(|block| match block { ContentBlock::ToolUse { name, .. } => Some(name.as_str()), ContentBlock::ToolResult { tool_name, .. } => Some(tool_name.as_str()), - ContentBlock::Text { .. } - | ContentBlock::Thinking { .. } - | ContentBlock::RedactedThinking { .. } => None, + ContentBlock::Text { .. } => None, }) .collect::>(); tool_names.sort_unstable(); @@ -277,8 +328,6 @@ fn summarize_block(block: &ContentBlock) -> String { "tool_result {tool_name}: {}{output}", if *is_error { "error " } else { "" } ), - ContentBlock::Thinking { thinking, .. } => format!("thinking: {thinking}"), - ContentBlock::RedactedThinking { .. } => "thinking: ".to_string(), }; truncate_summary(&raw, 160) } @@ -328,8 +377,6 @@ fn collect_key_files(messages: &[ConversationMessage]) -> Vec { .flat_map(|message| message.blocks.iter()) .map(|block| match block { ContentBlock::Text { text } => text.as_str(), - ContentBlock::Thinking { thinking, .. } => thinking.as_str(), - ContentBlock::RedactedThinking { .. } => "", ContentBlock::ToolUse { input, .. } => input.as_str(), ContentBlock::ToolResult { output, .. } => output.as_str(), }) @@ -354,9 +401,7 @@ fn first_text_block(message: &ConversationMessage) -> Option<&str> { ContentBlock::Text { text } if !text.trim().is_empty() => Some(text.as_str()), ContentBlock::ToolUse { .. } | ContentBlock::ToolResult { .. } - | ContentBlock::Text { .. } - | ContentBlock::Thinking { .. } - | ContentBlock::RedactedThinking { .. } => None, + | ContentBlock::Text { .. } => None, }) } @@ -402,8 +447,6 @@ fn estimate_message_tokens(message: &ConversationMessage) -> usize { .iter() .map(|block| match block { ContentBlock::Text { text } => text.len() / 4 + 1, - ContentBlock::Thinking { thinking, .. } => thinking.len() / 4 + 1, - ContentBlock::RedactedThinking { .. } => 1, ContentBlock::ToolUse { name, input, .. } => (name.len() + input.len()) / 4 + 1, ContentBlock::ToolResult { tool_name, output, .. @@ -512,7 +555,7 @@ fn extract_summary_timeline(summary: &str) -> Vec { #[cfg(test)] mod tests { use super::{ - collect_key_files, compact_session, estimate_session_tokens, format_compact_summary, + collect_key_files, compact_session, format_compact_summary, get_compact_continuation_message, infer_pending_work, should_compact, CompactionConfig, }; use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; @@ -525,10 +568,8 @@ mod tests { #[test] fn leaves_small_sessions_unchanged() { - let session = Session { - version: 1, - messages: vec![ConversationMessage::user_text("hello")], - }; + let mut session = Session::new(); + session.messages = vec![ConversationMessage::user_text("hello")]; let result = compact_session(&session, CompactionConfig::default()); assert_eq!(result.removed_message_count, 0); @@ -539,23 +580,21 @@ mod tests { #[test] fn compacts_older_messages_into_a_system_summary() { - let session = Session { - version: 1, - messages: vec![ - ConversationMessage::user_text("one ".repeat(200)), - ConversationMessage::assistant(vec![ContentBlock::Text { - text: "two ".repeat(200), - }]), - ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false), - ConversationMessage { - role: MessageRole::Assistant, - blocks: vec![ContentBlock::Text { - text: "recent".to_string(), - }], - usage: None, - }, - ], - }; + let mut session = Session::new(); + session.messages = vec![ + ConversationMessage::user_text("one ".repeat(200)), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "two ".repeat(200), + }]), + ConversationMessage::tool_result("1", "bash", "ok ".repeat(200), false), + ConversationMessage { + role: MessageRole::Assistant, + blocks: vec![ContentBlock::Text { + text: "recent".to_string(), + }], + usage: None, + }, + ]; let result = compact_session( &session, @@ -565,7 +604,14 @@ mod tests { }, ); - assert_eq!(result.removed_message_count, 2); + // With the tool-use/tool-result boundary fix, the compaction preserves + // one extra message to avoid an orphaned tool result at the boundary. + // messages[1] (assistant) must be kept along with messages[2] (tool result). + assert!( + result.removed_message_count <= 2, + "expected at most 2 removed, got {}", + result.removed_message_count + ); assert_eq!( result.compacted_session.messages[0].role, MessageRole::System @@ -583,28 +629,29 @@ mod tests { max_estimated_tokens: 1, } )); + // Note: with the tool-use/tool-result boundary guard the compacted session + // may preserve one extra message at the boundary, so token reduction is + // not guaranteed for small sessions. The invariant that matters is that + // the removed_message_count is non-zero (something was compacted). assert!( - estimate_session_tokens(&result.compacted_session) < estimate_session_tokens(&session) + result.removed_message_count > 0, + "compaction must remove at least one message" ); } #[test] fn keeps_previous_compacted_context_when_compacting_again() { - let initial_session = Session { - version: 1, - messages: vec![ - ConversationMessage::user_text("Investigate rust/crates/runtime/src/compact.rs"), - ConversationMessage::assistant(vec![ContentBlock::Text { - text: "I will inspect the compact flow.".to_string(), - }]), - ConversationMessage::user_text( - "Also update rust/crates/runtime/src/conversation.rs", - ), - ConversationMessage::assistant(vec![ContentBlock::Text { - text: "Next: preserve prior summary context during auto compact.".to_string(), - }]), - ], - }; + let mut initial_session = Session::new(); + initial_session.messages = vec![ + ConversationMessage::user_text("Investigate rust/crates/runtime/src/compact.rs"), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "I will inspect the compact flow.".to_string(), + }]), + ConversationMessage::user_text("Also update rust/crates/runtime/src/conversation.rs"), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "Next: preserve prior summary context during auto compact.".to_string(), + }]), + ]; let config = CompactionConfig { preserve_recent_messages: 2, max_estimated_tokens: 1, @@ -619,13 +666,9 @@ mod tests { }]), ]); - let second = compact_session( - &Session { - version: 1, - messages: follow_up_messages, - }, - config, - ); + let mut second_session = Session::new(); + second_session.messages = follow_up_messages; + let second = compact_session(&second_session, config); assert!(second .formatted_summary @@ -654,22 +697,20 @@ mod tests { #[test] fn ignores_existing_compacted_summary_when_deciding_to_recompact() { let summary = "Conversation summary:\n- Scope: earlier work preserved.\n- Key timeline:\n - user: large preserved context\n"; - let session = Session { - version: 1, - messages: vec![ - ConversationMessage { - role: MessageRole::System, - blocks: vec![ContentBlock::Text { - text: get_compact_continuation_message(summary, true, true), - }], - usage: None, - }, - ConversationMessage::user_text("tiny"), - ConversationMessage::assistant(vec![ContentBlock::Text { - text: "recent".to_string(), - }]), - ], - }; + let mut session = Session::new(); + session.messages = vec![ + ConversationMessage { + role: MessageRole::System, + blocks: vec![ContentBlock::Text { + text: get_compact_continuation_message(summary, true, true), + }], + usage: None, + }, + ConversationMessage::user_text("tiny"), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "recent".to_string(), + }]), + ]; assert!(!should_compact( &session, @@ -692,10 +733,84 @@ mod tests { #[test] fn extracts_key_files_from_message_content() { let files = collect_key_files(&[ConversationMessage::user_text( - "Update rust/crates/runtime/src/compact.rs and rust/crates/tools/src/lib.rs next.", + "Update rust/crates/runtime/src/compact.rs and rust/crates/rusty-claude-cli/src/main.rs next.", )]); assert!(files.contains(&"rust/crates/runtime/src/compact.rs".to_string())); - assert!(files.contains(&"rust/crates/tools/src/lib.rs".to_string())); + assert!(files.contains(&"rust/crates/rusty-claude-cli/src/main.rs".to_string())); + } + + /// Regression: compaction must not split an assistant(ToolUse) / + /// user(ToolResult) pair at the boundary. An orphaned tool-result message + /// without the preceding assistant tool_calls causes a 400 on the + /// OpenAI-compat path (gaebal-gajae repro 2026-04-09). + #[test] + fn compaction_does_not_split_tool_use_tool_result_pair() { + use crate::session::{ContentBlock, Session}; + + let tool_id = "call_abc"; + let mut session = Session::default(); + // Turn 1: user prompt + session + .push_message(ConversationMessage::user_text("Search for files")) + .unwrap(); + // Turn 2: assistant calls a tool + session + .push_message(ConversationMessage::assistant(vec![ + ContentBlock::ToolUse { + id: tool_id.to_string(), + name: "search".to_string(), + input: "{\"q\":\"*.rs\"}".to_string(), + }, + ])) + .unwrap(); + // Turn 3: tool result + session + .push_message(ConversationMessage::tool_result( + tool_id, + "search", + "found 5 files", + false, + )) + .unwrap(); + // Turn 4: assistant final response + session + .push_message(ConversationMessage::assistant(vec![ContentBlock::Text { + text: "Done.".to_string(), + }])) + .unwrap(); + + // Compact preserving only 1 recent message — without the fix this + // would cut the boundary so that the tool result (turn 3) is first, + // without its preceding assistant tool_calls (turn 2). + let config = CompactionConfig { + preserve_recent_messages: 1, + ..CompactionConfig::default() + }; + let result = compact_session(&session, config); + // After compaction, no two consecutive messages should have the pattern + // tool_result immediately following a non-assistant message (i.e. an + // orphaned tool result without a preceding assistant ToolUse). + let messages = &result.compacted_session.messages; + for i in 1..messages.len() { + let curr_is_tool_result = messages[i] + .blocks + .first() + .map(|b| matches!(b, ContentBlock::ToolResult { .. })) + .unwrap_or(false); + if curr_is_tool_result { + let prev_has_tool_use = messages[i - 1] + .blocks + .iter() + .any(|b| matches!(b, ContentBlock::ToolUse { .. })); + assert!( + prev_has_tool_use, + "message[{}] is a ToolResult but message[{}] has no ToolUse: {:?}", + i, + i - 1, + &messages[i - 1].blocks + ); + } + } } #[test] diff --git a/crates/runtime/src/config.rs b/crates/runtime/src/config.rs index 11ec21d..c1fe496 100644 --- a/crates/runtime/src/config.rs +++ b/crates/runtime/src/config.rs @@ -6,8 +6,10 @@ use std::path::{Path, PathBuf}; use crate::json::JsonValue; use crate::sandbox::{FilesystemIsolationMode, SandboxConfig}; +/// Schema name advertised by generated settings files. pub const CLAW_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema"; +/// Origin of a loaded settings file in the configuration precedence chain. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum ConfigSource { User, @@ -15,6 +17,7 @@ pub enum ConfigSource { Local, } +/// Effective permission mode after decoding config values. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ResolvedPermissionMode { ReadOnly, @@ -22,12 +25,14 @@ pub enum ResolvedPermissionMode { DangerFullAccess, } +/// A discovered config file and the scope it contributes to. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConfigEntry { pub source: ConfigSource, pub path: PathBuf, } +/// Fully merged runtime configuration plus parsed feature-specific views. #[derive(Debug, Clone, PartialEq, Eq)] pub struct RuntimeConfig { merged: BTreeMap, @@ -35,6 +40,7 @@ pub struct RuntimeConfig { feature_config: RuntimeFeatureConfig, } +/// Parsed plugin-related settings extracted from runtime config. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct RuntimePluginConfig { enabled_plugins: BTreeMap, @@ -42,8 +48,10 @@ pub struct RuntimePluginConfig { install_root: Option, registry_path: Option, bundled_root: Option, + max_output_tokens: Option, } +/// Structured feature configuration consumed by runtime subsystems. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct RuntimeFeatureConfig { hooks: RuntimeHookConfig, @@ -51,27 +59,53 @@ pub struct RuntimeFeatureConfig { mcp: McpConfigCollection, oauth: Option, model: Option, + aliases: BTreeMap, permission_mode: Option, + permission_rules: RuntimePermissionRuleConfig, sandbox: SandboxConfig, + provider_fallbacks: ProviderFallbackConfig, + trusted_roots: Vec, } +/// Ordered chain of fallback model identifiers used when the primary +/// provider returns a retryable failure (429/500/503/etc.). The chain is +/// strict: each entry is tried in order until one succeeds. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct ProviderFallbackConfig { + primary: Option, + fallbacks: Vec, +} + +/// Hook command lists grouped by lifecycle stage. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct RuntimeHookConfig { pre_tool_use: Vec, post_tool_use: Vec, + post_tool_use_failure: Vec, } +/// Raw permission rule lists grouped by allow, deny, and ask behavior. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimePermissionRuleConfig { + allow: Vec, + deny: Vec, + ask: Vec, +} + +/// Collection of configured MCP servers after scope-aware merging. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct McpConfigCollection { servers: BTreeMap, } +/// MCP server config paired with the scope that defined it. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ScopedMcpServerConfig { pub scope: ConfigSource, pub config: McpServerConfig, } +/// Transport families supported by configured MCP servers. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum McpTransport { Stdio, @@ -82,6 +116,7 @@ pub enum McpTransport { ManagedProxy, } +/// Scope-normalized MCP server configuration variants. #[derive(Debug, Clone, PartialEq, Eq)] pub enum McpServerConfig { Stdio(McpStdioServerConfig), @@ -92,13 +127,16 @@ pub enum McpServerConfig { ManagedProxy(McpManagedProxyServerConfig), } +/// Configuration for an MCP server launched as a local stdio process. #[derive(Debug, Clone, PartialEq, Eq)] pub struct McpStdioServerConfig { pub command: String, pub args: Vec, pub env: BTreeMap, + pub tool_call_timeout_ms: Option, } +/// Configuration for an MCP server reached over HTTP or SSE. #[derive(Debug, Clone, PartialEq, Eq)] pub struct McpRemoteServerConfig { pub url: String, @@ -107,6 +145,7 @@ pub struct McpRemoteServerConfig { pub oauth: Option, } +/// Configuration for an MCP server reached over WebSocket. #[derive(Debug, Clone, PartialEq, Eq)] pub struct McpWebSocketServerConfig { pub url: String, @@ -114,17 +153,20 @@ pub struct McpWebSocketServerConfig { pub headers_helper: Option, } +/// Configuration for an MCP server addressed through an SDK name. #[derive(Debug, Clone, PartialEq, Eq)] pub struct McpSdkServerConfig { pub name: String, } +/// Configuration for an MCP managed-proxy endpoint. #[derive(Debug, Clone, PartialEq, Eq)] pub struct McpManagedProxyServerConfig { pub url: String, pub id: String, } +/// OAuth overrides associated with a remote MCP server. #[derive(Debug, Clone, PartialEq, Eq)] pub struct McpOAuthConfig { pub client_id: Option, @@ -133,6 +175,7 @@ pub struct McpOAuthConfig { pub xaa: Option, } +/// OAuth client configuration used by the main Claw runtime. #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthConfig { pub client_id: String, @@ -143,6 +186,7 @@ pub struct OAuthConfig { pub scopes: Vec, } +/// Errors raised while reading or parsing runtime configuration files. #[derive(Debug)] pub enum ConfigError { Io(std::io::Error), @@ -166,6 +210,7 @@ impl From for ConfigError { } } +/// Discovers config files and merges them into a [`RuntimeConfig`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConfigLoader { cwd: PathBuf, @@ -227,16 +272,33 @@ impl ConfigLoader { let mut merged = BTreeMap::new(); let mut loaded_entries = Vec::new(); let mut mcp_servers = BTreeMap::new(); + let mut all_warnings = Vec::new(); for entry in self.discover() { - let Some(value) = read_optional_json_object(&entry.path)? else { + crate::config_validate::check_unsupported_format(&entry.path)?; + let Some(parsed) = read_optional_json_object(&entry.path)? else { continue; }; - merge_mcp_servers(&mut mcp_servers, entry.source, &value, &entry.path)?; - deep_merge_objects(&mut merged, &value); + let validation = crate::config_validate::validate_config_file( + &parsed.object, + &parsed.source, + &entry.path, + ); + if !validation.is_ok() { + let first_error = &validation.errors[0]; + return Err(ConfigError::Parse(first_error.to_string())); + } + all_warnings.extend(validation.warnings); + validate_optional_hooks_config(&parsed.object, &entry.path)?; + merge_mcp_servers(&mut mcp_servers, entry.source, &parsed.object, &entry.path)?; + deep_merge_objects(&mut merged, &parsed.object); loaded_entries.push(entry); } + for warning in &all_warnings { + eprintln!("warning: {warning}"); + } + let merged_value = JsonValue::Object(merged.clone()); let feature_config = RuntimeFeatureConfig { @@ -247,8 +309,12 @@ impl ConfigLoader { }, oauth: parse_optional_oauth_config(&merged_value, "merged settings.oauth")?, model: parse_optional_model(&merged_value), + aliases: parse_optional_aliases(&merged_value)?, permission_mode: parse_optional_permission_mode(&merged_value)?, + permission_rules: parse_optional_permission_rules(&merged_value)?, sandbox: parse_optional_sandbox_config(&merged_value)?, + provider_fallbacks: parse_optional_provider_fallbacks(&merged_value)?, + trusted_roots: parse_optional_trusted_roots(&merged_value)?, }; Ok(RuntimeConfig { @@ -319,15 +385,35 @@ impl RuntimeConfig { self.feature_config.model.as_deref() } + #[must_use] + pub fn aliases(&self) -> &BTreeMap { + &self.feature_config.aliases + } + #[must_use] pub fn permission_mode(&self) -> Option { self.feature_config.permission_mode } + #[must_use] + pub fn permission_rules(&self) -> &RuntimePermissionRuleConfig { + &self.feature_config.permission_rules + } + #[must_use] pub fn sandbox(&self) -> &SandboxConfig { &self.feature_config.sandbox } + + #[must_use] + pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig { + &self.feature_config.provider_fallbacks + } + + #[must_use] + pub fn trusted_roots(&self) -> &[String] { + &self.feature_config.trusted_roots + } } impl RuntimeFeatureConfig { @@ -368,15 +454,57 @@ impl RuntimeFeatureConfig { self.model.as_deref() } + #[must_use] + pub fn aliases(&self) -> &BTreeMap { + &self.aliases + } + #[must_use] pub fn permission_mode(&self) -> Option { self.permission_mode } + #[must_use] + pub fn permission_rules(&self) -> &RuntimePermissionRuleConfig { + &self.permission_rules + } + #[must_use] pub fn sandbox(&self) -> &SandboxConfig { &self.sandbox } + + #[must_use] + pub fn provider_fallbacks(&self) -> &ProviderFallbackConfig { + &self.provider_fallbacks + } + + #[must_use] + pub fn trusted_roots(&self) -> &[String] { + &self.trusted_roots + } +} + +impl ProviderFallbackConfig { + #[must_use] + pub fn new(primary: Option, fallbacks: Vec) -> Self { + Self { primary, fallbacks } + } + + #[must_use] + pub fn primary(&self) -> Option<&str> { + self.primary.as_deref() + } + + #[must_use] + pub fn fallbacks(&self) -> &[String] { + &self.fallbacks + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.fallbacks.is_empty() + } } impl RuntimePluginConfig { @@ -405,6 +533,15 @@ impl RuntimePluginConfig { self.bundled_root.as_deref() } + #[must_use] + pub fn max_output_tokens(&self) -> Option { + self.max_output_tokens + } + + pub fn set_max_output_tokens(&mut self, max_output_tokens: Option) { + self.max_output_tokens = max_output_tokens; + } + pub fn set_plugin_state(&mut self, plugin_id: String, enabled: bool) { self.enabled_plugins.insert(plugin_id, enabled); } @@ -419,6 +556,7 @@ impl RuntimePluginConfig { } #[must_use] +/// Returns the default per-user config directory used by the runtime. pub fn default_config_home() -> PathBuf { std::env::var_os("CLAW_CONFIG_HOME") .map(PathBuf::from) @@ -428,10 +566,15 @@ pub fn default_config_home() -> PathBuf { impl RuntimeHookConfig { #[must_use] - pub fn new(pre_tool_use: Vec, post_tool_use: Vec) -> Self { + pub fn new( + pre_tool_use: Vec, + post_tool_use: Vec, + post_tool_use_failure: Vec, + ) -> Self { Self { pre_tool_use, post_tool_use, + post_tool_use_failure, } } @@ -455,6 +598,37 @@ impl RuntimeHookConfig { pub fn extend(&mut self, other: &Self) { extend_unique(&mut self.pre_tool_use, other.pre_tool_use()); extend_unique(&mut self.post_tool_use, other.post_tool_use()); + extend_unique( + &mut self.post_tool_use_failure, + other.post_tool_use_failure(), + ); + } + + #[must_use] + pub fn post_tool_use_failure(&self) -> &[String] { + &self.post_tool_use_failure + } +} + +impl RuntimePermissionRuleConfig { + #[must_use] + pub fn new(allow: Vec, deny: Vec, ask: Vec) -> Self { + Self { allow, deny, ask } + } + + #[must_use] + pub fn allow(&self) -> &[String] { + &self.allow + } + + #[must_use] + pub fn deny(&self) -> &[String] { + &self.deny + } + + #[must_use] + pub fn ask(&self) -> &[String] { + &self.ask } } @@ -491,9 +665,13 @@ impl McpServerConfig { } } -fn read_optional_json_object( - path: &Path, -) -> Result>, ConfigError> { +/// Parsed JSON object paired with its raw source text for validation. +struct ParsedConfigFile { + object: BTreeMap, + source: String, +} + +fn read_optional_json_object(path: &Path) -> Result, ConfigError> { let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claw.json"); let contents = match fs::read_to_string(path) { Ok(contents) => contents, @@ -502,12 +680,15 @@ fn read_optional_json_object( }; if contents.trim().is_empty() { - return Ok(Some(BTreeMap::new())); + return Ok(Some(ParsedConfigFile { + object: BTreeMap::new(), + source: contents, + })); } let parsed = match JsonValue::parse(&contents) { Ok(parsed) => parsed, - Err(error) if is_legacy_config => return Ok(None), + Err(_error) if is_legacy_config => return Ok(None), Err(error) => return Err(ConfigError::Parse(format!("{}: {error}", path.display()))), }; let Some(object) = parsed.as_object() else { @@ -519,7 +700,10 @@ fn read_optional_json_object( path.display() ))); }; - Ok(Some(object.clone())) + Ok(Some(ParsedConfigFile { + object: object.clone(), + source: contents, + })) } fn merge_mcp_servers( @@ -556,18 +740,59 @@ fn parse_optional_model(root: &JsonValue) -> Option { .map(ToOwned::to_owned) } +fn parse_optional_aliases(root: &JsonValue) -> Result, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(BTreeMap::new()); + }; + Ok(optional_string_map(object, "aliases", "merged settings")?.unwrap_or_default()) +} + fn parse_optional_hooks_config(root: &JsonValue) -> Result { let Some(object) = root.as_object() else { return Ok(RuntimeHookConfig::default()); }; + parse_optional_hooks_config_object(object, "merged settings.hooks") +} + +fn parse_optional_hooks_config_object( + object: &BTreeMap, + context: &str, +) -> Result { let Some(hooks_value) = object.get("hooks") else { return Ok(RuntimeHookConfig::default()); }; - let hooks = expect_object(hooks_value, "merged settings.hooks")?; + let hooks = expect_object(hooks_value, context)?; Ok(RuntimeHookConfig { - pre_tool_use: optional_string_array(hooks, "PreToolUse", "merged settings.hooks")? + pre_tool_use: optional_string_array(hooks, "PreToolUse", context)?.unwrap_or_default(), + post_tool_use: optional_string_array(hooks, "PostToolUse", context)?.unwrap_or_default(), + post_tool_use_failure: optional_string_array(hooks, "PostToolUseFailure", context)? .unwrap_or_default(), - post_tool_use: optional_string_array(hooks, "PostToolUse", "merged settings.hooks")? + }) +} + +fn validate_optional_hooks_config( + root: &BTreeMap, + path: &Path, +) -> Result<(), ConfigError> { + parse_optional_hooks_config_object(root, &format!("{}: hooks", path.display())).map(|_| ()) +} + +fn parse_optional_permission_rules( + root: &JsonValue, +) -> Result { + let Some(object) = root.as_object() else { + return Ok(RuntimePermissionRuleConfig::default()); + }; + let Some(permissions) = object.get("permissions").and_then(JsonValue::as_object) else { + return Ok(RuntimePermissionRuleConfig::default()); + }; + + Ok(RuntimePermissionRuleConfig { + allow: optional_string_array(permissions, "allow", "merged settings.permissions")? + .unwrap_or_default(), + deny: optional_string_array(permissions, "deny", "merged settings.permissions")? + .unwrap_or_default(), + ask: optional_string_array(permissions, "ask", "merged settings.permissions")? .unwrap_or_default(), }) } @@ -599,6 +824,7 @@ fn parse_optional_plugin_config(root: &JsonValue) -> Result Result Result { + let Some(object) = root.as_object() else { + return Ok(ProviderFallbackConfig::default()); + }; + let Some(value) = object.get("providerFallbacks") else { + return Ok(ProviderFallbackConfig::default()); + }; + let entry = expect_object(value, "merged settings.providerFallbacks")?; + let primary = + optional_string(entry, "primary", "merged settings.providerFallbacks")?.map(str::to_string); + let fallbacks = optional_string_array(entry, "fallbacks", "merged settings.providerFallbacks")? + .unwrap_or_default(); + Ok(ProviderFallbackConfig { primary, fallbacks }) +} + +fn parse_optional_trusted_roots(root: &JsonValue) -> Result, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(Vec::new()); + }; + Ok( + optional_string_array(object, "trustedRoots", "merged settings.trustedRoots")? + .unwrap_or_default(), + ) +} + fn parse_filesystem_mode_label(value: &str) -> Result { match value { "off" => Ok(FilesystemIsolationMode::Off), @@ -703,12 +956,14 @@ fn parse_mcp_server_config( context: &str, ) -> Result { let object = expect_object(value, context)?; - let server_type = optional_string(object, "type", context)?.unwrap_or("stdio"); + let server_type = + optional_string(object, "type", context)?.unwrap_or_else(|| infer_mcp_server_type(object)); match server_type { "stdio" => Ok(McpServerConfig::Stdio(McpStdioServerConfig { command: expect_string(object, "command", context)?.to_string(), args: optional_string_array(object, "args", context)?.unwrap_or_default(), env: optional_string_map(object, "env", context)?.unwrap_or_default(), + tool_call_timeout_ms: optional_u64(object, "toolCallTimeoutMs", context)?, })), "sse" => Ok(McpServerConfig::Sse(parse_mcp_remote_server_config( object, context, @@ -734,6 +989,14 @@ fn parse_mcp_server_config( } } +fn infer_mcp_server_type(object: &BTreeMap) -> &'static str { + if object.contains_key("url") { + "http" + } else { + "stdio" + } +} + fn parse_mcp_remote_server_config( object: &BTreeMap, context: &str, @@ -832,6 +1095,48 @@ fn optional_u16( } } +fn optional_u32( + object: &BTreeMap, + key: &str, + context: &str, +) -> Result, ConfigError> { + match object.get(key) { + Some(value) => { + let Some(number) = value.as_i64() else { + return Err(ConfigError::Parse(format!( + "{context}: field {key} must be a non-negative integer" + ))); + }; + let number = u32::try_from(number).map_err(|_| { + ConfigError::Parse(format!("{context}: field {key} is out of range")) + })?; + Ok(Some(number)) + } + None => Ok(None), + } +} + +fn optional_u64( + object: &BTreeMap, + key: &str, + context: &str, +) -> Result, ConfigError> { + match object.get(key) { + Some(value) => { + let Some(number) = value.as_i64() else { + return Err(ConfigError::Parse(format!( + "{context}: field {key} must be a non-negative integer" + ))); + }; + let number = u64::try_from(number).map_err(|_| { + ConfigError::Parse(format!("{context}: field {key} is out of range")) + })?; + Ok(Some(number)) + } + None => Ok(None), + } +} + fn parse_bool_map(value: &JsonValue, context: &str) -> Result, ConfigError> { let Some(map) = value.as_object() else { return Err(ConfigError::Parse(format!( @@ -939,8 +1244,9 @@ fn push_unique(target: &mut Vec, value: String) { #[cfg(test)] mod tests { use super::{ - ConfigLoader, ConfigSource, McpServerConfig, McpTransport, ResolvedPermissionMode, - CLAW_SETTINGS_SCHEMA_NAME, + deep_merge_objects, parse_permission_mode_label, ConfigLoader, ConfigSource, + McpServerConfig, McpTransport, ResolvedPermissionMode, RuntimeHookConfig, + RuntimePluginConfig, CLAW_SETTINGS_SCHEMA_NAME, }; use crate::json::JsonValue; use crate::sandbox::FilesystemIsolationMode; @@ -971,11 +1277,13 @@ mod tests { .to_string() .contains("top-level settings value must be a JSON object")); - fs::remove_dir_all(root).expect("cleanup temp dir"); + if root.exists() { + fs::remove_dir_all(root).expect("cleanup temp dir"); + } } #[test] - fn loads_and_merges_claw_code_config_files_by_precedence() { + fn loads_and_merges_claude_code_config_files_by_precedence() { let root = temp_dir(); let cwd = root.join("project"); let home = root.join("home").join(".claw"); @@ -989,7 +1297,7 @@ mod tests { .expect("write user compat config"); fs::write( home.join("settings.json"), - r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan"}}"#, + r#"{"model":"sonnet","env":{"A2":"1"},"hooks":{"PreToolUse":["base"]},"permissions":{"defaultMode":"plan","allow":["Read"],"deny":["Bash(rm -rf)"]}}"#, ) .expect("write user settings"); fs::write( @@ -999,7 +1307,7 @@ mod tests { .expect("write project compat config"); fs::write( cwd.join(".claw").join("settings.json"), - r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#, + r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"],"PostToolUseFailure":["project-failure"]},"permissions":{"ask":["Edit"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#, ) .expect("write project settings"); fs::write( @@ -1044,6 +1352,16 @@ mod tests { .contains_key("PostToolUse")); assert_eq!(loaded.hooks().pre_tool_use(), &["base".to_string()]); assert_eq!(loaded.hooks().post_tool_use(), &["project".to_string()]); + assert_eq!( + loaded.hooks().post_tool_use_failure(), + &["project-failure".to_string()] + ); + assert_eq!(loaded.permission_rules().allow(), &["Read".to_string()]); + assert_eq!( + loaded.permission_rules().deny(), + &["Bash(rm -rf)".to_string()] + ); + assert_eq!(loaded.permission_rules().ask(), &["Edit".to_string()]); assert!(loaded.mcp().get("home").is_some()); assert!(loaded.mcp().get("project").is_some()); @@ -1088,6 +1406,113 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn parses_provider_fallbacks_chain_with_primary_and_ordered_fallbacks() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + fs::write( + home.join("settings.json"), + r#"{ + "providerFallbacks": { + "primary": "claude-opus-4-6", + "fallbacks": ["grok-3", "grok-3-mini"] + } + }"#, + ) + .expect("write provider fallback settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + // then + let chain = loaded.provider_fallbacks(); + assert_eq!(chain.primary(), Some("claude-opus-4-6")); + assert_eq!( + chain.fallbacks(), + &["grok-3".to_string(), "grok-3-mini".to_string()] + ); + assert!(!chain.is_empty()); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn provider_fallbacks_default_is_empty_when_unset() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write(home.join("settings.json"), "{}").expect("write empty settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + // then + let chain = loaded.provider_fallbacks(); + assert_eq!(chain.primary(), None); + assert!(chain.fallbacks().is_empty()); + assert!(chain.is_empty()); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn parses_trusted_roots_from_settings() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + home.join("settings.json"), + r#"{"trustedRoots": ["/tmp/worktrees", "/home/user/projects"]}"#, + ) + .expect("write settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + // then + let roots = loaded.trusted_roots(); + assert_eq!(roots, ["/tmp/worktrees", "/home/user/projects"]); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn trusted_roots_default_is_empty_when_unset() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write(home.join("settings.json"), "{}").expect("write empty settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + // then + assert!(loaded.trusted_roots().is_empty()); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn parses_typed_mcp_and_oauth_config() { let root = temp_dir(); @@ -1179,6 +1604,44 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn infers_http_mcp_servers_from_url_only_config() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + home.join("settings.json"), + r#"{ + "mcpServers": { + "remote": { + "url": "https://example.test/mcp" + } + } + }"#, + ) + .expect("write mcp settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + let remote_server = loaded + .mcp() + .get("remote") + .expect("remote server should exist"); + assert_eq!(remote_server.transport(), McpTransport::Http); + match &remote_server.config { + McpServerConfig::Http(config) => { + assert_eq!(config.url, "https://example.test/mcp"); + } + other => panic!("expected http config, got {other:?}"), + } + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn parses_plugin_config_from_enabled_plugins() { let root = temp_dir(); @@ -1271,6 +1734,7 @@ mod tests { #[test] fn rejects_invalid_mcp_server_shapes() { + // given let root = temp_dir(); let cwd = root.join("project"); let home = root.join("home").join(".claw"); @@ -1282,13 +1746,366 @@ mod tests { ) .expect("write broken settings"); + // when let error = ConfigLoader::new(&cwd, &home) .load() .expect_err("config should fail"); + + // then assert!(error .to_string() .contains("mcpServers.broken: missing string field url")); fs::remove_dir_all(root).expect("cleanup temp dir"); } + + #[test] + fn parses_user_defined_model_aliases_from_settings() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + + fs::write( + home.join("settings.json"), + r#"{"aliases":{"fast":"claude-haiku-4-5-20251213","smart":"claude-opus-4-6"}}"#, + ) + .expect("write user settings"); + fs::write( + cwd.join(".claw").join("settings.local.json"), + r#"{"aliases":{"smart":"claude-sonnet-4-6","cheap":"grok-3-mini"}}"#, + ) + .expect("write local settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + // then + let aliases = loaded.aliases(); + assert_eq!( + aliases.get("fast").map(String::as_str), + Some("claude-haiku-4-5-20251213") + ); + assert_eq!( + aliases.get("smart").map(String::as_str), + Some("claude-sonnet-4-6") + ); + assert_eq!( + aliases.get("cheap").map(String::as_str), + Some("grok-3-mini") + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn empty_settings_file_loads_defaults() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write(home.join("settings.json"), "").expect("write empty settings"); + + // when + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("empty settings should still load"); + + // then + assert_eq!(loaded.loaded_entries().len(), 1); + assert_eq!(loaded.permission_mode(), None); + assert_eq!(loaded.plugins().enabled_plugins().len(), 0); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn deep_merge_objects_merges_nested_maps() { + // given + let mut target = JsonValue::parse(r#"{"env":{"A":"1","B":"2"},"model":"haiku"}"#) + .expect("target JSON should parse") + .as_object() + .expect("target should be an object") + .clone(); + let source = + JsonValue::parse(r#"{"env":{"B":"override","C":"3"},"sandbox":{"enabled":true}}"#) + .expect("source JSON should parse") + .as_object() + .expect("source should be an object") + .clone(); + + // when + deep_merge_objects(&mut target, &source); + + // then + let env = target + .get("env") + .and_then(JsonValue::as_object) + .expect("env should remain an object"); + assert_eq!(env.get("A"), Some(&JsonValue::String("1".to_string()))); + assert_eq!( + env.get("B"), + Some(&JsonValue::String("override".to_string())) + ); + assert_eq!(env.get("C"), Some(&JsonValue::String("3".to_string()))); + assert!(target.contains_key("sandbox")); + } + + #[test] + fn rejects_invalid_hook_entries_before_merge() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + let project_settings = cwd.join(".claw").join("settings.json"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + + fs::write( + home.join("settings.json"), + r#"{"hooks":{"PreToolUse":["base"]}}"#, + ) + .expect("write user settings"); + fs::write( + &project_settings, + r#"{"hooks":{"PreToolUse":["project",42]}}"#, + ) + .expect("write invalid project settings"); + + // when + let error = ConfigLoader::new(&cwd, &home) + .load() + .expect_err("config should fail"); + + // then — config validation now catches the mixed array before the hooks parser + let rendered = error.to_string(); + assert!( + rendered.contains("hooks.PreToolUse") + && rendered.contains("must be an array of strings"), + "expected validation error for hooks.PreToolUse, got: {rendered}" + ); + assert!(!rendered.contains("merged settings.hooks")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn permission_mode_aliases_resolve_to_expected_modes() { + // given / when / then + assert_eq!( + parse_permission_mode_label("plan", "test").expect("plan should resolve"), + ResolvedPermissionMode::ReadOnly + ); + assert_eq!( + parse_permission_mode_label("acceptEdits", "test").expect("acceptEdits should resolve"), + ResolvedPermissionMode::WorkspaceWrite + ); + assert_eq!( + parse_permission_mode_label("dontAsk", "test").expect("dontAsk should resolve"), + ResolvedPermissionMode::DangerFullAccess + ); + } + + #[test] + fn hook_config_merge_preserves_uniques() { + // given + let base = RuntimeHookConfig::new( + vec!["pre-a".to_string()], + vec!["post-a".to_string()], + vec!["failure-a".to_string()], + ); + let overlay = RuntimeHookConfig::new( + vec!["pre-a".to_string(), "pre-b".to_string()], + vec!["post-a".to_string(), "post-b".to_string()], + vec!["failure-b".to_string()], + ); + + // when + let merged = base.merged(&overlay); + + // then + assert_eq!( + merged.pre_tool_use(), + &["pre-a".to_string(), "pre-b".to_string()] + ); + assert_eq!( + merged.post_tool_use(), + &["post-a".to_string(), "post-b".to_string()] + ); + assert_eq!( + merged.post_tool_use_failure(), + &["failure-a".to_string(), "failure-b".to_string()] + ); + } + + #[test] + fn plugin_state_falls_back_to_default_for_unknown_plugin() { + // given + let mut config = RuntimePluginConfig::default(); + config.set_plugin_state("known".to_string(), true); + + // when / then + assert!(config.state_for("known", false)); + assert!(config.state_for("missing", true)); + assert!(!config.state_for("missing", false)); + } + + #[test] + fn validates_unknown_top_level_keys_with_line_and_field_name() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + let user_settings = home.join("settings.json"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + &user_settings, + "{\n \"model\": \"opus\",\n \"telemetry\": true\n}\n", + ) + .expect("write user settings"); + + // when + let error = ConfigLoader::new(&cwd, &home) + .load() + .expect_err("config should fail"); + + // then + let rendered = error.to_string(); + assert!( + rendered.contains(&user_settings.display().to_string()), + "error should include file path, got: {rendered}" + ); + assert!( + rendered.contains("line 3"), + "error should include line number, got: {rendered}" + ); + assert!( + rendered.contains("telemetry"), + "error should name the offending field, got: {rendered}" + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn validates_deprecated_top_level_keys_with_replacement_guidance() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + let user_settings = home.join("settings.json"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + &user_settings, + "{\n \"model\": \"opus\",\n \"allowedTools\": [\"Read\"]\n}\n", + ) + .expect("write user settings"); + + // when + let error = ConfigLoader::new(&cwd, &home) + .load() + .expect_err("config should fail"); + + // then + let rendered = error.to_string(); + assert!( + rendered.contains(&user_settings.display().to_string()), + "error should include file path, got: {rendered}" + ); + assert!( + rendered.contains("line 3"), + "error should include line number, got: {rendered}" + ); + assert!( + rendered.contains("allowedTools"), + "error should call out the unknown field, got: {rendered}" + ); + // allowedTools is an unknown key; validator should name it in the error + assert!( + rendered.contains("allowedTools"), + "error should name the offending field, got: {rendered}" + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn validates_wrong_type_for_known_field_with_field_path() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + let user_settings = home.join("settings.json"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write( + &user_settings, + "{\n \"hooks\": {\n \"PreToolUse\": \"not-an-array\"\n }\n}\n", + ) + .expect("write user settings"); + + // when + let error = ConfigLoader::new(&cwd, &home) + .load() + .expect_err("config should fail"); + + // then + let rendered = error.to_string(); + assert!( + rendered.contains(&user_settings.display().to_string()), + "error should include file path, got: {rendered}" + ); + assert!( + rendered.contains("hooks"), + "error should include field path component 'hooks', got: {rendered}" + ); + assert!( + rendered.contains("PreToolUse"), + "error should describe the type mismatch, got: {rendered}" + ); + assert!( + rendered.contains("array"), + "error should describe the expected type, got: {rendered}" + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn unknown_top_level_key_suggests_closest_match() { + // given + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + let user_settings = home.join("settings.json"); + fs::create_dir_all(&home).expect("home config dir"); + fs::create_dir_all(&cwd).expect("project dir"); + fs::write(&user_settings, "{\n \"modle\": \"opus\"\n}\n").expect("write user settings"); + + // when + let error = ConfigLoader::new(&cwd, &home) + .load() + .expect_err("config should fail"); + + // then + let rendered = error.to_string(); + assert!( + rendered.contains("modle"), + "error should name the offending field, got: {rendered}" + ); + assert!( + rendered.contains("model"), + "error should suggest the closest known key, got: {rendered}" + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } } diff --git a/crates/runtime/src/config_validate.rs b/crates/runtime/src/config_validate.rs new file mode 100644 index 0000000..7a9c1c4 --- /dev/null +++ b/crates/runtime/src/config_validate.rs @@ -0,0 +1,901 @@ +use std::collections::BTreeMap; +use std::path::Path; + +use crate::config::ConfigError; +use crate::json::JsonValue; + +/// Diagnostic emitted when a config file contains a suspect field. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ConfigDiagnostic { + pub path: String, + pub field: String, + pub line: Option, + pub kind: DiagnosticKind, +} + +/// Classification of the diagnostic. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DiagnosticKind { + UnknownKey { + suggestion: Option, + }, + WrongType { + expected: &'static str, + got: &'static str, + }, + Deprecated { + replacement: &'static str, + }, +} + +impl std::fmt::Display for ConfigDiagnostic { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let location = self + .line + .map_or_else(String::new, |line| format!(" (line {line})")); + match &self.kind { + DiagnosticKind::UnknownKey { suggestion: None } => { + write!(f, "{}: unknown key \"{}\"{location}", self.path, self.field) + } + DiagnosticKind::UnknownKey { + suggestion: Some(hint), + } => { + write!( + f, + "{}: unknown key \"{}\"{location}. Did you mean \"{}\"?", + self.path, self.field, hint + ) + } + DiagnosticKind::WrongType { expected, got } => { + write!( + f, + "{}: field \"{}\" must be {expected}, got {got}{location}", + self.path, self.field + ) + } + DiagnosticKind::Deprecated { replacement } => { + write!( + f, + "{}: field \"{}\" is deprecated{location}. Use \"{replacement}\" instead", + self.path, self.field + ) + } + } + } +} + +/// Result of validating a single config file. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidationResult { + pub errors: Vec, + pub warnings: Vec, +} + +impl ValidationResult { + #[must_use] + pub fn is_ok(&self) -> bool { + self.errors.is_empty() + } + + fn merge(&mut self, other: Self) { + self.errors.extend(other.errors); + self.warnings.extend(other.warnings); + } +} + +// ---- known-key schema ---- + +/// Expected type for a config field. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FieldType { + String, + Bool, + Object, + StringArray, + Number, +} + +impl FieldType { + fn label(self) -> &'static str { + match self { + Self::String => "a string", + Self::Bool => "a boolean", + Self::Object => "an object", + Self::StringArray => "an array of strings", + Self::Number => "a number", + } + } + + fn matches(self, value: &JsonValue) -> bool { + match self { + Self::String => value.as_str().is_some(), + Self::Bool => value.as_bool().is_some(), + Self::Object => value.as_object().is_some(), + Self::StringArray => value + .as_array() + .is_some_and(|arr| arr.iter().all(|v| v.as_str().is_some())), + Self::Number => value.as_i64().is_some(), + } + } +} + +fn json_type_label(value: &JsonValue) -> &'static str { + match value { + JsonValue::Null => "null", + JsonValue::Bool(_) => "a boolean", + JsonValue::Number(_) => "a number", + JsonValue::String(_) => "a string", + JsonValue::Array(_) => "an array", + JsonValue::Object(_) => "an object", + } +} + +struct FieldSpec { + name: &'static str, + expected: FieldType, +} + +struct DeprecatedField { + name: &'static str, + replacement: &'static str, +} + +const TOP_LEVEL_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "$schema", + expected: FieldType::String, + }, + FieldSpec { + name: "model", + expected: FieldType::String, + }, + FieldSpec { + name: "hooks", + expected: FieldType::Object, + }, + FieldSpec { + name: "permissions", + expected: FieldType::Object, + }, + FieldSpec { + name: "permissionMode", + expected: FieldType::String, + }, + FieldSpec { + name: "mcpServers", + expected: FieldType::Object, + }, + FieldSpec { + name: "oauth", + expected: FieldType::Object, + }, + FieldSpec { + name: "enabledPlugins", + expected: FieldType::Object, + }, + FieldSpec { + name: "plugins", + expected: FieldType::Object, + }, + FieldSpec { + name: "sandbox", + expected: FieldType::Object, + }, + FieldSpec { + name: "env", + expected: FieldType::Object, + }, + FieldSpec { + name: "aliases", + expected: FieldType::Object, + }, + FieldSpec { + name: "providerFallbacks", + expected: FieldType::Object, + }, + FieldSpec { + name: "trustedRoots", + expected: FieldType::StringArray, + }, +]; + +const HOOKS_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "PreToolUse", + expected: FieldType::StringArray, + }, + FieldSpec { + name: "PostToolUse", + expected: FieldType::StringArray, + }, + FieldSpec { + name: "PostToolUseFailure", + expected: FieldType::StringArray, + }, +]; + +const PERMISSIONS_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "defaultMode", + expected: FieldType::String, + }, + FieldSpec { + name: "allow", + expected: FieldType::StringArray, + }, + FieldSpec { + name: "deny", + expected: FieldType::StringArray, + }, + FieldSpec { + name: "ask", + expected: FieldType::StringArray, + }, +]; + +const PLUGINS_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "enabled", + expected: FieldType::Object, + }, + FieldSpec { + name: "externalDirectories", + expected: FieldType::StringArray, + }, + FieldSpec { + name: "installRoot", + expected: FieldType::String, + }, + FieldSpec { + name: "registryPath", + expected: FieldType::String, + }, + FieldSpec { + name: "bundledRoot", + expected: FieldType::String, + }, + FieldSpec { + name: "maxOutputTokens", + expected: FieldType::Number, + }, +]; + +const SANDBOX_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "enabled", + expected: FieldType::Bool, + }, + FieldSpec { + name: "namespaceRestrictions", + expected: FieldType::Bool, + }, + FieldSpec { + name: "networkIsolation", + expected: FieldType::Bool, + }, + FieldSpec { + name: "filesystemMode", + expected: FieldType::String, + }, + FieldSpec { + name: "allowedMounts", + expected: FieldType::StringArray, + }, +]; + +const OAUTH_FIELDS: &[FieldSpec] = &[ + FieldSpec { + name: "clientId", + expected: FieldType::String, + }, + FieldSpec { + name: "authorizeUrl", + expected: FieldType::String, + }, + FieldSpec { + name: "tokenUrl", + expected: FieldType::String, + }, + FieldSpec { + name: "callbackPort", + expected: FieldType::Number, + }, + FieldSpec { + name: "manualRedirectUrl", + expected: FieldType::String, + }, + FieldSpec { + name: "scopes", + expected: FieldType::StringArray, + }, +]; + +const DEPRECATED_FIELDS: &[DeprecatedField] = &[ + DeprecatedField { + name: "permissionMode", + replacement: "permissions.defaultMode", + }, + DeprecatedField { + name: "enabledPlugins", + replacement: "plugins.enabled", + }, +]; + +// ---- line-number resolution ---- + +/// Find the 1-based line number where a JSON key first appears in the raw source. +fn find_key_line(source: &str, key: &str) -> Option { + // Search for `"key"` followed by optional whitespace and a colon. + let needle = format!("\"{key}\""); + let mut search_start = 0; + while let Some(offset) = source[search_start..].find(&needle) { + let absolute = search_start + offset; + let after = absolute + needle.len(); + // Verify the next non-whitespace char is `:` to confirm this is a key, not a value. + if source[after..].chars().find(|ch| !ch.is_ascii_whitespace()) == Some(':') { + return Some(source[..absolute].chars().filter(|&ch| ch == '\n').count() + 1); + } + search_start = after; + } + None +} + +// ---- core validation ---- + +fn validate_object_keys( + object: &BTreeMap, + known_fields: &[FieldSpec], + prefix: &str, + source: &str, + path_display: &str, +) -> ValidationResult { + let mut result = ValidationResult { + errors: Vec::new(), + warnings: Vec::new(), + }; + + let known_names: Vec<&str> = known_fields.iter().map(|f| f.name).collect(); + + for (key, value) in object { + let field_path = if prefix.is_empty() { + key.clone() + } else { + format!("{prefix}.{key}") + }; + + if let Some(spec) = known_fields.iter().find(|f| f.name == key) { + // Type check. + if !spec.expected.matches(value) { + result.errors.push(ConfigDiagnostic { + path: path_display.to_string(), + field: field_path, + line: find_key_line(source, key), + kind: DiagnosticKind::WrongType { + expected: spec.expected.label(), + got: json_type_label(value), + }, + }); + } + } else if DEPRECATED_FIELDS.iter().any(|d| d.name == key) { + // Deprecated key — handled separately, not an unknown-key error. + } else { + // Unknown key. + let suggestion = suggest_field(key, &known_names); + result.errors.push(ConfigDiagnostic { + path: path_display.to_string(), + field: field_path, + line: find_key_line(source, key), + kind: DiagnosticKind::UnknownKey { suggestion }, + }); + } + } + + result +} + +fn suggest_field(input: &str, candidates: &[&str]) -> Option { + let input_lower = input.to_ascii_lowercase(); + candidates + .iter() + .filter_map(|candidate| { + let distance = simple_edit_distance(&input_lower, &candidate.to_ascii_lowercase()); + (distance <= 3).then_some((distance, *candidate)) + }) + .min_by_key(|(distance, _)| *distance) + .map(|(_, name)| name.to_string()) +} + +fn simple_edit_distance(left: &str, right: &str) -> usize { + if left.is_empty() { + return right.len(); + } + if right.is_empty() { + return left.len(); + } + let right_chars: Vec = right.chars().collect(); + let mut previous: Vec = (0..=right_chars.len()).collect(); + let mut current = vec![0; right_chars.len() + 1]; + + for (left_index, left_char) in left.chars().enumerate() { + current[0] = left_index + 1; + for (right_index, right_char) in right_chars.iter().enumerate() { + let cost = usize::from(left_char != *right_char); + current[right_index + 1] = (previous[right_index + 1] + 1) + .min(current[right_index] + 1) + .min(previous[right_index] + cost); + } + previous.clone_from(¤t); + } + + previous[right_chars.len()] +} + +/// Validate a parsed config file's keys and types against the known schema. +/// +/// Returns diagnostics (errors and deprecation warnings) without blocking the load. +pub fn validate_config_file( + object: &BTreeMap, + source: &str, + file_path: &Path, +) -> ValidationResult { + let path_display = file_path.display().to_string(); + let mut result = validate_object_keys(object, TOP_LEVEL_FIELDS, "", source, &path_display); + + // Check deprecated fields. + for deprecated in DEPRECATED_FIELDS { + if object.contains_key(deprecated.name) { + result.warnings.push(ConfigDiagnostic { + path: path_display.clone(), + field: deprecated.name.to_string(), + line: find_key_line(source, deprecated.name), + kind: DiagnosticKind::Deprecated { + replacement: deprecated.replacement, + }, + }); + } + } + + // Validate known nested objects. + if let Some(hooks) = object.get("hooks").and_then(JsonValue::as_object) { + result.merge(validate_object_keys( + hooks, + HOOKS_FIELDS, + "hooks", + source, + &path_display, + )); + } + if let Some(permissions) = object.get("permissions").and_then(JsonValue::as_object) { + result.merge(validate_object_keys( + permissions, + PERMISSIONS_FIELDS, + "permissions", + source, + &path_display, + )); + } + if let Some(plugins) = object.get("plugins").and_then(JsonValue::as_object) { + result.merge(validate_object_keys( + plugins, + PLUGINS_FIELDS, + "plugins", + source, + &path_display, + )); + } + if let Some(sandbox) = object.get("sandbox").and_then(JsonValue::as_object) { + result.merge(validate_object_keys( + sandbox, + SANDBOX_FIELDS, + "sandbox", + source, + &path_display, + )); + } + if let Some(oauth) = object.get("oauth").and_then(JsonValue::as_object) { + result.merge(validate_object_keys( + oauth, + OAUTH_FIELDS, + "oauth", + source, + &path_display, + )); + } + + result +} + +/// Check whether a file path uses an unsupported config format (e.g. TOML). +pub fn check_unsupported_format(file_path: &Path) -> Result<(), ConfigError> { + if let Some(ext) = file_path.extension().and_then(|e| e.to_str()) { + if ext.eq_ignore_ascii_case("toml") { + return Err(ConfigError::Parse(format!( + "{}: TOML config files are not supported. Use JSON (settings.json) instead", + file_path.display() + ))); + } + } + Ok(()) +} + +/// Format all diagnostics into a human-readable report. +#[must_use] +pub fn format_diagnostics(result: &ValidationResult) -> String { + let mut lines = Vec::new(); + for warning in &result.warnings { + lines.push(format!("warning: {warning}")); + } + for error in &result.errors { + lines.push(format!("error: {error}")); + } + lines.join("\n") +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + fn test_path() -> PathBuf { + PathBuf::from("/test/settings.json") + } + + #[test] + fn detects_unknown_top_level_key() { + // given + let source = r#"{"model": "opus", "unknownField": true}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "unknownField"); + assert!(matches!( + result.errors[0].kind, + DiagnosticKind::UnknownKey { .. } + )); + } + + #[test] + fn detects_wrong_type_for_model() { + // given + let source = r#"{"model": 123}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "model"); + assert!(matches!( + result.errors[0].kind, + DiagnosticKind::WrongType { + expected: "a string", + got: "a number" + } + )); + } + + #[test] + fn detects_deprecated_permission_mode() { + // given + let source = r#"{"permissionMode": "plan"}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.warnings.len(), 1); + assert_eq!(result.warnings[0].field, "permissionMode"); + assert!(matches!( + result.warnings[0].kind, + DiagnosticKind::Deprecated { + replacement: "permissions.defaultMode" + } + )); + } + + #[test] + fn detects_deprecated_enabled_plugins() { + // given + let source = r#"{"enabledPlugins": {"tool-guard@builtin": true}}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.warnings.len(), 1); + assert_eq!(result.warnings[0].field, "enabledPlugins"); + assert!(matches!( + result.warnings[0].kind, + DiagnosticKind::Deprecated { + replacement: "plugins.enabled" + } + )); + } + + #[test] + fn reports_line_number_for_unknown_key() { + // given + let source = "{\n \"model\": \"opus\",\n \"badKey\": true\n}"; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].line, Some(3)); + assert_eq!(result.errors[0].field, "badKey"); + } + + #[test] + fn reports_line_number_for_wrong_type() { + // given + let source = "{\n \"model\": 42\n}"; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].line, Some(2)); + } + + #[test] + fn validates_nested_hooks_keys() { + // given + let source = r#"{"hooks": {"PreToolUse": ["cmd"], "BadHook": ["x"]}}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "hooks.BadHook"); + } + + #[test] + fn validates_nested_permissions_keys() { + // given + let source = r#"{"permissions": {"allow": ["Read"], "denyAll": true}}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "permissions.denyAll"); + } + + #[test] + fn validates_nested_sandbox_keys() { + // given + let source = r#"{"sandbox": {"enabled": true, "containerMode": "strict"}}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "sandbox.containerMode"); + } + + #[test] + fn validates_nested_plugins_keys() { + // given + let source = r#"{"plugins": {"installRoot": "/tmp", "autoUpdate": true}}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "plugins.autoUpdate"); + } + + #[test] + fn validates_nested_oauth_keys() { + // given + let source = r#"{"oauth": {"clientId": "abc", "secret": "hidden"}}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "oauth.secret"); + } + + #[test] + fn valid_config_produces_no_diagnostics() { + // given + let source = r#"{ + "model": "opus", + "hooks": {"PreToolUse": ["guard"]}, + "permissions": {"defaultMode": "plan", "allow": ["Read"]}, + "mcpServers": {}, + "sandbox": {"enabled": false} +}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert!(result.is_ok()); + assert!(result.warnings.is_empty()); + } + + #[test] + fn suggests_close_field_name() { + // given + let source = r#"{"modle": "opus"}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + match &result.errors[0].kind { + DiagnosticKind::UnknownKey { + suggestion: Some(s), + } => assert_eq!(s, "model"), + other => panic!("expected suggestion, got {other:?}"), + } + } + + #[test] + fn format_diagnostics_includes_all_entries() { + // given + let source = r#"{"permissionMode": "plan", "badKey": 1}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + let result = validate_config_file(object, source, &test_path()); + + // when + let output = format_diagnostics(&result); + + // then + assert!(output.contains("warning:")); + assert!(output.contains("error:")); + assert!(output.contains("badKey")); + assert!(output.contains("permissionMode")); + } + + #[test] + fn check_unsupported_format_rejects_toml() { + // given + let path = PathBuf::from("/home/.claw/settings.toml"); + + // when + let result = check_unsupported_format(&path); + + // then + assert!(result.is_err()); + let message = result.unwrap_err().to_string(); + assert!(message.contains("TOML")); + assert!(message.contains("settings.toml")); + } + + #[test] + fn check_unsupported_format_allows_json() { + // given + let path = PathBuf::from("/home/.claw/settings.json"); + + // when / then + assert!(check_unsupported_format(&path).is_ok()); + } + + #[test] + fn wrong_type_in_nested_sandbox_field() { + // given + let source = r#"{"sandbox": {"enabled": "yes"}}"#; + let parsed = JsonValue::parse(source).expect("valid json"); + let object = parsed.as_object().expect("object"); + + // when + let result = validate_config_file(object, source, &test_path()); + + // then + assert_eq!(result.errors.len(), 1); + assert_eq!(result.errors[0].field, "sandbox.enabled"); + assert!(matches!( + result.errors[0].kind, + DiagnosticKind::WrongType { + expected: "a boolean", + got: "a string" + } + )); + } + + #[test] + fn display_format_unknown_key_with_line() { + // given + let diag = ConfigDiagnostic { + path: "/test/settings.json".to_string(), + field: "badKey".to_string(), + line: Some(5), + kind: DiagnosticKind::UnknownKey { suggestion: None }, + }; + + // when + let output = diag.to_string(); + + // then + assert_eq!( + output, + r#"/test/settings.json: unknown key "badKey" (line 5)"# + ); + } + + #[test] + fn display_format_wrong_type_with_line() { + // given + let diag = ConfigDiagnostic { + path: "/test/settings.json".to_string(), + field: "model".to_string(), + line: Some(2), + kind: DiagnosticKind::WrongType { + expected: "a string", + got: "a number", + }, + }; + + // when + let output = diag.to_string(); + + // then + assert_eq!( + output, + r#"/test/settings.json: field "model" must be a string, got a number (line 2)"# + ); + } + + #[test] + fn display_format_deprecated_with_line() { + // given + let diag = ConfigDiagnostic { + path: "/test/settings.json".to_string(), + field: "permissionMode".to_string(), + line: Some(3), + kind: DiagnosticKind::Deprecated { + replacement: "permissions.defaultMode", + }, + }; + + // when + let output = diag.to_string(); + + // then + assert_eq!( + output, + r#"/test/settings.json: field "permissionMode" is deprecated (line 3). Use "permissions.defaultMode" instead"# + ); + } +} diff --git a/crates/runtime/src/conversation.rs b/crates/runtime/src/conversation.rs index 0a5435b..1dde559 100644 --- a/crates/runtime/src/conversation.rs +++ b/crates/runtime/src/conversation.rs @@ -1,42 +1,65 @@ use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; +use serde_json::{Map, Value}; +use telemetry::SessionTracer; + use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; use crate::config::RuntimeFeatureConfig; -use crate::hooks::{HookRunResult, HookRunner}; -use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter}; +use crate::hooks::{HookAbortSignal, HookProgressReporter, HookRunResult, HookRunner}; +use crate::permissions::{ + PermissionContext, PermissionOutcome, PermissionPolicy, PermissionPrompter, +}; use crate::session::{ContentBlock, ConversationMessage, Session}; use crate::usage::{TokenUsage, UsageTracker}; +const DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD: u32 = 100_000; +const AUTO_COMPACTION_THRESHOLD_ENV_VAR: &str = "CLAUDE_CODE_AUTO_COMPACT_INPUT_TOKENS"; + +/// Fully assembled request payload sent to the upstream model client. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ApiRequest { pub system_prompt: Vec, pub messages: Vec, } +/// Streamed events emitted while processing a single assistant turn. #[derive(Debug, Clone, PartialEq, Eq)] pub enum AssistantEvent { TextDelta(String), - ThinkingDelta(String), ToolUse { id: String, name: String, input: String, }, Usage(TokenUsage), + PromptCache(PromptCacheEvent), MessageStop, } +/// Prompt-cache telemetry captured from the provider response stream. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + +/// Minimal streaming API contract required by [`ConversationRuntime`]. pub trait ApiClient { fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError>; } +/// Trait implemented by tool dispatchers that execute model-requested tools. pub trait ToolExecutor { fn execute(&mut self, tool_name: &str, input: &str) -> Result; } +/// Error returned when a tool invocation fails locally. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ToolError { message: String, @@ -59,6 +82,7 @@ impl Display for ToolError { impl std::error::Error for ToolError {} +/// Error returned when a conversation turn cannot be completed. #[derive(Debug, Clone, PartialEq, Eq)] pub struct RuntimeError { message: String, @@ -81,14 +105,24 @@ impl Display for RuntimeError { impl std::error::Error for RuntimeError {} +/// Summary of one completed runtime turn, including tool results and usage. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TurnSummary { pub assistant_messages: Vec, pub tool_results: Vec, + pub prompt_cache_events: Vec, pub iterations: usize, pub usage: TokenUsage, + pub auto_compaction: Option, } +/// Details about automatic session compaction applied during a turn. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct AutoCompactionEvent { + pub removed_message_count: usize, +} + +/// Coordinates the model loop, tool execution, hooks, and session updates. pub struct ConversationRuntime { session: Session, api_client: C, @@ -98,6 +132,10 @@ pub struct ConversationRuntime { max_iterations: usize, usage_tracker: UsageTracker, hook_runner: HookRunner, + auto_compaction_input_tokens_threshold: u32, + hook_abort_signal: HookAbortSignal, + hook_progress_reporter: Option>, + session_tracer: Option, } impl ConversationRuntime @@ -119,19 +157,19 @@ where tool_executor, permission_policy, system_prompt, - RuntimeFeatureConfig::default(), + &RuntimeFeatureConfig::default(), ) } - #[allow(clippy::needless_pass_by_value)] #[must_use] + #[allow(clippy::needless_pass_by_value)] pub fn new_with_features( session: Session, api_client: C, tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec, - feature_config: RuntimeFeatureConfig, + feature_config: &RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -142,7 +180,11 @@ where system_prompt, max_iterations: usize::MAX, usage_tracker, - hook_runner: HookRunner::from_feature_config(&feature_config), + hook_runner: HookRunner::from_feature_config(feature_config), + auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(), + hook_abort_signal: HookAbortSignal::default(), + hook_progress_reporter: None, + session_tracer: None, } } @@ -152,36 +194,154 @@ where self } + #[must_use] + pub fn with_auto_compaction_input_tokens_threshold(mut self, threshold: u32) -> Self { + self.auto_compaction_input_tokens_threshold = threshold; + self + } + + #[must_use] + pub fn with_hook_abort_signal(mut self, hook_abort_signal: HookAbortSignal) -> Self { + self.hook_abort_signal = hook_abort_signal; + self + } + + #[must_use] + pub fn with_hook_progress_reporter( + mut self, + hook_progress_reporter: Box, + ) -> Self { + self.hook_progress_reporter = Some(hook_progress_reporter); + self + } + + #[must_use] + pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self { + self.session_tracer = Some(session_tracer); + self + } + + fn run_pre_tool_use_hook(&mut self, tool_name: &str, input: &str) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_pre_tool_use_with_context( + tool_name, + input, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_pre_tool_use_with_context( + tool_name, + input, + Some(&self.hook_abort_signal), + None, + ) + } + } + + fn run_post_tool_use_hook( + &mut self, + tool_name: &str, + input: &str, + output: &str, + is_error: bool, + ) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_post_tool_use_with_context( + tool_name, + input, + output, + is_error, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_post_tool_use_with_context( + tool_name, + input, + output, + is_error, + Some(&self.hook_abort_signal), + None, + ) + } + } + + fn run_post_tool_use_failure_hook( + &mut self, + tool_name: &str, + input: &str, + output: &str, + ) -> HookRunResult { + if let Some(reporter) = self.hook_progress_reporter.as_mut() { + self.hook_runner.run_post_tool_use_failure_with_context( + tool_name, + input, + output, + Some(&self.hook_abort_signal), + Some(reporter.as_mut()), + ) + } else { + self.hook_runner.run_post_tool_use_failure_with_context( + tool_name, + input, + output, + Some(&self.hook_abort_signal), + None, + ) + } + } + + #[allow(clippy::too_many_lines)] pub fn run_turn( &mut self, user_input: impl Into, mut prompter: Option<&mut dyn PermissionPrompter>, ) -> Result { + let user_input = user_input.into(); + self.record_turn_started(&user_input); self.session - .messages - .push(ConversationMessage::user_text(user_input.into())); + .push_user_text(user_input) + .map_err(|error| RuntimeError::new(error.to_string()))?; let mut assistant_messages = Vec::new(); let mut tool_results = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut iterations = 0; loop { iterations += 1; if iterations > self.max_iterations { - return Err(RuntimeError::new( + let error = RuntimeError::new( "conversation loop exceeded the maximum number of iterations", - )); + ); + self.record_turn_failed(iterations, &error); + return Err(error); } let request = ApiRequest { system_prompt: self.system_prompt.clone(), messages: self.session.messages.clone(), }; - let events = self.api_client.stream(request)?; - let (assistant_message, usage) = build_assistant_message(events)?; + let events = match self.api_client.stream(request) { + Ok(events) => events, + Err(error) => { + self.record_turn_failed(iterations, &error); + return Err(error); + } + }; + let (assistant_message, usage, turn_prompt_cache_events) = + match build_assistant_message(events) { + Ok(result) => result, + Err(error) => { + self.record_turn_failed(iterations, &error); + return Err(error); + } + }; if let Some(usage) = usage { self.usage_tracker.record(usage); } + prompt_cache_events.extend(turn_prompt_cache_events); let pending_tool_uses = assistant_message .blocks .iter() @@ -192,8 +352,15 @@ where _ => None, }) .collect::>(); + self.record_assistant_iteration( + iterations, + &assistant_message, + pending_tool_uses.len(), + ); - self.session.messages.push(assistant_message.clone()); + self.session + .push_message(assistant_message.clone()) + .map_err(|error| RuntimeError::new(error.to_string()))?; assistant_messages.push(assistant_message); if pending_tool_uses.is_empty() { @@ -201,67 +368,120 @@ where } for (tool_use_id, tool_name, input) in pending_tool_uses { - let permission_outcome = if let Some(prompt) = prompter.as_mut() { - self.permission_policy - .authorize(&tool_name, &input, Some(*prompt)) + let pre_hook_result = self.run_pre_tool_use_hook(&tool_name, &input); + let effective_input = pre_hook_result + .updated_input() + .map_or_else(|| input.clone(), ToOwned::to_owned); + let permission_context = PermissionContext::new( + pre_hook_result.permission_override(), + pre_hook_result.permission_reason().map(ToOwned::to_owned), + ); + + let permission_outcome = if pre_hook_result.is_cancelled() { + PermissionOutcome::Deny { + reason: format_hook_message( + &pre_hook_result, + &format!("PreToolUse hook cancelled tool `{tool_name}`"), + ), + } + } else if pre_hook_result.is_failed() { + PermissionOutcome::Deny { + reason: format_hook_message( + &pre_hook_result, + &format!("PreToolUse hook failed for tool `{tool_name}`"), + ), + } + } else if pre_hook_result.is_denied() { + PermissionOutcome::Deny { + reason: format_hook_message( + &pre_hook_result, + &format!("PreToolUse hook denied tool `{tool_name}`"), + ), + } + } else if let Some(prompt) = prompter.as_mut() { + self.permission_policy.authorize_with_context( + &tool_name, + &effective_input, + &permission_context, + Some(*prompt), + ) } else { - self.permission_policy.authorize(&tool_name, &input, None) + self.permission_policy.authorize_with_context( + &tool_name, + &effective_input, + &permission_context, + None, + ) }; let result_message = match permission_outcome { PermissionOutcome::Allow => { - let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input); - if pre_hook_result.is_denied() { - let deny_message = format!("PreToolUse hook denied tool `{tool_name}`"); - ConversationMessage::tool_result( - tool_use_id, - tool_name, - format_hook_message(&pre_hook_result, &deny_message), - true, + self.record_tool_started(iterations, &tool_name); + let (mut output, mut is_error) = + match self.tool_executor.execute(&tool_name, &effective_input) { + Ok(output) => (output, false), + Err(error) => (error.to_string(), true), + }; + output = merge_hook_feedback(pre_hook_result.messages(), output, false); + + let post_hook_result = if is_error { + self.run_post_tool_use_failure_hook( + &tool_name, + &effective_input, + &output, ) } else { - let (mut output, mut is_error) = - match self.tool_executor.execute(&tool_name, &input) { - Ok(output) => (output, false), - Err(error) => (error.to_string(), true), - }; - output = merge_hook_feedback(pre_hook_result.messages(), output, false); - - let post_hook_result = self - .hook_runner - .run_post_tool_use(&tool_name, &input, &output, is_error); - if post_hook_result.is_denied() { - is_error = true; - } - output = merge_hook_feedback( - post_hook_result.messages(), - output, - post_hook_result.is_denied(), - ); - - ConversationMessage::tool_result( - tool_use_id, - tool_name, - output, - is_error, + self.run_post_tool_use_hook( + &tool_name, + &effective_input, + &output, + false, ) + }; + if post_hook_result.is_denied() + || post_hook_result.is_failed() + || post_hook_result.is_cancelled() + { + is_error = true; } + output = merge_hook_feedback( + post_hook_result.messages(), + output, + post_hook_result.is_denied() + || post_hook_result.is_failed() + || post_hook_result.is_cancelled(), + ); + + ConversationMessage::tool_result(tool_use_id, tool_name, output, is_error) } - PermissionOutcome::Deny { reason } => { - ConversationMessage::tool_result(tool_use_id, tool_name, reason, true) - } + PermissionOutcome::Deny { reason } => ConversationMessage::tool_result( + tool_use_id, + tool_name, + merge_hook_feedback(pre_hook_result.messages(), reason, true), + true, + ), }; - self.session.messages.push(result_message.clone()); + self.session + .push_message(result_message.clone()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + self.record_tool_finished(iterations, &result_message); tool_results.push(result_message); } } - Ok(TurnSummary { + let auto_compaction = self.maybe_auto_compact(); + + let summary = TurnSummary { assistant_messages, tool_results, + prompt_cache_events, iterations, usage: self.usage_tracker.cumulative_usage(), - }) + auto_compaction, + }; + self.record_turn_completed(&summary); + + Ok(summary) } #[must_use] @@ -284,38 +504,200 @@ where &self.session } + pub fn api_client_mut(&mut self) -> &mut C { + &mut self.api_client + } + + pub fn session_mut(&mut self) -> &mut Session { + &mut self.session + } + + #[must_use] + pub fn fork_session(&self, branch_name: Option) -> Session { + self.session.fork(branch_name) + } + #[must_use] pub fn into_session(self) -> Session { self.session } + + fn maybe_auto_compact(&mut self) -> Option { + if self.usage_tracker.cumulative_usage().input_tokens + < self.auto_compaction_input_tokens_threshold + { + return None; + } + + let result = compact_session( + &self.session, + CompactionConfig { + max_estimated_tokens: 0, + ..CompactionConfig::default() + }, + ); + + if result.removed_message_count == 0 { + return None; + } + + self.session = result.compacted_session; + Some(AutoCompactionEvent { + removed_message_count: result.removed_message_count, + }) + } + + fn record_turn_started(&self, user_input: &str) { + let Some(session_tracer) = &self.session_tracer else { + return; + }; + + let mut attributes = Map::new(); + attributes.insert( + "user_input".to_string(), + Value::String(user_input.to_string()), + ); + session_tracer.record("turn_started", attributes); + } + + fn record_assistant_iteration( + &self, + iteration: usize, + assistant_message: &ConversationMessage, + pending_tool_use_count: usize, + ) { + let Some(session_tracer) = &self.session_tracer else { + return; + }; + + let mut attributes = Map::new(); + attributes.insert("iteration".to_string(), Value::from(iteration as u64)); + attributes.insert( + "assistant_blocks".to_string(), + Value::from(assistant_message.blocks.len() as u64), + ); + attributes.insert( + "pending_tool_use_count".to_string(), + Value::from(pending_tool_use_count as u64), + ); + session_tracer.record("assistant_iteration_completed", attributes); + } + + fn record_tool_started(&self, iteration: usize, tool_name: &str) { + let Some(session_tracer) = &self.session_tracer else { + return; + }; + + let mut attributes = Map::new(); + attributes.insert("iteration".to_string(), Value::from(iteration as u64)); + attributes.insert( + "tool_name".to_string(), + Value::String(tool_name.to_string()), + ); + session_tracer.record("tool_execution_started", attributes); + } + + fn record_tool_finished(&self, iteration: usize, result_message: &ConversationMessage) { + let Some(session_tracer) = &self.session_tracer else { + return; + }; + + let Some(ContentBlock::ToolResult { + tool_name, + is_error, + .. + }) = result_message.blocks.first() + else { + return; + }; + + let mut attributes = Map::new(); + attributes.insert("iteration".to_string(), Value::from(iteration as u64)); + attributes.insert("tool_name".to_string(), Value::String(tool_name.clone())); + attributes.insert("is_error".to_string(), Value::Bool(*is_error)); + session_tracer.record("tool_execution_finished", attributes); + } + + fn record_turn_completed(&self, summary: &TurnSummary) { + let Some(session_tracer) = &self.session_tracer else { + return; + }; + + let mut attributes = Map::new(); + attributes.insert( + "iterations".to_string(), + Value::from(summary.iterations as u64), + ); + attributes.insert( + "assistant_messages".to_string(), + Value::from(summary.assistant_messages.len() as u64), + ); + attributes.insert( + "tool_results".to_string(), + Value::from(summary.tool_results.len() as u64), + ); + attributes.insert( + "prompt_cache_events".to_string(), + Value::from(summary.prompt_cache_events.len() as u64), + ); + session_tracer.record("turn_completed", attributes); + } + + fn record_turn_failed(&self, iteration: usize, error: &RuntimeError) { + let Some(session_tracer) = &self.session_tracer else { + return; + }; + + let mut attributes = Map::new(); + attributes.insert("iteration".to_string(), Value::from(iteration as u64)); + attributes.insert("error".to_string(), Value::String(error.to_string())); + session_tracer.record("turn_failed", attributes); + } +} + +/// Reads the automatic compaction threshold from the environment. +#[must_use] +pub fn auto_compaction_threshold_from_env() -> u32 { + parse_auto_compaction_threshold( + std::env::var(AUTO_COMPACTION_THRESHOLD_ENV_VAR) + .ok() + .as_deref(), + ) +} + +#[must_use] +fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 { + value + .and_then(|raw| raw.trim().parse::().ok()) + .filter(|threshold| *threshold > 0) + .unwrap_or(DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD) } fn build_assistant_message( events: Vec, -) -> Result<(ConversationMessage, Option), RuntimeError> { +) -> Result< + ( + ConversationMessage, + Option, + Vec, + ), + RuntimeError, +> { let mut text = String::new(); let mut blocks = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut finished = false; let mut usage = None; for event in events { match event { AssistantEvent::TextDelta(delta) => text.push_str(&delta), - AssistantEvent::ThinkingDelta(delta) => { - if let Some(ContentBlock::Thinking { thinking, .. }) = blocks.last_mut() { - thinking.push_str(&delta); - } else { - blocks.push(ContentBlock::Thinking { - thinking: delta, - signature: None, - }); - } - } AssistantEvent::ToolUse { id, name, input } => { flush_text_block(&mut text, &mut blocks); blocks.push(ContentBlock::ToolUse { id, name, input }); } AssistantEvent::Usage(value) => usage = Some(value), + AssistantEvent::PromptCache(event) => prompt_cache_events.push(event), AssistantEvent::MessageStop => { finished = true; } @@ -336,6 +718,7 @@ fn build_assistant_message( Ok(( ConversationMessage::assistant_with_usage(blocks, usage), usage, + prompt_cache_events, )) } @@ -355,7 +738,7 @@ fn format_hook_message(result: &HookRunResult, fallback: &str) -> String { } } -fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String { +fn merge_hook_feedback(messages: &[String], output: String, is_error: bool) -> String { if messages.is_empty() { return output; } @@ -364,8 +747,8 @@ fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> Str if !output.trim().is_empty() { sections.push(output); } - let label = if denied { - "Hook feedback (denied)" + let label = if is_error { + "Hook feedback (error)" } else { "Hook feedback" }; @@ -375,6 +758,7 @@ fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> Str type ToolHandler = Box Result>; +/// Simple in-memory tool executor for tests and lightweight integrations. #[derive(Default)] pub struct StaticToolExecutor { handlers: BTreeMap, @@ -408,8 +792,9 @@ impl ToolExecutor for StaticToolExecutor { #[cfg(test)] mod tests { use super::{ - ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, - StaticToolExecutor, + build_assistant_message, parse_auto_compaction_threshold, ApiClient, ApiRequest, + AssistantEvent, AutoCompactionEvent, ConversationRuntime, PromptCacheEvent, RuntimeError, + StaticToolExecutor, ToolExecutor, DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD, }; use crate::compact::CompactionConfig; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; @@ -420,7 +805,12 @@ mod tests { use crate::prompt::{ProjectContext, SystemPromptBuilder}; use crate::session::{ContentBlock, MessageRole, Session}; use crate::usage::TokenUsage; + use crate::ToolError; + use std::fs; use std::path::PathBuf; + use std::sync::Arc; + use std::time::{SystemTime, UNIX_EPOCH}; + use telemetry::{MemoryTelemetrySink, SessionTracer, TelemetryEvent}; struct ScriptedApiClient { call_count: usize, @@ -465,10 +855,19 @@ mod tests { cache_creation_input_tokens: 1, cache_read_input_tokens: 3, }), + AssistantEvent::PromptCache(PromptCacheEvent { + unexpected: true, + reason: + "cache read tokens dropped while prompt fingerprint remained stable" + .to_string(), + previous_cache_read_input_tokens: 6_000, + current_cache_read_input_tokens: 1_000, + token_drop: 5_000, + }), AssistantEvent::MessageStop, ]) } - _ => Err(RuntimeError::new("unexpected extra API call")), + _ => unreachable!("extra API call"), } } } @@ -499,6 +898,7 @@ mod tests { current_date: "2026-03-31".to_string(), git_status: None, git_diff: None, + git_context: None, instruction_files: Vec::new(), }) .with_os("linux", "6.8") @@ -518,8 +918,10 @@ mod tests { assert_eq!(summary.iterations, 2); assert_eq!(summary.assistant_messages.len(), 2); assert_eq!(summary.tool_results.len(), 1); + assert_eq!(summary.prompt_cache_events.len(), 1); assert_eq!(runtime.session().messages.len(), 4); assert_eq!(summary.usage.output_tokens, 10); + assert_eq!(summary.auto_compaction, None); assert!(matches!( runtime.session().messages[1].blocks[1], ContentBlock::ToolUse { .. } @@ -533,6 +935,39 @@ mod tests { )); } + #[test] + fn records_runtime_session_trace_events() { + let sink = Arc::new(MemoryTelemetrySink::default()); + let tracer = SessionTracer::new("session-runtime", sink.clone()); + let mut runtime = ConversationRuntime::new( + Session::new(), + ScriptedApiClient { call_count: 0 }, + StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), + PermissionPolicy::new(PermissionMode::WorkspaceWrite), + vec!["system".to_string()], + ) + .with_session_tracer(tracer); + + runtime + .run_turn("what is 2 + 2?", Some(&mut PromptAllowOnce)) + .expect("conversation loop should succeed"); + + let events = sink.events(); + let trace_names = events + .iter() + .filter_map(|event| match event { + TelemetryEvent::SessionTrace(trace) => Some(trace.name.as_str()), + _ => None, + }) + .collect::>(); + + assert!(trace_names.contains(&"turn_started")); + assert!(trace_names.contains(&"assistant_iteration_completed")); + assert!(trace_names.contains(&"tool_execution_started")); + assert!(trace_names.contains(&"tool_execution_finished")); + assert!(trace_names.contains(&"turn_completed")); + } + #[test] fn records_denied_tool_results_when_prompt_rejects() { struct RejectPrompter; @@ -621,9 +1056,10 @@ mod tests { }), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), + Vec::new(), )), ); @@ -648,6 +1084,71 @@ mod tests { ); } + #[test] + fn denies_tool_use_when_pre_tool_hook_fails() { + struct SingleCallApiClient; + impl ApiClient for SingleCallApiClient { + fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { + if request + .messages + .iter() + .any(|message| message.role == MessageRole::Tool) + { + return Ok(vec![ + AssistantEvent::TextDelta("failed".to_string()), + AssistantEvent::MessageStop, + ]); + } + Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "blocked".to_string(), + input: r#"{"path":"secret.txt"}"#.to_string(), + }, + AssistantEvent::MessageStop, + ]) + } + } + + // given + let mut runtime = ConversationRuntime::new_with_features( + Session::new(), + SingleCallApiClient, + StaticToolExecutor::new().register("blocked", |_input| { + panic!("tool should not execute when hook fails") + }), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + vec![shell_snippet("printf 'broken hook'; exit 1")], + Vec::new(), + Vec::new(), + )), + ); + + // when + let summary = runtime + .run_turn("use the tool", None) + .expect("conversation should continue after hook failure"); + + // then + assert_eq!(summary.tool_results.len(), 1); + let ContentBlock::ToolResult { + is_error, output, .. + } = &summary.tool_results[0].blocks[0] + else { + panic!("expected tool result block"); + }; + assert!( + *is_error, + "hook failure should produce an error result: {output}" + ); + assert!( + output.contains("exited with status 1") || output.contains("broken hook"), + "unexpected hook failure output: {output:?}" + ); + } + #[test] fn appends_post_tool_hook_feedback_to_tool_result() { struct TwoCallApiClient { @@ -676,7 +1177,7 @@ mod tests { AssistantEvent::MessageStop, ]) } - _ => Err(RuntimeError::new("unexpected extra API call")), + _ => unreachable!("extra API call"), } } } @@ -687,9 +1188,10 @@ mod tests { StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], + Vec::new(), )), ); @@ -722,6 +1224,85 @@ mod tests { ); } + #[test] + fn appends_post_tool_use_failure_hook_feedback_to_tool_result() { + struct TwoCallApiClient { + calls: usize, + } + + impl ApiClient for TwoCallApiClient { + fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { + self.calls += 1; + match self.calls { + 1 => Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "fail".to_string(), + input: r#"{"path":"README.md"}"#.to_string(), + }, + AssistantEvent::MessageStop, + ]), + 2 => { + assert!(request + .messages + .iter() + .any(|message| message.role == MessageRole::Tool)); + Ok(vec![ + AssistantEvent::TextDelta("done".to_string()), + AssistantEvent::MessageStop, + ]) + } + _ => unreachable!("extra API call"), + } + } + } + + // given + let mut runtime = ConversationRuntime::new_with_features( + Session::new(), + TwoCallApiClient { calls: 0 }, + StaticToolExecutor::new() + .register("fail", |_input| Err(ToolError::new("tool exploded"))), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + Vec::new(), + vec![shell_snippet("printf 'post hook should not run'")], + vec![shell_snippet("printf 'failure hook ran'")], + )), + ); + + // when + let summary = runtime + .run_turn("use fail", None) + .expect("tool loop succeeds"); + + // then + assert_eq!(summary.tool_results.len(), 1); + let ContentBlock::ToolResult { + is_error, output, .. + } = &summary.tool_results[0].blocks[0] + else { + panic!("expected tool result block"); + }; + assert!( + *is_error, + "failure hook path should preserve error result: {output:?}" + ); + assert!( + output.contains("tool exploded"), + "tool output missing failure reason: {output:?}" + ); + assert!( + output.contains("failure hook ran"), + "tool output missing failure hook feedback: {output:?}" + ); + assert!( + !output.contains("post hook should not run"), + "normal post hook should not run on tool failure: {output:?}" + ); + } + #[test] fn reconstructs_usage_tracker_from_restored_session() { struct SimpleApi; @@ -799,6 +1380,86 @@ mod tests { result.compacted_session.messages[0].role, MessageRole::System ); + assert_eq!( + result.compacted_session.session_id, + runtime.session().session_id + ); + assert!(result.compacted_session.compaction.is_some()); + } + + #[test] + fn persists_conversation_turn_messages_to_jsonl_session() { + struct SimpleApi; + impl ApiClient for SimpleApi { + fn stream( + &mut self, + _request: ApiRequest, + ) -> Result, RuntimeError> { + Ok(vec![ + AssistantEvent::TextDelta("done".to_string()), + AssistantEvent::MessageStop, + ]) + } + } + + let path = temp_session_path("persisted-turn"); + let session = Session::new().with_persistence_path(path.clone()); + let mut runtime = ConversationRuntime::new( + session, + SimpleApi, + StaticToolExecutor::new(), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + ); + + runtime + .run_turn("persist this turn", None) + .expect("turn should succeed"); + + let restored = Session::load_from_path(&path).expect("persisted session should reload"); + fs::remove_file(&path).expect("temp session file should be removable"); + + assert_eq!(restored.messages.len(), 2); + assert_eq!(restored.messages[0].role, MessageRole::User); + assert_eq!(restored.messages[1].role, MessageRole::Assistant); + assert_eq!(restored.session_id, runtime.session().session_id); + } + + #[test] + fn forks_runtime_session_without_mutating_original() { + let mut session = Session::new(); + session + .push_user_text("branch me") + .expect("message should append"); + + let runtime = ConversationRuntime::new( + session.clone(), + ScriptedApiClient { call_count: 0 }, + StaticToolExecutor::new(), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + ); + + let forked = runtime.fork_session(Some("alt-path".to_string())); + + assert_eq!(forked.messages, session.messages); + assert_ne!(forked.session_id, session.session_id); + assert_eq!( + forked + .fork + .as_ref() + .map(|fork| (fork.parent_session_id.as_str(), fork.branch_name.as_deref())), + Some((session.session_id.as_str(), Some("alt-path"))) + ); + assert!(runtime.session().fork.is_none()); + } + + fn temp_session_path(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-conversation-{label}-{nanos}.json")) } #[cfg(windows)] @@ -810,4 +1471,229 @@ mod tests { fn shell_snippet(script: &str) -> String { script.to_string() } + + #[test] + fn auto_compacts_when_cumulative_input_threshold_is_crossed() { + struct SimpleApi; + impl ApiClient for SimpleApi { + fn stream( + &mut self, + _request: ApiRequest, + ) -> Result, RuntimeError> { + Ok(vec![ + AssistantEvent::TextDelta("done".to_string()), + AssistantEvent::Usage(TokenUsage { + input_tokens: 120_000, + output_tokens: 4, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }), + AssistantEvent::MessageStop, + ]) + } + } + + let mut session = Session::new(); + session.messages = vec![ + crate::session::ConversationMessage::user_text("one"), + crate::session::ConversationMessage::assistant(vec![ContentBlock::Text { + text: "two".to_string(), + }]), + crate::session::ConversationMessage::user_text("three"), + crate::session::ConversationMessage::assistant(vec![ContentBlock::Text { + text: "four".to_string(), + }]), + ]; + + let mut runtime = ConversationRuntime::new( + session, + SimpleApi, + StaticToolExecutor::new(), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + ) + .with_auto_compaction_input_tokens_threshold(100_000); + + let summary = runtime + .run_turn("trigger", None) + .expect("turn should succeed"); + + assert_eq!( + summary.auto_compaction, + Some(AutoCompactionEvent { + removed_message_count: 2, + }) + ); + assert_eq!(runtime.session().messages[0].role, MessageRole::System); + } + + #[test] + fn skips_auto_compaction_below_threshold() { + struct SimpleApi; + impl ApiClient for SimpleApi { + fn stream( + &mut self, + _request: ApiRequest, + ) -> Result, RuntimeError> { + Ok(vec![ + AssistantEvent::TextDelta("done".to_string()), + AssistantEvent::Usage(TokenUsage { + input_tokens: 99_999, + output_tokens: 4, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }), + AssistantEvent::MessageStop, + ]) + } + } + + let mut runtime = ConversationRuntime::new( + Session::new(), + SimpleApi, + StaticToolExecutor::new(), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + ) + .with_auto_compaction_input_tokens_threshold(100_000); + + let summary = runtime + .run_turn("trigger", None) + .expect("turn should succeed"); + assert_eq!(summary.auto_compaction, None); + assert_eq!(runtime.session().messages.len(), 2); + } + + #[test] + fn auto_compaction_threshold_defaults_and_parses_values() { + assert_eq!( + parse_auto_compaction_threshold(None), + DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD + ); + assert_eq!(parse_auto_compaction_threshold(Some("4321")), 4321); + assert_eq!( + parse_auto_compaction_threshold(Some("0")), + DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD + ); + assert_eq!( + parse_auto_compaction_threshold(Some("not-a-number")), + DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD + ); + } + + #[test] + fn build_assistant_message_requires_message_stop_event() { + // given + let events = vec![AssistantEvent::TextDelta("hello".to_string())]; + + // when + let error = build_assistant_message(events) + .expect_err("assistant messages should require a stop event"); + + // then + assert!(error + .to_string() + .contains("assistant stream ended without a message stop event")); + } + + #[test] + fn build_assistant_message_requires_content() { + // given + let events = vec![AssistantEvent::MessageStop]; + + // when + let error = + build_assistant_message(events).expect_err("assistant messages should require content"); + + // then + assert!(error + .to_string() + .contains("assistant stream produced no content")); + } + + #[test] + fn static_tool_executor_rejects_unknown_tools() { + // given + let mut executor = StaticToolExecutor::new(); + + // when + let error = executor + .execute("missing", "{}") + .expect_err("unregistered tools should fail"); + + // then + assert_eq!(error.to_string(), "unknown tool: missing"); + } + + #[test] + fn run_turn_errors_when_max_iterations_is_exceeded() { + struct LoopingApi; + + impl ApiClient for LoopingApi { + fn stream( + &mut self, + _request: ApiRequest, + ) -> Result, RuntimeError> { + Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "echo".to_string(), + input: "payload".to_string(), + }, + AssistantEvent::MessageStop, + ]) + } + } + + // given + let mut runtime = ConversationRuntime::new( + Session::new(), + LoopingApi, + StaticToolExecutor::new().register("echo", |input| Ok(input.to_string())), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + ) + .with_max_iterations(1); + + // when + let error = runtime + .run_turn("loop", None) + .expect_err("conversation loop should stop after the configured limit"); + + // then + assert!(error + .to_string() + .contains("conversation loop exceeded the maximum number of iterations")); + } + + #[test] + fn run_turn_propagates_api_errors() { + struct FailingApi; + + impl ApiClient for FailingApi { + fn stream( + &mut self, + _request: ApiRequest, + ) -> Result, RuntimeError> { + Err(RuntimeError::new("upstream failed")) + } + } + + // given + let mut runtime = ConversationRuntime::new( + Session::new(), + FailingApi, + StaticToolExecutor::new(), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + ); + + // when + let error = runtime + .run_turn("hello", None) + .expect_err("API failures should propagate"); + + // then + assert_eq!(error.to_string(), "upstream failed"); + } } diff --git a/crates/runtime/src/file_ops.rs b/crates/runtime/src/file_ops.rs index 1faf9ab..a9db1db 100644 --- a/crates/runtime/src/file_ops.rs +++ b/crates/runtime/src/file_ops.rs @@ -9,6 +9,41 @@ use regex::RegexBuilder; use serde::{Deserialize, Serialize}; use walkdir::WalkDir; +/// Maximum file size that can be read (10 MB). +const MAX_READ_SIZE: u64 = 10 * 1024 * 1024; + +/// Maximum file size that can be written (10 MB). +const MAX_WRITE_SIZE: usize = 10 * 1024 * 1024; + +/// Check whether a file appears to contain binary content by examining +/// the first chunk for NUL bytes. +fn is_binary_file(path: &Path) -> io::Result { + use std::io::Read; + let mut file = fs::File::open(path)?; + let mut buffer = [0u8; 8192]; + let bytes_read = file.read(&mut buffer)?; + Ok(buffer[..bytes_read].contains(&0)) +} + +/// Validate that a resolved path stays within the given workspace root. +/// Returns the canonical path on success, or an error if the path escapes +/// the workspace boundary (e.g. via `../` traversal or symlink). +#[allow(dead_code)] +fn validate_workspace_boundary(resolved: &Path, workspace_root: &Path) -> io::Result<()> { + if !resolved.starts_with(workspace_root) { + return Err(io::Error::new( + io::ErrorKind::PermissionDenied, + format!( + "path {} escapes workspace boundary {}", + resolved.display(), + workspace_root.display() + ), + )); + } + Ok(()) +} + +/// Text payload returned by file-reading operations. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct TextFilePayload { #[serde(rename = "filePath")] @@ -22,6 +57,7 @@ pub struct TextFilePayload { pub total_lines: usize, } +/// Output envelope for the `read_file` tool. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct ReadFileOutput { #[serde(rename = "type")] @@ -29,6 +65,7 @@ pub struct ReadFileOutput { pub file: TextFilePayload, } +/// Structured patch hunk emitted by write and edit operations. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct StructuredPatchHunk { #[serde(rename = "oldStart")] @@ -42,6 +79,7 @@ pub struct StructuredPatchHunk { pub lines: Vec, } +/// Output envelope for full-file write operations. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct WriteFileOutput { #[serde(rename = "type")] @@ -57,6 +95,7 @@ pub struct WriteFileOutput { pub git_diff: Option, } +/// Output envelope for targeted string-replacement edits. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct EditFileOutput { #[serde(rename = "filePath")] @@ -77,6 +116,7 @@ pub struct EditFileOutput { pub git_diff: Option, } +/// Result of a glob-based filename search. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct GlobSearchOutput { #[serde(rename = "durationMs")] @@ -87,6 +127,7 @@ pub struct GlobSearchOutput { pub truncated: bool, } +/// Parameters accepted by the grep-style search tool. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct GrepSearchInput { pub pattern: String, @@ -112,6 +153,7 @@ pub struct GrepSearchInput { pub multiline: Option, } +/// Result payload returned by the grep-style search tool. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct GrepSearchOutput { pub mode: Option, @@ -129,12 +171,35 @@ pub struct GrepSearchOutput { pub applied_offset: Option, } +/// Reads a text file and returns a line-windowed payload. pub fn read_file( path: &str, offset: Option, limit: Option, ) -> io::Result { let absolute_path = normalize_path(path)?; + + // Check file size before reading + let metadata = fs::metadata(&absolute_path)?; + if metadata.len() > MAX_READ_SIZE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "file is too large ({} bytes, max {} bytes)", + metadata.len(), + MAX_READ_SIZE + ), + )); + } + + // Detect binary files + if is_binary_file(&absolute_path)? { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "file appears to be binary", + )); + } + let content = fs::read_to_string(&absolute_path)?; let lines: Vec<&str> = content.lines().collect(); let start_index = offset.unwrap_or(0).min(lines.len()); @@ -155,7 +220,19 @@ pub fn read_file( }) } +/// Replaces a file's contents and returns patch metadata. pub fn write_file(path: &str, content: &str) -> io::Result { + if content.len() > MAX_WRITE_SIZE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "content is too large ({} bytes, max {} bytes)", + content.len(), + MAX_WRITE_SIZE + ), + )); + } + let absolute_path = normalize_path_allow_missing(path)?; let original_file = fs::read_to_string(&absolute_path).ok(); if let Some(parent) = absolute_path.parent() { @@ -177,6 +254,7 @@ pub fn write_file(path: &str, content: &str) -> io::Result { }) } +/// Performs an in-file string replacement and returns patch metadata. pub fn edit_file( path: &str, old_string: &str, @@ -217,6 +295,7 @@ pub fn edit_file( }) } +/// Expands a glob pattern and returns matching filenames. pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result { let started = Instant::now(); let base_dir = path @@ -229,12 +308,20 @@ pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result) -> io::Result io::Result { let base_path = input .path @@ -477,18 +565,105 @@ fn normalize_path_allow_missing(path: &str) -> io::Result { Ok(candidate) } +/// Read a file with workspace boundary enforcement. +#[allow(dead_code)] +pub fn read_file_in_workspace( + path: &str, + offset: Option, + limit: Option, + workspace_root: &Path, +) -> io::Result { + let absolute_path = normalize_path(path)?; + let canonical_root = workspace_root + .canonicalize() + .unwrap_or_else(|_| workspace_root.to_path_buf()); + validate_workspace_boundary(&absolute_path, &canonical_root)?; + read_file(path, offset, limit) +} + +/// Write a file with workspace boundary enforcement. +#[allow(dead_code)] +pub fn write_file_in_workspace( + path: &str, + content: &str, + workspace_root: &Path, +) -> io::Result { + let absolute_path = normalize_path_allow_missing(path)?; + let canonical_root = workspace_root + .canonicalize() + .unwrap_or_else(|_| workspace_root.to_path_buf()); + validate_workspace_boundary(&absolute_path, &canonical_root)?; + write_file(path, content) +} + +/// Edit a file with workspace boundary enforcement. +#[allow(dead_code)] +pub fn edit_file_in_workspace( + path: &str, + old_string: &str, + new_string: &str, + replace_all: bool, + workspace_root: &Path, +) -> io::Result { + let absolute_path = normalize_path(path)?; + let canonical_root = workspace_root + .canonicalize() + .unwrap_or_else(|_| workspace_root.to_path_buf()); + validate_workspace_boundary(&absolute_path, &canonical_root)?; + edit_file(path, old_string, new_string, replace_all) +} + +/// Check whether a path is a symlink that resolves outside the workspace. +#[allow(dead_code)] +pub fn is_symlink_escape(path: &Path, workspace_root: &Path) -> io::Result { + let metadata = fs::symlink_metadata(path)?; + if !metadata.is_symlink() { + return Ok(false); + } + let resolved = path.canonicalize()?; + let canonical_root = workspace_root + .canonicalize() + .unwrap_or_else(|_| workspace_root.to_path_buf()); + Ok(!resolved.starts_with(&canonical_root)) +} + +/// Expand shell-style brace groups in a glob pattern. +/// +/// Handles one level of braces: `foo.{a,b,c}` → `["foo.a", "foo.b", "foo.c"]`. +/// Nested braces are not expanded (uncommon in practice). +/// Patterns without braces pass through unchanged. +fn expand_braces(pattern: &str) -> Vec { + let Some(open) = pattern.find('{') else { + return vec![pattern.to_owned()]; + }; + let Some(close) = pattern[open..].find('}').map(|i| open + i) else { + // Unmatched brace — treat as literal. + return vec![pattern.to_owned()]; + }; + let prefix = &pattern[..open]; + let suffix = &pattern[close + 1..]; + let alternatives = &pattern[open + 1..close]; + alternatives + .split(',') + .flat_map(|alt| expand_braces(&format!("{prefix}{alt}{suffix}"))) + .collect() +} + #[cfg(test)] mod tests { use std::time::{SystemTime, UNIX_EPOCH}; - use super::{edit_file, glob_search, grep_search, read_file, write_file, GrepSearchInput}; + use super::{ + edit_file, expand_braces, glob_search, grep_search, is_symlink_escape, read_file, + read_file_in_workspace, write_file, GrepSearchInput, MAX_WRITE_SIZE, + }; fn temp_path(name: &str) -> std::path::PathBuf { let unique = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time should move forward") .as_nanos(); - std::env::temp_dir().join(format!("claw-native-{name}-{unique}")) + std::env::temp_dir().join(format!("clawd-native-{name}-{unique}")) } #[test] @@ -513,6 +688,73 @@ mod tests { assert!(output.replace_all); } + #[test] + fn rejects_binary_files() { + let path = temp_path("binary-test.bin"); + std::fs::write(&path, b"\x00\x01\x02\x03binary content").expect("write should succeed"); + let result = read_file(path.to_string_lossy().as_ref(), None, None); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::InvalidData); + assert!(error.to_string().contains("binary")); + } + + #[test] + fn rejects_oversized_writes() { + let path = temp_path("oversize-write.txt"); + let huge = "x".repeat(MAX_WRITE_SIZE + 1); + let result = write_file(path.to_string_lossy().as_ref(), &huge); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::InvalidData); + assert!(error.to_string().contains("too large")); + } + + #[test] + fn enforces_workspace_boundary() { + let workspace = temp_path("workspace-boundary"); + std::fs::create_dir_all(&workspace).expect("workspace dir should be created"); + let inside = workspace.join("inside.txt"); + write_file(inside.to_string_lossy().as_ref(), "safe content") + .expect("write inside workspace should succeed"); + + // Reading inside workspace should succeed + let result = + read_file_in_workspace(inside.to_string_lossy().as_ref(), None, None, &workspace); + assert!(result.is_ok()); + + // Reading outside workspace should fail + let outside = temp_path("outside-boundary.txt"); + write_file(outside.to_string_lossy().as_ref(), "unsafe content") + .expect("write outside should succeed"); + let result = + read_file_in_workspace(outside.to_string_lossy().as_ref(), None, None, &workspace); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::PermissionDenied); + assert!(error.to_string().contains("escapes workspace")); + } + + #[test] + fn detects_symlink_escape() { + let workspace = temp_path("symlink-workspace"); + std::fs::create_dir_all(&workspace).expect("workspace dir should be created"); + let outside = temp_path("symlink-target.txt"); + std::fs::write(&outside, "target content").expect("target should write"); + + let _link_path = workspace.join("escape-link.txt"); + #[cfg(unix)] + { + std::os::unix::fs::symlink(&outside, &link_path).expect("symlink should create"); + assert!(is_symlink_escape(&link_path, &workspace).expect("check should succeed")); + } + + // Non-symlink file should not be an escape + let normal = workspace.join("normal.txt"); + std::fs::write(&normal, "normal content").expect("normal file should write"); + assert!(!is_symlink_escape(&normal, &workspace).expect("check should succeed")); + } + #[test] fn globs_and_greps_directory() { let dir = temp_path("search-dir"); @@ -547,4 +789,51 @@ mod tests { .expect("grep should succeed"); assert!(grep_output.content.unwrap_or_default().contains("hello")); } + + #[test] + fn expand_braces_no_braces() { + assert_eq!(expand_braces("*.rs"), vec!["*.rs"]); + } + + #[test] + fn expand_braces_single_group() { + let mut result = expand_braces("Assets/**/*.{cs,uxml,uss}"); + result.sort(); + assert_eq!( + result, + vec!["Assets/**/*.cs", "Assets/**/*.uss", "Assets/**/*.uxml",] + ); + } + + #[test] + fn expand_braces_nested() { + let mut result = expand_braces("src/{a,b}.{rs,toml}"); + result.sort(); + assert_eq!( + result, + vec!["src/a.rs", "src/a.toml", "src/b.rs", "src/b.toml"] + ); + } + + #[test] + fn expand_braces_unmatched() { + assert_eq!(expand_braces("foo.{bar"), vec!["foo.{bar"]); + } + + #[test] + fn glob_search_with_braces_finds_files() { + let dir = temp_path("glob-braces"); + std::fs::create_dir_all(&dir).unwrap(); + std::fs::write(dir.join("a.rs"), "fn main() {}").unwrap(); + std::fs::write(dir.join("b.toml"), "[package]").unwrap(); + std::fs::write(dir.join("c.txt"), "hello").unwrap(); + + let result = + glob_search("*.{rs,toml}", Some(dir.to_str().unwrap())).expect("glob should succeed"); + assert_eq!( + result.num_files, 2, + "should match .rs and .toml but not .txt" + ); + let _ = std::fs::remove_dir_all(&dir); + } } diff --git a/crates/runtime/src/git_context.rs b/crates/runtime/src/git_context.rs new file mode 100644 index 0000000..5703ebe --- /dev/null +++ b/crates/runtime/src/git_context.rs @@ -0,0 +1,324 @@ +use std::path::Path; +use std::process::Command; + +/// A single git commit entry from the log. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GitCommitEntry { + pub hash: String, + pub subject: String, +} + +/// Git-aware context gathered at startup for injection into the system prompt. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GitContext { + pub branch: Option, + pub recent_commits: Vec, + pub staged_files: Vec, +} + +const MAX_RECENT_COMMITS: usize = 5; + +impl GitContext { + /// Detect the git context from the given working directory. + /// + /// Returns `None` when the directory is not inside a git repository. + #[must_use] + pub fn detect(cwd: &Path) -> Option { + // Quick gate: is this a git repo at all? + let rev_parse = Command::new("git") + .args(["rev-parse", "--is-inside-work-tree"]) + .current_dir(cwd) + .output() + .ok()?; + if !rev_parse.status.success() { + return None; + } + + Some(Self { + branch: read_branch(cwd), + recent_commits: read_recent_commits(cwd), + staged_files: read_staged_files(cwd), + }) + } + + /// Render a human-readable summary suitable for system-prompt injection. + #[must_use] + pub fn render(&self) -> String { + let mut lines = Vec::new(); + + if let Some(branch) = &self.branch { + lines.push(format!("Git branch: {branch}")); + } + + if !self.recent_commits.is_empty() { + lines.push(String::new()); + lines.push("Recent commits:".to_string()); + for entry in &self.recent_commits { + lines.push(format!(" {} {}", entry.hash, entry.subject)); + } + } + + if !self.staged_files.is_empty() { + lines.push(String::new()); + lines.push("Staged files:".to_string()); + for file in &self.staged_files { + lines.push(format!(" {file}")); + } + } + + lines.join("\n") + } +} + +fn read_branch(cwd: &Path) -> Option { + let output = Command::new("git") + .args(["rev-parse", "--abbrev-ref", "HEAD"]) + .current_dir(cwd) + .output() + .ok()?; + if !output.status.success() { + return None; + } + let branch = String::from_utf8(output.stdout).ok()?; + let trimmed = branch.trim(); + if trimmed.is_empty() || trimmed == "HEAD" { + None + } else { + Some(trimmed.to_string()) + } +} + +fn read_recent_commits(cwd: &Path) -> Vec { + let output = Command::new("git") + .args([ + "--no-optional-locks", + "log", + "--oneline", + "-n", + &MAX_RECENT_COMMITS.to_string(), + "--no-decorate", + ]) + .current_dir(cwd) + .output() + .ok(); + let Some(output) = output else { + return Vec::new(); + }; + if !output.status.success() { + return Vec::new(); + } + let stdout = String::from_utf8(output.stdout).unwrap_or_default(); + stdout + .lines() + .filter_map(|line| { + let line = line.trim(); + if line.is_empty() { + return None; + } + let (hash, subject) = line.split_once(' ')?; + Some(GitCommitEntry { + hash: hash.to_string(), + subject: subject.to_string(), + }) + }) + .collect() +} + +fn read_staged_files(cwd: &Path) -> Vec { + let output = Command::new("git") + .args(["--no-optional-locks", "diff", "--cached", "--name-only"]) + .current_dir(cwd) + .output() + .ok(); + let Some(output) = output else { + return Vec::new(); + }; + if !output.status.success() { + return Vec::new(); + } + let stdout = String::from_utf8(output.stdout).unwrap_or_default(); + stdout + .lines() + .filter(|line| !line.trim().is_empty()) + .map(|line| line.trim().to_string()) + .collect() +} + +#[cfg(test)] +mod tests { + use super::{GitCommitEntry, GitContext}; + use std::fs; + use std::process::Command; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir(label: &str) -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-git-context-{label}-{nanos}")) + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + crate::test_env_lock() + } + + fn ensure_valid_cwd() { + if std::env::current_dir().is_err() { + std::env::set_current_dir(env!("CARGO_MANIFEST_DIR")) + .expect("test cwd should be recoverable"); + } + } + + #[test] + fn returns_none_for_non_git_directory() { + // given + let _guard = env_lock(); + ensure_valid_cwd(); + let root = temp_dir("non-git"); + fs::create_dir_all(&root).expect("create dir"); + + // when + let context = GitContext::detect(&root); + + // then + assert!(context.is_none()); + fs::remove_dir_all(root).expect("cleanup"); + } + + #[test] + fn detects_branch_name_and_commits() { + // given + let _guard = env_lock(); + ensure_valid_cwd(); + let root = temp_dir("branch-commits"); + fs::create_dir_all(&root).expect("create dir"); + git(&root, &["init", "--quiet", "--initial-branch=main"]); + git(&root, &["config", "user.email", "tests@example.com"]); + git(&root, &["config", "user.name", "Git Context Tests"]); + fs::write(root.join("a.txt"), "a\n").expect("write a"); + git(&root, &["add", "a.txt"]); + git(&root, &["commit", "-m", "first commit", "--quiet"]); + fs::write(root.join("b.txt"), "b\n").expect("write b"); + git(&root, &["add", "b.txt"]); + git(&root, &["commit", "-m", "second commit", "--quiet"]); + + // when + let context = GitContext::detect(&root).expect("should detect git repo"); + + // then + assert_eq!(context.branch.as_deref(), Some("main")); + assert_eq!(context.recent_commits.len(), 2); + assert_eq!(context.recent_commits[0].subject, "second commit"); + assert_eq!(context.recent_commits[1].subject, "first commit"); + assert!(context.staged_files.is_empty()); + fs::remove_dir_all(root).expect("cleanup"); + } + + #[test] + fn detects_staged_files() { + // given + let _guard = env_lock(); + ensure_valid_cwd(); + let root = temp_dir("staged"); + fs::create_dir_all(&root).expect("create dir"); + git(&root, &["init", "--quiet", "--initial-branch=main"]); + git(&root, &["config", "user.email", "tests@example.com"]); + git(&root, &["config", "user.name", "Git Context Tests"]); + fs::write(root.join("init.txt"), "init\n").expect("write init"); + git(&root, &["add", "init.txt"]); + git(&root, &["commit", "-m", "initial", "--quiet"]); + fs::write(root.join("staged.txt"), "staged\n").expect("write staged"); + git(&root, &["add", "staged.txt"]); + + // when + let context = GitContext::detect(&root).expect("should detect git repo"); + + // then + assert_eq!(context.staged_files, vec!["staged.txt"]); + fs::remove_dir_all(root).expect("cleanup"); + } + + #[test] + fn render_formats_all_sections() { + // given + let context = GitContext { + branch: Some("feat/test".to_string()), + recent_commits: vec![ + GitCommitEntry { + hash: "abc1234".to_string(), + subject: "add feature".to_string(), + }, + GitCommitEntry { + hash: "def5678".to_string(), + subject: "fix bug".to_string(), + }, + ], + staged_files: vec!["src/main.rs".to_string()], + }; + + // when + let rendered = context.render(); + + // then + assert!(rendered.contains("Git branch: feat/test")); + assert!(rendered.contains("abc1234 add feature")); + assert!(rendered.contains("def5678 fix bug")); + assert!(rendered.contains("src/main.rs")); + } + + #[test] + fn render_omits_empty_sections() { + // given + let context = GitContext { + branch: Some("main".to_string()), + recent_commits: Vec::new(), + staged_files: Vec::new(), + }; + + // when + let rendered = context.render(); + + // then + assert!(rendered.contains("Git branch: main")); + assert!(!rendered.contains("Recent commits:")); + assert!(!rendered.contains("Staged files:")); + } + + #[test] + fn limits_to_five_recent_commits() { + // given + let _guard = env_lock(); + ensure_valid_cwd(); + let root = temp_dir("five-commits"); + fs::create_dir_all(&root).expect("create dir"); + git(&root, &["init", "--quiet", "--initial-branch=main"]); + git(&root, &["config", "user.email", "tests@example.com"]); + git(&root, &["config", "user.name", "Git Context Tests"]); + for i in 1..=8 { + let name = format!("file{i}.txt"); + fs::write(root.join(&name), format!("{i}\n")).expect("write file"); + git(&root, &["add", &name]); + git(&root, &["commit", "-m", &format!("commit {i}"), "--quiet"]); + } + + // when + let context = GitContext::detect(&root).expect("should detect git repo"); + + // then + assert_eq!(context.recent_commits.len(), 5); + assert_eq!(context.recent_commits[0].subject, "commit 8"); + assert_eq!(context.recent_commits[4].subject, "commit 4"); + fs::remove_dir_all(root).expect("cleanup"); + } + + fn git(cwd: &std::path::Path, args: &[&str]) { + let status = Command::new("git") + .args(args) + .current_dir(cwd) + .output() + .unwrap_or_else(|_| panic!("git {args:?} should run")) + .status; + assert!(status.success(), "git {args:?} failed"); + } +} diff --git a/crates/runtime/src/green_contract.rs b/crates/runtime/src/green_contract.rs new file mode 100644 index 0000000..d65ce91 --- /dev/null +++ b/crates/runtime/src/green_contract.rs @@ -0,0 +1,152 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum GreenLevel { + TargetedTests, + Package, + Workspace, + MergeReady, +} + +impl GreenLevel { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::TargetedTests => "targeted_tests", + Self::Package => "package", + Self::Workspace => "workspace", + Self::MergeReady => "merge_ready", + } + } +} + +impl std::fmt::Display for GreenLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub struct GreenContract { + pub required_level: GreenLevel, +} + +impl GreenContract { + #[must_use] + pub fn new(required_level: GreenLevel) -> Self { + Self { required_level } + } + + #[must_use] + pub fn evaluate(self, observed_level: Option) -> GreenContractOutcome { + match observed_level { + Some(level) if level >= self.required_level => GreenContractOutcome::Satisfied { + required_level: self.required_level, + observed_level: level, + }, + _ => GreenContractOutcome::Unsatisfied { + required_level: self.required_level, + observed_level, + }, + } + } + + #[must_use] + pub fn is_satisfied_by(self, observed_level: GreenLevel) -> bool { + observed_level >= self.required_level + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "outcome", rename_all = "snake_case")] +pub enum GreenContractOutcome { + Satisfied { + required_level: GreenLevel, + observed_level: GreenLevel, + }, + Unsatisfied { + required_level: GreenLevel, + observed_level: Option, + }, +} + +impl GreenContractOutcome { + #[must_use] + pub fn is_satisfied(&self) -> bool { + matches!(self, Self::Satisfied { .. }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn given_matching_level_when_evaluating_contract_then_it_is_satisfied() { + // given + let contract = GreenContract::new(GreenLevel::Package); + + // when + let outcome = contract.evaluate(Some(GreenLevel::Package)); + + // then + assert_eq!( + outcome, + GreenContractOutcome::Satisfied { + required_level: GreenLevel::Package, + observed_level: GreenLevel::Package, + } + ); + assert!(outcome.is_satisfied()); + } + + #[test] + fn given_higher_level_when_checking_requirement_then_it_still_satisfies_contract() { + // given + let contract = GreenContract::new(GreenLevel::TargetedTests); + + // when + let is_satisfied = contract.is_satisfied_by(GreenLevel::Workspace); + + // then + assert!(is_satisfied); + } + + #[test] + fn given_lower_level_when_evaluating_contract_then_it_is_unsatisfied() { + // given + let contract = GreenContract::new(GreenLevel::Workspace); + + // when + let outcome = contract.evaluate(Some(GreenLevel::Package)); + + // then + assert_eq!( + outcome, + GreenContractOutcome::Unsatisfied { + required_level: GreenLevel::Workspace, + observed_level: Some(GreenLevel::Package), + } + ); + assert!(!outcome.is_satisfied()); + } + + #[test] + fn given_no_green_level_when_evaluating_contract_then_contract_is_unsatisfied() { + // given + let contract = GreenContract::new(GreenLevel::MergeReady); + + // when + let outcome = contract.evaluate(None); + + // then + assert_eq!( + outcome, + GreenContractOutcome::Unsatisfied { + required_level: GreenLevel::MergeReady, + observed_level: None, + } + ); + } +} diff --git a/crates/runtime/src/hooks.rs b/crates/runtime/src/hooks.rs index eaa7f85..94e4727 100644 --- a/crates/runtime/src/hooks.rs +++ b/crates/runtime/src/hooks.rs @@ -1,29 +1,91 @@ use std::ffi::OsStr; -use std::process::Command; +use std::io::Write; +use std::process::{Command, Stdio}; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; +use std::thread; +use std::time::Duration; -use serde_json::json; +use serde_json::{json, Value}; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; +use crate::permissions::PermissionOverride; + +pub type HookPermissionDecision = PermissionOverride; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum HookEvent { PreToolUse, PostToolUse, + PostToolUseFailure, } impl HookEvent { - fn as_str(self) -> &'static str { + #[must_use] + pub fn as_str(self) -> &'static str { match self { Self::PreToolUse => "PreToolUse", Self::PostToolUse => "PostToolUse", + Self::PostToolUseFailure => "PostToolUseFailure", } } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HookProgressEvent { + Started { + event: HookEvent, + tool_name: String, + command: String, + }, + Completed { + event: HookEvent, + tool_name: String, + command: String, + }, + Cancelled { + event: HookEvent, + tool_name: String, + command: String, + }, +} + +pub trait HookProgressReporter: Send { + fn on_event(&mut self, event: &HookProgressEvent); +} + +#[derive(Debug, Clone, Default)] +pub struct HookAbortSignal { + aborted: Arc, +} + +impl HookAbortSignal { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn abort(&self) { + self.aborted.store(true, Ordering::SeqCst); + } + + #[must_use] + pub fn is_aborted(&self) -> bool { + self.aborted.load(Ordering::SeqCst) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct HookRunResult { denied: bool, + failed: bool, + cancelled: bool, messages: Vec, + permission_override: Option, + permission_reason: Option, + updated_input: Option, } impl HookRunResult { @@ -31,7 +93,12 @@ impl HookRunResult { pub fn allow(messages: Vec) -> Self { Self { denied: false, + failed: false, + cancelled: false, messages, + permission_override: None, + permission_reason: None, + updated_input: None, } } @@ -40,10 +107,45 @@ impl HookRunResult { self.denied } + #[must_use] + pub fn is_failed(&self) -> bool { + self.failed + } + + #[must_use] + pub fn is_cancelled(&self) -> bool { + self.cancelled + } + #[must_use] pub fn messages(&self) -> &[String] { &self.messages } + + #[must_use] + pub fn permission_override(&self) -> Option { + self.permission_override + } + + #[must_use] + pub fn permission_decision(&self) -> Option { + self.permission_override + } + + #[must_use] + pub fn permission_reason(&self) -> Option<&str> { + self.permission_reason.as_deref() + } + + #[must_use] + pub fn updated_input(&self) -> Option<&str> { + self.updated_input.as_deref() + } + + #[must_use] + pub fn updated_input_json(&self) -> Option<&str> { + self.updated_input() + } } #[derive(Debug, Clone, PartialEq, Eq, Default)] @@ -51,16 +153,6 @@ pub struct HookRunner { config: RuntimeHookConfig, } -#[derive(Debug, Clone, Copy)] -struct HookCommandRequest<'a> { - event: HookEvent, - tool_name: &'a str, - tool_input: &'a str, - tool_output: Option<&'a str>, - is_error: bool, - payload: &'a str, -} - impl HookRunner { #[must_use] pub fn new(config: RuntimeHookConfig) -> Self { @@ -74,6 +166,17 @@ impl HookRunner { #[must_use] pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { + self.run_pre_tool_use_with_context(tool_name, tool_input, None, None) + } + + #[must_use] + pub fn run_pre_tool_use_with_context( + &self, + tool_name: &str, + tool_input: &str, + abort_signal: Option<&HookAbortSignal>, + reporter: Option<&mut dyn HookProgressReporter>, + ) -> HookRunResult { Self::run_commands( HookEvent::PreToolUse, self.config.pre_tool_use(), @@ -81,9 +184,21 @@ impl HookRunner { tool_input, None, false, + abort_signal, + reporter, ) } + #[must_use] + pub fn run_pre_tool_use_with_signal( + &self, + tool_name: &str, + tool_input: &str, + abort_signal: Option<&HookAbortSignal>, + ) -> HookRunResult { + self.run_pre_tool_use_with_context(tool_name, tool_input, abort_signal, None) + } + #[must_use] pub fn run_post_tool_use( &self, @@ -91,6 +206,26 @@ impl HookRunner { tool_input: &str, tool_output: &str, is_error: bool, + ) -> HookRunResult { + self.run_post_tool_use_with_context( + tool_name, + tool_input, + tool_output, + is_error, + None, + None, + ) + } + + #[must_use] + pub fn run_post_tool_use_with_context( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + abort_signal: Option<&HookAbortSignal>, + reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { Self::run_commands( HookEvent::PostToolUse, @@ -99,9 +234,79 @@ impl HookRunner { tool_input, Some(tool_output), is_error, + abort_signal, + reporter, ) } + #[must_use] + pub fn run_post_tool_use_with_signal( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + abort_signal: Option<&HookAbortSignal>, + ) -> HookRunResult { + self.run_post_tool_use_with_context( + tool_name, + tool_input, + tool_output, + is_error, + abort_signal, + None, + ) + } + + #[must_use] + pub fn run_post_tool_use_failure( + &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + ) -> HookRunResult { + self.run_post_tool_use_failure_with_context(tool_name, tool_input, tool_error, None, None) + } + + #[must_use] + pub fn run_post_tool_use_failure_with_context( + &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + abort_signal: Option<&HookAbortSignal>, + reporter: Option<&mut dyn HookProgressReporter>, + ) -> HookRunResult { + Self::run_commands( + HookEvent::PostToolUseFailure, + self.config.post_tool_use_failure(), + tool_name, + tool_input, + Some(tool_error), + true, + abort_signal, + reporter, + ) + } + + #[must_use] + pub fn run_post_tool_use_failure_with_signal( + &self, + tool_name: &str, + tool_input: &str, + tool_error: &str, + abort_signal: Option<&HookAbortSignal>, + ) -> HookRunResult { + self.run_post_tool_use_failure_with_context( + tool_name, + tool_input, + tool_error, + abort_signal, + None, + ) + } + + #[allow(clippy::too_many_arguments)] fn run_commands( event: HookEvent, commands: &[String], @@ -109,122 +314,313 @@ impl HookRunner { tool_input: &str, tool_output: Option<&str>, is_error: bool, + abort_signal: Option<&HookAbortSignal>, + mut reporter: Option<&mut dyn HookProgressReporter>, ) -> HookRunResult { if commands.is_empty() { return HookRunResult::allow(Vec::new()); } - let payload = json!({ - "hook_event_name": event.as_str(), - "tool_name": tool_name, - "tool_input": parse_tool_input(tool_input), - "tool_input_json": tool_input, - "tool_output": tool_output, - "tool_result_is_error": is_error, - }) - .to_string(); + if abort_signal.is_some_and(HookAbortSignal::is_aborted) { + return HookRunResult { + denied: false, + failed: false, + cancelled: true, + messages: vec![format!( + "{} hook cancelled before execution", + event.as_str() + )], + permission_override: None, + permission_reason: None, + updated_input: None, + }; + } - let mut messages = Vec::new(); + let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string(); + let mut result = HookRunResult::allow(Vec::new()); for command in commands { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Started { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + match Self::run_command( command, - HookCommandRequest { - event, - tool_name, - tool_input, - tool_output, - is_error, - payload: &payload, - }, + event, + tool_name, + tool_input, + tool_output, + is_error, + &payload, + abort_signal, ) { - HookCommandOutcome::Allow { message } => { - if let Some(message) = message { - messages.push(message); + HookCommandOutcome::Allow { parsed } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Completed { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); } + merge_parsed_hook_output(&mut result, parsed); } - HookCommandOutcome::Deny { message } => { - let message = message.unwrap_or_else(|| { - format!("{} hook denied tool `{tool_name}`", event.as_str()) - }); - messages.push(message); - return HookRunResult { - denied: true, - messages, - }; + HookCommandOutcome::Deny { parsed } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Completed { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + merge_parsed_hook_output(&mut result, parsed); + result.denied = true; + return result; + } + HookCommandOutcome::Failed { parsed } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Completed { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + merge_parsed_hook_output(&mut result, parsed); + result.failed = true; + return result; + } + HookCommandOutcome::Cancelled { message } => { + if let Some(reporter) = reporter.as_deref_mut() { + reporter.on_event(&HookProgressEvent::Cancelled { + event, + tool_name: tool_name.to_string(), + command: command.clone(), + }); + } + result.cancelled = true; + result.messages.push(message); + return result; } - HookCommandOutcome::Warn { message } => messages.push(message), } } - HookRunResult::allow(messages) + result } - fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome { + #[allow(clippy::too_many_arguments)] + fn run_command( + command: &str, + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + payload: &str, + abort_signal: Option<&HookAbortSignal>, + ) -> HookCommandOutcome { let mut child = shell_command(command); - child.stdin(std::process::Stdio::piped()); - child.stdout(std::process::Stdio::piped()); - child.stderr(std::process::Stdio::piped()); - child.env("HOOK_EVENT", request.event.as_str()); - child.env("HOOK_TOOL_NAME", request.tool_name); - child.env("HOOK_TOOL_INPUT", request.tool_input); - child.env( - "HOOK_TOOL_IS_ERROR", - if request.is_error { "1" } else { "0" }, - ); - if let Some(tool_output) = request.tool_output { + child.stdin(Stdio::piped()); + child.stdout(Stdio::piped()); + child.stderr(Stdio::piped()); + child.env("HOOK_EVENT", event.as_str()); + child.env("HOOK_TOOL_NAME", tool_name); + child.env("HOOK_TOOL_INPUT", tool_input); + child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); + if let Some(tool_output) = tool_output { child.env("HOOK_TOOL_OUTPUT", tool_output); } - match child.output_with_stdin(request.payload.as_bytes()) { - Ok(output) => { + match child.output_with_stdin(payload.as_bytes(), abort_signal) { + Ok(CommandExecution::Finished(output)) => { let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - let message = (!stdout.is_empty()).then_some(stdout); + let parsed = parse_hook_output(&stdout); + let primary_message = parsed.primary_message().map(ToOwned::to_owned); match output.status.code() { - Some(0) => HookCommandOutcome::Allow { message }, - Some(2) => HookCommandOutcome::Deny { message }, - Some(code) => HookCommandOutcome::Warn { - message: format_hook_warning( + Some(0) => { + if parsed.deny { + HookCommandOutcome::Deny { parsed } + } else { + HookCommandOutcome::Allow { parsed } + } + } + Some(2) => HookCommandOutcome::Deny { + parsed: parsed.with_fallback_message(format!( + "{} hook denied tool `{tool_name}`", + event.as_str() + )), + }, + Some(code) => HookCommandOutcome::Failed { + parsed: parsed.with_fallback_message(format_hook_failure( command, code, - message.as_deref(), + primary_message.as_deref(), stderr.as_str(), - ), + )), }, - None => HookCommandOutcome::Warn { - message: format!( + None => HookCommandOutcome::Failed { + parsed: parsed.with_fallback_message(format!( "{} hook `{command}` terminated by signal while handling `{}`", - request.event.as_str(), - request.tool_name - ), + event.as_str(), + tool_name + )), }, } } - Err(error) => HookCommandOutcome::Warn { + Ok(CommandExecution::Cancelled) => HookCommandOutcome::Cancelled { message: format!( - "{} hook `{command}` failed to start for `{}`: {error}", - request.event.as_str(), - request.tool_name + "{} hook `{command}` cancelled while handling `{tool_name}`", + event.as_str() ), }, + Err(error) => HookCommandOutcome::Failed { + parsed: ParsedHookOutput { + messages: vec![format!( + "{} hook `{command}` failed to start for `{}`: {error}", + event.as_str(), + tool_name + )], + ..ParsedHookOutput::default() + }, + }, } } } enum HookCommandOutcome { - Allow { message: Option }, - Deny { message: Option }, - Warn { message: String }, + Allow { parsed: ParsedHookOutput }, + Deny { parsed: ParsedHookOutput }, + Failed { parsed: ParsedHookOutput }, + Cancelled { message: String }, } -fn parse_tool_input(tool_input: &str) -> serde_json::Value { +#[derive(Debug, Clone, PartialEq, Eq, Default)] +struct ParsedHookOutput { + messages: Vec, + deny: bool, + permission_override: Option, + permission_reason: Option, + updated_input: Option, +} + +impl ParsedHookOutput { + fn with_fallback_message(mut self, fallback: String) -> Self { + if self.messages.is_empty() { + self.messages.push(fallback); + } + self + } + + fn primary_message(&self) -> Option<&str> { + self.messages.first().map(String::as_str) + } +} + +fn merge_parsed_hook_output(target: &mut HookRunResult, parsed: ParsedHookOutput) { + target.messages.extend(parsed.messages); + if parsed.permission_override.is_some() { + target.permission_override = parsed.permission_override; + } + if parsed.permission_reason.is_some() { + target.permission_reason = parsed.permission_reason; + } + if parsed.updated_input.is_some() { + target.updated_input = parsed.updated_input; + } +} + +fn parse_hook_output(stdout: &str) -> ParsedHookOutput { + if stdout.is_empty() { + return ParsedHookOutput::default(); + } + + let Ok(Value::Object(root)) = serde_json::from_str::(stdout) else { + return ParsedHookOutput { + messages: vec![stdout.to_string()], + ..ParsedHookOutput::default() + }; + }; + + let mut parsed = ParsedHookOutput::default(); + + if let Some(message) = root.get("systemMessage").and_then(Value::as_str) { + parsed.messages.push(message.to_string()); + } + if let Some(message) = root.get("reason").and_then(Value::as_str) { + parsed.messages.push(message.to_string()); + } + if root.get("continue").and_then(Value::as_bool) == Some(false) + || root.get("decision").and_then(Value::as_str) == Some("block") + { + parsed.deny = true; + } + + if let Some(Value::Object(specific)) = root.get("hookSpecificOutput") { + if let Some(Value::String(additional_context)) = specific.get("additionalContext") { + parsed.messages.push(additional_context.clone()); + } + if let Some(decision) = specific.get("permissionDecision").and_then(Value::as_str) { + parsed.permission_override = match decision { + "allow" => Some(PermissionOverride::Allow), + "deny" => Some(PermissionOverride::Deny), + "ask" => Some(PermissionOverride::Ask), + _ => None, + }; + } + if let Some(reason) = specific + .get("permissionDecisionReason") + .and_then(Value::as_str) + { + parsed.permission_reason = Some(reason.to_string()); + } + if let Some(updated_input) = specific.get("updatedInput") { + parsed.updated_input = serde_json::to_string(updated_input).ok(); + } + } + + if parsed.messages.is_empty() { + parsed.messages.push(stdout.to_string()); + } + + parsed +} + +fn hook_payload( + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, +) -> Value { + match event { + HookEvent::PostToolUseFailure => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_error": tool_output, + "tool_result_is_error": true, + }), + _ => json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }), + } +} + +fn parse_tool_input(tool_input: &str) -> Value { serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) } -fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { - let mut message = - format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); +fn format_hook_failure(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { + let mut message = format!("Hook `{command}` exited with status {code}"); if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { message.push_str(": "); message.push_str(stdout); @@ -237,7 +633,8 @@ fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: & fn shell_command(command: &str) -> CommandWithStdin { #[cfg(windows)] - let command_builder = { + #[allow(unused_mut)] + let mut command_builder = { let mut command_builder = Command::new("cmd"); command_builder.arg("/C").arg(command); CommandWithStdin::new(command_builder) @@ -262,17 +659,17 @@ impl CommandWithStdin { Self { command } } - fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { + fn stdin(&mut self, cfg: Stdio) -> &mut Self { self.command.stdin(cfg); self } - fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { + fn stdout(&mut self, cfg: Stdio) -> &mut Self { self.command.stdout(cfg); self } - fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { + fn stderr(&mut self, cfg: Stdio) -> &mut Self { self.command.stderr(cfg); self } @@ -286,26 +683,64 @@ impl CommandWithStdin { self } - fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result { + fn output_with_stdin( + &mut self, + stdin: &[u8], + abort_signal: Option<&HookAbortSignal>, + ) -> std::io::Result { let mut child = self.command.spawn()?; if let Some(mut child_stdin) = child.stdin.take() { - use std::io::Write; child_stdin.write_all(stdin)?; } - child.wait_with_output() + + loop { + if abort_signal.is_some_and(HookAbortSignal::is_aborted) { + let _ = child.kill(); + let _ = child.wait_with_output(); + return Ok(CommandExecution::Cancelled); + } + + match child.try_wait()? { + Some(_) => return child.wait_with_output().map(CommandExecution::Finished), + None => thread::sleep(Duration::from_millis(20)), + } + } } } +enum CommandExecution { + Finished(std::process::Output), + Cancelled, +} + #[cfg(test)] mod tests { - use super::{HookRunResult, HookRunner}; + use std::thread; + use std::time::Duration; + + use super::{ + HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, + HookRunner, + }; use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; + use crate::permissions::PermissionOverride; + + struct RecordingReporter { + events: Vec, + } + + impl HookProgressReporter for RecordingReporter { + fn on_event(&mut self, event: &HookProgressEvent) { + self.events.push(event.clone()); + } + } #[test] fn allows_exit_code_zero_and_captures_stdout() { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre ok'")], Vec::new(), + Vec::new(), )); let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#); @@ -318,6 +753,7 @@ mod tests { let runner = HookRunner::new(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), + Vec::new(), )); let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); @@ -327,21 +763,217 @@ mod tests { } #[test] - fn warns_for_other_non_zero_statuses() { + fn propagates_other_non_zero_statuses_as_failures() { let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks( RuntimeHookConfig::new( vec![shell_snippet("printf 'warning hook'; exit 1")], Vec::new(), + Vec::new(), ), )); + // given + // when let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#); - assert!(!result.is_denied()); + // then + assert!(result.is_failed()); assert!(result .messages() .iter() - .any(|message| message.contains("allowing tool execution to continue"))); + .any(|message| message.contains("warning hook"))); + } + + #[test] + fn parses_pre_hook_permission_override_and_updated_input() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet( + r#"printf '%s' '{"systemMessage":"updated","hookSpecificOutput":{"permissionDecision":"allow","permissionDecisionReason":"hook ok","updatedInput":{"command":"git status"}}}'"#, + )], + Vec::new(), + Vec::new(), + )); + + let result = runner.run_pre_tool_use("bash", r#"{"command":"pwd"}"#); + + assert_eq!( + result.permission_override(), + Some(PermissionOverride::Allow) + ); + assert_eq!(result.permission_reason(), Some("hook ok")); + assert_eq!(result.updated_input(), Some(r#"{"command":"git status"}"#)); + assert!(result.messages().iter().any(|message| message == "updated")); + } + + #[test] + fn runs_post_tool_use_failure_hooks() { + // given + let runner = HookRunner::new(RuntimeHookConfig::new( + Vec::new(), + Vec::new(), + vec![shell_snippet("printf 'failure hook ran'")], + )); + + // when + let result = + runner.run_post_tool_use_failure("bash", r#"{"command":"false"}"#, "command failed"); + + // then + assert!(!result.is_denied()); + assert_eq!(result.messages(), &["failure hook ran".to_string()]); + } + + #[test] + fn stops_running_failure_hooks_after_failure() { + // given + let runner = HookRunner::new(RuntimeHookConfig::new( + Vec::new(), + Vec::new(), + vec![ + shell_snippet("printf 'broken failure hook'; exit 1"), + shell_snippet("printf 'later failure hook'"), + ], + )); + + // when + let result = + runner.run_post_tool_use_failure("bash", r#"{"command":"false"}"#, "command failed"); + + // then + assert!(result.is_failed()); + assert!(result + .messages() + .iter() + .any(|message| message.contains("broken failure hook"))); + assert!(!result + .messages() + .iter() + .any(|message| message == "later failure hook")); + } + + #[test] + fn executes_hooks_in_configured_order() { + // given + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![ + shell_snippet("printf 'first'"), + shell_snippet("printf 'second'"), + ], + Vec::new(), + Vec::new(), + )); + let mut reporter = RecordingReporter { events: Vec::new() }; + + // when + let result = runner.run_pre_tool_use_with_context( + "Read", + r#"{"path":"README.md"}"#, + None, + Some(&mut reporter), + ); + + // then + assert_eq!( + result, + HookRunResult::allow(vec!["first".to_string(), "second".to_string()]) + ); + assert_eq!(reporter.events.len(), 4); + assert!(matches!( + &reporter.events[0], + HookProgressEvent::Started { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'first'" + )); + assert!(matches!( + &reporter.events[1], + HookProgressEvent::Completed { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'first'" + )); + assert!(matches!( + &reporter.events[2], + HookProgressEvent::Started { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'second'" + )); + assert!(matches!( + &reporter.events[3], + HookProgressEvent::Completed { + event: HookEvent::PreToolUse, + command, + .. + } if command == "printf 'second'" + )); + } + + #[test] + fn stops_running_hooks_after_failure() { + // given + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![ + shell_snippet("printf 'broken'; exit 1"), + shell_snippet("printf 'later'"), + ], + Vec::new(), + Vec::new(), + )); + + // when + let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#); + + // then + assert!(result.is_failed()); + assert!(result + .messages() + .iter() + .any(|message| message.contains("broken"))); + assert!(!result.messages().iter().any(|message| message == "later")); + } + + #[test] + fn abort_signal_cancels_long_running_hook_and_reports_progress() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet("sleep 5")], + Vec::new(), + Vec::new(), + )); + let abort_signal = HookAbortSignal::new(); + let abort_signal_for_thread = abort_signal.clone(); + let mut reporter = RecordingReporter { events: Vec::new() }; + + thread::spawn(move || { + thread::sleep(Duration::from_millis(100)); + abort_signal_for_thread.abort(); + }); + + let result = runner.run_pre_tool_use_with_context( + "bash", + r#"{"command":"sleep 5"}"#, + Some(&abort_signal), + Some(&mut reporter), + ); + + assert!(result.is_cancelled()); + assert!(reporter.events.iter().any(|event| matches!( + event, + HookProgressEvent::Started { + event: HookEvent::PreToolUse, + .. + } + ))); + assert!(reporter.events.iter().any(|event| matches!( + event, + HookProgressEvent::Cancelled { + event: HookEvent::PreToolUse, + .. + } + ))); } #[cfg(windows)] diff --git a/crates/runtime/src/json.rs b/crates/runtime/src/json.rs index 2c89795..d829a15 100644 --- a/crates/runtime/src/json.rs +++ b/crates/runtime/src/json.rs @@ -1,8 +1,7 @@ use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; -use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum JsonValue { Null, Bool(bool), diff --git a/crates/runtime/src/lane_events.rs b/crates/runtime/src/lane_events.rs new file mode 100644 index 0000000..96a9ac8 --- /dev/null +++ b/crates/runtime/src/lane_events.rs @@ -0,0 +1,383 @@ +#![allow(clippy::similar_names)] +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum LaneEventName { + #[serde(rename = "lane.started")] + Started, + #[serde(rename = "lane.ready")] + Ready, + #[serde(rename = "lane.prompt_misdelivery")] + PromptMisdelivery, + #[serde(rename = "lane.blocked")] + Blocked, + #[serde(rename = "lane.red")] + Red, + #[serde(rename = "lane.green")] + Green, + #[serde(rename = "lane.commit.created")] + CommitCreated, + #[serde(rename = "lane.pr.opened")] + PrOpened, + #[serde(rename = "lane.merge.ready")] + MergeReady, + #[serde(rename = "lane.finished")] + Finished, + #[serde(rename = "lane.failed")] + Failed, + #[serde(rename = "lane.reconciled")] + Reconciled, + #[serde(rename = "lane.merged")] + Merged, + #[serde(rename = "lane.superseded")] + Superseded, + #[serde(rename = "lane.closed")] + Closed, + #[serde(rename = "branch.stale_against_main")] + BranchStaleAgainstMain, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LaneEventStatus { + Running, + Ready, + Blocked, + Red, + Green, + Completed, + Failed, + Reconciled, + Merged, + Superseded, + Closed, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LaneFailureClass { + PromptDelivery, + TrustGate, + BranchDivergence, + Compile, + Test, + PluginStartup, + McpStartup, + McpHandshake, + GatewayRouting, + ToolRuntime, + Infra, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct LaneEventBlocker { + #[serde(rename = "failureClass")] + pub failure_class: LaneFailureClass, + pub detail: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct LaneCommitProvenance { + pub commit: String, + pub branch: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub worktree: Option, + #[serde(rename = "canonicalCommit", skip_serializing_if = "Option::is_none")] + pub canonical_commit: Option, + #[serde(rename = "supersededBy", skip_serializing_if = "Option::is_none")] + pub superseded_by: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub lineage: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct LaneEvent { + pub event: LaneEventName, + pub status: LaneEventStatus, + #[serde(rename = "emittedAt")] + pub emitted_at: String, + #[serde(rename = "failureClass", skip_serializing_if = "Option::is_none")] + pub failure_class: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl LaneEvent { + #[must_use] + pub fn new( + event: LaneEventName, + status: LaneEventStatus, + emitted_at: impl Into, + ) -> Self { + Self { + event, + status, + emitted_at: emitted_at.into(), + failure_class: None, + detail: None, + data: None, + } + } + + #[must_use] + pub fn started(emitted_at: impl Into) -> Self { + Self::new(LaneEventName::Started, LaneEventStatus::Running, emitted_at) + } + + #[must_use] + pub fn finished(emitted_at: impl Into, detail: Option) -> Self { + Self::new( + LaneEventName::Finished, + LaneEventStatus::Completed, + emitted_at, + ) + .with_optional_detail(detail) + } + + #[must_use] + pub fn commit_created( + emitted_at: impl Into, + detail: Option, + provenance: LaneCommitProvenance, + ) -> Self { + Self::new( + LaneEventName::CommitCreated, + LaneEventStatus::Completed, + emitted_at, + ) + .with_optional_detail(detail) + .with_data(serde_json::to_value(provenance).expect("commit provenance should serialize")) + } + + #[must_use] + pub fn superseded( + emitted_at: impl Into, + detail: Option, + provenance: LaneCommitProvenance, + ) -> Self { + Self::new( + LaneEventName::Superseded, + LaneEventStatus::Superseded, + emitted_at, + ) + .with_optional_detail(detail) + .with_data(serde_json::to_value(provenance).expect("commit provenance should serialize")) + } + + #[must_use] + pub fn blocked(emitted_at: impl Into, blocker: &LaneEventBlocker) -> Self { + Self::new(LaneEventName::Blocked, LaneEventStatus::Blocked, emitted_at) + .with_failure_class(blocker.failure_class) + .with_detail(blocker.detail.clone()) + } + + #[must_use] + pub fn failed(emitted_at: impl Into, blocker: &LaneEventBlocker) -> Self { + Self::new(LaneEventName::Failed, LaneEventStatus::Failed, emitted_at) + .with_failure_class(blocker.failure_class) + .with_detail(blocker.detail.clone()) + } + + #[must_use] + pub fn with_failure_class(mut self, failure_class: LaneFailureClass) -> Self { + self.failure_class = Some(failure_class); + self + } + + #[must_use] + pub fn with_detail(mut self, detail: impl Into) -> Self { + self.detail = Some(detail.into()); + self + } + + #[must_use] + pub fn with_optional_detail(mut self, detail: Option) -> Self { + self.detail = detail; + self + } + + #[must_use] + pub fn with_data(mut self, data: Value) -> Self { + self.data = Some(data); + self + } +} + +#[must_use] +pub fn dedupe_superseded_commit_events(events: &[LaneEvent]) -> Vec { + let mut keep = vec![true; events.len()]; + let mut latest_by_key = std::collections::BTreeMap::::new(); + + for (index, event) in events.iter().enumerate() { + if event.event != LaneEventName::CommitCreated { + continue; + } + let Some(data) = event.data.as_ref() else { + continue; + }; + let key = data + .get("canonicalCommit") + .or_else(|| data.get("commit")) + .and_then(serde_json::Value::as_str) + .map(str::to_string); + let superseded = data + .get("supersededBy") + .and_then(serde_json::Value::as_str) + .is_some(); + if superseded { + keep[index] = false; + continue; + } + if let Some(key) = key { + if let Some(previous) = latest_by_key.insert(key, index) { + keep[previous] = false; + } + } + } + + events + .iter() + .cloned() + .zip(keep) + .filter_map(|(event, retain)| retain.then_some(event)) + .collect() +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::{ + dedupe_superseded_commit_events, LaneCommitProvenance, LaneEvent, LaneEventBlocker, + LaneEventName, LaneEventStatus, LaneFailureClass, + }; + + #[test] + fn canonical_lane_event_names_serialize_to_expected_wire_values() { + let cases = [ + (LaneEventName::Started, "lane.started"), + (LaneEventName::Ready, "lane.ready"), + (LaneEventName::PromptMisdelivery, "lane.prompt_misdelivery"), + (LaneEventName::Blocked, "lane.blocked"), + (LaneEventName::Red, "lane.red"), + (LaneEventName::Green, "lane.green"), + (LaneEventName::CommitCreated, "lane.commit.created"), + (LaneEventName::PrOpened, "lane.pr.opened"), + (LaneEventName::MergeReady, "lane.merge.ready"), + (LaneEventName::Finished, "lane.finished"), + (LaneEventName::Failed, "lane.failed"), + (LaneEventName::Reconciled, "lane.reconciled"), + (LaneEventName::Merged, "lane.merged"), + (LaneEventName::Superseded, "lane.superseded"), + (LaneEventName::Closed, "lane.closed"), + ( + LaneEventName::BranchStaleAgainstMain, + "branch.stale_against_main", + ), + ]; + + for (event, expected) in cases { + assert_eq!( + serde_json::to_value(event).expect("serialize event"), + json!(expected) + ); + } + } + + #[test] + fn failure_classes_cover_canonical_taxonomy_wire_values() { + let cases = [ + (LaneFailureClass::PromptDelivery, "prompt_delivery"), + (LaneFailureClass::TrustGate, "trust_gate"), + (LaneFailureClass::BranchDivergence, "branch_divergence"), + (LaneFailureClass::Compile, "compile"), + (LaneFailureClass::Test, "test"), + (LaneFailureClass::PluginStartup, "plugin_startup"), + (LaneFailureClass::McpStartup, "mcp_startup"), + (LaneFailureClass::McpHandshake, "mcp_handshake"), + (LaneFailureClass::GatewayRouting, "gateway_routing"), + (LaneFailureClass::ToolRuntime, "tool_runtime"), + (LaneFailureClass::Infra, "infra"), + ]; + + for (failure_class, expected) in cases { + assert_eq!( + serde_json::to_value(failure_class).expect("serialize failure class"), + json!(expected) + ); + } + } + + #[test] + fn blocked_and_failed_events_reuse_blocker_details() { + let blocker = LaneEventBlocker { + failure_class: LaneFailureClass::McpStartup, + detail: "broken server".to_string(), + }; + + let blocked = LaneEvent::blocked("2026-04-04T00:00:00Z", &blocker); + let failed = LaneEvent::failed("2026-04-04T00:00:01Z", &blocker); + + assert_eq!(blocked.event, LaneEventName::Blocked); + assert_eq!(blocked.status, LaneEventStatus::Blocked); + assert_eq!(blocked.failure_class, Some(LaneFailureClass::McpStartup)); + assert_eq!(failed.event, LaneEventName::Failed); + assert_eq!(failed.status, LaneEventStatus::Failed); + assert_eq!(failed.detail.as_deref(), Some("broken server")); + } + + #[test] + fn commit_events_can_carry_worktree_and_supersession_metadata() { + let event = LaneEvent::commit_created( + "2026-04-04T00:00:00Z", + Some("commit created".to_string()), + LaneCommitProvenance { + commit: "abc123".to_string(), + branch: "feature/provenance".to_string(), + worktree: Some("wt-a".to_string()), + canonical_commit: Some("abc123".to_string()), + superseded_by: None, + lineage: vec!["abc123".to_string()], + }, + ); + let event_json = serde_json::to_value(&event).expect("lane event should serialize"); + assert_eq!(event_json["event"], "lane.commit.created"); + assert_eq!(event_json["data"]["branch"], "feature/provenance"); + assert_eq!(event_json["data"]["worktree"], "wt-a"); + } + + #[test] + fn dedupes_superseded_commit_events_by_canonical_commit() { + let retained = dedupe_superseded_commit_events(&[ + LaneEvent::commit_created( + "2026-04-04T00:00:00Z", + Some("old".to_string()), + LaneCommitProvenance { + commit: "old123".to_string(), + branch: "feature/provenance".to_string(), + worktree: Some("wt-a".to_string()), + canonical_commit: Some("canon123".to_string()), + superseded_by: Some("new123".to_string()), + lineage: vec!["old123".to_string(), "new123".to_string()], + }, + ), + LaneEvent::commit_created( + "2026-04-04T00:00:01Z", + Some("new".to_string()), + LaneCommitProvenance { + commit: "new123".to_string(), + branch: "feature/provenance".to_string(), + worktree: Some("wt-b".to_string()), + canonical_commit: Some("canon123".to_string()), + superseded_by: None, + lineage: vec!["old123".to_string(), "new123".to_string()], + }, + ), + ]); + assert_eq!(retained.len(), 1); + assert_eq!(retained[0].detail.as_deref(), Some("new")); + } +} diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index c714f95..e691df2 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -1,64 +1,112 @@ +//! Core runtime primitives for the `claw` CLI and supporting crates. +//! +//! This crate owns session persistence, permission evaluation, prompt assembly, +//! MCP plumbing, tool-facing file operations, and the core conversation loop +//! that drives interactive and one-shot turns. + mod bash; +pub mod bash_validation; mod bootstrap; +pub mod branch_lock; mod compact; mod config; +pub mod config_validate; mod conversation; mod file_ops; +mod git_context; +pub mod green_contract; mod hooks; mod json; +mod lane_events; +pub mod lsp_client; mod mcp; mod mcp_client; +pub mod mcp_lifecycle_hardened; +pub mod mcp_server; mod mcp_stdio; +pub mod mcp_tool_bridge; mod oauth; +pub mod permission_enforcer; mod permissions; +pub mod plugin_lifecycle; +mod policy_engine; mod prompt; +pub mod recovery_recipes; mod remote; pub mod sandbox; mod session; +pub mod session_control; +pub use session_control::SessionStore; +mod sse; +pub mod stale_base; +pub mod stale_branch; +pub mod summary_compression; +pub mod task_packet; +pub mod task_registry; +pub mod team_cron_registry; +#[cfg(test)] +mod trust_resolver; mod usage; +pub mod worker_boot; -pub use lsp::{ - FileDiagnostics, LspContextEnrichment, LspError, LspManager, LspServerConfig, - SymbolLocation, WorkspaceDiagnostics, -}; pub use bash::{execute_bash, BashCommandInput, BashCommandOutput}; pub use bootstrap::{BootstrapPhase, BootstrapPlan}; +pub use branch_lock::{detect_branch_lock_collisions, BranchLockCollision, BranchLockIntent}; pub use compact::{ compact_session, estimate_session_tokens, format_compact_summary, get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult, }; pub use config::{ - ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpManagedProxyServerConfig, - McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, + ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpConfigCollection, + McpManagedProxyServerConfig, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, - ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig, - RuntimePluginConfig, ScopedMcpServerConfig, CLAW_SETTINGS_SCHEMA_NAME, + ProviderFallbackConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, + RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig, + CLAW_SETTINGS_SCHEMA_NAME, +}; +pub use config_validate::{ + check_unsupported_format, format_diagnostics, validate_config_file, ConfigDiagnostic, + DiagnosticKind, ValidationResult, }; pub use conversation::{ - ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, - ToolError, ToolExecutor, TurnSummary, + auto_compaction_threshold_from_env, ApiClient, ApiRequest, AssistantEvent, AutoCompactionEvent, + ConversationRuntime, PromptCacheEvent, RuntimeError, StaticToolExecutor, ToolError, + ToolExecutor, TurnSummary, }; pub use file_ops::{ edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput, GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload, WriteFileOutput, }; -pub use hooks::{HookEvent, HookRunResult, HookRunner}; +pub use git_context::{GitCommitEntry, GitContext}; +pub use hooks::{ + HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, HookRunner, +}; +pub use lane_events::{ + dedupe_superseded_commit_events, LaneCommitProvenance, LaneEvent, LaneEventBlocker, + LaneEventName, LaneEventStatus, LaneFailureClass, +}; pub use mcp::{ mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp, scoped_mcp_config_hash, unwrap_ccr_proxy_url, }; pub use mcp_client::{ - McpManagedProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport, + McpClientAuth, McpClientBootstrap, McpClientTransport, McpManagedProxyTransport, McpRemoteTransport, McpSdkTransport, McpStdioTransport, }; +pub use mcp_lifecycle_hardened::{ + McpDegradedReport, McpErrorSurface, McpFailedServer, McpLifecyclePhase, McpLifecycleState, + McpLifecycleValidator, McpPhaseResult, +}; +pub use mcp_server::{McpServer, McpServerSpec, ToolCallHandler, MCP_SERVER_PROTOCOL_VERSION}; pub use mcp_stdio::{ spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, - ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, - McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams, - McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource, - McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool, - McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer, + ManagedMcpTool, McpDiscoveryFailure, McpInitializeClientInfo, McpInitializeParams, + McpInitializeResult, McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, + McpListToolsParams, McpListToolsResult, McpReadResourceParams, McpReadResourceResult, + McpResource, McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, + McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult, McpToolDiscoveryReport, + UnsupportedMcpServer, }; pub use oauth::{ clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair, @@ -68,22 +116,59 @@ pub use oauth::{ PkceChallengeMethod, PkceCodePair, }; pub use permissions::{ - PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, - PermissionPrompter, PermissionRequest, + PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy, + PermissionPromptDecision, PermissionPrompter, PermissionRequest, +}; +pub use plugin_lifecycle::{ + DegradedMode, DiscoveryResult, PluginHealthcheck, PluginLifecycle, PluginLifecycleEvent, + PluginState, ResourceInfo, ServerHealth, ServerStatus, ToolInfo, +}; +pub use policy_engine::{ + evaluate, DiffScope, GreenLevel, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, + PolicyEngine, PolicyRule, ReconcileReason, ReviewStatus, }; pub use prompt::{ load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError, SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY, }; +pub use recovery_recipes::{ + attempt_recovery, recipe_for, EscalationPolicy, FailureScenario, RecoveryContext, + RecoveryEvent, RecoveryRecipe, RecoveryResult, RecoveryStep, +}; pub use remote::{ inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url, RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL, DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS, }; -pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError}; +pub use sandbox::{ + build_linux_sandbox_command, detect_container_environment, detect_container_environment_from, + resolve_sandbox_status, resolve_sandbox_status_for_request, ContainerEnvironment, + FilesystemIsolationMode, LinuxSandboxCommand, SandboxConfig, SandboxDetectionInputs, + SandboxRequest, SandboxStatus, +}; +pub use session::{ + ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError, + SessionFork, SessionPromptEntry, +}; +pub use sse::{IncrementalSseParser, SseEvent}; +pub use stale_base::{ + check_base_commit, format_stale_base_warning, read_claw_base_file, resolve_expected_base, + BaseCommitSource, BaseCommitState, +}; +pub use stale_branch::{ + apply_policy, check_freshness, BranchFreshness, StaleBranchAction, StaleBranchEvent, + StaleBranchPolicy, +}; +pub use task_packet::{validate_packet, TaskPacket, TaskPacketValidationError, ValidatedPacket}; +#[cfg(test)] +pub use trust_resolver::{TrustConfig, TrustDecision, TrustEvent, TrustPolicy, TrustResolver}; pub use usage::{ format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker, }; +pub use worker_boot::{ + Worker, WorkerEvent, WorkerEventKind, WorkerEventPayload, WorkerFailure, WorkerFailureKind, + WorkerPromptTarget, WorkerReadySnapshot, WorkerRegistry, WorkerStatus, WorkerTrustResolution, +}; #[cfg(test)] pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> { diff --git a/crates/runtime/src/lsp_client.rs b/crates/runtime/src/lsp_client.rs new file mode 100644 index 0000000..6302713 --- /dev/null +++ b/crates/runtime/src/lsp_client.rs @@ -0,0 +1,747 @@ +#![allow(clippy::should_implement_trait, clippy::must_use_candidate)] +//! LSP (Language Server Protocol) client registry for tool dispatch. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use serde::{Deserialize, Serialize}; + +/// Supported LSP actions. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LspAction { + Diagnostics, + Hover, + Definition, + References, + Completion, + Symbols, + Format, +} + +impl LspAction { + pub fn from_str(s: &str) -> Option { + match s { + "diagnostics" => Some(Self::Diagnostics), + "hover" => Some(Self::Hover), + "definition" | "goto_definition" => Some(Self::Definition), + "references" | "find_references" => Some(Self::References), + "completion" | "completions" => Some(Self::Completion), + "symbols" | "document_symbols" => Some(Self::Symbols), + "format" | "formatting" => Some(Self::Format), + _ => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LspDiagnostic { + pub path: String, + pub line: u32, + pub character: u32, + pub severity: String, + pub message: String, + pub source: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LspLocation { + pub path: String, + pub line: u32, + pub character: u32, + pub end_line: Option, + pub end_character: Option, + pub preview: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LspHoverResult { + pub content: String, + pub language: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LspCompletionItem { + pub label: String, + pub kind: Option, + pub detail: Option, + pub insert_text: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LspSymbol { + pub name: String, + pub kind: String, + pub path: String, + pub line: u32, + pub character: u32, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LspServerStatus { + Connected, + Disconnected, + Starting, + Error, +} + +impl std::fmt::Display for LspServerStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Connected => write!(f, "connected"), + Self::Disconnected => write!(f, "disconnected"), + Self::Starting => write!(f, "starting"), + Self::Error => write!(f, "error"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LspServerState { + pub language: String, + pub status: LspServerStatus, + pub root_path: Option, + pub capabilities: Vec, + pub diagnostics: Vec, +} + +#[derive(Debug, Clone, Default)] +pub struct LspRegistry { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct RegistryInner { + servers: HashMap, +} + +impl LspRegistry { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn register( + &self, + language: &str, + status: LspServerStatus, + root_path: Option<&str>, + capabilities: Vec, + ) { + let mut inner = self.inner.lock().expect("lsp registry lock poisoned"); + inner.servers.insert( + language.to_owned(), + LspServerState { + language: language.to_owned(), + status, + root_path: root_path.map(str::to_owned), + capabilities, + diagnostics: Vec::new(), + }, + ); + } + + pub fn get(&self, language: &str) -> Option { + let inner = self.inner.lock().expect("lsp registry lock poisoned"); + inner.servers.get(language).cloned() + } + + /// Find the appropriate server for a file path based on extension. + pub fn find_server_for_path(&self, path: &str) -> Option { + let ext = std::path::Path::new(path) + .extension() + .and_then(|e| e.to_str()) + .unwrap_or(""); + + let language = match ext { + "rs" => "rust", + "ts" | "tsx" => "typescript", + "js" | "jsx" => "javascript", + "py" => "python", + "go" => "go", + "java" => "java", + "c" | "h" => "c", + "cpp" | "hpp" | "cc" => "cpp", + "rb" => "ruby", + "lua" => "lua", + _ => return None, + }; + + self.get(language) + } + + /// List all registered servers. + pub fn list_servers(&self) -> Vec { + let inner = self.inner.lock().expect("lsp registry lock poisoned"); + inner.servers.values().cloned().collect() + } + + /// Add diagnostics to a server. + pub fn add_diagnostics( + &self, + language: &str, + diagnostics: Vec, + ) -> Result<(), String> { + let mut inner = self.inner.lock().expect("lsp registry lock poisoned"); + let server = inner + .servers + .get_mut(language) + .ok_or_else(|| format!("LSP server not found for language: {language}"))?; + server.diagnostics.extend(diagnostics); + Ok(()) + } + + /// Get diagnostics for a specific file path. + pub fn get_diagnostics(&self, path: &str) -> Vec { + let inner = self.inner.lock().expect("lsp registry lock poisoned"); + inner + .servers + .values() + .flat_map(|s| &s.diagnostics) + .filter(|d| d.path == path) + .cloned() + .collect() + } + + /// Clear diagnostics for a language server. + pub fn clear_diagnostics(&self, language: &str) -> Result<(), String> { + let mut inner = self.inner.lock().expect("lsp registry lock poisoned"); + let server = inner + .servers + .get_mut(language) + .ok_or_else(|| format!("LSP server not found for language: {language}"))?; + server.diagnostics.clear(); + Ok(()) + } + + /// Disconnect a server. + pub fn disconnect(&self, language: &str) -> Option { + let mut inner = self.inner.lock().expect("lsp registry lock poisoned"); + inner.servers.remove(language) + } + + #[must_use] + pub fn len(&self) -> usize { + let inner = self.inner.lock().expect("lsp registry lock poisoned"); + inner.servers.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Dispatch an LSP action and return a structured result. + pub fn dispatch( + &self, + action: &str, + path: Option<&str>, + line: Option, + character: Option, + _query: Option<&str>, + ) -> Result { + let lsp_action = + LspAction::from_str(action).ok_or_else(|| format!("unknown LSP action: {action}"))?; + + // For diagnostics, we can check existing cached diagnostics + if lsp_action == LspAction::Diagnostics { + if let Some(path) = path { + let diags = self.get_diagnostics(path); + return Ok(serde_json::json!({ + "action": "diagnostics", + "path": path, + "diagnostics": diags, + "count": diags.len() + })); + } + // All diagnostics across all servers + let inner = self.inner.lock().expect("lsp registry lock poisoned"); + let all_diags: Vec<_> = inner + .servers + .values() + .flat_map(|s| &s.diagnostics) + .collect(); + return Ok(serde_json::json!({ + "action": "diagnostics", + "diagnostics": all_diags, + "count": all_diags.len() + })); + } + + // For other actions, we need a connected server for the given file + let path = path.ok_or("path is required for this LSP action")?; + let server = self + .find_server_for_path(path) + .ok_or_else(|| format!("no LSP server available for path: {path}"))?; + + if server.status != LspServerStatus::Connected { + return Err(format!( + "LSP server for '{}' is not connected (status: {})", + server.language, server.status + )); + } + + // Return structured placeholder — actual LSP JSON-RPC calls would + // go through the real LSP process here. + Ok(serde_json::json!({ + "action": action, + "path": path, + "line": line, + "character": character, + "language": server.language, + "status": "dispatched", + "message": format!("LSP {} dispatched to {} server", action, server.language) + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn registers_and_retrieves_server() { + let registry = LspRegistry::new(); + registry.register( + "rust", + LspServerStatus::Connected, + Some("/workspace"), + vec!["hover".into(), "completion".into()], + ); + + let server = registry.get("rust").expect("should exist"); + assert_eq!(server.language, "rust"); + assert_eq!(server.status, LspServerStatus::Connected); + assert_eq!(server.capabilities.len(), 2); + } + + #[test] + fn finds_server_by_file_extension() { + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + registry.register("typescript", LspServerStatus::Connected, None, vec![]); + + let rs_server = registry.find_server_for_path("src/main.rs").unwrap(); + assert_eq!(rs_server.language, "rust"); + + let ts_server = registry.find_server_for_path("src/index.ts").unwrap(); + assert_eq!(ts_server.language, "typescript"); + + assert!(registry.find_server_for_path("data.csv").is_none()); + } + + #[test] + fn manages_diagnostics() { + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + + registry + .add_diagnostics( + "rust", + vec![LspDiagnostic { + path: "src/main.rs".into(), + line: 10, + character: 5, + severity: "error".into(), + message: "mismatched types".into(), + source: Some("rust-analyzer".into()), + }], + ) + .unwrap(); + + let diags = registry.get_diagnostics("src/main.rs"); + assert_eq!(diags.len(), 1); + assert_eq!(diags[0].message, "mismatched types"); + + registry.clear_diagnostics("rust").unwrap(); + assert!(registry.get_diagnostics("src/main.rs").is_empty()); + } + + #[test] + fn dispatches_diagnostics_action() { + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + registry + .add_diagnostics( + "rust", + vec![LspDiagnostic { + path: "src/lib.rs".into(), + line: 1, + character: 0, + severity: "warning".into(), + message: "unused import".into(), + source: None, + }], + ) + .unwrap(); + + let result = registry + .dispatch("diagnostics", Some("src/lib.rs"), None, None, None) + .unwrap(); + assert_eq!(result["count"], 1); + } + + #[test] + fn dispatches_hover_action() { + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + + let result = registry + .dispatch("hover", Some("src/main.rs"), Some(10), Some(5), None) + .unwrap(); + assert_eq!(result["action"], "hover"); + assert_eq!(result["language"], "rust"); + } + + #[test] + fn rejects_action_on_disconnected_server() { + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Disconnected, None, vec![]); + + assert!(registry + .dispatch("hover", Some("src/main.rs"), Some(1), Some(0), None) + .is_err()); + } + + #[test] + fn rejects_unknown_action() { + let registry = LspRegistry::new(); + assert!(registry + .dispatch("unknown_action", Some("file.rs"), None, None, None) + .is_err()); + } + + #[test] + fn disconnects_server() { + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + assert_eq!(registry.len(), 1); + + let removed = registry.disconnect("rust"); + assert!(removed.is_some()); + assert!(registry.is_empty()); + } + + #[test] + fn lsp_action_from_str_all_aliases() { + // given + let cases = [ + ("diagnostics", Some(LspAction::Diagnostics)), + ("hover", Some(LspAction::Hover)), + ("definition", Some(LspAction::Definition)), + ("goto_definition", Some(LspAction::Definition)), + ("references", Some(LspAction::References)), + ("find_references", Some(LspAction::References)), + ("completion", Some(LspAction::Completion)), + ("completions", Some(LspAction::Completion)), + ("symbols", Some(LspAction::Symbols)), + ("document_symbols", Some(LspAction::Symbols)), + ("format", Some(LspAction::Format)), + ("formatting", Some(LspAction::Format)), + ("unknown", None), + ]; + + // when + let resolved: Vec<_> = cases + .into_iter() + .map(|(input, expected)| (input, LspAction::from_str(input), expected)) + .collect(); + + // then + for (input, actual, expected) in resolved { + assert_eq!(actual, expected, "unexpected action resolution for {input}"); + } + } + + #[test] + fn lsp_server_status_display_all_variants() { + // given + let cases = [ + (LspServerStatus::Connected, "connected"), + (LspServerStatus::Disconnected, "disconnected"), + (LspServerStatus::Starting, "starting"), + (LspServerStatus::Error, "error"), + ]; + + // when + let rendered: Vec<_> = cases + .into_iter() + .map(|(status, expected)| (status.to_string(), expected)) + .collect(); + + // then + assert_eq!( + rendered, + vec![ + ("connected".to_string(), "connected"), + ("disconnected".to_string(), "disconnected"), + ("starting".to_string(), "starting"), + ("error".to_string(), "error"), + ] + ); + } + + #[test] + fn dispatch_diagnostics_without_path_aggregates() { + // given + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + registry.register("python", LspServerStatus::Connected, None, vec![]); + registry + .add_diagnostics( + "rust", + vec![LspDiagnostic { + path: "src/lib.rs".into(), + line: 1, + character: 0, + severity: "warning".into(), + message: "unused import".into(), + source: Some("rust-analyzer".into()), + }], + ) + .expect("rust diagnostics should add"); + registry + .add_diagnostics( + "python", + vec![LspDiagnostic { + path: "script.py".into(), + line: 2, + character: 4, + severity: "error".into(), + message: "undefined name".into(), + source: Some("pyright".into()), + }], + ) + .expect("python diagnostics should add"); + + // when + let result = registry + .dispatch("diagnostics", None, None, None, None) + .expect("aggregate diagnostics should work"); + + // then + assert_eq!(result["action"], "diagnostics"); + assert_eq!(result["count"], 2); + assert_eq!(result["diagnostics"].as_array().map(Vec::len), Some(2)); + } + + #[test] + fn dispatch_non_diagnostics_requires_path() { + // given + let registry = LspRegistry::new(); + + // when + let result = registry.dispatch("hover", None, Some(1), Some(0), None); + + // then + assert_eq!( + result.expect_err("path should be required"), + "path is required for this LSP action" + ); + } + + #[test] + fn dispatch_no_server_for_path_errors() { + // given + let registry = LspRegistry::new(); + + // when + let result = registry.dispatch("hover", Some("notes.md"), Some(1), Some(0), None); + + // then + let error = result.expect_err("missing server should fail"); + assert!(error.contains("no LSP server available for path: notes.md")); + } + + #[test] + fn dispatch_disconnected_server_error_payload() { + // given + let registry = LspRegistry::new(); + registry.register("typescript", LspServerStatus::Disconnected, None, vec![]); + + // when + let result = registry.dispatch("hover", Some("src/index.ts"), Some(3), Some(2), None); + + // then + let error = result.expect_err("disconnected server should fail"); + assert!(error.contains("typescript")); + assert!(error.contains("disconnected")); + } + + #[test] + fn find_server_for_all_extensions() { + // given + let registry = LspRegistry::new(); + for language in [ + "rust", + "typescript", + "javascript", + "python", + "go", + "java", + "c", + "cpp", + "ruby", + "lua", + ] { + registry.register(language, LspServerStatus::Connected, None, vec![]); + } + let cases = [ + ("src/main.rs", "rust"), + ("src/index.ts", "typescript"), + ("src/view.tsx", "typescript"), + ("src/app.js", "javascript"), + ("src/app.jsx", "javascript"), + ("script.py", "python"), + ("main.go", "go"), + ("Main.java", "java"), + ("native.c", "c"), + ("native.h", "c"), + ("native.cpp", "cpp"), + ("native.hpp", "cpp"), + ("native.cc", "cpp"), + ("script.rb", "ruby"), + ("script.lua", "lua"), + ]; + + // when + let resolved: Vec<_> = cases + .into_iter() + .map(|(path, expected)| { + ( + path, + registry + .find_server_for_path(path) + .map(|server| server.language), + expected, + ) + }) + .collect(); + + // then + for (path, actual, expected) in resolved { + assert_eq!( + actual.as_deref(), + Some(expected), + "unexpected mapping for {path}" + ); + } + } + + #[test] + fn find_server_for_path_no_extension() { + // given + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + + // when + let result = registry.find_server_for_path("Makefile"); + + // then + assert!(result.is_none()); + } + + #[test] + fn list_servers_with_multiple() { + // given + let registry = LspRegistry::new(); + registry.register("rust", LspServerStatus::Connected, None, vec![]); + registry.register("typescript", LspServerStatus::Starting, None, vec![]); + registry.register("python", LspServerStatus::Error, None, vec![]); + + // when + let servers = registry.list_servers(); + + // then + assert_eq!(servers.len(), 3); + assert!(servers.iter().any(|server| server.language == "rust")); + assert!(servers.iter().any(|server| server.language == "typescript")); + assert!(servers.iter().any(|server| server.language == "python")); + } + + #[test] + fn get_missing_server_returns_none() { + // given + let registry = LspRegistry::new(); + + // when + let server = registry.get("missing"); + + // then + assert!(server.is_none()); + } + + #[test] + fn add_diagnostics_missing_language_errors() { + // given + let registry = LspRegistry::new(); + + // when + let result = registry.add_diagnostics("missing", vec![]); + + // then + let error = result.expect_err("missing language should fail"); + assert!(error.contains("LSP server not found for language: missing")); + } + + #[test] + fn get_diagnostics_across_servers() { + // given + let registry = LspRegistry::new(); + let shared_path = "shared/file.txt"; + registry.register("rust", LspServerStatus::Connected, None, vec![]); + registry.register("python", LspServerStatus::Connected, None, vec![]); + registry + .add_diagnostics( + "rust", + vec![LspDiagnostic { + path: shared_path.into(), + line: 4, + character: 1, + severity: "warning".into(), + message: "warn".into(), + source: None, + }], + ) + .expect("rust diagnostics should add"); + registry + .add_diagnostics( + "python", + vec![LspDiagnostic { + path: shared_path.into(), + line: 8, + character: 3, + severity: "error".into(), + message: "err".into(), + source: None, + }], + ) + .expect("python diagnostics should add"); + + // when + let diagnostics = registry.get_diagnostics(shared_path); + + // then + assert_eq!(diagnostics.len(), 2); + assert!(diagnostics + .iter() + .any(|diagnostic| diagnostic.message == "warn")); + assert!(diagnostics + .iter() + .any(|diagnostic| diagnostic.message == "err")); + } + + #[test] + fn clear_diagnostics_missing_language_errors() { + // given + let registry = LspRegistry::new(); + + // when + let result = registry.clear_diagnostics("missing"); + + // then + let error = result.expect_err("missing language should fail"); + assert!(error.contains("LSP server not found for language: missing")); + } +} diff --git a/crates/runtime/src/mcp.rs b/crates/runtime/src/mcp.rs index b37ea33..e65cd08 100644 --- a/crates/runtime/src/mcp.rs +++ b/crates/runtime/src/mcp.rs @@ -84,10 +84,13 @@ pub fn mcp_server_signature(config: &McpServerConfig) -> Option { pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String { let rendered = match &config.config { McpServerConfig::Stdio(stdio) => format!( - "stdio|{}|{}|{}", + "stdio|{}|{}|{}|{}", stdio.command, render_command_signature(&stdio.args), - render_env_signature(&stdio.env) + render_env_signature(&stdio.env), + stdio + .tool_call_timeout_ms + .map_or_else(String::new, |timeout_ms| timeout_ms.to_string()) ), McpServerConfig::Sse(remote) => format!( "sse|{}|{}|{}|{}", @@ -245,6 +248,7 @@ mod tests { command: "uvx".to_string(), args: vec!["mcp-server".to_string()], env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]), + tool_call_timeout_ms: None, }); assert_eq!( mcp_server_signature(&stdio), diff --git a/crates/runtime/src/mcp_client.rs b/crates/runtime/src/mcp_client.rs index e0e1f2c..96a6db2 100644 --- a/crates/runtime/src/mcp_client.rs +++ b/crates/runtime/src/mcp_client.rs @@ -3,6 +3,8 @@ use std::collections::BTreeMap; use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig}; use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp}; +pub const DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS: u64 = 60_000; + #[derive(Debug, Clone, PartialEq, Eq)] pub enum McpClientTransport { Stdio(McpStdioTransport), @@ -18,6 +20,7 @@ pub struct McpStdioTransport { pub command: String, pub args: Vec, pub env: BTreeMap, + pub tool_call_timeout_ms: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -75,6 +78,7 @@ impl McpClientTransport { command: config.command.clone(), args: config.args.clone(), env: config.env.clone(), + tool_call_timeout_ms: config.tool_call_timeout_ms, }), McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport { url: config.url.clone(), @@ -105,6 +109,14 @@ impl McpClientTransport { } } +impl McpStdioTransport { + #[must_use] + pub fn resolved_tool_call_timeout_ms(&self) -> u64 { + self.tool_call_timeout_ms + .unwrap_or(DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS) + } +} + impl McpClientAuth { #[must_use] pub fn from_oauth(oauth: Option) -> Self { @@ -136,6 +148,7 @@ mod tests { command: "uvx".to_string(), args: vec!["mcp-server".to_string()], env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]), + tool_call_timeout_ms: Some(15_000), }), }; @@ -154,6 +167,7 @@ mod tests { transport.env.get("TOKEN").map(String::as_str), Some("secret") ); + assert_eq!(transport.tool_call_timeout_ms, Some(15_000)); } other => panic!("expected stdio transport, got {other:?}"), } diff --git a/crates/runtime/src/mcp_lifecycle_hardened.rs b/crates/runtime/src/mcp_lifecycle_hardened.rs new file mode 100644 index 0000000..330ff63 --- /dev/null +++ b/crates/runtime/src/mcp_lifecycle_hardened.rs @@ -0,0 +1,843 @@ +#![allow(clippy::unnested_or_patterns, clippy::map_unwrap_or)] +use std::collections::{BTreeMap, BTreeSet}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum McpLifecyclePhase { + ConfigLoad, + ServerRegistration, + SpawnConnect, + InitializeHandshake, + ToolDiscovery, + ResourceDiscovery, + Ready, + Invocation, + ErrorSurfacing, + Shutdown, + Cleanup, +} + +impl McpLifecyclePhase { + #[must_use] + pub fn all() -> [Self; 11] { + [ + Self::ConfigLoad, + Self::ServerRegistration, + Self::SpawnConnect, + Self::InitializeHandshake, + Self::ToolDiscovery, + Self::ResourceDiscovery, + Self::Ready, + Self::Invocation, + Self::ErrorSurfacing, + Self::Shutdown, + Self::Cleanup, + ] + } +} + +impl std::fmt::Display for McpLifecyclePhase { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConfigLoad => write!(f, "config_load"), + Self::ServerRegistration => write!(f, "server_registration"), + Self::SpawnConnect => write!(f, "spawn_connect"), + Self::InitializeHandshake => write!(f, "initialize_handshake"), + Self::ToolDiscovery => write!(f, "tool_discovery"), + Self::ResourceDiscovery => write!(f, "resource_discovery"), + Self::Ready => write!(f, "ready"), + Self::Invocation => write!(f, "invocation"), + Self::ErrorSurfacing => write!(f, "error_surfacing"), + Self::Shutdown => write!(f, "shutdown"), + Self::Cleanup => write!(f, "cleanup"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct McpErrorSurface { + pub phase: McpLifecyclePhase, + pub server_name: Option, + pub message: String, + pub context: BTreeMap, + pub recoverable: bool, + pub timestamp: u64, +} + +impl McpErrorSurface { + #[must_use] + pub fn new( + phase: McpLifecyclePhase, + server_name: Option, + message: impl Into, + context: BTreeMap, + recoverable: bool, + ) -> Self { + Self { + phase, + server_name, + message: message.into(), + context, + recoverable, + timestamp: now_secs(), + } + } +} + +impl std::fmt::Display for McpErrorSurface { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "MCP lifecycle error during {}: {}", + self.phase, self.message + )?; + if let Some(server_name) = &self.server_name { + write!(f, " (server: {server_name})")?; + } + if !self.context.is_empty() { + write!(f, " with context {:?}", self.context)?; + } + if self.recoverable { + write!(f, " [recoverable]")?; + } + Ok(()) + } +} + +impl std::error::Error for McpErrorSurface {} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum McpPhaseResult { + Success { + phase: McpLifecyclePhase, + duration: Duration, + }, + Failure { + phase: McpLifecyclePhase, + error: McpErrorSurface, + }, + Timeout { + phase: McpLifecyclePhase, + waited: Duration, + error: McpErrorSurface, + }, +} + +impl McpPhaseResult { + #[must_use] + pub fn phase(&self) -> McpLifecyclePhase { + match self { + Self::Success { phase, .. } + | Self::Failure { phase, .. } + | Self::Timeout { phase, .. } => *phase, + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct McpLifecycleState { + current_phase: Option, + phase_errors: BTreeMap>, + phase_timestamps: BTreeMap, + phase_results: Vec, +} + +impl McpLifecycleState { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn current_phase(&self) -> Option { + self.current_phase + } + + #[must_use] + pub fn errors_for_phase(&self, phase: McpLifecyclePhase) -> &[McpErrorSurface] { + self.phase_errors + .get(&phase) + .map(Vec::as_slice) + .unwrap_or(&[]) + } + + #[must_use] + pub fn results(&self) -> &[McpPhaseResult] { + &self.phase_results + } + + #[must_use] + pub fn phase_timestamps(&self) -> &BTreeMap { + &self.phase_timestamps + } + + #[must_use] + pub fn phase_timestamp(&self, phase: McpLifecyclePhase) -> Option { + self.phase_timestamps.get(&phase).copied() + } + + fn record_phase(&mut self, phase: McpLifecyclePhase) { + self.current_phase = Some(phase); + self.phase_timestamps.insert(phase, now_secs()); + } + + fn record_error(&mut self, error: McpErrorSurface) { + self.phase_errors + .entry(error.phase) + .or_default() + .push(error); + } + + fn record_result(&mut self, result: McpPhaseResult) { + self.phase_results.push(result); + } + + fn can_resume_after_error(&self) -> bool { + match self.phase_results.last() { + Some(McpPhaseResult::Failure { error, .. } | McpPhaseResult::Timeout { error, .. }) => { + error.recoverable + } + _ => false, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct McpFailedServer { + pub server_name: String, + pub phase: McpLifecyclePhase, + pub error: McpErrorSurface, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct McpDegradedReport { + pub working_servers: Vec, + pub failed_servers: Vec, + pub available_tools: Vec, + pub missing_tools: Vec, +} + +impl McpDegradedReport { + #[must_use] + pub fn new( + working_servers: Vec, + failed_servers: Vec, + available_tools: Vec, + expected_tools: Vec, + ) -> Self { + let working_servers = dedupe_sorted(working_servers); + let available_tools = dedupe_sorted(available_tools); + let available_tool_set: BTreeSet<_> = available_tools.iter().cloned().collect(); + let expected_tools = dedupe_sorted(expected_tools); + let missing_tools = expected_tools + .into_iter() + .filter(|tool| !available_tool_set.contains(tool)) + .collect(); + + Self { + working_servers, + failed_servers, + available_tools, + missing_tools, + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct McpLifecycleValidator { + state: McpLifecycleState, +} + +impl McpLifecycleValidator { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn state(&self) -> &McpLifecycleState { + &self.state + } + + #[must_use] + pub fn validate_phase_transition(from: McpLifecyclePhase, to: McpLifecyclePhase) -> bool { + match (from, to) { + (McpLifecyclePhase::ConfigLoad, McpLifecyclePhase::ServerRegistration) + | (McpLifecyclePhase::ServerRegistration, McpLifecyclePhase::SpawnConnect) + | (McpLifecyclePhase::SpawnConnect, McpLifecyclePhase::InitializeHandshake) + | (McpLifecyclePhase::InitializeHandshake, McpLifecyclePhase::ToolDiscovery) + | (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::ResourceDiscovery) + | (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready) + | (McpLifecyclePhase::ResourceDiscovery, McpLifecyclePhase::Ready) + | (McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation) + | (McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready) + | (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Ready) + | (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Shutdown) + | (McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup) => true, + (_, McpLifecyclePhase::Shutdown) => from != McpLifecyclePhase::Cleanup, + (_, McpLifecyclePhase::ErrorSurfacing) => { + from != McpLifecyclePhase::Cleanup && from != McpLifecyclePhase::Shutdown + } + _ => false, + } + } + + pub fn run_phase(&mut self, phase: McpLifecyclePhase) -> McpPhaseResult { + let started = Instant::now(); + + if let Some(current_phase) = self.state.current_phase() { + if current_phase == McpLifecyclePhase::ErrorSurfacing + && phase == McpLifecyclePhase::Ready + && !self.state.can_resume_after_error() + { + return self.record_failure(McpErrorSurface::new( + phase, + None, + "cannot return to ready after a non-recoverable MCP lifecycle failure", + BTreeMap::from([ + ("from".to_string(), current_phase.to_string()), + ("to".to_string(), phase.to_string()), + ]), + false, + )); + } + + if !Self::validate_phase_transition(current_phase, phase) { + return self.record_failure(McpErrorSurface::new( + phase, + None, + format!("invalid MCP lifecycle transition from {current_phase} to {phase}"), + BTreeMap::from([ + ("from".to_string(), current_phase.to_string()), + ("to".to_string(), phase.to_string()), + ]), + false, + )); + } + } else if phase != McpLifecyclePhase::ConfigLoad { + return self.record_failure(McpErrorSurface::new( + phase, + None, + format!("invalid initial MCP lifecycle phase {phase}"), + BTreeMap::from([("phase".to_string(), phase.to_string())]), + false, + )); + } + + self.state.record_phase(phase); + let result = McpPhaseResult::Success { + phase, + duration: started.elapsed(), + }; + self.state.record_result(result.clone()); + result + } + + pub fn record_failure(&mut self, error: McpErrorSurface) -> McpPhaseResult { + let phase = error.phase; + self.state.record_error(error.clone()); + self.state.record_phase(McpLifecyclePhase::ErrorSurfacing); + let result = McpPhaseResult::Failure { phase, error }; + self.state.record_result(result.clone()); + result + } + + pub fn record_timeout( + &mut self, + phase: McpLifecyclePhase, + waited: Duration, + server_name: Option, + mut context: BTreeMap, + ) -> McpPhaseResult { + context.insert("waited_ms".to_string(), waited.as_millis().to_string()); + let error = McpErrorSurface::new( + phase, + server_name, + format!( + "MCP lifecycle phase {phase} timed out after {} ms", + waited.as_millis() + ), + context, + true, + ); + self.state.record_error(error.clone()); + self.state.record_phase(McpLifecyclePhase::ErrorSurfacing); + let result = McpPhaseResult::Timeout { + phase, + waited, + error, + }; + self.state.record_result(result.clone()); + result + } +} + +fn dedupe_sorted(mut values: Vec) -> Vec { + values.sort(); + values.dedup(); + values +} + +#[cfg(test)] +mod tests { + use super::*; + + use serde_json::json; + + #[test] + fn phase_display_matches_serde_name() { + // given + let phases = McpLifecyclePhase::all(); + + // when + let serialized = phases + .into_iter() + .map(|phase| { + ( + phase.to_string(), + serde_json::to_value(phase).expect("serialize phase"), + ) + }) + .collect::>(); + + // then + for (display, json_value) in serialized { + assert_eq!(json_value, json!(display)); + } + } + + #[test] + fn given_startup_path_when_running_to_cleanup_then_each_control_transition_succeeds() { + // given + let mut validator = McpLifecycleValidator::new(); + let phases = [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + McpLifecyclePhase::ResourceDiscovery, + McpLifecyclePhase::Ready, + McpLifecyclePhase::Invocation, + McpLifecyclePhase::Ready, + McpLifecyclePhase::Shutdown, + McpLifecyclePhase::Cleanup, + ]; + + // when + let results = phases + .into_iter() + .map(|phase| validator.run_phase(phase)) + .collect::>(); + + // then + assert!(results + .iter() + .all(|result| matches!(result, McpPhaseResult::Success { .. }))); + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::Cleanup) + ); + for phase in [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + McpLifecyclePhase::ResourceDiscovery, + McpLifecyclePhase::Ready, + McpLifecyclePhase::Invocation, + McpLifecyclePhase::Shutdown, + McpLifecyclePhase::Cleanup, + ] { + assert!(validator.state().phase_timestamp(phase).is_some()); + } + } + + #[test] + fn given_tool_discovery_when_resource_discovery_is_skipped_then_ready_is_still_allowed() { + // given + let mut validator = McpLifecycleValidator::new(); + for phase in [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + ] { + let result = validator.run_phase(phase); + assert!(matches!(result, McpPhaseResult::Success { .. })); + } + + // when + let result = validator.run_phase(McpLifecyclePhase::Ready); + + // then + assert!(matches!(result, McpPhaseResult::Success { .. })); + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::Ready) + ); + } + + #[test] + fn validates_expected_phase_transitions() { + // given + let valid_transitions = [ + ( + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + ), + ( + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + ), + ( + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + ), + ( + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + ), + ( + McpLifecyclePhase::ToolDiscovery, + McpLifecyclePhase::ResourceDiscovery, + ), + (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready), + ( + McpLifecyclePhase::ResourceDiscovery, + McpLifecyclePhase::Ready, + ), + (McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation), + (McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready), + (McpLifecyclePhase::Ready, McpLifecyclePhase::Shutdown), + ( + McpLifecyclePhase::Invocation, + McpLifecyclePhase::ErrorSurfacing, + ), + ( + McpLifecyclePhase::ErrorSurfacing, + McpLifecyclePhase::Shutdown, + ), + (McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup), + ]; + + // when / then + for (from, to) in valid_transitions { + assert!(McpLifecycleValidator::validate_phase_transition(from, to)); + } + assert!(!McpLifecycleValidator::validate_phase_transition( + McpLifecyclePhase::Ready, + McpLifecyclePhase::ConfigLoad, + )); + assert!(!McpLifecycleValidator::validate_phase_transition( + McpLifecyclePhase::Cleanup, + McpLifecyclePhase::Ready, + )); + } + + #[test] + fn given_invalid_transition_when_running_phase_then_structured_failure_is_recorded() { + // given + let mut validator = McpLifecycleValidator::new(); + let _ = validator.run_phase(McpLifecyclePhase::ConfigLoad); + let _ = validator.run_phase(McpLifecyclePhase::ServerRegistration); + + // when + let result = validator.run_phase(McpLifecyclePhase::Ready); + + // then + match result { + McpPhaseResult::Failure { phase, error } => { + assert_eq!(phase, McpLifecyclePhase::Ready); + assert!(!error.recoverable); + assert_eq!(error.phase, McpLifecyclePhase::Ready); + assert_eq!( + error.context.get("from").map(String::as_str), + Some("server_registration") + ); + assert_eq!(error.context.get("to").map(String::as_str), Some("ready")); + } + other => panic!("expected failure result, got {other:?}"), + } + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::ErrorSurfacing) + ); + assert_eq!( + validator + .state() + .errors_for_phase(McpLifecyclePhase::Ready) + .len(), + 1 + ); + } + + #[test] + fn given_each_phase_when_failure_is_recorded_then_error_is_tracked_per_phase() { + // given + let mut validator = McpLifecycleValidator::new(); + + // when / then + for phase in McpLifecyclePhase::all() { + let result = validator.record_failure(McpErrorSurface::new( + phase, + Some("alpha".to_string()), + format!("failure at {phase}"), + BTreeMap::from([("server".to_string(), "alpha".to_string())]), + phase == McpLifecyclePhase::ResourceDiscovery, + )); + + match result { + McpPhaseResult::Failure { + phase: failed_phase, + error, + } => { + assert_eq!(failed_phase, phase); + assert_eq!(error.phase, phase); + assert_eq!( + error.recoverable, + phase == McpLifecyclePhase::ResourceDiscovery + ); + } + other => panic!("expected failure result, got {other:?}"), + } + assert_eq!(validator.state().errors_for_phase(phase).len(), 1); + } + } + + #[test] + fn given_spawn_connect_timeout_when_recorded_then_waited_duration_is_preserved() { + // given + let mut validator = McpLifecycleValidator::new(); + let waited = Duration::from_millis(250); + + // when + let result = validator.record_timeout( + McpLifecyclePhase::SpawnConnect, + waited, + Some("alpha".to_string()), + BTreeMap::from([("attempt".to_string(), "1".to_string())]), + ); + + // then + match result { + McpPhaseResult::Timeout { + phase, + waited: actual, + error, + } => { + assert_eq!(phase, McpLifecyclePhase::SpawnConnect); + assert_eq!(actual, waited); + assert!(error.recoverable); + assert_eq!(error.server_name.as_deref(), Some("alpha")); + } + other => panic!("expected timeout result, got {other:?}"), + } + let errors = validator + .state() + .errors_for_phase(McpLifecyclePhase::SpawnConnect); + assert_eq!(errors.len(), 1); + assert_eq!( + errors[0].context.get("waited_ms").map(String::as_str), + Some("250") + ); + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::ErrorSurfacing) + ); + } + + #[test] + fn given_partial_server_health_when_building_degraded_report_then_missing_tools_are_reported() { + // given + let failed = vec![McpFailedServer { + server_name: "broken".to_string(), + phase: McpLifecyclePhase::InitializeHandshake, + error: McpErrorSurface::new( + McpLifecyclePhase::InitializeHandshake, + Some("broken".to_string()), + "initialize failed", + BTreeMap::from([("reason".to_string(), "broken pipe".to_string())]), + false, + ), + }]; + + // when + let report = McpDegradedReport::new( + vec!["alpha".to_string(), "beta".to_string(), "alpha".to_string()], + failed, + vec![ + "alpha.echo".to_string(), + "beta.search".to_string(), + "alpha.echo".to_string(), + ], + vec![ + "alpha.echo".to_string(), + "beta.search".to_string(), + "broken.fetch".to_string(), + ], + ); + + // then + assert_eq!( + report.working_servers, + vec!["alpha".to_string(), "beta".to_string()] + ); + assert_eq!(report.failed_servers.len(), 1); + assert_eq!(report.failed_servers[0].server_name, "broken"); + assert_eq!( + report.available_tools, + vec!["alpha.echo".to_string(), "beta.search".to_string()] + ); + assert_eq!(report.missing_tools, vec!["broken.fetch".to_string()]); + } + + #[test] + fn given_failure_during_resource_discovery_when_shutting_down_then_cleanup_still_succeeds() { + // given + let mut validator = McpLifecycleValidator::new(); + for phase in [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + ] { + let result = validator.run_phase(phase); + assert!(matches!(result, McpPhaseResult::Success { .. })); + } + let _ = validator.record_failure(McpErrorSurface::new( + McpLifecyclePhase::ResourceDiscovery, + Some("alpha".to_string()), + "resource listing failed", + BTreeMap::from([("reason".to_string(), "timeout".to_string())]), + true, + )); + + // when + let shutdown = validator.run_phase(McpLifecyclePhase::Shutdown); + let cleanup = validator.run_phase(McpLifecyclePhase::Cleanup); + + // then + assert!(matches!(shutdown, McpPhaseResult::Success { .. })); + assert!(matches!(cleanup, McpPhaseResult::Success { .. })); + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::Cleanup) + ); + assert!(validator + .state() + .phase_timestamp(McpLifecyclePhase::ErrorSurfacing) + .is_some()); + } + + #[test] + fn error_surface_display_includes_phase_server_and_recoverable_flag() { + // given + let error = McpErrorSurface::new( + McpLifecyclePhase::SpawnConnect, + Some("alpha".to_string()), + "process exited early", + BTreeMap::from([("exit_code".to_string(), "1".to_string())]), + true, + ); + + // when + let rendered = error.to_string(); + + // then + assert!(rendered.contains("spawn_connect")); + assert!(rendered.contains("process exited early")); + assert!(rendered.contains("server: alpha")); + assert!(rendered.contains("recoverable")); + let trait_object: &dyn std::error::Error = &error; + assert_eq!(trait_object.to_string(), rendered); + } + + #[test] + fn given_nonrecoverable_failure_when_returning_to_ready_then_validator_rejects_resume() { + // given + let mut validator = McpLifecycleValidator::new(); + for phase in [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + McpLifecyclePhase::Ready, + ] { + let result = validator.run_phase(phase); + assert!(matches!(result, McpPhaseResult::Success { .. })); + } + let _ = validator.record_failure(McpErrorSurface::new( + McpLifecyclePhase::Invocation, + Some("alpha".to_string()), + "tool call corrupted the session", + BTreeMap::from([("reason".to_string(), "invalid frame".to_string())]), + false, + )); + + // when + let result = validator.run_phase(McpLifecyclePhase::Ready); + + // then + match result { + McpPhaseResult::Failure { phase, error } => { + assert_eq!(phase, McpLifecyclePhase::Ready); + assert!(!error.recoverable); + assert!(error.message.contains("non-recoverable")); + } + other => panic!("expected failure result, got {other:?}"), + } + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::ErrorSurfacing) + ); + } + + #[test] + fn given_recoverable_failure_when_returning_to_ready_then_validator_allows_resume() { + // given + let mut validator = McpLifecycleValidator::new(); + for phase in [ + McpLifecyclePhase::ConfigLoad, + McpLifecyclePhase::ServerRegistration, + McpLifecyclePhase::SpawnConnect, + McpLifecyclePhase::InitializeHandshake, + McpLifecyclePhase::ToolDiscovery, + McpLifecyclePhase::Ready, + ] { + let result = validator.run_phase(phase); + assert!(matches!(result, McpPhaseResult::Success { .. })); + } + let _ = validator.record_failure(McpErrorSurface::new( + McpLifecyclePhase::Invocation, + Some("alpha".to_string()), + "tool call failed but can be retried", + BTreeMap::from([("reason".to_string(), "upstream timeout".to_string())]), + true, + )); + + // when + let result = validator.run_phase(McpLifecyclePhase::Ready); + + // then + assert!(matches!(result, McpPhaseResult::Success { .. })); + assert_eq!( + validator.state().current_phase(), + Some(McpLifecyclePhase::Ready) + ); + } +} diff --git a/crates/runtime/src/mcp_server.rs b/crates/runtime/src/mcp_server.rs new file mode 100644 index 0000000..4610ed4 --- /dev/null +++ b/crates/runtime/src/mcp_server.rs @@ -0,0 +1,440 @@ +//! Minimal Model Context Protocol (MCP) server. +//! +//! Implements a newline-safe, LSP-framed JSON-RPC server over stdio that +//! answers `initialize`, `tools/list`, and `tools/call` requests. The framing +//! matches the client transport implemented in [`crate::mcp_stdio`] so this +//! server can be driven by either an external MCP client (e.g. Claude +//! Desktop) or `claw`'s own [`McpServerManager`](crate::McpServerManager). +//! +//! The server is intentionally small: it exposes a list of pre-built +//! [`McpTool`] descriptors and delegates `tools/call` to a caller-supplied +//! handler. Tool execution itself lives in the `tools` crate; this module is +//! purely the transport + dispatch loop. +//! +//! [`McpTool`]: crate::mcp_stdio::McpTool + +use std::io; + +use serde_json::{json, Value as JsonValue}; +use tokio::io::{ + stdin, stdout, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, Stdin, Stdout, +}; + +use crate::mcp_stdio::{ + JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, McpInitializeResult, + McpInitializeServerInfo, McpListToolsResult, McpTool, McpToolCallContent, McpToolCallParams, + McpToolCallResult, +}; + +/// Protocol version the server advertises during `initialize`. +/// +/// Matches the version used by the built-in client in +/// [`crate::mcp_stdio`], so the two stay in lockstep. +pub const MCP_SERVER_PROTOCOL_VERSION: &str = "2025-03-26"; + +/// Synchronous handler invoked for every `tools/call` request. +/// +/// Returning `Ok(text)` yields a single `text` content block and +/// `isError: false`. Returning `Err(message)` yields a `text` block with the +/// error and `isError: true`, mirroring the error-surfacing convention used +/// elsewhere in claw. +pub type ToolCallHandler = + Box Result + Send + Sync + 'static>; + +/// Configuration for an [`McpServer`] instance. +/// +/// Named `McpServerSpec` rather than `McpServerConfig` to avoid colliding +/// with the existing client-side [`crate::config::McpServerConfig`] that +/// describes *remote* MCP servers the runtime connects to. +pub struct McpServerSpec { + /// Name advertised in the `serverInfo` field of the `initialize` response. + pub server_name: String, + /// Version advertised in the `serverInfo` field of the `initialize` + /// response. + pub server_version: String, + /// Tool descriptors returned for `tools/list`. + pub tools: Vec, + /// Handler invoked for `tools/call`. + pub tool_handler: ToolCallHandler, +} + +/// Minimal MCP stdio server. +/// +/// The server runs a blocking read/dispatch/write loop over the current +/// process's stdin/stdout, terminating cleanly when the peer closes the +/// stream. +pub struct McpServer { + spec: McpServerSpec, + stdin: BufReader, + stdout: Stdout, +} + +impl McpServer { + #[must_use] + pub fn new(spec: McpServerSpec) -> Self { + Self { + spec, + stdin: BufReader::new(stdin()), + stdout: stdout(), + } + } + + /// Runs the server until the client closes stdin. + /// + /// Returns `Ok(())` on clean EOF; any other I/O error is propagated so + /// callers can log and exit non-zero. + pub async fn run(&mut self) -> io::Result<()> { + loop { + let Some(payload) = read_frame(&mut self.stdin).await? else { + return Ok(()); + }; + + // Requests and notifications share a wire format; the absence of + // `id` distinguishes notifications, which must never receive a + // response. + let message: JsonValue = match serde_json::from_slice(&payload) { + Ok(value) => value, + Err(error) => { + // Parse error with null id per JSON-RPC 2.0 §4.2. + let response = JsonRpcResponse:: { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Null, + result: None, + error: Some(JsonRpcError { + code: -32700, + message: format!("parse error: {error}"), + data: None, + }), + }; + write_response(&mut self.stdout, &response).await?; + continue; + } + }; + + if message.get("id").is_none() { + // Notification: dispatch for side effects only (e.g. log), + // but send no reply. + continue; + } + + let request: JsonRpcRequest = match serde_json::from_value(message) { + Ok(request) => request, + Err(error) => { + let response = JsonRpcResponse:: { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Null, + result: None, + error: Some(JsonRpcError { + code: -32600, + message: format!("invalid request: {error}"), + data: None, + }), + }; + write_response(&mut self.stdout, &response).await?; + continue; + } + }; + + let response = self.dispatch(request); + write_response(&mut self.stdout, &response).await?; + } + } + + fn dispatch(&self, request: JsonRpcRequest) -> JsonRpcResponse { + let id = request.id.clone(); + match request.method.as_str() { + "initialize" => self.handle_initialize(id), + "tools/list" => self.handle_tools_list(id), + "tools/call" => self.handle_tools_call(id, request.params), + other => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(JsonRpcError { + code: -32601, + message: format!("method not found: {other}"), + data: None, + }), + }, + } + } + + fn handle_initialize(&self, id: JsonRpcId) -> JsonRpcResponse { + let result = McpInitializeResult { + protocol_version: MCP_SERVER_PROTOCOL_VERSION.to_string(), + capabilities: json!({ "tools": {} }), + server_info: McpInitializeServerInfo { + name: self.spec.server_name.clone(), + version: self.spec.server_version.clone(), + }, + }; + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: serde_json::to_value(result).ok(), + error: None, + } + } + + fn handle_tools_list(&self, id: JsonRpcId) -> JsonRpcResponse { + let result = McpListToolsResult { + tools: self.spec.tools.clone(), + next_cursor: None, + }; + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: serde_json::to_value(result).ok(), + error: None, + } + } + + fn handle_tools_call( + &self, + id: JsonRpcId, + params: Option, + ) -> JsonRpcResponse { + let Some(params) = params else { + return invalid_params_response(id, "missing params for tools/call"); + }; + let call: McpToolCallParams = match serde_json::from_value(params) { + Ok(value) => value, + Err(error) => { + return invalid_params_response(id, &format!("invalid tools/call params: {error}")); + } + }; + let arguments = call.arguments.unwrap_or_else(|| json!({})); + let tool_result = (self.spec.tool_handler)(&call.name, &arguments); + let (text, is_error) = match tool_result { + Ok(text) => (text, false), + Err(message) => (message, true), + }; + let mut data = std::collections::BTreeMap::new(); + data.insert("text".to_string(), JsonValue::String(text)); + let call_result = McpToolCallResult { + content: vec![McpToolCallContent { + kind: "text".to_string(), + data, + }], + structured_content: None, + is_error: Some(is_error), + meta: None, + }; + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: serde_json::to_value(call_result).ok(), + error: None, + } + } +} + +fn invalid_params_response(id: JsonRpcId, message: &str) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(JsonRpcError { + code: -32602, + message: message.to_string(), + data: None, + }), + } +} + +/// Reads a single LSP-framed JSON-RPC payload from `reader`. +/// +/// Returns `Ok(None)` on clean EOF before any header bytes have been read, +/// matching how [`crate::mcp_stdio::McpStdioProcess`] treats stream closure. +async fn read_frame(reader: &mut BufReader) -> io::Result>> { + let mut content_length: Option = None; + let mut first_header = true; + loop { + let mut line = String::new(); + let bytes_read = reader.read_line(&mut line).await?; + if bytes_read == 0 { + if first_header { + return Ok(None); + } + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "MCP stdio stream closed while reading headers", + )); + } + first_header = false; + if line == "\r\n" || line == "\n" { + break; + } + let header = line.trim_end_matches(['\r', '\n']); + if let Some((name, value)) = header.split_once(':') { + if name.trim().eq_ignore_ascii_case("Content-Length") { + let parsed = value + .trim() + .parse::() + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + content_length = Some(parsed); + } + } + } + + let content_length = content_length.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header") + })?; + let mut payload = vec![0_u8; content_length]; + reader.read_exact(&mut payload).await?; + Ok(Some(payload)) +} + +async fn write_response( + stdout: &mut Stdout, + response: &JsonRpcResponse, +) -> io::Result<()> { + let body = serde_json::to_vec(response) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + stdout.write_all(header.as_bytes()).await?; + stdout.write_all(&body).await?; + stdout.flush().await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dispatch_initialize_returns_server_info() { + let server = McpServer { + spec: McpServerSpec { + server_name: "test".to_string(), + server_version: "9.9.9".to_string(), + tools: Vec::new(), + tool_handler: Box::new(|_, _| Ok(String::new())), + }, + stdin: BufReader::new(stdin()), + stdout: stdout(), + }; + let request = JsonRpcRequest:: { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Number(1), + method: "initialize".to_string(), + params: None, + }; + let response = server.dispatch(request); + assert_eq!(response.id, JsonRpcId::Number(1)); + assert!(response.error.is_none()); + let result = response.result.expect("initialize result"); + assert_eq!(result["protocolVersion"], MCP_SERVER_PROTOCOL_VERSION); + assert_eq!(result["serverInfo"]["name"], "test"); + assert_eq!(result["serverInfo"]["version"], "9.9.9"); + } + + #[test] + fn dispatch_tools_list_returns_registered_tools() { + let tool = McpTool { + name: "echo".to_string(), + description: Some("Echo".to_string()), + input_schema: Some(json!({"type": "object"})), + annotations: None, + meta: None, + }; + let server = McpServer { + spec: McpServerSpec { + server_name: "test".to_string(), + server_version: "0.0.0".to_string(), + tools: vec![tool.clone()], + tool_handler: Box::new(|_, _| Ok(String::new())), + }, + stdin: BufReader::new(stdin()), + stdout: stdout(), + }; + let request = JsonRpcRequest:: { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Number(2), + method: "tools/list".to_string(), + params: None, + }; + let response = server.dispatch(request); + assert!(response.error.is_none()); + let result = response.result.expect("tools/list result"); + assert_eq!(result["tools"][0]["name"], "echo"); + } + + #[test] + fn dispatch_tools_call_wraps_handler_output() { + let server = McpServer { + spec: McpServerSpec { + server_name: "test".to_string(), + server_version: "0.0.0".to_string(), + tools: Vec::new(), + tool_handler: Box::new(|name, args| Ok(format!("called {name} with {args}"))), + }, + stdin: BufReader::new(stdin()), + stdout: stdout(), + }; + let request = JsonRpcRequest:: { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Number(3), + method: "tools/call".to_string(), + params: Some(json!({ + "name": "echo", + "arguments": {"text": "hi"} + })), + }; + let response = server.dispatch(request); + assert!(response.error.is_none()); + let result = response.result.expect("tools/call result"); + assert_eq!(result["isError"], false); + assert_eq!(result["content"][0]["type"], "text"); + assert!(result["content"][0]["text"] + .as_str() + .unwrap() + .starts_with("called echo")); + } + + #[test] + fn dispatch_tools_call_surfaces_handler_error() { + let server = McpServer { + spec: McpServerSpec { + server_name: "test".to_string(), + server_version: "0.0.0".to_string(), + tools: Vec::new(), + tool_handler: Box::new(|_, _| Err("boom".to_string())), + }, + stdin: BufReader::new(stdin()), + stdout: stdout(), + }; + let request = JsonRpcRequest:: { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Number(4), + method: "tools/call".to_string(), + params: Some(json!({"name": "broken"})), + }; + let response = server.dispatch(request); + let result = response.result.expect("tools/call result"); + assert_eq!(result["isError"], true); + assert_eq!(result["content"][0]["text"], "boom"); + } + + #[test] + fn dispatch_unknown_method_returns_method_not_found() { + let server = McpServer { + spec: McpServerSpec { + server_name: "test".to_string(), + server_version: "0.0.0".to_string(), + tools: Vec::new(), + tool_handler: Box::new(|_, _| Ok(String::new())), + }, + stdin: BufReader::new(stdin()), + stdout: stdout(), + }; + let request = JsonRpcRequest:: { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::Number(5), + method: "nonsense".to_string(), + params: None, + }; + let response = server.dispatch(request); + let error = response.error.expect("error payload"); + assert_eq!(error.code, -32601); + } +} diff --git a/crates/runtime/src/mcp_stdio.rs b/crates/runtime/src/mcp_stdio.rs index f3f16b9..c5b4c75 100644 --- a/crates/runtime/src/mcp_stdio.rs +++ b/crates/runtime/src/mcp_stdio.rs @@ -1,16 +1,32 @@ use std::collections::BTreeMap; +use std::future::Future; use std::io; use std::process::Stdio; +use std::time::Duration; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::time::timeout; use crate::config::{McpTransport, RuntimeConfig, ScopedMcpServerConfig}; use crate::mcp::mcp_tool_name; use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport}; +use crate::mcp_lifecycle_hardened::{ + McpDegradedReport, McpErrorSurface, McpFailedServer, McpLifecyclePhase, +}; + +#[cfg(test)] +const MCP_INITIALIZE_TIMEOUT_MS: u64 = 200; +#[cfg(not(test))] +const MCP_INITIALIZE_TIMEOUT_MS: u64 = 10_000; + +#[cfg(test)] +const MCP_LIST_TOOLS_TIMEOUT_MS: u64 = 300; +#[cfg(not(test))] +const MCP_LIST_TOOLS_TIMEOUT_MS: u64 = 30_000; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(untagged)] @@ -217,9 +233,31 @@ pub struct UnsupportedMcpServer { pub reason: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct McpDiscoveryFailure { + pub server_name: String, + pub phase: McpLifecyclePhase, + pub error: String, + pub recoverable: bool, + pub context: BTreeMap, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct McpToolDiscoveryReport { + pub tools: Vec, + pub failed_servers: Vec, + pub unsupported_servers: Vec, + pub degraded_startup: Option, +} + #[derive(Debug)] pub enum McpServerManagerError { Io(io::Error), + Transport { + server_name: String, + method: &'static str, + source: io::Error, + }, JsonRpc { server_name: String, method: &'static str, @@ -230,6 +268,11 @@ pub enum McpServerManagerError { method: &'static str, details: String, }, + Timeout { + server_name: String, + method: &'static str, + timeout_ms: u64, + }, UnknownTool { qualified_name: String, }, @@ -242,6 +285,14 @@ impl std::fmt::Display for McpServerManagerError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Io(error) => write!(f, "{error}"), + Self::Transport { + server_name, + method, + source, + } => write!( + f, + "MCP server `{server_name}` transport failed during {method}: {source}" + ), Self::JsonRpc { server_name, method, @@ -259,6 +310,14 @@ impl std::fmt::Display for McpServerManagerError { f, "MCP server `{server_name}` returned invalid response for {method}: {details}" ), + Self::Timeout { + server_name, + method, + timeout_ms, + } => write!( + f, + "MCP server `{server_name}` timed out after {timeout_ms} ms while handling {method}" + ), Self::UnknownTool { qualified_name } => { write!(f, "unknown MCP tool `{qualified_name}`") } @@ -271,8 +330,10 @@ impl std::error::Error for McpServerManagerError { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { Self::Io(error) => Some(error), + Self::Transport { source, .. } => Some(source), Self::JsonRpc { .. } | Self::InvalidResponse { .. } + | Self::Timeout { .. } | Self::UnknownTool { .. } | Self::UnknownServer { .. } => None, } @@ -285,6 +346,113 @@ impl From for McpServerManagerError { } } +impl McpServerManagerError { + fn lifecycle_phase(&self) -> McpLifecyclePhase { + match self { + Self::Io(_) => McpLifecyclePhase::SpawnConnect, + Self::Transport { method, .. } + | Self::JsonRpc { method, .. } + | Self::InvalidResponse { method, .. } + | Self::Timeout { method, .. } => lifecycle_phase_for_method(method), + Self::UnknownTool { .. } => McpLifecyclePhase::ToolDiscovery, + Self::UnknownServer { .. } => McpLifecyclePhase::ServerRegistration, + } + } + + fn recoverable(&self) -> bool { + !matches!( + self.lifecycle_phase(), + McpLifecyclePhase::InitializeHandshake + ) && matches!(self, Self::Transport { .. } | Self::Timeout { .. }) + } + + fn discovery_failure(&self, server_name: &str) -> McpDiscoveryFailure { + let phase = self.lifecycle_phase(); + let recoverable = self.recoverable(); + let context = self.error_context(); + + McpDiscoveryFailure { + server_name: server_name.to_string(), + phase, + error: self.to_string(), + recoverable, + context, + } + } + + fn error_context(&self) -> BTreeMap { + match self { + Self::Io(error) => BTreeMap::from([("kind".to_string(), error.kind().to_string())]), + Self::Transport { + server_name, + method, + source, + } => BTreeMap::from([ + ("server".to_string(), server_name.clone()), + ("method".to_string(), (*method).to_string()), + ("io_kind".to_string(), source.kind().to_string()), + ]), + Self::JsonRpc { + server_name, + method, + error, + } => BTreeMap::from([ + ("server".to_string(), server_name.clone()), + ("method".to_string(), (*method).to_string()), + ("jsonrpc_code".to_string(), error.code.to_string()), + ]), + Self::InvalidResponse { + server_name, + method, + details, + } => BTreeMap::from([ + ("server".to_string(), server_name.clone()), + ("method".to_string(), (*method).to_string()), + ("details".to_string(), details.clone()), + ]), + Self::Timeout { + server_name, + method, + timeout_ms, + } => BTreeMap::from([ + ("server".to_string(), server_name.clone()), + ("method".to_string(), (*method).to_string()), + ("timeout_ms".to_string(), timeout_ms.to_string()), + ]), + Self::UnknownTool { qualified_name } => { + BTreeMap::from([("qualified_tool".to_string(), qualified_name.clone())]) + } + Self::UnknownServer { server_name } => { + BTreeMap::from([("server".to_string(), server_name.clone())]) + } + } + } +} + +fn lifecycle_phase_for_method(method: &str) -> McpLifecyclePhase { + match method { + "initialize" => McpLifecyclePhase::InitializeHandshake, + "tools/list" => McpLifecyclePhase::ToolDiscovery, + "resources/list" => McpLifecyclePhase::ResourceDiscovery, + "resources/read" | "tools/call" => McpLifecyclePhase::Invocation, + _ => McpLifecyclePhase::ErrorSurfacing, + } +} + +fn unsupported_server_failed_server(server: &UnsupportedMcpServer) -> McpFailedServer { + McpFailedServer { + server_name: server.server_name.clone(), + phase: McpLifecyclePhase::ServerRegistration, + error: McpErrorSurface::new( + McpLifecyclePhase::ServerRegistration, + Some(server.server_name.clone()), + server.reason.clone(), + BTreeMap::from([("transport".to_string(), format!("{:?}", server.transport))]), + false, + ), + } +} + #[derive(Debug, Clone, PartialEq, Eq)] struct ToolRoute { server_name: String, @@ -356,80 +524,103 @@ impl McpServerManager { &self.unsupported_servers } + #[must_use] + pub fn server_names(&self) -> Vec { + self.servers.keys().cloned().collect() + } + pub async fn discover_tools(&mut self) -> Result, McpServerManagerError> { let server_names = self.servers.keys().cloned().collect::>(); let mut discovered_tools = Vec::new(); for server_name in server_names { - self.ensure_server_ready(&server_name).await?; + let server_tools = self.discover_tools_for_server(&server_name).await?; self.clear_routes_for_server(&server_name); - let mut cursor = None; - loop { - let request_id = self.take_request_id(); - let response = { - let server = self.server_mut(&server_name)?; - let process = server.process.as_mut().ok_or_else(|| { - McpServerManagerError::InvalidResponse { - server_name: server_name.clone(), - method: "tools/list", - details: "server process missing after initialization".to_string(), - } - })?; - process - .list_tools( - request_id, - Some(McpListToolsParams { - cursor: cursor.clone(), - }), - ) - .await? - }; - - if let Some(error) = response.error { - return Err(McpServerManagerError::JsonRpc { - server_name: server_name.clone(), - method: "tools/list", - error, - }); - } - - let result = - response - .result - .ok_or_else(|| McpServerManagerError::InvalidResponse { - server_name: server_name.clone(), - method: "tools/list", - details: "missing result payload".to_string(), - })?; - - for tool in result.tools { - let qualified_name = mcp_tool_name(&server_name, &tool.name); - self.tool_index.insert( - qualified_name.clone(), - ToolRoute { - server_name: server_name.clone(), - raw_name: tool.name.clone(), - }, - ); - discovered_tools.push(ManagedMcpTool { - server_name: server_name.clone(), - qualified_name, - raw_name: tool.name.clone(), - tool, - }); - } - - match result.next_cursor { - Some(next_cursor) => cursor = Some(next_cursor), - None => break, - } + for tool in server_tools { + self.tool_index.insert( + tool.qualified_name.clone(), + ToolRoute { + server_name: tool.server_name.clone(), + raw_name: tool.raw_name.clone(), + }, + ); + discovered_tools.push(tool); } } Ok(discovered_tools) } + pub async fn discover_tools_best_effort(&mut self) -> McpToolDiscoveryReport { + let server_names = self.server_names(); + let mut discovered_tools = Vec::new(); + let mut working_servers = Vec::new(); + let mut failed_servers = Vec::new(); + + for server_name in server_names { + match self.discover_tools_for_server(&server_name).await { + Ok(server_tools) => { + working_servers.push(server_name.clone()); + self.clear_routes_for_server(&server_name); + for tool in server_tools { + self.tool_index.insert( + tool.qualified_name.clone(), + ToolRoute { + server_name: tool.server_name.clone(), + raw_name: tool.raw_name.clone(), + }, + ); + discovered_tools.push(tool); + } + } + Err(error) => { + self.clear_routes_for_server(&server_name); + failed_servers.push(error.discovery_failure(&server_name)); + } + } + } + + let degraded_failed_servers = failed_servers + .iter() + .map(|failure| McpFailedServer { + server_name: failure.server_name.clone(), + phase: failure.phase, + error: McpErrorSurface::new( + failure.phase, + Some(failure.server_name.clone()), + failure.error.clone(), + failure.context.clone(), + failure.recoverable, + ), + }) + .chain( + self.unsupported_servers + .iter() + .map(unsupported_server_failed_server), + ) + .collect::>(); + let degraded_startup = (!working_servers.is_empty() && !degraded_failed_servers.is_empty()) + .then(|| { + McpDegradedReport::new( + working_servers, + degraded_failed_servers, + discovered_tools + .iter() + .map(|tool| tool.qualified_name.clone()) + .collect(), + Vec::new(), + ) + }); + + McpToolDiscoveryReport { + tools: discovered_tools, + failed_servers, + unsupported_servers: self.unsupported_servers.clone(), + degraded_startup, + } + } + pub async fn call_tool( &mut self, qualified_tool_name: &str, @@ -443,6 +634,8 @@ impl McpServerManager { qualified_name: qualified_tool_name.to_string(), })?; + let timeout_ms = self.tool_call_timeout_ms(&route.server_name)?; + self.ensure_server_ready(&route.server_name).await?; let request_id = self.take_request_id(); let response = @@ -455,18 +648,76 @@ impl McpServerManager { details: "server process missing after initialization".to_string(), } })?; - process - .call_tool( + Self::run_process_request( + &route.server_name, + "tools/call", + timeout_ms, + process.call_tool( request_id, McpToolCallParams { name: route.raw_name, arguments, meta: None, }, - ) - .await? + ), + ) + .await }; - Ok(response) + + if let Err(error) = &response { + if Self::should_reset_server(error) { + self.reset_server(&route.server_name).await?; + } + } + + response + } + + pub async fn list_resources( + &mut self, + server_name: &str, + ) -> Result { + let mut attempts = 0; + + loop { + match self.list_resources_once(server_name).await { + Ok(resources) => return Ok(resources), + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } + } + } + } + + pub async fn read_resource( + &mut self, + server_name: &str, + uri: &str, + ) -> Result { + let mut attempts = 0; + + loop { + match self.read_resource_once(server_name, uri).await { + Ok(resource) => return Ok(resource), + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } + } + } } pub async fn shutdown(&mut self) -> Result<(), McpServerManagerError> { @@ -504,33 +755,331 @@ impl McpServerManager { JsonRpcId::Number(id) } + fn tool_call_timeout_ms(&self, server_name: &str) -> Result { + let server = + self.servers + .get(server_name) + .ok_or_else(|| McpServerManagerError::UnknownServer { + server_name: server_name.to_string(), + })?; + match &server.bootstrap.transport { + McpClientTransport::Stdio(transport) => Ok(transport.resolved_tool_call_timeout_ms()), + other => Err(McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "tools/call", + details: format!("unsupported MCP transport for stdio manager: {other:?}"), + }), + } + } + + fn server_process_exited(&mut self, server_name: &str) -> Result { + let server = self.server_mut(server_name)?; + match server.process.as_mut() { + Some(process) => Ok(process.has_exited()?), + None => Ok(false), + } + } + + async fn discover_tools_for_server( + &mut self, + server_name: &str, + ) -> Result, McpServerManagerError> { + let mut attempts = 0; + + loop { + match self.discover_tools_for_server_once(server_name).await { + Ok(tools) => return Ok(tools), + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } + } + } + } + + async fn discover_tools_for_server_once( + &mut self, + server_name: &str, + ) -> Result, McpServerManagerError> { + self.ensure_server_ready(server_name).await?; + + let mut discovered_tools = Vec::new(); + let mut cursor = None; + loop { + let request_id = self.take_request_id(); + let response = { + let server = self.server_mut(server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "tools/list", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + server_name, + "tools/list", + MCP_LIST_TOOLS_TIMEOUT_MS, + process.list_tools( + request_id, + Some(McpListToolsParams { + cursor: cursor.clone(), + }), + ), + ) + .await? + }; + + if let Some(error) = response.error { + return Err(McpServerManagerError::JsonRpc { + server_name: server_name.to_string(), + method: "tools/list", + error, + }); + } + + let result = response + .result + .ok_or_else(|| McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "tools/list", + details: "missing result payload".to_string(), + })?; + + for tool in result.tools { + let qualified_name = mcp_tool_name(server_name, &tool.name); + discovered_tools.push(ManagedMcpTool { + server_name: server_name.to_string(), + qualified_name, + raw_name: tool.name.clone(), + tool, + }); + } + + match result.next_cursor { + Some(next_cursor) => cursor = Some(next_cursor), + None => break, + } + } + + Ok(discovered_tools) + } + + async fn list_resources_once( + &mut self, + server_name: &str, + ) -> Result { + self.ensure_server_ready(server_name).await?; + + let mut resources = Vec::new(); + let mut cursor = None; + loop { + let request_id = self.take_request_id(); + let response = { + let server = self.server_mut(server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/list", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + server_name, + "resources/list", + MCP_LIST_TOOLS_TIMEOUT_MS, + process.list_resources( + request_id, + Some(McpListResourcesParams { + cursor: cursor.clone(), + }), + ), + ) + .await? + }; + + if let Some(error) = response.error { + return Err(McpServerManagerError::JsonRpc { + server_name: server_name.to_string(), + method: "resources/list", + error, + }); + } + + let result = response + .result + .ok_or_else(|| McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/list", + details: "missing result payload".to_string(), + })?; + + resources.extend(result.resources); + + match result.next_cursor { + Some(next_cursor) => cursor = Some(next_cursor), + None => break, + } + } + + Ok(McpListResourcesResult { + resources, + next_cursor: None, + }) + } + + async fn read_resource_once( + &mut self, + server_name: &str, + uri: &str, + ) -> Result { + self.ensure_server_ready(server_name).await?; + + let request_id = self.take_request_id(); + let response = + { + let server = self.server_mut(server_name)?; + let process = server.process.as_mut().ok_or_else(|| { + McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/read", + details: "server process missing after initialization".to_string(), + } + })?; + Self::run_process_request( + server_name, + "resources/read", + MCP_LIST_TOOLS_TIMEOUT_MS, + process.read_resource( + request_id, + McpReadResourceParams { + uri: uri.to_string(), + }, + ), + ) + .await? + }; + + if let Some(error) = response.error { + return Err(McpServerManagerError::JsonRpc { + server_name: server_name.to_string(), + method: "resources/read", + error, + }); + } + + response + .result + .ok_or_else(|| McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method: "resources/read", + details: "missing result payload".to_string(), + }) + } + + async fn reset_server(&mut self, server_name: &str) -> Result<(), McpServerManagerError> { + let mut process = { + let server = self.server_mut(server_name)?; + server.initialized = false; + server.process.take() + }; + + if let Some(process) = process.as_mut() { + let _ = process.shutdown().await; + } + + Ok(()) + } + + fn is_retryable_error(error: &McpServerManagerError) -> bool { + matches!( + error, + McpServerManagerError::Transport { .. } | McpServerManagerError::Timeout { .. } + ) + } + + fn should_reset_server(error: &McpServerManagerError) -> bool { + matches!( + error, + McpServerManagerError::Transport { .. } + | McpServerManagerError::Timeout { .. } + | McpServerManagerError::InvalidResponse { .. } + ) + } + + async fn run_process_request( + server_name: &str, + method: &'static str, + timeout_ms: u64, + future: F, + ) -> Result + where + F: Future>, + { + match timeout(Duration::from_millis(timeout_ms), future).await { + Ok(Ok(value)) => Ok(value), + Ok(Err(error)) if error.kind() == io::ErrorKind::InvalidData => { + Err(McpServerManagerError::InvalidResponse { + server_name: server_name.to_string(), + method, + details: error.to_string(), + }) + } + Ok(Err(source)) => Err(McpServerManagerError::Transport { + server_name: server_name.to_string(), + method, + source, + }), + Err(_) => Err(McpServerManagerError::Timeout { + server_name: server_name.to_string(), + method, + timeout_ms, + }), + } + } + async fn ensure_server_ready( &mut self, server_name: &str, ) -> Result<(), McpServerManagerError> { - let needs_spawn = self - .servers - .get(server_name) - .map(|server| server.process.is_none()) - .ok_or_else(|| McpServerManagerError::UnknownServer { - server_name: server_name.to_string(), - })?; - - if needs_spawn { - let server = self.server_mut(server_name)?; - server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?); - server.initialized = false; + if self.server_process_exited(server_name)? { + self.reset_server(server_name).await?; } - let needs_initialize = self - .servers - .get(server_name) - .map(|server| !server.initialized) - .ok_or_else(|| McpServerManagerError::UnknownServer { - server_name: server_name.to_string(), - })?; + let mut attempts = 0; + loop { + let needs_spawn = self + .servers + .get(server_name) + .map(|server| server.process.is_none()) + .ok_or_else(|| McpServerManagerError::UnknownServer { + server_name: server_name.to_string(), + })?; + + if needs_spawn { + let server = self.server_mut(server_name)?; + server.process = Some(spawn_mcp_stdio_process(&server.bootstrap)?); + server.initialized = false; + } + + let needs_initialize = self + .servers + .get(server_name) + .map(|server| !server.initialized) + .ok_or_else(|| McpServerManagerError::UnknownServer { + server_name: server_name.to_string(), + })?; + + if !needs_initialize { + return Ok(()); + } - if needs_initialize { let request_id = self.take_request_id(); let response = { let server = self.server_mut(server_name)?; @@ -541,9 +1090,28 @@ impl McpServerManager { details: "server process missing before initialize".to_string(), } })?; - process - .initialize(request_id, default_initialize_params()) - .await? + Self::run_process_request( + server_name, + "initialize", + MCP_INITIALIZE_TIMEOUT_MS, + process.initialize(request_id, default_initialize_params()), + ) + .await + }; + + let response = match response { + Ok(response) => response, + Err(error) if attempts == 0 && Self::is_retryable_error(&error) => { + self.reset_server(server_name).await?; + attempts += 1; + continue; + } + Err(error) => { + if Self::should_reset_server(&error) { + self.reset_server(server_name).await?; + } + return Err(error); + } }; if let Some(error) = response.error { @@ -555,18 +1123,19 @@ impl McpServerManager { } if response.result.is_none() { - return Err(McpServerManagerError::InvalidResponse { + let error = McpServerManagerError::InvalidResponse { server_name: server_name.to_string(), method: "initialize", details: "missing result payload".to_string(), - }); + }; + self.reset_server(server_name).await?; + return Err(error); } let server = self.server_mut(server_name)?; server.initialized = true; + return Ok(()); } - - Ok(()) } } @@ -657,12 +1226,15 @@ impl McpStdioProcess { if line == "\r\n" { break; } - if let Some(value) = line.strip_prefix("Content-Length:") { - let parsed = value - .trim() - .parse::() - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - content_length = Some(parsed); + let header = line.trim_end_matches(['\r', '\n']); + if let Some((name, value)) = header.split_once(':') { + if name.trim().eq_ignore_ascii_case("Content-Length") { + let parsed = value + .trim() + .parse::() + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + content_length = Some(parsed); + } } } @@ -703,9 +1275,32 @@ impl McpStdioProcess { method: impl Into, params: Option, ) -> io::Result> { - let request = JsonRpcRequest::new(id, method, params); + let method = method.into(); + let request = JsonRpcRequest::new(id.clone(), method.clone(), params); self.send_request(&request).await?; - self.read_response().await + let response = self.read_response().await?; + + if response.jsonrpc != "2.0" { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "MCP response for {method} used unsupported jsonrpc version `{}`", + response.jsonrpc + ), + )); + } + + if response.id != id { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "MCP response for {method} used mismatched id: expected {id:?}, got {:?}", + response.id + ), + )); + } + + Ok(response) } pub async fn initialize( @@ -756,9 +1351,17 @@ impl McpStdioProcess { self.child.wait().await } + pub fn has_exited(&mut self) -> io::Result { + Ok(self.child.try_wait()?.is_some()) + } + async fn shutdown(&mut self) -> io::Result<()> { if self.child.try_wait()?.is_none() { - self.child.kill().await?; + match self.child.kill().await { + Ok(()) => {} + Err(error) if error.kind() == io::ErrorKind::InvalidInput => {} + Err(error) => return Err(error), + } } let _ = self.child.wait().await?; Ok(()) @@ -802,15 +1405,14 @@ fn default_initialize_params() -> McpInitializeParams { } } -#[cfg(test)] +#[cfg(all(test, unix))] mod tests { use std::collections::BTreeMap; use std::fs; use std::io::ErrorKind; - #[cfg(unix)] use std::os::unix::fs::PermissionsExt; use std::path::{Path, PathBuf}; - use std::process::Command; + use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{SystemTime, UNIX_EPOCH}; use serde_json::json; @@ -824,18 +1426,21 @@ mod tests { use crate::mcp_client::McpClientBootstrap; use super::{ - spawn_mcp_stdio_process, JsonRpcId, JsonRpcRequest, JsonRpcResponse, - McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, McpInitializeServerInfo, - McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpServerManager, - McpServerManagerError, McpStdioProcess, McpTool, McpToolCallParams, + spawn_mcp_stdio_process, unsupported_server_failed_server, JsonRpcId, JsonRpcRequest, + JsonRpcResponse, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult, + McpInitializeServerInfo, McpListToolsResult, McpReadResourceParams, McpReadResourceResult, + McpServerManager, McpServerManagerError, McpStdioProcess, McpTool, McpToolCallParams, }; + use crate::McpLifecyclePhase; fn temp_dir() -> PathBuf { + static NEXT_TEMP_DIR_ID: AtomicU64 = AtomicU64::new(0); let nanos = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("time should be after epoch") .as_nanos(); - std::env::temp_dir().join(format!("runtime-mcp-stdio-{nanos}")) + let unique_id = NEXT_TEMP_DIR_ID.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!("runtime-mcp-stdio-{nanos}-{unique_id}")) } fn write_echo_script() -> PathBuf { @@ -847,12 +1452,9 @@ mod tests { "#!/bin/sh\nprintf 'READY:%s\\n' \"$MCP_TEST_TOKEN\"\nIFS= read -r line\nprintf 'ECHO:%s\\n' \"$line\"\n", ) .expect("write script"); - #[cfg(unix)] - { - let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); - permissions.set_mode(0o755); - fs::set_permissions(&script_path, permissions).expect("chmod"); - } + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); script_path } @@ -862,7 +1464,9 @@ mod tests { let script_path = root.join("jsonrpc-mcp.py"); let script = [ "#!/usr/bin/env python3", - "import json, sys", + "import json, os, sys", + "LOWERCASE_CONTENT_LENGTH = os.environ.get('MCP_LOWERCASE_CONTENT_LENGTH') == '1'", + "MISMATCHED_RESPONSE_ID = os.environ.get('MCP_MISMATCHED_RESPONSE_ID') == '1'", "header = b''", r"while not header.endswith(b'\r\n\r\n'):", " chunk = sys.stdin.buffer.read(1)", @@ -877,27 +1481,26 @@ mod tests { "request = json.loads(payload.decode())", r"assert request['jsonrpc'] == '2.0'", r"assert request['method'] == 'initialize'", + "response_id = 'wrong-id' if MISMATCHED_RESPONSE_ID else request['id']", + "header_name = 'content-length' if LOWERCASE_CONTENT_LENGTH else 'Content-Length'", r"response = json.dumps({", r" 'jsonrpc': '2.0',", - r" 'id': request['id'],", + r" 'id': response_id,", r" 'result': {", r" 'protocolVersion': request['params']['protocolVersion'],", r" 'capabilities': {'tools': {}},", r" 'serverInfo': {'name': 'fake-mcp', 'version': '0.1.0'}", r" }", r"}).encode()", - r"sys.stdout.buffer.write(f'Content-Length: {len(response)}\r\n\r\n'.encode() + response)", + r"sys.stdout.buffer.write(f'{header_name}: {len(response)}\r\n\r\n'.encode() + response)", "sys.stdout.buffer.flush()", "", ] .join("\n"); fs::write(&script_path, script).expect("write script"); - #[cfg(unix)] - { - let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); - permissions.set_mode(0o755); - fs::set_permissions(&script_path, permissions).expect("chmod"); - } + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); script_path } @@ -908,7 +1511,9 @@ mod tests { let script_path = root.join("fake-mcp-server.py"); let script = [ "#!/usr/bin/env python3", - "import json, sys", + "import json, os, sys, time", + "TOOL_CALL_DELAY_MS = int(os.environ.get('MCP_TOOL_CALL_DELAY_MS', '0'))", + "INVALID_TOOL_CALL_RESPONSE = os.environ.get('MCP_INVALID_TOOL_CALL_RESPONSE') == '1'", "", "def read_message():", " header = b''", @@ -963,6 +1568,12 @@ mod tests { " }", " })", " elif method == 'tools/call':", + " if INVALID_TOOL_CALL_RESPONSE:", + " sys.stdout.buffer.write(b'Content-Length: 5\\r\\n\\r\\nnope!')", + " sys.stdout.buffer.flush()", + " continue", + " if TOOL_CALL_DELAY_MS:", + " time.sleep(TOOL_CALL_DELAY_MS / 1000)", " args = request['params'].get('arguments') or {}", " if request['params']['name'] == 'fail':", " send_message({", @@ -1021,12 +1632,9 @@ mod tests { ] .join("\n"); fs::write(&script_path, script).expect("write script"); - #[cfg(unix)] - { - let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); - permissions.set_mode(0o755); - fs::set_permissions(&script_path, permissions).expect("chmod"); - } + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); script_path } @@ -1037,10 +1645,13 @@ mod tests { let script_path = root.join("manager-mcp-server.py"); let script = [ "#!/usr/bin/env python3", - "import json, os, sys", + "import json, os, sys, time", "", "LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')", "LOG_PATH = os.environ.get('MCP_LOG_PATH')", + "EXIT_AFTER_TOOLS_LIST = os.environ.get('MCP_EXIT_AFTER_TOOLS_LIST') == '1'", + "FAIL_ONCE_MODE = os.environ.get('MCP_FAIL_ONCE_MODE')", + "FAIL_ONCE_MARKER = os.environ.get('MCP_FAIL_ONCE_MARKER')", "initialize_count = 0", "", "def log(method):", @@ -1048,6 +1659,15 @@ mod tests { " with open(LOG_PATH, 'a', encoding='utf-8') as handle:", " handle.write(f'{method}\\n')", "", + "def should_fail_once():", + " if not FAIL_ONCE_MODE or not FAIL_ONCE_MARKER:", + " return False", + " if os.path.exists(FAIL_ONCE_MARKER):", + " return False", + " with open(FAIL_ONCE_MARKER, 'w', encoding='utf-8') as handle:", + " handle.write(FAIL_ONCE_MODE)", + " return True", + "", "def read_message():", " header = b''", r" while not header.endswith(b'\r\n\r\n'):", @@ -1074,6 +1694,10 @@ mod tests { " method = request['method']", " log(method)", " if method == 'initialize':", + " if FAIL_ONCE_MODE == 'initialize_hang' and should_fail_once():", + " log('initialize-hang')", + " while True:", + " time.sleep(1)", " initialize_count += 1", " send_message({", " 'jsonrpc': '2.0',", @@ -1102,7 +1726,12 @@ mod tests { " ]", " }", " })", + " if EXIT_AFTER_TOOLS_LIST:", + " raise SystemExit(0)", " elif method == 'tools/call':", + " if FAIL_ONCE_MODE == 'tool_call_disconnect' and should_fail_once():", + " log('tools/call-disconnect')", + " raise SystemExit(0)", " args = request['params'].get('arguments') or {}", " text = args.get('text', '')", " send_message({", @@ -1128,12 +1757,9 @@ mod tests { ] .join("\n"); fs::write(&script_path, script).expect("write script"); - #[cfg(unix)] - { - let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); - permissions.set_mode(0o755); - fs::set_permissions(&script_path, permissions).expect("chmod"); - } + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); script_path } @@ -1144,43 +1770,42 @@ mod tests { command: "/bin/sh".to_string(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "secret-value".to_string())]), + tool_call_timeout_ms: None, }), }; McpClientBootstrap::from_scoped_config("stdio server", &config) } fn script_transport(script_path: &Path) -> crate::mcp_client::McpStdioTransport { - crate::mcp_client::McpStdioTransport { - command: python_command(), - args: vec![script_path.to_string_lossy().into_owned()], - env: BTreeMap::new(), - } + script_transport_with_env(script_path, BTreeMap::new()) } - fn python_command() -> String { - for key in ["MCP_TEST_PYTHON", "PYTHON3", "PYTHON"] { - if let Ok(value) = std::env::var(key) { - if !value.trim().is_empty() { - return value; - } - } + fn script_transport_with_env( + script_path: &Path, + env: BTreeMap, + ) -> crate::mcp_client::McpStdioTransport { + crate::mcp_client::McpStdioTransport { + command: "python3".to_string(), + args: vec![script_path.to_string_lossy().into_owned()], + env, + tool_call_timeout_ms: None, } - - for candidate in ["python3", "python"] { - if Command::new(candidate).arg("--version").output().is_ok() { - return candidate.to_string(); - } - } - - panic!("expected a Python interpreter for MCP stdio tests") } fn cleanup_script(script_path: &Path) { if let Err(error) = fs::remove_file(script_path) { - assert_eq!(error.kind(), std::io::ErrorKind::NotFound, "cleanup script"); + assert_eq!( + error.kind(), + std::io::ErrorKind::NotFound, + "cleanup script: {error}" + ); } if let Err(error) = fs::remove_dir_all(script_path.parent().expect("script parent")) { - assert_eq!(error.kind(), std::io::ErrorKind::NotFound, "cleanup dir"); + assert_eq!( + error.kind(), + std::io::ErrorKind::NotFound, + "cleanup dir: {error}" + ); } } @@ -1189,18 +1814,30 @@ mod tests { label: &str, log_path: &Path, ) -> ScopedMcpServerConfig { + manager_server_config_with_env(script_path, label, log_path, BTreeMap::new()) + } + + fn manager_server_config_with_env( + script_path: &Path, + label: &str, + log_path: &Path, + extra_env: BTreeMap, + ) -> ScopedMcpServerConfig { + let mut env = BTreeMap::from([ + ("MCP_SERVER_LABEL".to_string(), label.to_string()), + ( + "MCP_LOG_PATH".to_string(), + log_path.to_string_lossy().into_owned(), + ), + ]); + env.extend(extra_env); ScopedMcpServerConfig { scope: ConfigSource::Local, config: McpServerConfig::Stdio(McpStdioServerConfig { - command: python_command(), + command: "python3".to_string(), args: vec![script_path.to_string_lossy().into_owned()], - env: BTreeMap::from([ - ("MCP_SERVER_LABEL".to_string(), label.to_string()), - ( - "MCP_LOG_PATH".to_string(), - log_path.to_string_lossy().into_owned(), - ), - ]), + env, + tool_call_timeout_ms: None, }), } } @@ -1328,6 +1965,85 @@ mod tests { }); } + #[test] + fn given_lowercase_content_length_when_initialize_then_response_parses() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_jsonrpc_script(); + let transport = script_transport_with_env( + &script_path, + BTreeMap::from([("MCP_LOWERCASE_CONTENT_LENGTH".to_string(), "1".to_string())]), + ); + let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); + + let response = process + .initialize( + JsonRpcId::Number(8), + McpInitializeParams { + protocol_version: "2025-03-26".to_string(), + capabilities: json!({"roots": {}}), + client_info: McpInitializeClientInfo { + name: "runtime-tests".to_string(), + version: "0.1.0".to_string(), + }, + }, + ) + .await + .expect("initialize roundtrip"); + + assert_eq!(response.id, JsonRpcId::Number(8)); + assert_eq!(response.error, None); + assert!(response.result.is_some()); + + let status = process.wait().await.expect("wait for exit"); + assert!(status.success()); + + cleanup_script(&script_path); + }); + } + + #[test] + fn given_mismatched_response_id_when_initialize_then_invalid_data_is_returned() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_jsonrpc_script(); + let transport = script_transport_with_env( + &script_path, + BTreeMap::from([("MCP_MISMATCHED_RESPONSE_ID".to_string(), "1".to_string())]), + ); + let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); + + let error = process + .initialize( + JsonRpcId::Number(9), + McpInitializeParams { + protocol_version: "2025-03-26".to_string(), + capabilities: json!({"roots": {}}), + client_info: McpInitializeClientInfo { + name: "runtime-tests".to_string(), + version: "0.1.0".to_string(), + }, + }, + ) + .await + .expect_err("mismatched response id should fail"); + + assert_eq!(error.kind(), ErrorKind::InvalidData); + assert!(error.to_string().contains("mismatched id")); + + let status = process.wait().await.expect("wait for exit"); + assert!(status.success()); + + cleanup_script(&script_path); + }); + } + #[test] fn direct_spawn_uses_transport_env() { let runtime = Builder::new_current_thread() @@ -1340,6 +2056,7 @@ mod tests { command: "/bin/sh".to_string(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::from([("MCP_TEST_TOKEN".to_string(), "direct-secret".to_string())]), + tool_call_timeout_ms: None, }; let mut process = McpStdioProcess::spawn(&transport).expect("spawn transport directly"); let ready = process.read_available().await.expect("read ready"); @@ -1580,6 +2297,480 @@ mod tests { }); } + #[test] + fn manager_times_out_slow_tool_calls() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("timeout.log"); + let servers = BTreeMap::from([( + "slow".to_string(), + ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: "python3".to_string(), + args: vec![script_path.to_string_lossy().into_owned()], + env: BTreeMap::from([( + "MCP_TOOL_CALL_DELAY_MS".to_string(), + "200".to_string(), + )]), + tool_call_timeout_ms: Some(25), + }), + }, + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let error = manager + .call_tool( + &mcp_tool_name("slow", "echo"), + Some(json!({"text": "slow"})), + ) + .await + .expect_err("slow tool call should time out"); + + match error { + McpServerManagerError::Timeout { + server_name, + method, + timeout_ms, + } => { + assert_eq!(server_name, "slow"); + assert_eq!(method, "tools/call"); + assert_eq!(timeout_ms, 25); + } + other => panic!("expected timeout error, got {other:?}"), + } + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + let _ = fs::remove_file(log_path); + }); + } + + #[test] + fn manager_surfaces_parse_errors_from_tool_calls() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_mcp_server_script(); + let servers = BTreeMap::from([( + "broken".to_string(), + ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: "python3".to_string(), + args: vec![script_path.to_string_lossy().into_owned()], + env: BTreeMap::from([( + "MCP_INVALID_TOOL_CALL_RESPONSE".to_string(), + "1".to_string(), + )]), + tool_call_timeout_ms: Some(1_000), + }), + }, + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let error = manager + .call_tool( + &mcp_tool_name("broken", "echo"), + Some(json!({"text": "invalid-json"})), + ) + .await + .expect_err("invalid json should fail"); + + match error { + McpServerManagerError::InvalidResponse { + server_name, + method, + details, + } => { + assert_eq!(server_name, "broken"); + assert_eq!(method, "tools/call"); + assert!( + details.contains("expected ident") || details.contains("expected value") + ); + } + other => panic!("expected invalid response error, got {other:?}"), + } + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn given_child_exits_after_discovery_when_calling_twice_then_second_call_succeeds_after_reset() + { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("dropping.log"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config_with_env( + &script_path, + "alpha", + &log_path, + BTreeMap::from([("MCP_EXIT_AFTER_TOOLS_LIST".to_string(), "1".to_string())]), + ), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let first_error = manager + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "reconnect"})), + ) + .await + .expect_err("first call should fail after transport drops"); + + match first_error { + McpServerManagerError::Transport { + server_name, + method, + source, + } => { + assert_eq!(server_name, "alpha"); + assert_eq!(method, "tools/call"); + assert_eq!(source.kind(), ErrorKind::UnexpectedEof); + } + other => panic!("expected transport error, got {other:?}"), + } + + let response = manager + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "reconnect"})), + ) + .await + .expect("second tool call should succeed after reset"); + + assert_eq!( + response + .result + .as_ref() + .and_then(|result| result.structured_content.as_ref()) + .and_then(|value| value.get("server")), + Some(&json!("alpha")) + ); + let log = fs::read_to_string(&log_path).expect("read log"); + assert_eq!( + log.lines().collect::>(), + vec!["initialize", "tools/list", "initialize", "tools/call"] + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn given_initialize_hangs_once_when_discover_tools_then_manager_retries_and_succeeds() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("initialize-hang.log"); + let marker_path = root.join("initialize-hang.marker"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config_with_env( + &script_path, + "alpha", + &log_path, + BTreeMap::from([ + ( + "MCP_FAIL_ONCE_MODE".to_string(), + "initialize_hang".to_string(), + ), + ( + "MCP_FAIL_ONCE_MARKER".to_string(), + marker_path.to_string_lossy().into_owned(), + ), + ]), + ), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + let tools = manager + .discover_tools() + .await + .expect("discover tools after retry"); + + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].qualified_name, mcp_tool_name("alpha", "echo")); + let log = fs::read_to_string(&log_path).expect("read log"); + assert_eq!( + log.lines().collect::>(), + vec!["initialize", "initialize-hang", "initialize", "tools/list"] + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn given_tool_call_disconnects_once_when_calling_twice_then_manager_resets_and_next_call_succeeds( + ) { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("tool-call-disconnect.log"); + let marker_path = root.join("tool-call-disconnect.marker"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config_with_env( + &script_path, + "alpha", + &log_path, + BTreeMap::from([ + ( + "MCP_FAIL_ONCE_MODE".to_string(), + "tool_call_disconnect".to_string(), + ), + ( + "MCP_FAIL_ONCE_MARKER".to_string(), + marker_path.to_string_lossy().into_owned(), + ), + ]), + ), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + manager.discover_tools().await.expect("discover tools"); + let first_error = manager + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "first"})), + ) + .await + .expect_err("first tool call should fail when transport drops"); + + match first_error { + McpServerManagerError::Transport { + server_name, + method, + source, + } => { + assert_eq!(server_name, "alpha"); + assert_eq!(method, "tools/call"); + assert_eq!(source.kind(), ErrorKind::UnexpectedEof); + } + other => panic!("expected transport error, got {other:?}"), + } + + let response = manager + .call_tool( + &mcp_tool_name("alpha", "echo"), + Some(json!({"text": "second"})), + ) + .await + .expect("second tool call should succeed after reset"); + + assert_eq!( + response + .result + .as_ref() + .and_then(|result| result.structured_content.as_ref()) + .and_then(|value| value.get("echoed")), + Some(&json!("second")) + ); + let log = fs::read_to_string(&log_path).expect("read log"); + assert_eq!( + log.lines().collect::>(), + vec![ + "initialize", + "tools/list", + "tools/call", + "tools/call-disconnect", + "initialize", + "tools/call", + ] + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + #[test] + fn manager_lists_and_reads_resources_from_stdio_servers() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("resources.log"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config(&script_path, "alpha", &log_path), + )]); + let mut manager = McpServerManager::from_servers(&servers); + + let listed = manager + .list_resources("alpha") + .await + .expect("list resources"); + assert_eq!(listed.resources.len(), 1); + assert_eq!(listed.resources[0].uri, "file://guide.txt"); + + let read = manager + .read_resource("alpha", "file://guide.txt") + .await + .expect("read resource"); + assert_eq!(read.contents.len(), 1); + assert_eq!( + read.contents[0].text.as_deref(), + Some("contents for file://guide.txt") + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + }); + } + + fn write_initialize_disconnect_script() -> PathBuf { + let root = temp_dir(); + fs::create_dir_all(&root).expect("temp dir"); + let script_path = root.join("initialize-disconnect.py"); + let script = [ + "#!/usr/bin/env python3", + "import sys", + "header = b''", + r"while not header.endswith(b'\r\n\r\n'):", + " chunk = sys.stdin.buffer.read(1)", + " if not chunk:", + " raise SystemExit(1)", + " header += chunk", + "length = 0", + r"for line in header.decode().split('\r\n'):", + r" if line.lower().startswith('content-length:'):", + r" length = int(line.split(':', 1)[1].strip())", + "if length:", + " sys.stdin.buffer.read(length)", + "raise SystemExit(0)", + "", + ] + .join("\n"); + fs::write(&script_path, script).expect("write script"); + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); + script_path + } + + #[test] + fn manager_discovery_report_keeps_healthy_servers_when_one_server_fails() { + let runtime = Builder::new_current_thread() + .enable_all() + .build() + .expect("runtime"); + runtime.block_on(async { + let script_path = write_manager_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let alpha_log = root.join("alpha.log"); + let broken_script_path = write_initialize_disconnect_script(); + let servers = BTreeMap::from([ + ( + "alpha".to_string(), + manager_server_config(&script_path, "alpha", &alpha_log), + ), + ( + "broken".to_string(), + ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: broken_script_path.display().to_string(), + args: Vec::new(), + env: BTreeMap::new(), + tool_call_timeout_ms: None, + }), + }, + ), + ]); + let mut manager = McpServerManager::from_servers(&servers); + + let report = manager.discover_tools_best_effort().await; + + assert_eq!(report.tools.len(), 1); + assert_eq!( + report.tools[0].qualified_name, + mcp_tool_name("alpha", "echo") + ); + assert_eq!(report.failed_servers.len(), 1); + assert_eq!(report.failed_servers[0].server_name, "broken"); + assert_eq!( + report.failed_servers[0].phase, + McpLifecyclePhase::InitializeHandshake + ); + assert!(!report.failed_servers[0].recoverable); + assert_eq!( + report.failed_servers[0] + .context + .get("method") + .map(String::as_str), + Some("initialize") + ); + assert!(report.failed_servers[0].error.contains("initialize")); + let degraded = report + .degraded_startup + .as_ref() + .expect("partial startup should surface degraded report"); + assert_eq!(degraded.working_servers, vec!["alpha".to_string()]); + assert_eq!(degraded.failed_servers.len(), 1); + assert_eq!(degraded.failed_servers[0].server_name, "broken"); + assert_eq!( + degraded.failed_servers[0].phase, + McpLifecyclePhase::InitializeHandshake + ); + assert_eq!( + degraded.available_tools, + vec![mcp_tool_name("alpha", "echo")] + ); + assert!(degraded.missing_tools.is_empty()); + + let response = manager + .call_tool(&mcp_tool_name("alpha", "echo"), Some(json!({"text": "ok"}))) + .await + .expect("healthy server should remain callable"); + assert_eq!( + response + .result + .as_ref() + .and_then(|result| result.structured_content.as_ref()) + .and_then(|value| value.get("echoed")), + Some(&json!("ok")) + ); + + manager.shutdown().await.expect("shutdown"); + cleanup_script(&script_path); + cleanup_script(&broken_script_path); + }); + } + #[test] fn manager_records_unsupported_non_stdio_servers_without_panicking() { let servers = BTreeMap::from([ @@ -1624,6 +2815,10 @@ mod tests { assert_eq!(unsupported[0].server_name, "http"); assert_eq!(unsupported[1].server_name, "sdk"); assert_eq!(unsupported[2].server_name, "ws"); + assert_eq!( + unsupported_server_failed_server(&unsupported[0]).phase, + McpLifecyclePhase::ServerRegistration + ); } #[test] diff --git a/crates/runtime/src/mcp_tool_bridge.rs b/crates/runtime/src/mcp_tool_bridge.rs new file mode 100644 index 0000000..a159e34 --- /dev/null +++ b/crates/runtime/src/mcp_tool_bridge.rs @@ -0,0 +1,920 @@ +#![allow( + clippy::await_holding_lock, + clippy::doc_markdown, + clippy::match_same_arms, + clippy::must_use_candidate, + clippy::uninlined_format_args, + clippy::unnested_or_patterns +)] +//! Bridge between MCP tool surface (ListMcpResources, ReadMcpResource, McpAuth, MCP) +//! and the existing McpServerManager runtime. +//! +//! Provides a stateful client registry that tool handlers can use to +//! connect to MCP servers and invoke their capabilities. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex, OnceLock}; + +use crate::mcp::mcp_tool_name; +use crate::mcp_stdio::McpServerManager; +use serde::{Deserialize, Serialize}; + +/// Status of a managed MCP server connection. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum McpConnectionStatus { + Disconnected, + Connecting, + Connected, + AuthRequired, + Error, +} + +impl std::fmt::Display for McpConnectionStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Disconnected => write!(f, "disconnected"), + Self::Connecting => write!(f, "connecting"), + Self::Connected => write!(f, "connected"), + Self::AuthRequired => write!(f, "auth_required"), + Self::Error => write!(f, "error"), + } + } +} + +/// Metadata about an MCP resource. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpResourceInfo { + pub uri: String, + pub name: String, + pub description: Option, + pub mime_type: Option, +} + +/// Metadata about an MCP tool exposed by a server. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpToolInfo { + pub name: String, + pub description: Option, + pub input_schema: Option, +} + +/// Tracked state of an MCP server connection. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpServerState { + pub server_name: String, + pub status: McpConnectionStatus, + pub tools: Vec, + pub resources: Vec, + pub server_info: Option, + pub error_message: Option, +} + +#[derive(Debug, Clone, Default)] +pub struct McpToolRegistry { + inner: Arc>>, + manager: Arc>>>, +} + +impl McpToolRegistry { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn set_manager( + &self, + manager: Arc>, + ) -> Result<(), Arc>> { + self.manager.set(manager) + } + + pub fn register_server( + &self, + server_name: &str, + status: McpConnectionStatus, + tools: Vec, + resources: Vec, + server_info: Option, + ) { + let mut inner = self.inner.lock().expect("mcp registry lock poisoned"); + inner.insert( + server_name.to_owned(), + McpServerState { + server_name: server_name.to_owned(), + status, + tools, + resources, + server_info, + error_message: None, + }, + ); + } + + pub fn get_server(&self, server_name: &str) -> Option { + let inner = self.inner.lock().expect("mcp registry lock poisoned"); + inner.get(server_name).cloned() + } + + pub fn list_servers(&self) -> Vec { + let inner = self.inner.lock().expect("mcp registry lock poisoned"); + inner.values().cloned().collect() + } + + pub fn list_resources(&self, server_name: &str) -> Result, String> { + let inner = self.inner.lock().expect("mcp registry lock poisoned"); + match inner.get(server_name) { + Some(state) => { + if state.status != McpConnectionStatus::Connected { + return Err(format!( + "server '{}' is not connected (status: {})", + server_name, state.status + )); + } + Ok(state.resources.clone()) + } + None => Err(format!("server '{}' not found", server_name)), + } + } + + pub fn read_resource(&self, server_name: &str, uri: &str) -> Result { + let inner = self.inner.lock().expect("mcp registry lock poisoned"); + let state = inner + .get(server_name) + .ok_or_else(|| format!("server '{}' not found", server_name))?; + + if state.status != McpConnectionStatus::Connected { + return Err(format!( + "server '{}' is not connected (status: {})", + server_name, state.status + )); + } + + state + .resources + .iter() + .find(|r| r.uri == uri) + .cloned() + .ok_or_else(|| format!("resource '{}' not found on server '{}'", uri, server_name)) + } + + pub fn list_tools(&self, server_name: &str) -> Result, String> { + let inner = self.inner.lock().expect("mcp registry lock poisoned"); + match inner.get(server_name) { + Some(state) => { + if state.status != McpConnectionStatus::Connected { + return Err(format!( + "server '{}' is not connected (status: {})", + server_name, state.status + )); + } + Ok(state.tools.clone()) + } + None => Err(format!("server '{}' not found", server_name)), + } + } + + fn spawn_tool_call( + manager: Arc>, + qualified_tool_name: String, + arguments: Option, + ) -> Result { + let join_handle = std::thread::Builder::new() + .name(format!("mcp-tool-call-{qualified_tool_name}")) + .spawn(move || { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|error| format!("failed to create MCP tool runtime: {error}"))?; + + runtime.block_on(async move { + let response = { + let mut manager = manager + .lock() + .map_err(|_| "mcp server manager lock poisoned".to_string())?; + manager + .discover_tools() + .await + .map_err(|error| error.to_string())?; + let response = manager + .call_tool(&qualified_tool_name, arguments) + .await + .map_err(|error| error.to_string()); + let shutdown = manager.shutdown().await.map_err(|error| error.to_string()); + + match (response, shutdown) { + (Ok(response), Ok(())) => Ok(response), + (Err(error), Ok(())) | (Err(error), Err(_)) => Err(error), + (Ok(_), Err(error)) => Err(error), + } + }?; + + if let Some(error) = response.error { + return Err(format!( + "MCP server returned JSON-RPC error for tools/call: {} ({})", + error.message, error.code + )); + } + + let result = response.result.ok_or_else(|| { + "MCP server returned no result for tools/call".to_string() + })?; + + serde_json::to_value(result) + .map_err(|error| format!("failed to serialize MCP tool result: {error}")) + }) + }) + .map_err(|error| format!("failed to spawn MCP tool call thread: {error}"))?; + + join_handle.join().map_err(|panic_payload| { + if let Some(message) = panic_payload.downcast_ref::<&str>() { + format!("MCP tool call thread panicked: {message}") + } else if let Some(message) = panic_payload.downcast_ref::() { + format!("MCP tool call thread panicked: {message}") + } else { + "MCP tool call thread panicked".to_string() + } + })? + } + + pub fn call_tool( + &self, + server_name: &str, + tool_name: &str, + arguments: &serde_json::Value, + ) -> Result { + let inner = self.inner.lock().expect("mcp registry lock poisoned"); + let state = inner + .get(server_name) + .ok_or_else(|| format!("server '{}' not found", server_name))?; + + if state.status != McpConnectionStatus::Connected { + return Err(format!( + "server '{}' is not connected (status: {})", + server_name, state.status + )); + } + + if !state.tools.iter().any(|t| t.name == tool_name) { + return Err(format!( + "tool '{}' not found on server '{}'", + tool_name, server_name + )); + } + + drop(inner); + + let manager = self + .manager + .get() + .cloned() + .ok_or_else(|| "MCP server manager is not configured".to_string())?; + + Self::spawn_tool_call( + manager, + mcp_tool_name(server_name, tool_name), + (!arguments.is_null()).then(|| arguments.clone()), + ) + } + + /// Set auth status for a server. + pub fn set_auth_status( + &self, + server_name: &str, + status: McpConnectionStatus, + ) -> Result<(), String> { + let mut inner = self.inner.lock().expect("mcp registry lock poisoned"); + let state = inner + .get_mut(server_name) + .ok_or_else(|| format!("server '{}' not found", server_name))?; + state.status = status; + Ok(()) + } + + /// Disconnect / remove a server. + pub fn disconnect(&self, server_name: &str) -> Option { + let mut inner = self.inner.lock().expect("mcp registry lock poisoned"); + inner.remove(server_name) + } + + /// Number of registered servers. + #[must_use] + pub fn len(&self) -> usize { + let inner = self.inner.lock().expect("mcp registry lock poisoned"); + inner.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[cfg(all(test, unix))] +mod tests { + use std::collections::BTreeMap; + use std::fs; + use std::os::unix::fs::PermissionsExt; + use std::path::{Path, PathBuf}; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::time::{SystemTime, UNIX_EPOCH}; + + use super::*; + use crate::config::{ + ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig, + }; + + fn temp_dir() -> PathBuf { + static NEXT_TEMP_DIR_ID: AtomicU64 = AtomicU64::new(0); + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + let unique_id = NEXT_TEMP_DIR_ID.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!("runtime-mcp-tool-bridge-{nanos}-{unique_id}")) + } + + fn cleanup_script(script_path: &Path) { + if let Some(root) = script_path.parent() { + let _ = fs::remove_dir_all(root); + } + } + + fn write_bridge_mcp_server_script() -> PathBuf { + let root = temp_dir(); + fs::create_dir_all(&root).expect("temp dir"); + let script_path = root.join("bridge-mcp-server.py"); + let script = [ + "#!/usr/bin/env python3", + "import json, os, sys", + "LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')", + "LOG_PATH = os.environ.get('MCP_LOG_PATH')", + "", + "def log(method):", + " if LOG_PATH:", + " with open(LOG_PATH, 'a', encoding='utf-8') as handle:", + " handle.write(f'{method}\\n')", + "", + "def read_message():", + " header = b''", + r" while not header.endswith(b'\r\n\r\n'):", + " chunk = sys.stdin.buffer.read(1)", + " if not chunk:", + " return None", + " header += chunk", + " length = 0", + r" for line in header.decode().split('\r\n'):", + r" if line.lower().startswith('content-length:'):", + r" length = int(line.split(':', 1)[1].strip())", + " payload = sys.stdin.buffer.read(length)", + " return json.loads(payload.decode())", + "", + "def send_message(message):", + " payload = json.dumps(message).encode()", + r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)", + " sys.stdout.buffer.flush()", + "", + "while True:", + " request = read_message()", + " if request is None:", + " break", + " method = request['method']", + " log(method)", + " if method == 'initialize':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'protocolVersion': request['params']['protocolVersion'],", + " 'capabilities': {'tools': {}},", + " 'serverInfo': {'name': LABEL, 'version': '1.0.0'}", + " }", + " })", + " elif method == 'tools/list':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'tools': [", + " {", + " 'name': 'echo',", + " 'description': f'Echo tool for {LABEL}',", + " 'inputSchema': {", + " 'type': 'object',", + " 'properties': {'text': {'type': 'string'}},", + " 'required': ['text']", + " }", + " }", + " ]", + " }", + " })", + " elif method == 'tools/call':", + " args = request['params'].get('arguments') or {}", + " text = args.get('text', '')", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],", + " 'structuredContent': {'server': LABEL, 'echoed': text},", + " 'isError': False", + " }", + " })", + " else:", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'error': {'code': -32601, 'message': f'unknown method: {method}'},", + " })", + "", + ] + .join("\n"); + fs::write(&script_path, script).expect("write script"); + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); + script_path + } + + fn manager_server_config( + script_path: &Path, + server_name: &str, + log_path: &Path, + ) -> ScopedMcpServerConfig { + ScopedMcpServerConfig { + scope: ConfigSource::Local, + config: McpServerConfig::Stdio(McpStdioServerConfig { + command: "python3".to_string(), + args: vec![script_path.to_string_lossy().into_owned()], + env: BTreeMap::from([ + ("MCP_SERVER_LABEL".to_string(), server_name.to_string()), + ( + "MCP_LOG_PATH".to_string(), + log_path.to_string_lossy().into_owned(), + ), + ]), + tool_call_timeout_ms: Some(1_000), + }), + } + } + + #[test] + fn registers_and_retrieves_server() { + let registry = McpToolRegistry::new(); + registry.register_server( + "test-server", + McpConnectionStatus::Connected, + vec![McpToolInfo { + name: "greet".into(), + description: Some("Greet someone".into()), + input_schema: None, + }], + vec![McpResourceInfo { + uri: "res://data".into(), + name: "Data".into(), + description: None, + mime_type: Some("application/json".into()), + }], + Some("TestServer v1.0".into()), + ); + + let server = registry.get_server("test-server").expect("should exist"); + assert_eq!(server.status, McpConnectionStatus::Connected); + assert_eq!(server.tools.len(), 1); + assert_eq!(server.resources.len(), 1); + } + + #[test] + fn lists_resources_from_connected_server() { + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::Connected, + vec![], + vec![McpResourceInfo { + uri: "res://alpha".into(), + name: "Alpha".into(), + description: None, + mime_type: None, + }], + None, + ); + + let resources = registry.list_resources("srv").expect("should succeed"); + assert_eq!(resources.len(), 1); + assert_eq!(resources[0].uri, "res://alpha"); + } + + #[test] + fn rejects_resource_listing_for_disconnected_server() { + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::Disconnected, + vec![], + vec![], + None, + ); + assert!(registry.list_resources("srv").is_err()); + } + + #[test] + fn reads_specific_resource() { + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::Connected, + vec![], + vec![McpResourceInfo { + uri: "res://data".into(), + name: "Data".into(), + description: Some("Test data".into()), + mime_type: Some("text/plain".into()), + }], + None, + ); + + let resource = registry + .read_resource("srv", "res://data") + .expect("should find"); + assert_eq!(resource.name, "Data"); + + assert!(registry.read_resource("srv", "res://missing").is_err()); + } + + #[test] + fn given_connected_server_without_manager_when_calling_tool_then_it_errors() { + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::Connected, + vec![McpToolInfo { + name: "greet".into(), + description: None, + input_schema: None, + }], + vec![], + None, + ); + + let error = registry + .call_tool("srv", "greet", &serde_json::json!({"name": "world"})) + .expect_err("should require a configured manager"); + assert!(error.contains("MCP server manager is not configured")); + + // Unknown tool should fail + assert!(registry + .call_tool("srv", "missing", &serde_json::json!({})) + .is_err()); + } + + #[test] + fn given_connected_server_with_manager_when_calling_tool_then_it_returns_live_result() { + let script_path = write_bridge_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("bridge.log"); + let servers = BTreeMap::from([( + "alpha".to_string(), + manager_server_config(&script_path, "alpha", &log_path), + )]); + let manager = Arc::new(Mutex::new(McpServerManager::from_servers(&servers))); + + let registry = McpToolRegistry::new(); + registry.register_server( + "alpha", + McpConnectionStatus::Connected, + vec![McpToolInfo { + name: "echo".into(), + description: Some("Echo tool for alpha".into()), + input_schema: Some(serde_json::json!({ + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"] + })), + }], + vec![], + Some("bridge test server".into()), + ); + registry + .set_manager(Arc::clone(&manager)) + .expect("manager should only be set once"); + + let result = registry + .call_tool("alpha", "echo", &serde_json::json!({"text": "hello"})) + .expect("should return live MCP result"); + + assert_eq!( + result["structuredContent"]["server"], + serde_json::json!("alpha") + ); + assert_eq!( + result["structuredContent"]["echoed"], + serde_json::json!("hello") + ); + assert_eq!( + result["content"][0]["text"], + serde_json::json!("alpha:hello") + ); + + let log = fs::read_to_string(&log_path).expect("read log"); + assert_eq!( + log.lines().collect::>(), + vec!["initialize", "tools/list", "tools/call"] + ); + + cleanup_script(&script_path); + } + + #[test] + fn rejects_tool_call_on_disconnected_server() { + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::AuthRequired, + vec![McpToolInfo { + name: "greet".into(), + description: None, + input_schema: None, + }], + vec![], + None, + ); + + assert!(registry + .call_tool("srv", "greet", &serde_json::json!({})) + .is_err()); + } + + #[test] + fn sets_auth_and_disconnects() { + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::AuthRequired, + vec![], + vec![], + None, + ); + + registry + .set_auth_status("srv", McpConnectionStatus::Connected) + .expect("should succeed"); + let state = registry.get_server("srv").unwrap(); + assert_eq!(state.status, McpConnectionStatus::Connected); + + let removed = registry.disconnect("srv"); + assert!(removed.is_some()); + assert!(registry.is_empty()); + } + + #[test] + fn rejects_operations_on_missing_server() { + let registry = McpToolRegistry::new(); + assert!(registry.list_resources("missing").is_err()); + assert!(registry.read_resource("missing", "uri").is_err()); + assert!(registry.list_tools("missing").is_err()); + assert!(registry + .call_tool("missing", "tool", &serde_json::json!({})) + .is_err()); + assert!(registry + .set_auth_status("missing", McpConnectionStatus::Connected) + .is_err()); + } + + #[test] + fn mcp_connection_status_display_all_variants() { + // given + let cases = [ + (McpConnectionStatus::Disconnected, "disconnected"), + (McpConnectionStatus::Connecting, "connecting"), + (McpConnectionStatus::Connected, "connected"), + (McpConnectionStatus::AuthRequired, "auth_required"), + (McpConnectionStatus::Error, "error"), + ]; + + // when + let rendered: Vec<_> = cases + .into_iter() + .map(|(status, expected)| (status.to_string(), expected)) + .collect(); + + // then + assert_eq!( + rendered, + vec![ + ("disconnected".to_string(), "disconnected"), + ("connecting".to_string(), "connecting"), + ("connected".to_string(), "connected"), + ("auth_required".to_string(), "auth_required"), + ("error".to_string(), "error"), + ] + ); + } + + #[test] + fn list_servers_returns_all_registered() { + // given + let registry = McpToolRegistry::new(); + registry.register_server( + "alpha", + McpConnectionStatus::Connected, + vec![], + vec![], + None, + ); + registry.register_server( + "beta", + McpConnectionStatus::Connecting, + vec![], + vec![], + None, + ); + + // when + let servers = registry.list_servers(); + + // then + assert_eq!(servers.len(), 2); + assert!(servers.iter().any(|server| server.server_name == "alpha")); + assert!(servers.iter().any(|server| server.server_name == "beta")); + } + + #[test] + fn list_tools_from_connected_server() { + // given + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::Connected, + vec![McpToolInfo { + name: "inspect".into(), + description: Some("Inspect data".into()), + input_schema: Some(serde_json::json!({"type": "object"})), + }], + vec![], + None, + ); + + // when + let tools = registry.list_tools("srv").expect("tools should list"); + + // then + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].name, "inspect"); + } + + #[test] + fn list_tools_rejects_disconnected_server() { + // given + let registry = McpToolRegistry::new(); + registry.register_server( + "srv", + McpConnectionStatus::AuthRequired, + vec![], + vec![], + None, + ); + + // when + let result = registry.list_tools("srv"); + + // then + let error = result.expect_err("non-connected server should fail"); + assert!(error.contains("not connected")); + assert!(error.contains("auth_required")); + } + + #[test] + fn list_tools_rejects_missing_server() { + // given + let registry = McpToolRegistry::new(); + + // when + let result = registry.list_tools("missing"); + + // then + assert_eq!( + result.expect_err("missing server should fail"), + "server 'missing' not found" + ); + } + + #[test] + fn get_server_returns_none_for_missing() { + // given + let registry = McpToolRegistry::new(); + + // when + let server = registry.get_server("missing"); + + // then + assert!(server.is_none()); + } + + #[test] + fn call_tool_payload_structure() { + let script_path = write_bridge_mcp_server_script(); + let root = script_path.parent().expect("script parent"); + let log_path = root.join("payload.log"); + let servers = BTreeMap::from([( + "srv".to_string(), + manager_server_config(&script_path, "srv", &log_path), + )]); + let registry = McpToolRegistry::new(); + let arguments = serde_json::json!({"text": "world"}); + registry.register_server( + "srv", + McpConnectionStatus::Connected, + vec![McpToolInfo { + name: "echo".into(), + description: Some("Echo tool for srv".into()), + input_schema: Some(serde_json::json!({ + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"] + })), + }], + vec![], + None, + ); + registry + .set_manager(Arc::new(Mutex::new(McpServerManager::from_servers( + &servers, + )))) + .expect("manager should only be set once"); + + let result = registry + .call_tool("srv", "echo", &arguments) + .expect("tool should return live payload"); + + assert_eq!(result["structuredContent"]["server"], "srv"); + assert_eq!(result["structuredContent"]["echoed"], "world"); + assert_eq!(result["content"][0]["text"], "srv:world"); + + cleanup_script(&script_path); + } + + #[test] + fn upsert_overwrites_existing_server() { + // given + let registry = McpToolRegistry::new(); + registry.register_server("srv", McpConnectionStatus::Connecting, vec![], vec![], None); + + // when + registry.register_server( + "srv", + McpConnectionStatus::Connected, + vec![McpToolInfo { + name: "inspect".into(), + description: None, + input_schema: None, + }], + vec![], + Some("Inspector".into()), + ); + let state = registry.get_server("srv").expect("server should exist"); + + // then + assert_eq!(state.status, McpConnectionStatus::Connected); + assert_eq!(state.tools.len(), 1); + assert_eq!(state.server_info.as_deref(), Some("Inspector")); + } + + #[test] + fn disconnect_missing_returns_none() { + // given + let registry = McpToolRegistry::new(); + + // when + let removed = registry.disconnect("missing"); + + // then + assert!(removed.is_none()); + } + + #[test] + fn len_and_is_empty_transitions() { + // given + let registry = McpToolRegistry::new(); + + // when + registry.register_server( + "alpha", + McpConnectionStatus::Connected, + vec![], + vec![], + None, + ); + registry.register_server("beta", McpConnectionStatus::Connected, vec![], vec![], None); + let after_create = registry.len(); + registry.disconnect("alpha"); + let after_first_remove = registry.len(); + registry.disconnect("beta"); + + // then + assert_eq!(after_create, 2); + assert_eq!(after_first_remove, 1); + assert_eq!(registry.len(), 0); + assert!(registry.is_empty()); + } +} diff --git a/crates/runtime/src/oauth.rs b/crates/runtime/src/oauth.rs index e4756c1..aa3ca15 100644 --- a/crates/runtime/src/oauth.rs +++ b/crates/runtime/src/oauth.rs @@ -9,6 +9,7 @@ use sha2::{Digest, Sha256}; use crate::config::OAuthConfig; +/// Persisted OAuth access token bundle used by the CLI. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct OAuthTokenSet { pub access_token: String, @@ -17,6 +18,7 @@ pub struct OAuthTokenSet { pub scopes: Vec, } +/// PKCE verifier/challenge pair generated for an OAuth authorization flow. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PkceCodePair { pub verifier: String, @@ -24,6 +26,7 @@ pub struct PkceCodePair { pub challenge_method: PkceChallengeMethod, } +/// Challenge algorithms supported by the local PKCE helpers. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PkceChallengeMethod { S256, @@ -38,6 +41,7 @@ impl PkceChallengeMethod { } } +/// Parameters needed to build an authorization URL for browser-based login. #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthAuthorizationRequest { pub authorize_url: String, @@ -50,6 +54,7 @@ pub struct OAuthAuthorizationRequest { pub extra_params: BTreeMap, } +/// Request body for exchanging an OAuth authorization code for tokens. #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthTokenExchangeRequest { pub grant_type: &'static str, @@ -60,6 +65,7 @@ pub struct OAuthTokenExchangeRequest { pub state: String, } +/// Request body for refreshing an existing OAuth token set. #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthRefreshRequest { pub grant_type: &'static str, @@ -68,6 +74,7 @@ pub struct OAuthRefreshRequest { pub scopes: Vec, } +/// Parsed query parameters returned to the local OAuth callback endpoint. #[derive(Debug, Clone, PartialEq, Eq)] pub struct OAuthCallbackParams { pub code: Option, @@ -327,15 +334,16 @@ fn credentials_home_dir() -> io::Result { if let Some(path) = std::env::var_os("CLAW_CONFIG_HOME") { return Ok(PathBuf::from(path)); } - if let Some(path) = std::env::var_os("HOME") { - return Ok(PathBuf::from(path).join(".claw")); - } - if cfg!(target_os = "windows") { - if let Some(path) = std::env::var_os("USERPROFILE") { - return Ok(PathBuf::from(path).join(".claw")); - } - } - Err(io::Error::new(io::ErrorKind::NotFound, "HOME or USERPROFILE is not set")) + let home = std::env::var_os("HOME") + .or_else(|| std::env::var_os("USERPROFILE")) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "HOME is not set (on Windows, set USERPROFILE or HOME, \ + or use CLAW_CONFIG_HOME to point directly at the config directory)", + ) + })?; + Ok(PathBuf::from(home).join(".claw")) } fn read_credentials_root(path: &PathBuf) -> io::Result> { @@ -448,7 +456,7 @@ fn decode_hex(byte: u8) -> Result { b'0'..=b'9' => Ok(byte - b'0'), b'a'..=b'f' => Ok(byte - b'a' + 10), b'A'..=b'F' => Ok(byte - b'A' + 10), - _ => Err(format!("invalid percent-encoding byte: {byte}")), + _ => Err(format!("invalid percent byte: {byte}")), } } diff --git a/crates/runtime/src/permission_enforcer.rs b/crates/runtime/src/permission_enforcer.rs new file mode 100644 index 0000000..3a45dc8 --- /dev/null +++ b/crates/runtime/src/permission_enforcer.rs @@ -0,0 +1,551 @@ +#![allow( + clippy::match_wildcard_for_single_variants, + clippy::must_use_candidate, + clippy::uninlined_format_args +)] +//! Permission enforcement layer that gates tool execution based on the +//! active `PermissionPolicy`. + +use crate::permissions::{PermissionMode, PermissionOutcome, PermissionPolicy}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "outcome")] +pub enum EnforcementResult { + /// Tool execution is allowed. + Allowed, + /// Tool execution was denied due to insufficient permissions. + Denied { + tool: String, + active_mode: String, + required_mode: String, + reason: String, + }, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PermissionEnforcer { + policy: PermissionPolicy, +} + +impl PermissionEnforcer { + #[must_use] + pub fn new(policy: PermissionPolicy) -> Self { + Self { policy } + } + + /// Check whether a tool can be executed under the current permission policy. + /// Auto-denies when prompting is required but no prompter is provided. + pub fn check(&self, tool_name: &str, input: &str) -> EnforcementResult { + // When the active mode is Prompt, defer to the caller's interactive + // prompt flow rather than hard-denying (the enforcer has no prompter). + if self.policy.active_mode() == PermissionMode::Prompt { + return EnforcementResult::Allowed; + } + + let outcome = self.policy.authorize(tool_name, input, None); + + match outcome { + PermissionOutcome::Allow => EnforcementResult::Allowed, + PermissionOutcome::Deny { reason } => { + let active_mode = self.policy.active_mode(); + let required_mode = self.policy.required_mode_for(tool_name); + EnforcementResult::Denied { + tool: tool_name.to_owned(), + active_mode: active_mode.as_str().to_owned(), + required_mode: required_mode.as_str().to_owned(), + reason, + } + } + } + } + + #[must_use] + pub fn is_allowed(&self, tool_name: &str, input: &str) -> bool { + matches!(self.check(tool_name, input), EnforcementResult::Allowed) + } + + #[must_use] + pub fn active_mode(&self) -> PermissionMode { + self.policy.active_mode() + } + + /// Classify a file operation against workspace boundaries. + pub fn check_file_write(&self, path: &str, workspace_root: &str) -> EnforcementResult { + let mode = self.policy.active_mode(); + + match mode { + PermissionMode::ReadOnly => EnforcementResult::Denied { + tool: "write_file".to_owned(), + active_mode: mode.as_str().to_owned(), + required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(), + reason: format!("file writes are not allowed in '{}' mode", mode.as_str()), + }, + PermissionMode::WorkspaceWrite => { + if is_within_workspace(path, workspace_root) { + EnforcementResult::Allowed + } else { + EnforcementResult::Denied { + tool: "write_file".to_owned(), + active_mode: mode.as_str().to_owned(), + required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(), + reason: format!( + "path '{}' is outside workspace root '{}'", + path, workspace_root + ), + } + } + } + // Allow and DangerFullAccess permit all writes + PermissionMode::Allow | PermissionMode::DangerFullAccess => EnforcementResult::Allowed, + PermissionMode::Prompt => EnforcementResult::Denied { + tool: "write_file".to_owned(), + active_mode: mode.as_str().to_owned(), + required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(), + reason: "file write requires confirmation in prompt mode".to_owned(), + }, + } + } + + /// Check if a bash command should be allowed based on current mode. + pub fn check_bash(&self, command: &str) -> EnforcementResult { + let mode = self.policy.active_mode(); + + match mode { + PermissionMode::ReadOnly => { + if is_read_only_command(command) { + EnforcementResult::Allowed + } else { + EnforcementResult::Denied { + tool: "bash".to_owned(), + active_mode: mode.as_str().to_owned(), + required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(), + reason: format!( + "command may modify state; not allowed in '{}' mode", + mode.as_str() + ), + } + } + } + PermissionMode::Prompt => EnforcementResult::Denied { + tool: "bash".to_owned(), + active_mode: mode.as_str().to_owned(), + required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(), + reason: "bash requires confirmation in prompt mode".to_owned(), + }, + // WorkspaceWrite, Allow, DangerFullAccess: permit bash + _ => EnforcementResult::Allowed, + } + } +} + +/// Simple workspace boundary check via string prefix. +fn is_within_workspace(path: &str, workspace_root: &str) -> bool { + let normalized = if path.starts_with('/') { + path.to_owned() + } else { + format!("{workspace_root}/{path}") + }; + + let root = if workspace_root.ends_with('/') { + workspace_root.to_owned() + } else { + format!("{workspace_root}/") + }; + + normalized.starts_with(&root) || normalized == workspace_root.trim_end_matches('/') +} + +/// Conservative heuristic: is this bash command read-only? +fn is_read_only_command(command: &str) -> bool { + let first_token = command + .split_whitespace() + .next() + .unwrap_or("") + .rsplit('/') + .next() + .unwrap_or(""); + + matches!( + first_token, + "cat" + | "head" + | "tail" + | "less" + | "more" + | "wc" + | "ls" + | "find" + | "grep" + | "rg" + | "awk" + | "sed" + | "echo" + | "printf" + | "which" + | "where" + | "whoami" + | "pwd" + | "env" + | "printenv" + | "date" + | "cal" + | "df" + | "du" + | "free" + | "uptime" + | "uname" + | "file" + | "stat" + | "diff" + | "sort" + | "uniq" + | "tr" + | "cut" + | "paste" + | "tee" + | "xargs" + | "test" + | "true" + | "false" + | "type" + | "readlink" + | "realpath" + | "basename" + | "dirname" + | "sha256sum" + | "md5sum" + | "b3sum" + | "xxd" + | "hexdump" + | "od" + | "strings" + | "tree" + | "jq" + | "yq" + | "python3" + | "python" + | "node" + | "ruby" + | "cargo" + | "rustc" + | "git" + | "gh" + ) && !command.contains("-i ") + && !command.contains("--in-place") + && !command.contains(" > ") + && !command.contains(" >> ") +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_enforcer(mode: PermissionMode) -> PermissionEnforcer { + let policy = PermissionPolicy::new(mode); + PermissionEnforcer::new(policy) + } + + #[test] + fn allow_mode_permits_everything() { + let enforcer = make_enforcer(PermissionMode::Allow); + assert!(enforcer.is_allowed("bash", "")); + assert!(enforcer.is_allowed("write_file", "")); + assert!(enforcer.is_allowed("edit_file", "")); + assert_eq!( + enforcer.check_file_write("/outside/path", "/workspace"), + EnforcementResult::Allowed + ); + assert_eq!(enforcer.check_bash("rm -rf /"), EnforcementResult::Allowed); + } + + #[test] + fn read_only_denies_writes() { + let policy = PermissionPolicy::new(PermissionMode::ReadOnly) + .with_tool_requirement("read_file", PermissionMode::ReadOnly) + .with_tool_requirement("grep_search", PermissionMode::ReadOnly) + .with_tool_requirement("write_file", PermissionMode::WorkspaceWrite); + + let enforcer = PermissionEnforcer::new(policy); + assert!(enforcer.is_allowed("read_file", "")); + assert!(enforcer.is_allowed("grep_search", "")); + + // write_file requires WorkspaceWrite but we're in ReadOnly + let result = enforcer.check("write_file", ""); + assert!(matches!(result, EnforcementResult::Denied { .. })); + + let result = enforcer.check_file_write("/workspace/file.rs", "/workspace"); + assert!(matches!(result, EnforcementResult::Denied { .. })); + } + + #[test] + fn read_only_allows_read_commands() { + let enforcer = make_enforcer(PermissionMode::ReadOnly); + assert_eq!( + enforcer.check_bash("cat src/main.rs"), + EnforcementResult::Allowed + ); + assert_eq!( + enforcer.check_bash("grep -r 'pattern' ."), + EnforcementResult::Allowed + ); + assert_eq!(enforcer.check_bash("ls -la"), EnforcementResult::Allowed); + } + + #[test] + fn read_only_denies_write_commands() { + let enforcer = make_enforcer(PermissionMode::ReadOnly); + let result = enforcer.check_bash("rm file.txt"); + assert!(matches!(result, EnforcementResult::Denied { .. })); + } + + #[test] + fn workspace_write_allows_within_workspace() { + let enforcer = make_enforcer(PermissionMode::WorkspaceWrite); + let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace"); + assert_eq!(result, EnforcementResult::Allowed); + } + + #[test] + fn workspace_write_denies_outside_workspace() { + let enforcer = make_enforcer(PermissionMode::WorkspaceWrite); + let result = enforcer.check_file_write("/etc/passwd", "/workspace"); + assert!(matches!(result, EnforcementResult::Denied { .. })); + } + + #[test] + fn prompt_mode_denies_without_prompter() { + let enforcer = make_enforcer(PermissionMode::Prompt); + let result = enforcer.check_bash("echo test"); + assert!(matches!(result, EnforcementResult::Denied { .. })); + + let result = enforcer.check_file_write("/workspace/file.rs", "/workspace"); + assert!(matches!(result, EnforcementResult::Denied { .. })); + } + + #[test] + fn workspace_boundary_check() { + assert!(is_within_workspace("/workspace/src/main.rs", "/workspace")); + assert!(is_within_workspace("/workspace", "/workspace")); + assert!(!is_within_workspace("/etc/passwd", "/workspace")); + assert!(!is_within_workspace("/workspacex/hack", "/workspace")); + } + + #[test] + fn read_only_command_heuristic() { + assert!(is_read_only_command("cat file.txt")); + assert!(is_read_only_command("grep pattern file")); + assert!(is_read_only_command("git log --oneline")); + assert!(!is_read_only_command("rm file.txt")); + assert!(!is_read_only_command("echo test > file.txt")); + assert!(!is_read_only_command("sed -i 's/a/b/' file")); + } + + #[test] + fn active_mode_returns_policy_mode() { + // given + let modes = [ + PermissionMode::ReadOnly, + PermissionMode::WorkspaceWrite, + PermissionMode::DangerFullAccess, + PermissionMode::Prompt, + PermissionMode::Allow, + ]; + + // when + let active_modes: Vec<_> = modes + .into_iter() + .map(|mode| make_enforcer(mode).active_mode()) + .collect(); + + // then + assert_eq!(active_modes, modes); + } + + #[test] + fn danger_full_access_permits_file_writes_and_bash() { + // given + let enforcer = make_enforcer(PermissionMode::DangerFullAccess); + + // when + let file_result = enforcer.check_file_write("/outside/workspace/file.txt", "/workspace"); + let bash_result = enforcer.check_bash("rm -rf /tmp/scratch"); + + // then + assert_eq!(file_result, EnforcementResult::Allowed); + assert_eq!(bash_result, EnforcementResult::Allowed); + } + + #[test] + fn check_denied_payload_contains_tool_and_modes() { + // given + let policy = PermissionPolicy::new(PermissionMode::ReadOnly) + .with_tool_requirement("write_file", PermissionMode::WorkspaceWrite); + let enforcer = PermissionEnforcer::new(policy); + + // when + let result = enforcer.check("write_file", "{}"); + + // then + match result { + EnforcementResult::Denied { + tool, + active_mode, + required_mode, + reason, + } => { + assert_eq!(tool, "write_file"); + assert_eq!(active_mode, "read-only"); + assert_eq!(required_mode, "workspace-write"); + assert!(reason.contains("requires workspace-write permission")); + } + other => panic!("expected denied result, got {other:?}"), + } + } + + #[test] + fn workspace_write_relative_path_resolved() { + // given + let enforcer = make_enforcer(PermissionMode::WorkspaceWrite); + + // when + let result = enforcer.check_file_write("src/main.rs", "/workspace"); + + // then + assert_eq!(result, EnforcementResult::Allowed); + } + + #[test] + fn workspace_root_with_trailing_slash() { + // given + let enforcer = make_enforcer(PermissionMode::WorkspaceWrite); + + // when + let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace/"); + + // then + assert_eq!(result, EnforcementResult::Allowed); + } + + #[test] + fn workspace_root_equality() { + // given + let root = "/workspace/"; + + // when + let equal_to_root = is_within_workspace("/workspace", root); + + // then + assert!(equal_to_root); + } + + #[test] + fn bash_heuristic_full_path_prefix() { + // given + let full_path_command = "/usr/bin/cat Cargo.toml"; + let git_path_command = "/usr/local/bin/git status"; + + // when + let cat_result = is_read_only_command(full_path_command); + let git_result = is_read_only_command(git_path_command); + + // then + assert!(cat_result); + assert!(git_result); + } + + #[test] + fn bash_heuristic_redirects_block_read_only_commands() { + // given + let overwrite = "cat Cargo.toml > out.txt"; + let append = "echo test >> out.txt"; + + // when + let overwrite_result = is_read_only_command(overwrite); + let append_result = is_read_only_command(append); + + // then + assert!(!overwrite_result); + assert!(!append_result); + } + + #[test] + fn bash_heuristic_in_place_flag_blocks() { + // given + let interactive_python = "python -i script.py"; + let in_place_sed = "sed --in-place 's/a/b/' file.txt"; + + // when + let interactive_result = is_read_only_command(interactive_python); + let in_place_result = is_read_only_command(in_place_sed); + + // then + assert!(!interactive_result); + assert!(!in_place_result); + } + + #[test] + fn bash_heuristic_empty_command() { + // given + let empty = ""; + let whitespace = " "; + + // when + let empty_result = is_read_only_command(empty); + let whitespace_result = is_read_only_command(whitespace); + + // then + assert!(!empty_result); + assert!(!whitespace_result); + } + + #[test] + fn prompt_mode_check_bash_denied_payload_fields() { + // given + let enforcer = make_enforcer(PermissionMode::Prompt); + + // when + let result = enforcer.check_bash("git status"); + + // then + match result { + EnforcementResult::Denied { + tool, + active_mode, + required_mode, + reason, + } => { + assert_eq!(tool, "bash"); + assert_eq!(active_mode, "prompt"); + assert_eq!(required_mode, "danger-full-access"); + assert_eq!(reason, "bash requires confirmation in prompt mode"); + } + other => panic!("expected denied result, got {other:?}"), + } + } + + #[test] + fn read_only_check_file_write_denied_payload() { + // given + let enforcer = make_enforcer(PermissionMode::ReadOnly); + + // when + let result = enforcer.check_file_write("/workspace/file.txt", "/workspace"); + + // then + match result { + EnforcementResult::Denied { + tool, + active_mode, + required_mode, + reason, + } => { + assert_eq!(tool, "write_file"); + assert_eq!(active_mode, "read-only"); + assert_eq!(required_mode, "workspace-write"); + assert!(reason.contains("file writes are not allowed")); + } + other => panic!("expected denied result, got {other:?}"), + } + } +} diff --git a/crates/runtime/src/permissions.rs b/crates/runtime/src/permissions.rs index bed2eab..81340dd 100644 --- a/crates/runtime/src/permissions.rs +++ b/crates/runtime/src/permissions.rs @@ -1,5 +1,10 @@ use std::collections::BTreeMap; +use serde_json::Value; + +use crate::config::RuntimePermissionRuleConfig; + +/// Permission level assigned to a tool invocation or runtime session. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum PermissionMode { ReadOnly, @@ -22,34 +27,81 @@ impl PermissionMode { } } +/// Hook-provided override applied before standard permission evaluation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PermissionOverride { + Allow, + Deny, + Ask, +} + +/// Additional permission context supplied by hooks or higher-level orchestration. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct PermissionContext { + override_decision: Option, + override_reason: Option, +} + +impl PermissionContext { + #[must_use] + pub fn new( + override_decision: Option, + override_reason: Option, + ) -> Self { + Self { + override_decision, + override_reason, + } + } + + #[must_use] + pub fn override_decision(&self) -> Option { + self.override_decision + } + + #[must_use] + pub fn override_reason(&self) -> Option<&str> { + self.override_reason.as_deref() + } +} + +/// Full authorization request presented to a permission prompt. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PermissionRequest { pub tool_name: String, pub input: String, pub current_mode: PermissionMode, pub required_mode: PermissionMode, + pub reason: Option, } +/// User-facing decision returned by a [`PermissionPrompter`]. #[derive(Debug, Clone, PartialEq, Eq)] pub enum PermissionPromptDecision { Allow, Deny { reason: String }, } +/// Prompting interface used when policy requires interactive approval. pub trait PermissionPrompter { fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision; } +/// Final authorization result after evaluating static rules and prompts. #[derive(Debug, Clone, PartialEq, Eq)] pub enum PermissionOutcome { Allow, Deny { reason: String }, } +/// Evaluates permission mode requirements plus allow/deny/ask rules. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PermissionPolicy { active_mode: PermissionMode, tool_requirements: BTreeMap, + allow_rules: Vec, + deny_rules: Vec, + ask_rules: Vec, } impl PermissionPolicy { @@ -58,6 +110,9 @@ impl PermissionPolicy { Self { active_mode, tool_requirements: BTreeMap::new(), + allow_rules: Vec::new(), + deny_rules: Vec::new(), + ask_rules: Vec::new(), } } @@ -72,6 +127,26 @@ impl PermissionPolicy { self } + #[must_use] + pub fn with_permission_rules(mut self, config: &RuntimePermissionRuleConfig) -> Self { + self.allow_rules = config + .allow() + .iter() + .map(|rule| PermissionRule::parse(rule)) + .collect(); + self.deny_rules = config + .deny() + .iter() + .map(|rule| PermissionRule::parse(rule)) + .collect(); + self.ask_rules = config + .ask() + .iter() + .map(|rule| PermissionRule::parse(rule)) + .collect(); + self + } + #[must_use] pub fn active_mode(&self) -> PermissionMode { self.active_mode @@ -90,38 +165,121 @@ impl PermissionPolicy { &self, tool_name: &str, input: &str, - mut prompter: Option<&mut dyn PermissionPrompter>, + prompter: Option<&mut dyn PermissionPrompter>, ) -> PermissionOutcome { - let current_mode = self.active_mode(); - let required_mode = self.required_mode_for(tool_name); - if current_mode == PermissionMode::Allow || current_mode >= required_mode { - return PermissionOutcome::Allow; + self.authorize_with_context(tool_name, input, &PermissionContext::default(), prompter) + } + + #[must_use] + #[allow(clippy::too_many_lines)] + pub fn authorize_with_context( + &self, + tool_name: &str, + input: &str, + context: &PermissionContext, + prompter: Option<&mut dyn PermissionPrompter>, + ) -> PermissionOutcome { + if let Some(rule) = Self::find_matching_rule(&self.deny_rules, tool_name, input) { + return PermissionOutcome::Deny { + reason: format!( + "Permission to use {tool_name} has been denied by rule '{}'", + rule.raw + ), + }; } - let request = PermissionRequest { - tool_name: tool_name.to_string(), - input: input.to_string(), - current_mode, - required_mode, - }; + let current_mode = self.active_mode(); + let required_mode = self.required_mode_for(tool_name); + let ask_rule = Self::find_matching_rule(&self.ask_rules, tool_name, input); + let allow_rule = Self::find_matching_rule(&self.allow_rules, tool_name, input); + + match context.override_decision() { + Some(PermissionOverride::Deny) => { + return PermissionOutcome::Deny { + reason: context.override_reason().map_or_else( + || format!("tool '{tool_name}' denied by hook"), + ToOwned::to_owned, + ), + }; + } + Some(PermissionOverride::Ask) => { + let reason = context.override_reason().map_or_else( + || format!("tool '{tool_name}' requires approval due to hook guidance"), + ToOwned::to_owned, + ); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + Some(reason), + prompter, + ); + } + Some(PermissionOverride::Allow) => { + if let Some(rule) = ask_rule { + let reason = format!( + "tool '{tool_name}' requires approval due to ask rule '{}'", + rule.raw + ); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + Some(reason), + prompter, + ); + } + if allow_rule.is_some() + || current_mode == PermissionMode::Allow + || current_mode >= required_mode + { + return PermissionOutcome::Allow; + } + } + None => {} + } + + if let Some(rule) = ask_rule { + let reason = format!( + "tool '{tool_name}' requires approval due to ask rule '{}'", + rule.raw + ); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + Some(reason), + prompter, + ); + } + + if allow_rule.is_some() + || current_mode == PermissionMode::Allow + || current_mode >= required_mode + { + return PermissionOutcome::Allow; + } if current_mode == PermissionMode::Prompt || (current_mode == PermissionMode::WorkspaceWrite && required_mode == PermissionMode::DangerFullAccess) { - return match prompter.as_mut() { - Some(prompter) => match prompter.decide(&request) { - PermissionPromptDecision::Allow => PermissionOutcome::Allow, - PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason }, - }, - None => PermissionOutcome::Deny { - reason: format!( - "tool '{tool_name}' requires approval to escalate from {} to {}", - current_mode.as_str(), - required_mode.as_str() - ), - }, - }; + let reason = Some(format!( + "tool '{tool_name}' requires approval to escalate from {} to {}", + current_mode.as_str(), + required_mode.as_str() + )); + return Self::prompt_or_deny( + tool_name, + input, + current_mode, + required_mode, + reason, + prompter, + ); } PermissionOutcome::Deny { @@ -132,14 +290,191 @@ impl PermissionPolicy { ), } } + + fn prompt_or_deny( + tool_name: &str, + input: &str, + current_mode: PermissionMode, + required_mode: PermissionMode, + reason: Option, + mut prompter: Option<&mut dyn PermissionPrompter>, + ) -> PermissionOutcome { + let request = PermissionRequest { + tool_name: tool_name.to_string(), + input: input.to_string(), + current_mode, + required_mode, + reason: reason.clone(), + }; + + match prompter.as_mut() { + Some(prompter) => match prompter.decide(&request) { + PermissionPromptDecision::Allow => PermissionOutcome::Allow, + PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason }, + }, + None => PermissionOutcome::Deny { + reason: reason.unwrap_or_else(|| { + format!( + "tool '{tool_name}' requires approval to run while mode is {}", + current_mode.as_str() + ) + }), + }, + } + } + + fn find_matching_rule<'a>( + rules: &'a [PermissionRule], + tool_name: &str, + input: &str, + ) -> Option<&'a PermissionRule> { + rules.iter().find(|rule| rule.matches(tool_name, input)) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct PermissionRule { + raw: String, + tool_name: String, + matcher: PermissionRuleMatcher, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum PermissionRuleMatcher { + Any, + Exact(String), + Prefix(String), +} + +impl PermissionRule { + fn parse(raw: &str) -> Self { + let trimmed = raw.trim(); + let open = find_first_unescaped(trimmed, '('); + let close = find_last_unescaped(trimmed, ')'); + + if let (Some(open), Some(close)) = (open, close) { + if close == trimmed.len() - 1 && open < close { + let tool_name = trimmed[..open].trim(); + let content = &trimmed[open + 1..close]; + if !tool_name.is_empty() { + let matcher = parse_rule_matcher(content); + return Self { + raw: trimmed.to_string(), + tool_name: tool_name.to_string(), + matcher, + }; + } + } + } + + Self { + raw: trimmed.to_string(), + tool_name: trimmed.to_string(), + matcher: PermissionRuleMatcher::Any, + } + } + + fn matches(&self, tool_name: &str, input: &str) -> bool { + if self.tool_name != tool_name { + return false; + } + + match &self.matcher { + PermissionRuleMatcher::Any => true, + PermissionRuleMatcher::Exact(expected) => { + extract_permission_subject(input).is_some_and(|candidate| candidate == *expected) + } + PermissionRuleMatcher::Prefix(prefix) => extract_permission_subject(input) + .is_some_and(|candidate| candidate.starts_with(prefix)), + } + } +} + +fn parse_rule_matcher(content: &str) -> PermissionRuleMatcher { + let unescaped = unescape_rule_content(content.trim()); + if unescaped.is_empty() || unescaped == "*" { + PermissionRuleMatcher::Any + } else if let Some(prefix) = unescaped.strip_suffix(":*") { + PermissionRuleMatcher::Prefix(prefix.to_string()) + } else { + PermissionRuleMatcher::Exact(unescaped) + } +} + +fn unescape_rule_content(content: &str) -> String { + content + .replace(r"\(", "(") + .replace(r"\)", ")") + .replace(r"\\", r"\") +} + +fn find_first_unescaped(value: &str, needle: char) -> Option { + let mut escaped = false; + for (idx, ch) in value.char_indices() { + if ch == '\\' { + escaped = !escaped; + continue; + } + if ch == needle && !escaped { + return Some(idx); + } + escaped = false; + } + None +} + +fn find_last_unescaped(value: &str, needle: char) -> Option { + let chars = value.char_indices().collect::>(); + for (pos, (idx, ch)) in chars.iter().enumerate().rev() { + if *ch != needle { + continue; + } + let mut backslashes = 0; + for (_, prev) in chars[..pos].iter().rev() { + if *prev == '\\' { + backslashes += 1; + } else { + break; + } + } + if backslashes % 2 == 0 { + return Some(*idx); + } + } + None +} + +fn extract_permission_subject(input: &str) -> Option { + let parsed = serde_json::from_str::(input).ok(); + if let Some(Value::Object(object)) = parsed { + for key in [ + "command", + "path", + "file_path", + "filePath", + "notebook_path", + "notebookPath", + "url", + "pattern", + "code", + "message", + ] { + if let Some(value) = object.get(key).and_then(Value::as_str) { + return Some(value.to_string()); + } + } + } + + (!input.trim().is_empty()).then(|| input.to_string()) } #[cfg(test)] mod tests { use super::{ - PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision, - PermissionPrompter, PermissionRequest, + PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy, + PermissionPromptDecision, PermissionPrompter, PermissionRequest, }; + use crate::config::RuntimePermissionRuleConfig; struct RecordingPrompter { seen: Vec, @@ -229,4 +564,120 @@ mod tests { PermissionOutcome::Deny { reason } if reason == "not now" )); } + + #[test] + fn applies_rule_based_denials_and_allows() { + let rules = RuntimePermissionRuleConfig::new( + vec!["bash(git:*)".to_string()], + vec!["bash(rm -rf:*)".to_string()], + Vec::new(), + ); + let policy = PermissionPolicy::new(PermissionMode::ReadOnly) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess) + .with_permission_rules(&rules); + + assert_eq!( + policy.authorize("bash", r#"{"command":"git status"}"#, None), + PermissionOutcome::Allow + ); + assert!(matches!( + policy.authorize("bash", r#"{"command":"rm -rf /tmp/x"}"#, None), + PermissionOutcome::Deny { reason } if reason.contains("denied by rule") + )); + } + + #[test] + fn ask_rules_force_prompt_even_when_mode_allows() { + let rules = RuntimePermissionRuleConfig::new( + Vec::new(), + Vec::new(), + vec!["bash(git:*)".to_string()], + ); + let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess) + .with_permission_rules(&rules); + let mut prompter = RecordingPrompter { + seen: Vec::new(), + allow: true, + }; + + let outcome = policy.authorize("bash", r#"{"command":"git status"}"#, Some(&mut prompter)); + + assert_eq!(outcome, PermissionOutcome::Allow); + assert_eq!(prompter.seen.len(), 1); + assert!(prompter.seen[0] + .reason + .as_deref() + .is_some_and(|reason| reason.contains("ask rule"))); + } + + #[test] + fn hook_allow_still_respects_ask_rules() { + let rules = RuntimePermissionRuleConfig::new( + Vec::new(), + Vec::new(), + vec!["bash(git:*)".to_string()], + ); + let policy = PermissionPolicy::new(PermissionMode::ReadOnly) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess) + .with_permission_rules(&rules); + let context = PermissionContext::new( + Some(PermissionOverride::Allow), + Some("hook approved".to_string()), + ); + let mut prompter = RecordingPrompter { + seen: Vec::new(), + allow: true, + }; + + let outcome = policy.authorize_with_context( + "bash", + r#"{"command":"git status"}"#, + &context, + Some(&mut prompter), + ); + + assert_eq!(outcome, PermissionOutcome::Allow); + assert_eq!(prompter.seen.len(), 1); + } + + #[test] + fn hook_deny_short_circuits_permission_flow() { + let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess); + let context = PermissionContext::new( + Some(PermissionOverride::Deny), + Some("blocked by hook".to_string()), + ); + + assert_eq!( + policy.authorize_with_context("bash", "{}", &context, None), + PermissionOutcome::Deny { + reason: "blocked by hook".to_string(), + } + ); + } + + #[test] + fn hook_ask_forces_prompt() { + let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess) + .with_tool_requirement("bash", PermissionMode::DangerFullAccess); + let context = PermissionContext::new( + Some(PermissionOverride::Ask), + Some("hook requested confirmation".to_string()), + ); + let mut prompter = RecordingPrompter { + seen: Vec::new(), + allow: true, + }; + + let outcome = policy.authorize_with_context("bash", "{}", &context, Some(&mut prompter)); + + assert_eq!(outcome, PermissionOutcome::Allow); + assert_eq!(prompter.seen.len(), 1); + assert_eq!( + prompter.seen[0].reason.as_deref(), + Some("hook requested confirmation") + ); + } } diff --git a/crates/runtime/src/plugin_lifecycle.rs b/crates/runtime/src/plugin_lifecycle.rs new file mode 100644 index 0000000..bd12321 --- /dev/null +++ b/crates/runtime/src/plugin_lifecycle.rs @@ -0,0 +1,533 @@ +#![allow(clippy::redundant_closure_for_method_calls)] +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::config::RuntimePluginConfig; +use crate::mcp_tool_bridge::{McpResourceInfo, McpToolInfo}; + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +pub type ToolInfo = McpToolInfo; +pub type ResourceInfo = McpResourceInfo; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ServerStatus { + Healthy, + Degraded, + Failed, +} + +impl std::fmt::Display for ServerStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Healthy => write!(f, "healthy"), + Self::Degraded => write!(f, "degraded"), + Self::Failed => write!(f, "failed"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ServerHealth { + pub server_name: String, + pub status: ServerStatus, + pub capabilities: Vec, + pub last_error: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "state")] +pub enum PluginState { + Unconfigured, + Validated, + Starting, + Healthy, + Degraded { + healthy_servers: Vec, + failed_servers: Vec, + }, + Failed { + reason: String, + }, + ShuttingDown, + Stopped, +} + +impl PluginState { + #[must_use] + pub fn from_servers(servers: &[ServerHealth]) -> Self { + if servers.is_empty() { + return Self::Failed { + reason: "no servers available".to_string(), + }; + } + + let healthy_servers = servers + .iter() + .filter(|server| server.status != ServerStatus::Failed) + .map(|server| server.server_name.clone()) + .collect::>(); + let failed_servers = servers + .iter() + .filter(|server| server.status == ServerStatus::Failed) + .cloned() + .collect::>(); + let has_degraded_server = servers + .iter() + .any(|server| server.status == ServerStatus::Degraded); + + if failed_servers.is_empty() && !has_degraded_server { + Self::Healthy + } else if healthy_servers.is_empty() { + Self::Failed { + reason: format!("all {} servers failed", failed_servers.len()), + } + } else { + Self::Degraded { + healthy_servers, + failed_servers, + } + } + } +} + +impl std::fmt::Display for PluginState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Unconfigured => write!(f, "unconfigured"), + Self::Validated => write!(f, "validated"), + Self::Starting => write!(f, "starting"), + Self::Healthy => write!(f, "healthy"), + Self::Degraded { .. } => write!(f, "degraded"), + Self::Failed { .. } => write!(f, "failed"), + Self::ShuttingDown => write!(f, "shutting_down"), + Self::Stopped => write!(f, "stopped"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PluginHealthcheck { + pub plugin_name: String, + pub state: PluginState, + pub servers: Vec, + pub last_check: u64, +} + +impl PluginHealthcheck { + #[must_use] + pub fn new(plugin_name: impl Into, servers: Vec) -> Self { + let state = PluginState::from_servers(&servers); + Self { + plugin_name: plugin_name.into(), + state, + servers, + last_check: now_secs(), + } + } + + #[must_use] + pub fn degraded_mode(&self, discovery: &DiscoveryResult) -> Option { + match &self.state { + PluginState::Degraded { + healthy_servers, + failed_servers, + } => Some(DegradedMode { + available_tools: discovery + .tools + .iter() + .map(|tool| tool.name.clone()) + .collect(), + unavailable_tools: failed_servers + .iter() + .flat_map(|server| server.capabilities.iter().cloned()) + .collect(), + reason: format!( + "{} servers healthy, {} servers failed", + healthy_servers.len(), + failed_servers.len() + ), + }), + _ => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiscoveryResult { + pub tools: Vec, + pub resources: Vec, + pub partial: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DegradedMode { + pub available_tools: Vec, + pub unavailable_tools: Vec, + pub reason: String, +} + +impl DegradedMode { + #[must_use] + pub fn new( + available_tools: Vec, + unavailable_tools: Vec, + reason: impl Into, + ) -> Self { + Self { + available_tools, + unavailable_tools, + reason: reason.into(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PluginLifecycleEvent { + ConfigValidated, + StartupHealthy, + StartupDegraded, + StartupFailed, + Shutdown, +} + +impl std::fmt::Display for PluginLifecycleEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConfigValidated => write!(f, "config_validated"), + Self::StartupHealthy => write!(f, "startup_healthy"), + Self::StartupDegraded => write!(f, "startup_degraded"), + Self::StartupFailed => write!(f, "startup_failed"), + Self::Shutdown => write!(f, "shutdown"), + } + } +} + +pub trait PluginLifecycle { + fn validate_config(&self, config: &RuntimePluginConfig) -> Result<(), String>; + fn healthcheck(&self) -> PluginHealthcheck; + fn discover(&self) -> DiscoveryResult; + fn shutdown(&mut self) -> Result<(), String>; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone)] + struct MockPluginLifecycle { + plugin_name: String, + valid_config: bool, + healthcheck: PluginHealthcheck, + discovery: DiscoveryResult, + shutdown_error: Option, + shutdown_called: bool, + } + + impl MockPluginLifecycle { + fn new( + plugin_name: &str, + valid_config: bool, + servers: Vec, + discovery: DiscoveryResult, + shutdown_error: Option, + ) -> Self { + Self { + plugin_name: plugin_name.to_string(), + valid_config, + healthcheck: PluginHealthcheck::new(plugin_name, servers), + discovery, + shutdown_error, + shutdown_called: false, + } + } + } + + impl PluginLifecycle for MockPluginLifecycle { + fn validate_config(&self, _config: &RuntimePluginConfig) -> Result<(), String> { + if self.valid_config { + Ok(()) + } else { + Err(format!( + "plugin `{}` failed configuration validation", + self.plugin_name + )) + } + } + + fn healthcheck(&self) -> PluginHealthcheck { + if self.shutdown_called { + PluginHealthcheck { + plugin_name: self.plugin_name.clone(), + state: PluginState::Stopped, + servers: self.healthcheck.servers.clone(), + last_check: now_secs(), + } + } else { + self.healthcheck.clone() + } + } + + fn discover(&self) -> DiscoveryResult { + self.discovery.clone() + } + + fn shutdown(&mut self) -> Result<(), String> { + if let Some(error) = &self.shutdown_error { + return Err(error.clone()); + } + + self.shutdown_called = true; + Ok(()) + } + } + + fn healthy_server(name: &str, capabilities: &[&str]) -> ServerHealth { + ServerHealth { + server_name: name.to_string(), + status: ServerStatus::Healthy, + capabilities: capabilities + .iter() + .map(|capability| capability.to_string()) + .collect(), + last_error: None, + } + } + + fn failed_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth { + ServerHealth { + server_name: name.to_string(), + status: ServerStatus::Failed, + capabilities: capabilities + .iter() + .map(|capability| capability.to_string()) + .collect(), + last_error: Some(error.to_string()), + } + } + + fn degraded_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth { + ServerHealth { + server_name: name.to_string(), + status: ServerStatus::Degraded, + capabilities: capabilities + .iter() + .map(|capability| capability.to_string()) + .collect(), + last_error: Some(error.to_string()), + } + } + + fn tool(name: &str) -> ToolInfo { + ToolInfo { + name: name.to_string(), + description: Some(format!("{name} tool")), + input_schema: None, + } + } + + fn resource(name: &str, uri: &str) -> ResourceInfo { + ResourceInfo { + uri: uri.to_string(), + name: name.to_string(), + description: Some(format!("{name} resource")), + mime_type: Some("application/json".to_string()), + } + } + + #[test] + fn full_lifecycle_happy_path() { + // given + let mut lifecycle = MockPluginLifecycle::new( + "healthy-plugin", + true, + vec![ + healthy_server("alpha", &["search", "read"]), + healthy_server("beta", &["write"]), + ], + DiscoveryResult { + tools: vec![tool("search"), tool("read"), tool("write")], + resources: vec![resource("docs", "file:///docs")], + partial: false, + }, + None, + ); + let config = RuntimePluginConfig::default(); + + // when + let validation = lifecycle.validate_config(&config); + let healthcheck = lifecycle.healthcheck(); + let discovery = lifecycle.discover(); + let shutdown = lifecycle.shutdown(); + let post_shutdown = lifecycle.healthcheck(); + + // then + assert_eq!(validation, Ok(())); + assert_eq!(healthcheck.state, PluginState::Healthy); + assert_eq!(healthcheck.plugin_name, "healthy-plugin"); + assert_eq!(discovery.tools.len(), 3); + assert_eq!(discovery.resources.len(), 1); + assert!(!discovery.partial); + assert_eq!(shutdown, Ok(())); + assert_eq!(post_shutdown.state, PluginState::Stopped); + } + + #[test] + fn degraded_startup_when_one_of_three_servers_fails() { + // given + let lifecycle = MockPluginLifecycle::new( + "degraded-plugin", + true, + vec![ + healthy_server("alpha", &["search"]), + failed_server("beta", &["write"], "connection refused"), + healthy_server("gamma", &["read"]), + ], + DiscoveryResult { + tools: vec![tool("search"), tool("read")], + resources: vec![resource("alpha-docs", "file:///alpha")], + partial: true, + }, + None, + ); + + // when + let healthcheck = lifecycle.healthcheck(); + let discovery = lifecycle.discover(); + let degraded_mode = healthcheck + .degraded_mode(&discovery) + .expect("degraded startup should expose degraded mode"); + + // then + match healthcheck.state { + PluginState::Degraded { + healthy_servers, + failed_servers, + } => { + assert_eq!( + healthy_servers, + vec!["alpha".to_string(), "gamma".to_string()] + ); + assert_eq!(failed_servers.len(), 1); + assert_eq!(failed_servers[0].server_name, "beta"); + assert_eq!( + failed_servers[0].last_error.as_deref(), + Some("connection refused") + ); + } + other => panic!("expected degraded state, got {other:?}"), + } + assert!(discovery.partial); + assert_eq!( + degraded_mode.available_tools, + vec!["search".to_string(), "read".to_string()] + ); + assert_eq!(degraded_mode.unavailable_tools, vec!["write".to_string()]); + assert_eq!(degraded_mode.reason, "2 servers healthy, 1 servers failed"); + } + + #[test] + fn degraded_server_status_keeps_server_usable() { + // given + let lifecycle = MockPluginLifecycle::new( + "soft-degraded-plugin", + true, + vec![ + healthy_server("alpha", &["search"]), + degraded_server("beta", &["write"], "high latency"), + ], + DiscoveryResult { + tools: vec![tool("search"), tool("write")], + resources: Vec::new(), + partial: true, + }, + None, + ); + + // when + let healthcheck = lifecycle.healthcheck(); + + // then + match healthcheck.state { + PluginState::Degraded { + healthy_servers, + failed_servers, + } => { + assert_eq!( + healthy_servers, + vec!["alpha".to_string(), "beta".to_string()] + ); + assert!(failed_servers.is_empty()); + } + other => panic!("expected degraded state, got {other:?}"), + } + } + + #[test] + fn complete_failure_when_all_servers_fail() { + // given + let lifecycle = MockPluginLifecycle::new( + "failed-plugin", + true, + vec![ + failed_server("alpha", &["search"], "timeout"), + failed_server("beta", &["read"], "handshake failed"), + ], + DiscoveryResult { + tools: Vec::new(), + resources: Vec::new(), + partial: false, + }, + None, + ); + + // when + let healthcheck = lifecycle.healthcheck(); + let discovery = lifecycle.discover(); + + // then + match &healthcheck.state { + PluginState::Failed { reason } => { + assert_eq!(reason, "all 2 servers failed"); + } + other => panic!("expected failed state, got {other:?}"), + } + assert!(!discovery.partial); + assert!(discovery.tools.is_empty()); + assert!(discovery.resources.is_empty()); + assert!(healthcheck.degraded_mode(&discovery).is_none()); + } + + #[test] + fn graceful_shutdown() { + // given + let mut lifecycle = MockPluginLifecycle::new( + "shutdown-plugin", + true, + vec![healthy_server("alpha", &["search"])], + DiscoveryResult { + tools: vec![tool("search")], + resources: Vec::new(), + partial: false, + }, + None, + ); + + // when + let shutdown = lifecycle.shutdown(); + let post_shutdown = lifecycle.healthcheck(); + + // then + assert_eq!(shutdown, Ok(())); + assert_eq!(PluginLifecycleEvent::Shutdown.to_string(), "shutdown"); + assert_eq!(post_shutdown.state, PluginState::Stopped); + } +} diff --git a/crates/runtime/src/policy_engine.rs b/crates/runtime/src/policy_engine.rs new file mode 100644 index 0000000..84912a6 --- /dev/null +++ b/crates/runtime/src/policy_engine.rs @@ -0,0 +1,581 @@ +use std::time::Duration; + +pub type GreenLevel = u8; + +const STALE_BRANCH_THRESHOLD: Duration = Duration::from_secs(60 * 60); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PolicyRule { + pub name: String, + pub condition: PolicyCondition, + pub action: PolicyAction, + pub priority: u32, +} + +impl PolicyRule { + #[must_use] + pub fn new( + name: impl Into, + condition: PolicyCondition, + action: PolicyAction, + priority: u32, + ) -> Self { + Self { + name: name.into(), + condition, + action, + priority, + } + } + + #[must_use] + pub fn matches(&self, context: &LaneContext) -> bool { + self.condition.matches(context) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PolicyCondition { + And(Vec), + Or(Vec), + GreenAt { level: GreenLevel }, + StaleBranch, + StartupBlocked, + LaneCompleted, + LaneReconciled, + ReviewPassed, + ScopedDiff, + TimedOut { duration: Duration }, +} + +impl PolicyCondition { + #[must_use] + pub fn matches(&self, context: &LaneContext) -> bool { + match self { + Self::And(conditions) => conditions + .iter() + .all(|condition| condition.matches(context)), + Self::Or(conditions) => conditions + .iter() + .any(|condition| condition.matches(context)), + Self::GreenAt { level } => context.green_level >= *level, + Self::StaleBranch => context.branch_freshness >= STALE_BRANCH_THRESHOLD, + Self::StartupBlocked => context.blocker == LaneBlocker::Startup, + Self::LaneCompleted => context.completed, + Self::LaneReconciled => context.reconciled, + Self::ReviewPassed => context.review_status == ReviewStatus::Approved, + Self::ScopedDiff => context.diff_scope == DiffScope::Scoped, + Self::TimedOut { duration } => context.branch_freshness >= *duration, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PolicyAction { + MergeToDev, + MergeForward, + RecoverOnce, + Escalate { reason: String }, + CloseoutLane, + CleanupSession, + Reconcile { reason: ReconcileReason }, + Notify { channel: String }, + Block { reason: String }, + Chain(Vec), +} + +/// Why a lane was reconciled without further action. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReconcileReason { + /// Branch already merged into main — no PR needed. + AlreadyMerged, + /// Work superseded by another lane or direct commit. + Superseded, + /// PR would be empty — all changes already landed. + EmptyDiff, + /// Lane manually closed by operator. + ManualClose, +} + +impl PolicyAction { + fn flatten_into(&self, actions: &mut Vec) { + match self { + Self::Chain(chained) => { + for action in chained { + action.flatten_into(actions); + } + } + _ => actions.push(self.clone()), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LaneBlocker { + None, + Startup, + External, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReviewStatus { + Pending, + Approved, + Rejected, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DiffScope { + Full, + Scoped, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LaneContext { + pub lane_id: String, + pub green_level: GreenLevel, + pub branch_freshness: Duration, + pub blocker: LaneBlocker, + pub review_status: ReviewStatus, + pub diff_scope: DiffScope, + pub completed: bool, + pub reconciled: bool, +} + +impl LaneContext { + #[must_use] + pub fn new( + lane_id: impl Into, + green_level: GreenLevel, + branch_freshness: Duration, + blocker: LaneBlocker, + review_status: ReviewStatus, + diff_scope: DiffScope, + completed: bool, + ) -> Self { + Self { + lane_id: lane_id.into(), + green_level, + branch_freshness, + blocker, + review_status, + diff_scope, + completed, + reconciled: false, + } + } + + /// Create a lane context that is already reconciled (no further action needed). + #[must_use] + pub fn reconciled(lane_id: impl Into) -> Self { + Self { + lane_id: lane_id.into(), + green_level: 0, + branch_freshness: Duration::from_secs(0), + blocker: LaneBlocker::None, + review_status: ReviewStatus::Pending, + diff_scope: DiffScope::Full, + completed: true, + reconciled: true, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PolicyEngine { + rules: Vec, +} + +impl PolicyEngine { + #[must_use] + pub fn new(mut rules: Vec) -> Self { + rules.sort_by_key(|rule| rule.priority); + Self { rules } + } + + #[must_use] + pub fn rules(&self) -> &[PolicyRule] { + &self.rules + } + + #[must_use] + pub fn evaluate(&self, context: &LaneContext) -> Vec { + evaluate(self, context) + } +} + +#[must_use] +pub fn evaluate(engine: &PolicyEngine, context: &LaneContext) -> Vec { + let mut actions = Vec::new(); + for rule in &engine.rules { + if rule.matches(context) { + rule.action.flatten_into(&mut actions); + } + } + actions +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::{ + evaluate, DiffScope, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, PolicyEngine, + PolicyRule, ReconcileReason, ReviewStatus, STALE_BRANCH_THRESHOLD, + }; + + fn default_context() -> LaneContext { + LaneContext::new( + "lane-7", + 0, + Duration::from_secs(0), + LaneBlocker::None, + ReviewStatus::Pending, + DiffScope::Full, + false, + ) + } + + #[test] + fn merge_to_dev_rule_fires_for_green_scoped_reviewed_lane() { + // given + let engine = PolicyEngine::new(vec![PolicyRule::new( + "merge-to-dev", + PolicyCondition::And(vec![ + PolicyCondition::GreenAt { level: 2 }, + PolicyCondition::ScopedDiff, + PolicyCondition::ReviewPassed, + ]), + PolicyAction::MergeToDev, + 20, + )]); + let context = LaneContext::new( + "lane-7", + 3, + Duration::from_secs(5), + LaneBlocker::None, + ReviewStatus::Approved, + DiffScope::Scoped, + false, + ); + + // when + let actions = engine.evaluate(&context); + + // then + assert_eq!(actions, vec![PolicyAction::MergeToDev]); + } + + #[test] + fn stale_branch_rule_fires_at_threshold() { + // given + let engine = PolicyEngine::new(vec![PolicyRule::new( + "merge-forward", + PolicyCondition::StaleBranch, + PolicyAction::MergeForward, + 10, + )]); + let context = LaneContext::new( + "lane-7", + 1, + STALE_BRANCH_THRESHOLD, + LaneBlocker::None, + ReviewStatus::Pending, + DiffScope::Full, + false, + ); + + // when + let actions = engine.evaluate(&context); + + // then + assert_eq!(actions, vec![PolicyAction::MergeForward]); + } + + #[test] + fn startup_blocked_rule_recovers_then_escalates() { + // given + let engine = PolicyEngine::new(vec![PolicyRule::new( + "startup-recovery", + PolicyCondition::StartupBlocked, + PolicyAction::Chain(vec![ + PolicyAction::RecoverOnce, + PolicyAction::Escalate { + reason: "startup remained blocked".to_string(), + }, + ]), + 15, + )]); + let context = LaneContext::new( + "lane-7", + 0, + Duration::from_secs(0), + LaneBlocker::Startup, + ReviewStatus::Pending, + DiffScope::Full, + false, + ); + + // when + let actions = engine.evaluate(&context); + + // then + assert_eq!( + actions, + vec![ + PolicyAction::RecoverOnce, + PolicyAction::Escalate { + reason: "startup remained blocked".to_string(), + }, + ] + ); + } + + #[test] + fn completed_lane_rule_closes_out_and_cleans_up() { + // given + let engine = PolicyEngine::new(vec![PolicyRule::new( + "lane-closeout", + PolicyCondition::LaneCompleted, + PolicyAction::Chain(vec![ + PolicyAction::CloseoutLane, + PolicyAction::CleanupSession, + ]), + 30, + )]); + let context = LaneContext::new( + "lane-7", + 0, + Duration::from_secs(0), + LaneBlocker::None, + ReviewStatus::Pending, + DiffScope::Full, + true, + ); + + // when + let actions = engine.evaluate(&context); + + // then + assert_eq!( + actions, + vec![PolicyAction::CloseoutLane, PolicyAction::CleanupSession] + ); + } + + #[test] + fn matching_rules_are_returned_in_priority_order_with_stable_ties() { + // given + let engine = PolicyEngine::new(vec![ + PolicyRule::new( + "late-cleanup", + PolicyCondition::And(vec![]), + PolicyAction::CleanupSession, + 30, + ), + PolicyRule::new( + "first-notify", + PolicyCondition::And(vec![]), + PolicyAction::Notify { + channel: "ops".to_string(), + }, + 10, + ), + PolicyRule::new( + "second-notify", + PolicyCondition::And(vec![]), + PolicyAction::Notify { + channel: "review".to_string(), + }, + 10, + ), + PolicyRule::new( + "merge", + PolicyCondition::And(vec![]), + PolicyAction::MergeToDev, + 20, + ), + ]); + let context = default_context(); + + // when + let actions = evaluate(&engine, &context); + + // then + assert_eq!( + actions, + vec![ + PolicyAction::Notify { + channel: "ops".to_string(), + }, + PolicyAction::Notify { + channel: "review".to_string(), + }, + PolicyAction::MergeToDev, + PolicyAction::CleanupSession, + ] + ); + } + + #[test] + fn combinators_handle_empty_cases_and_nested_chains() { + // given + let engine = PolicyEngine::new(vec![ + PolicyRule::new( + "empty-and", + PolicyCondition::And(vec![]), + PolicyAction::Notify { + channel: "orchestrator".to_string(), + }, + 5, + ), + PolicyRule::new( + "empty-or", + PolicyCondition::Or(vec![]), + PolicyAction::Block { + reason: "should not fire".to_string(), + }, + 10, + ), + PolicyRule::new( + "nested", + PolicyCondition::Or(vec![ + PolicyCondition::StartupBlocked, + PolicyCondition::And(vec![ + PolicyCondition::GreenAt { level: 2 }, + PolicyCondition::TimedOut { + duration: Duration::from_secs(5), + }, + ]), + ]), + PolicyAction::Chain(vec![ + PolicyAction::Notify { + channel: "alerts".to_string(), + }, + PolicyAction::Chain(vec![ + PolicyAction::MergeForward, + PolicyAction::CleanupSession, + ]), + ]), + 15, + ), + ]); + let context = LaneContext::new( + "lane-7", + 2, + Duration::from_secs(10), + LaneBlocker::External, + ReviewStatus::Pending, + DiffScope::Full, + false, + ); + + // when + let actions = engine.evaluate(&context); + + // then + assert_eq!( + actions, + vec![ + PolicyAction::Notify { + channel: "orchestrator".to_string(), + }, + PolicyAction::Notify { + channel: "alerts".to_string(), + }, + PolicyAction::MergeForward, + PolicyAction::CleanupSession, + ] + ); + } + + #[test] + fn reconciled_lane_emits_reconcile_and_cleanup() { + // given — a lane where branch is already merged, no PR needed, session stale + let engine = PolicyEngine::new(vec![ + PolicyRule::new( + "reconcile-closeout", + PolicyCondition::LaneReconciled, + PolicyAction::Chain(vec![ + PolicyAction::Reconcile { + reason: ReconcileReason::AlreadyMerged, + }, + PolicyAction::CloseoutLane, + PolicyAction::CleanupSession, + ]), + 5, + ), + // This rule should NOT fire — reconciled lanes are completed but we want + // the more specific reconcile rule to handle them + PolicyRule::new( + "generic-closeout", + PolicyCondition::And(vec![ + PolicyCondition::LaneCompleted, + // Only fire if NOT reconciled + PolicyCondition::And(vec![]), + ]), + PolicyAction::CloseoutLane, + 30, + ), + ]); + let context = LaneContext::reconciled("lane-9411"); + + // when + let actions = engine.evaluate(&context); + + // then — reconcile rule fires first (priority 5), then generic closeout also fires + // because reconciled context has completed=true + assert_eq!( + actions, + vec![ + PolicyAction::Reconcile { + reason: ReconcileReason::AlreadyMerged, + }, + PolicyAction::CloseoutLane, + PolicyAction::CleanupSession, + PolicyAction::CloseoutLane, + ] + ); + } + + #[test] + fn reconciled_context_has_correct_defaults() { + let ctx = LaneContext::reconciled("test-lane"); + assert_eq!(ctx.lane_id, "test-lane"); + assert!(ctx.completed); + assert!(ctx.reconciled); + assert_eq!(ctx.blocker, LaneBlocker::None); + assert_eq!(ctx.green_level, 0); + } + + #[test] + fn non_reconciled_lane_does_not_trigger_reconcile_rule() { + let engine = PolicyEngine::new(vec![PolicyRule::new( + "reconcile-closeout", + PolicyCondition::LaneReconciled, + PolicyAction::Reconcile { + reason: ReconcileReason::EmptyDiff, + }, + 5, + )]); + // Normal completed lane — not reconciled + let context = LaneContext::new( + "lane-7", + 0, + Duration::from_secs(0), + LaneBlocker::None, + ReviewStatus::Pending, + DiffScope::Full, + true, + ); + + let actions = engine.evaluate(&context); + assert!(actions.is_empty()); + } + + #[test] + fn reconcile_reason_variants_are_distinct() { + assert_ne!(ReconcileReason::AlreadyMerged, ReconcileReason::Superseded); + assert_ne!(ReconcileReason::EmptyDiff, ReconcileReason::ManualClose); + } +} diff --git a/crates/runtime/src/prompt.rs b/crates/runtime/src/prompt.rs index d3b09e3..e46b7eb 100644 --- a/crates/runtime/src/prompt.rs +++ b/crates/runtime/src/prompt.rs @@ -4,8 +4,9 @@ use std::path::{Path, PathBuf}; use std::process::Command; use crate::config::{ConfigError, ConfigLoader, RuntimeConfig}; -use lsp::LspContextEnrichment; +use crate::git_context::GitContext; +/// Errors raised while assembling the final system prompt. #[derive(Debug)] pub enum PromptBuildError { Io(std::io::Error), @@ -35,23 +36,28 @@ impl From for PromptBuildError { } } +/// Marker separating static prompt scaffolding from dynamic runtime context. pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__"; -pub const FRONTIER_MODEL_NAME: &str = "Opus 4.6"; +/// Human-readable default frontier model name embedded into generated prompts. +pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6"; const MAX_INSTRUCTION_FILE_CHARS: usize = 4_000; const MAX_TOTAL_INSTRUCTION_CHARS: usize = 12_000; +/// Contents of an instruction file included in prompt construction. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ContextFile { pub path: PathBuf, pub content: String, } +/// Project-local context injected into the rendered system prompt. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct ProjectContext { pub cwd: PathBuf, pub current_date: String, pub git_status: Option, pub git_diff: Option, + pub git_context: Option, pub instruction_files: Vec, } @@ -67,6 +73,7 @@ impl ProjectContext { current_date: current_date.into(), git_status: None, git_diff: None, + git_context: None, instruction_files, }) } @@ -78,10 +85,12 @@ impl ProjectContext { let mut context = Self::discover(cwd, current_date)?; context.git_status = read_git_status(&context.cwd); context.git_diff = read_git_diff(&context.cwd); + context.git_context = GitContext::detect(&context.cwd); Ok(context) } } +/// Builder for the runtime system prompt and dynamic environment sections. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct SystemPromptBuilder { output_style_name: Option, @@ -131,15 +140,6 @@ impl SystemPromptBuilder { self } - #[must_use] - pub fn with_lsp_context(mut self, enrichment: &LspContextEnrichment) -> Self { - if !enrichment.is_empty() { - self.append_sections - .push(enrichment.render_prompt_section()); - } - self - } - #[must_use] pub fn build(&self) -> Vec { let mut sections = Vec::new(); @@ -194,6 +194,7 @@ impl SystemPromptBuilder { } } +/// Formats each item as an indented bullet for prompt sections. #[must_use] pub fn prepend_bullets(items: Vec) -> Vec { items.into_iter().map(|item| format!(" - {item}")).collect() @@ -211,9 +212,9 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result> { let mut files = Vec::new(); for dir in directories { for candidate in [ - dir.join("CLAW.md"), - dir.join("CLAW.local.md"), - dir.join(".claw").join("CLAW.md"), + dir.join("CLAUDE.md"), + dir.join("CLAUDE.local.md"), + dir.join(".claw").join("CLAUDE.md"), dir.join(".claw").join("instructions.md"), ] { push_context_file(&mut files, candidate)?; @@ -292,7 +293,7 @@ fn render_project_context(project_context: &ProjectContext) -> String { ]; if !project_context.instruction_files.is_empty() { bullets.push(format!( - "Claw instruction files discovered: {}.", + "Claude instruction files discovered: {}.", project_context.instruction_files.len() )); } @@ -302,16 +303,32 @@ fn render_project_context(project_context: &ProjectContext) -> String { lines.push("Git status snapshot:".to_string()); lines.push(status.clone()); } + if let Some(ref gc) = project_context.git_context { + if !gc.recent_commits.is_empty() { + lines.push(String::new()); + lines.push("Recent commits (last 5):".to_string()); + for c in &gc.recent_commits { + lines.push(format!(" {} {}", c.hash, c.subject)); + } + } + } if let Some(diff) = &project_context.git_diff { lines.push(String::new()); lines.push("Git diff snapshot:".to_string()); lines.push(diff.clone()); } + if let Some(git_context) = &project_context.git_context { + let rendered = git_context.render(); + if !rendered.is_empty() { + lines.push(String::new()); + lines.push(rendered); + } + } lines.join("\n") } fn render_instruction_files(files: &[ContextFile]) -> String { - let mut sections = vec!["# Claw instructions".to_string()]; + let mut sections = vec!["# Claude instructions".to_string()]; let mut remaining_chars = MAX_TOTAL_INSTRUCTION_CHARS; for file in files { if remaining_chars == 0 { @@ -411,6 +428,7 @@ fn collapse_blank_lines(content: &str) -> String { result } +/// Loads config and project context, then renders the system prompt text. pub fn load_system_prompt( cwd: impl Into, current_date: impl Into, @@ -523,24 +541,31 @@ mod tests { crate::test_env_lock() } + fn ensure_valid_cwd() { + if std::env::current_dir().is_err() { + std::env::set_current_dir(env!("CARGO_MANIFEST_DIR")) + .expect("test cwd should be recoverable"); + } + } + #[test] fn discovers_instruction_files_from_ancestor_chain() { let root = temp_dir(); let nested = root.join("apps").join("api"); fs::create_dir_all(nested.join(".claw")).expect("nested claw dir"); - fs::write(root.join("CLAW.md"), "root instructions").expect("write root instructions"); - fs::write(root.join("CLAW.local.md"), "local instructions") + fs::write(root.join("CLAUDE.md"), "root instructions").expect("write root instructions"); + fs::write(root.join("CLAUDE.local.md"), "local instructions") .expect("write local instructions"); fs::create_dir_all(root.join("apps")).expect("apps dir"); fs::create_dir_all(root.join("apps").join(".claw")).expect("apps claw dir"); - fs::write(root.join("apps").join("CLAW.md"), "apps instructions") + fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions") .expect("write apps instructions"); fs::write( root.join("apps").join(".claw").join("instructions.md"), - "apps dot claw instructions", + "apps dot claude instructions", ) - .expect("write apps dot claw instructions"); - fs::write(nested.join(".claw").join("CLAW.md"), "nested rules") + .expect("write apps dot claude instructions"); + fs::write(nested.join(".claw").join("CLAUDE.md"), "nested rules") .expect("write nested rules"); fs::write( nested.join(".claw").join("instructions.md"), @@ -561,7 +586,7 @@ mod tests { "root instructions", "local instructions", "apps instructions", - "apps dot claw instructions", + "apps dot claude instructions", "nested rules", "nested instructions" ] @@ -574,8 +599,8 @@ mod tests { let root = temp_dir(); let nested = root.join("apps").join("api"); fs::create_dir_all(&nested).expect("nested dir"); - fs::write(root.join("CLAW.md"), "same rules\n\n").expect("write root"); - fs::write(nested.join("CLAW.md"), "same rules\n").expect("write nested"); + fs::write(root.join("CLAUDE.md"), "same rules\n\n").expect("write root"); + fs::write(nested.join("CLAUDE.md"), "same rules\n").expect("write nested"); let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load"); assert_eq!(context.instruction_files.len(), 1); @@ -603,14 +628,15 @@ mod tests { #[test] fn displays_context_paths_compactly() { assert_eq!( - display_context_path(Path::new("/tmp/project/.claw/CLAW.md")), - "CLAW.md" + display_context_path(Path::new("/tmp/project/.claw/CLAUDE.md")), + "CLAUDE.md" ); } #[test] fn discover_with_git_includes_status_snapshot() { let _guard = env_lock(); + ensure_valid_cwd(); let root = temp_dir(); fs::create_dir_all(&root).expect("root dir"); std::process::Command::new("git") @@ -618,7 +644,7 @@ mod tests { .current_dir(&root) .status() .expect("git init should run"); - fs::write(root.join("CLAW.md"), "rules").expect("write instructions"); + fs::write(root.join("CLAUDE.md"), "rules").expect("write instructions"); fs::write(root.join("tracked.txt"), "hello").expect("write tracked file"); let context = @@ -626,16 +652,99 @@ mod tests { let status = context.git_status.expect("git status should be present"); assert!(status.contains("## No commits yet on") || status.contains("## ")); - assert!(status.contains("?? CLAW.md")); + assert!(status.contains("?? CLAUDE.md")); assert!(status.contains("?? tracked.txt")); assert!(context.git_diff.is_none()); fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn discover_with_git_includes_recent_commits_and_renders_them() { + // given: a git repo with three commits and a current branch + let _guard = env_lock(); + ensure_valid_cwd(); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + std::process::Command::new("git") + .args(["init", "--quiet", "-b", "main"]) + .current_dir(&root) + .status() + .expect("git init should run"); + std::process::Command::new("git") + .args(["config", "user.email", "tests@example.com"]) + .current_dir(&root) + .status() + .expect("git config email should run"); + std::process::Command::new("git") + .args(["config", "user.name", "Runtime Prompt Tests"]) + .current_dir(&root) + .status() + .expect("git config name should run"); + for (file, message) in [ + ("a.txt", "first commit"), + ("b.txt", "second commit"), + ("c.txt", "third commit"), + ] { + fs::write(root.join(file), "x\n").expect("write commit file"); + std::process::Command::new("git") + .args(["add", file]) + .current_dir(&root) + .status() + .expect("git add should run"); + std::process::Command::new("git") + .args(["commit", "-m", message, "--quiet"]) + .current_dir(&root) + .status() + .expect("git commit should run"); + } + fs::write(root.join("d.txt"), "staged\n").expect("write staged file"); + std::process::Command::new("git") + .args(["add", "d.txt"]) + .current_dir(&root) + .status() + .expect("git add staged should run"); + + // when: discovering project context with git auto-include + let context = + ProjectContext::discover_with_git(&root, "2026-03-31").expect("context should load"); + let rendered = SystemPromptBuilder::new() + .with_os("linux", "6.8") + .with_project_context(context.clone()) + .render(); + + // then: branch, recent commits and staged files are present in context + let gc = context + .git_context + .as_ref() + .expect("git context should be present"); + let commits: String = gc + .recent_commits + .iter() + .map(|c| c.subject.clone()) + .collect::>() + .join("\n"); + assert!(commits.contains("first commit")); + assert!(commits.contains("second commit")); + assert!(commits.contains("third commit")); + assert_eq!(gc.recent_commits.len(), 3); + + let status = context.git_status.as_deref().expect("status snapshot"); + assert!(status.contains("## main")); + assert!(status.contains("A d.txt")); + + assert!(rendered.contains("Recent commits (last 5):")); + assert!(rendered.contains("first commit")); + assert!(rendered.contains("Git status snapshot:")); + assert!(rendered.contains("## main")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn discover_with_git_includes_diff_snapshot_for_tracked_changes() { let _guard = env_lock(); + ensure_valid_cwd(); let root = temp_dir(); fs::create_dir_all(&root).expect("root dir"); std::process::Command::new("git") @@ -677,10 +786,10 @@ mod tests { } #[test] - fn load_system_prompt_reads_claw_files_and_config() { + fn load_system_prompt_reads_claude_files_and_config() { let root = temp_dir(); fs::create_dir_all(root.join(".claw")).expect("claw dir"); - fs::write(root.join("CLAW.md"), "Project rules").expect("write instructions"); + fs::write(root.join("CLAUDE.md"), "Project rules").expect("write instructions"); fs::write( root.join(".claw").join("settings.json"), r#"{"permissionMode":"acceptEdits"}"#, @@ -688,6 +797,7 @@ mod tests { .expect("write settings"); let _guard = env_lock(); + ensure_valid_cwd(); let previous = std::env::current_dir().expect("cwd"); let original_home = std::env::var("HOME").ok(); let original_claw_home = std::env::var("CLAW_CONFIG_HOME").ok(); @@ -719,10 +829,10 @@ mod tests { } #[test] - fn renders_claw_code_style_sections_with_project_context() { + fn renders_claude_code_style_sections_with_project_context() { let root = temp_dir(); fs::create_dir_all(root.join(".claw")).expect("claw dir"); - fs::write(root.join("CLAW.md"), "Project rules").expect("write CLAW.md"); + fs::write(root.join("CLAUDE.md"), "Project rules").expect("write CLAUDE.md"); fs::write( root.join(".claw").join("settings.json"), r#"{"permissionMode":"acceptEdits"}"#, @@ -743,7 +853,7 @@ mod tests { assert!(prompt.contains("# System")); assert!(prompt.contains("# Project context")); - assert!(prompt.contains("# Claw instructions")); + assert!(prompt.contains("# Claude instructions")); assert!(prompt.contains("Project rules")); assert!(prompt.contains("permissionMode")); assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY)); @@ -760,7 +870,7 @@ mod tests { } #[test] - fn discovers_dot_claw_instructions_markdown() { + fn discovers_dot_claude_instructions_markdown() { let root = temp_dir(); let nested = root.join("apps").join("api"); fs::create_dir_all(nested.join(".claw")).expect("nested claw dir"); @@ -785,10 +895,10 @@ mod tests { #[test] fn renders_instruction_file_metadata() { let rendered = render_instruction_files(&[ContextFile { - path: PathBuf::from("/tmp/project/CLAW.md"), + path: PathBuf::from("/tmp/project/CLAUDE.md"), content: "Project rules".to_string(), }]); - assert!(rendered.contains("# Claw instructions")); + assert!(rendered.contains("# Claude instructions")); assert!(rendered.contains("scope: /tmp/project")); assert!(rendered.contains("Project rules")); } diff --git a/crates/runtime/src/recovery_recipes.rs b/crates/runtime/src/recovery_recipes.rs new file mode 100644 index 0000000..5a916b8 --- /dev/null +++ b/crates/runtime/src/recovery_recipes.rs @@ -0,0 +1,631 @@ +#![allow(clippy::cast_possible_truncation, clippy::uninlined_format_args)] +//! Recovery recipes for common failure scenarios. +//! +//! Encodes known automatic recoveries for the six failure scenarios +//! listed in ROADMAP item 8, and enforces one automatic recovery +//! attempt before escalation. Each attempt is emitted as a structured +//! recovery event. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::worker_boot::WorkerFailureKind; + +/// The six failure scenarios that have known recovery recipes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FailureScenario { + TrustPromptUnresolved, + PromptMisdelivery, + StaleBranch, + CompileRedCrossCrate, + McpHandshakeFailure, + PartialPluginStartup, + ProviderFailure, +} + +impl FailureScenario { + /// Returns all known failure scenarios. + #[must_use] + pub fn all() -> &'static [FailureScenario] { + &[ + Self::TrustPromptUnresolved, + Self::PromptMisdelivery, + Self::StaleBranch, + Self::CompileRedCrossCrate, + Self::McpHandshakeFailure, + Self::PartialPluginStartup, + Self::ProviderFailure, + ] + } + + /// Map a `WorkerFailureKind` to the corresponding `FailureScenario`. + /// This is the bridge that lets recovery policy consume worker boot events. + #[must_use] + pub fn from_worker_failure_kind(kind: WorkerFailureKind) -> Self { + match kind { + WorkerFailureKind::TrustGate => Self::TrustPromptUnresolved, + WorkerFailureKind::PromptDelivery => Self::PromptMisdelivery, + WorkerFailureKind::Protocol => Self::McpHandshakeFailure, + WorkerFailureKind::Provider => Self::ProviderFailure, + } + } +} + +impl std::fmt::Display for FailureScenario { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::TrustPromptUnresolved => write!(f, "trust_prompt_unresolved"), + Self::PromptMisdelivery => write!(f, "prompt_misdelivery"), + Self::StaleBranch => write!(f, "stale_branch"), + Self::CompileRedCrossCrate => write!(f, "compile_red_cross_crate"), + Self::McpHandshakeFailure => write!(f, "mcp_handshake_failure"), + Self::PartialPluginStartup => write!(f, "partial_plugin_startup"), + Self::ProviderFailure => write!(f, "provider_failure"), + } + } +} + +/// Individual step that can be executed as part of a recovery recipe. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RecoveryStep { + AcceptTrustPrompt, + RedirectPromptToAgent, + RebaseBranch, + CleanBuild, + RetryMcpHandshake { timeout: u64 }, + RestartPlugin { name: String }, + RestartWorker, + EscalateToHuman { reason: String }, +} + +/// Policy governing what happens when automatic recovery is exhausted. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum EscalationPolicy { + AlertHuman, + LogAndContinue, + Abort, +} + +/// A recovery recipe encodes the sequence of steps to attempt for a +/// given failure scenario, along with the maximum number of automatic +/// attempts and the escalation policy. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RecoveryRecipe { + pub scenario: FailureScenario, + pub steps: Vec, + pub max_attempts: u32, + pub escalation_policy: EscalationPolicy, +} + +/// Outcome of a recovery attempt. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RecoveryResult { + Recovered { + steps_taken: u32, + }, + PartialRecovery { + recovered: Vec, + remaining: Vec, + }, + EscalationRequired { + reason: String, + }, +} + +/// Structured event emitted during recovery. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RecoveryEvent { + RecoveryAttempted { + scenario: FailureScenario, + recipe: RecoveryRecipe, + result: RecoveryResult, + }, + RecoverySucceeded, + RecoveryFailed, + Escalated, +} + +/// Minimal context for tracking recovery state and emitting events. +/// +/// Holds per-scenario attempt counts, a structured event log, and an +/// optional simulation knob for controlling step outcomes during tests. +#[derive(Debug, Clone, Default)] +pub struct RecoveryContext { + attempts: HashMap, + events: Vec, + /// Optional step index at which simulated execution fails. + /// `None` means all steps succeed. + fail_at_step: Option, +} + +impl RecoveryContext { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Configure a step index at which simulated execution will fail. + #[must_use] + pub fn with_fail_at_step(mut self, index: usize) -> Self { + self.fail_at_step = Some(index); + self + } + + /// Returns the structured event log populated during recovery. + #[must_use] + pub fn events(&self) -> &[RecoveryEvent] { + &self.events + } + + /// Returns the number of recovery attempts made for a scenario. + #[must_use] + pub fn attempt_count(&self, scenario: &FailureScenario) -> u32 { + self.attempts.get(scenario).copied().unwrap_or(0) + } +} + +/// Returns the known recovery recipe for the given failure scenario. +#[must_use] +pub fn recipe_for(scenario: &FailureScenario) -> RecoveryRecipe { + match scenario { + FailureScenario::TrustPromptUnresolved => RecoveryRecipe { + scenario: *scenario, + steps: vec![RecoveryStep::AcceptTrustPrompt], + max_attempts: 1, + escalation_policy: EscalationPolicy::AlertHuman, + }, + FailureScenario::PromptMisdelivery => RecoveryRecipe { + scenario: *scenario, + steps: vec![RecoveryStep::RedirectPromptToAgent], + max_attempts: 1, + escalation_policy: EscalationPolicy::AlertHuman, + }, + FailureScenario::StaleBranch => RecoveryRecipe { + scenario: *scenario, + steps: vec![RecoveryStep::RebaseBranch, RecoveryStep::CleanBuild], + max_attempts: 1, + escalation_policy: EscalationPolicy::AlertHuman, + }, + FailureScenario::CompileRedCrossCrate => RecoveryRecipe { + scenario: *scenario, + steps: vec![RecoveryStep::CleanBuild], + max_attempts: 1, + escalation_policy: EscalationPolicy::AlertHuman, + }, + FailureScenario::McpHandshakeFailure => RecoveryRecipe { + scenario: *scenario, + steps: vec![RecoveryStep::RetryMcpHandshake { timeout: 5000 }], + max_attempts: 1, + escalation_policy: EscalationPolicy::Abort, + }, + FailureScenario::PartialPluginStartup => RecoveryRecipe { + scenario: *scenario, + steps: vec![ + RecoveryStep::RestartPlugin { + name: "stalled".to_string(), + }, + RecoveryStep::RetryMcpHandshake { timeout: 3000 }, + ], + max_attempts: 1, + escalation_policy: EscalationPolicy::LogAndContinue, + }, + FailureScenario::ProviderFailure => RecoveryRecipe { + scenario: *scenario, + steps: vec![RecoveryStep::RestartWorker], + max_attempts: 1, + escalation_policy: EscalationPolicy::AlertHuman, + }, + } +} + +/// Attempts automatic recovery for the given failure scenario. +/// +/// Looks up the recipe, enforces the one-attempt-before-escalation +/// policy, simulates step execution (controlled by the context), and +/// emits structured [`RecoveryEvent`]s for every attempt. +pub fn attempt_recovery(scenario: &FailureScenario, ctx: &mut RecoveryContext) -> RecoveryResult { + let recipe = recipe_for(scenario); + let attempt_count = ctx.attempts.entry(*scenario).or_insert(0); + + // Enforce one automatic recovery attempt before escalation. + if *attempt_count >= recipe.max_attempts { + let result = RecoveryResult::EscalationRequired { + reason: format!( + "max recovery attempts ({}) exceeded for {}", + recipe.max_attempts, scenario + ), + }; + ctx.events.push(RecoveryEvent::RecoveryAttempted { + scenario: *scenario, + recipe, + result: result.clone(), + }); + ctx.events.push(RecoveryEvent::Escalated); + return result; + } + + *attempt_count += 1; + + // Execute steps, honoring the optional fail_at_step simulation. + let fail_index = ctx.fail_at_step; + let mut executed = Vec::new(); + let mut failed = false; + + for (i, step) in recipe.steps.iter().enumerate() { + if fail_index == Some(i) { + failed = true; + break; + } + executed.push(step.clone()); + } + + let result = if failed { + let remaining: Vec = recipe.steps[executed.len()..].to_vec(); + if executed.is_empty() { + RecoveryResult::EscalationRequired { + reason: format!("recovery failed at first step for {}", scenario), + } + } else { + RecoveryResult::PartialRecovery { + recovered: executed, + remaining, + } + } + } else { + RecoveryResult::Recovered { + steps_taken: recipe.steps.len() as u32, + } + }; + + // Emit the attempt as structured event data. + ctx.events.push(RecoveryEvent::RecoveryAttempted { + scenario: *scenario, + recipe, + result: result.clone(), + }); + + match &result { + RecoveryResult::Recovered { .. } => { + ctx.events.push(RecoveryEvent::RecoverySucceeded); + } + RecoveryResult::PartialRecovery { .. } => { + ctx.events.push(RecoveryEvent::RecoveryFailed); + } + RecoveryResult::EscalationRequired { .. } => { + ctx.events.push(RecoveryEvent::Escalated); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn each_scenario_has_a_matching_recipe() { + // given + let scenarios = FailureScenario::all(); + + // when / then + for scenario in scenarios { + let recipe = recipe_for(scenario); + assert_eq!( + recipe.scenario, *scenario, + "recipe scenario should match requested scenario" + ); + assert!( + !recipe.steps.is_empty(), + "recipe for {} should have at least one step", + scenario + ); + assert!( + recipe.max_attempts >= 1, + "recipe for {} should allow at least one attempt", + scenario + ); + } + } + + #[test] + fn successful_recovery_returns_recovered_and_emits_events() { + // given + let mut ctx = RecoveryContext::new(); + let scenario = FailureScenario::TrustPromptUnresolved; + + // when + let result = attempt_recovery(&scenario, &mut ctx); + + // then + assert_eq!(result, RecoveryResult::Recovered { steps_taken: 1 }); + assert_eq!(ctx.events().len(), 2); + assert!(matches!( + &ctx.events()[0], + RecoveryEvent::RecoveryAttempted { + scenario: s, + result: r, + .. + } if *s == FailureScenario::TrustPromptUnresolved + && matches!(r, RecoveryResult::Recovered { steps_taken: 1 }) + )); + assert_eq!(ctx.events()[1], RecoveryEvent::RecoverySucceeded); + } + + #[test] + fn escalation_after_max_attempts_exceeded() { + // given + let mut ctx = RecoveryContext::new(); + let scenario = FailureScenario::PromptMisdelivery; + + // when — first attempt succeeds + let first = attempt_recovery(&scenario, &mut ctx); + assert!(matches!(first, RecoveryResult::Recovered { .. })); + + // when — second attempt should escalate + let second = attempt_recovery(&scenario, &mut ctx); + + // then + assert!( + matches!( + &second, + RecoveryResult::EscalationRequired { reason } + if reason.contains("max recovery attempts") + ), + "second attempt should require escalation, got: {second:?}" + ); + assert_eq!(ctx.attempt_count(&scenario), 1); + assert!(ctx + .events() + .iter() + .any(|e| matches!(e, RecoveryEvent::Escalated))); + } + + #[test] + fn partial_recovery_when_step_fails_midway() { + // given — PartialPluginStartup has two steps; fail at step index 1 + let mut ctx = RecoveryContext::new().with_fail_at_step(1); + let scenario = FailureScenario::PartialPluginStartup; + + // when + let result = attempt_recovery(&scenario, &mut ctx); + + // then + match &result { + RecoveryResult::PartialRecovery { + recovered, + remaining, + } => { + assert_eq!(recovered.len(), 1, "one step should have succeeded"); + assert_eq!(remaining.len(), 1, "one step should remain"); + assert!(matches!(recovered[0], RecoveryStep::RestartPlugin { .. })); + assert!(matches!( + remaining[0], + RecoveryStep::RetryMcpHandshake { .. } + )); + } + other => panic!("expected PartialRecovery, got {other:?}"), + } + assert!(ctx + .events() + .iter() + .any(|e| matches!(e, RecoveryEvent::RecoveryFailed))); + } + + #[test] + fn first_step_failure_escalates_immediately() { + // given — fail at step index 0 + let mut ctx = RecoveryContext::new().with_fail_at_step(0); + let scenario = FailureScenario::CompileRedCrossCrate; + + // when + let result = attempt_recovery(&scenario, &mut ctx); + + // then + assert!( + matches!( + &result, + RecoveryResult::EscalationRequired { reason } + if reason.contains("failed at first step") + ), + "zero-step failure should escalate, got: {result:?}" + ); + assert!(ctx + .events() + .iter() + .any(|e| matches!(e, RecoveryEvent::Escalated))); + } + + #[test] + fn emitted_events_include_structured_attempt_data() { + // given + let mut ctx = RecoveryContext::new(); + let scenario = FailureScenario::McpHandshakeFailure; + + // when + let _ = attempt_recovery(&scenario, &mut ctx); + + // then — verify the RecoveryAttempted event carries full context + let attempted = ctx + .events() + .iter() + .find(|e| matches!(e, RecoveryEvent::RecoveryAttempted { .. })) + .expect("should have emitted RecoveryAttempted event"); + + match attempted { + RecoveryEvent::RecoveryAttempted { + scenario: s, + recipe, + result, + } => { + assert_eq!(*s, scenario); + assert_eq!(recipe.scenario, scenario); + assert!(!recipe.steps.is_empty()); + assert!(matches!(result, RecoveryResult::Recovered { .. })); + } + _ => unreachable!(), + } + + // Verify the event is serializable as structured JSON + let json = serde_json::to_string(&ctx.events()[0]) + .expect("recovery event should be serializable to JSON"); + assert!( + json.contains("mcp_handshake_failure"), + "serialized event should contain scenario name" + ); + } + + #[test] + fn recovery_context_tracks_attempts_per_scenario() { + // given + let mut ctx = RecoveryContext::new(); + + // when + assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 0); + attempt_recovery(&FailureScenario::StaleBranch, &mut ctx); + + // then + assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 1); + assert_eq!(ctx.attempt_count(&FailureScenario::PromptMisdelivery), 0); + } + + #[test] + fn stale_branch_recipe_has_rebase_then_clean_build() { + // given + let recipe = recipe_for(&FailureScenario::StaleBranch); + + // then + assert_eq!(recipe.steps.len(), 2); + assert_eq!(recipe.steps[0], RecoveryStep::RebaseBranch); + assert_eq!(recipe.steps[1], RecoveryStep::CleanBuild); + } + + #[test] + fn partial_plugin_startup_recipe_has_restart_then_handshake() { + // given + let recipe = recipe_for(&FailureScenario::PartialPluginStartup); + + // then + assert_eq!(recipe.steps.len(), 2); + assert!(matches!( + recipe.steps[0], + RecoveryStep::RestartPlugin { .. } + )); + assert!(matches!( + recipe.steps[1], + RecoveryStep::RetryMcpHandshake { timeout: 3000 } + )); + assert_eq!(recipe.escalation_policy, EscalationPolicy::LogAndContinue); + } + + #[test] + fn failure_scenario_display_all_variants() { + // given + let cases = [ + ( + FailureScenario::TrustPromptUnresolved, + "trust_prompt_unresolved", + ), + (FailureScenario::PromptMisdelivery, "prompt_misdelivery"), + (FailureScenario::StaleBranch, "stale_branch"), + ( + FailureScenario::CompileRedCrossCrate, + "compile_red_cross_crate", + ), + ( + FailureScenario::McpHandshakeFailure, + "mcp_handshake_failure", + ), + ( + FailureScenario::PartialPluginStartup, + "partial_plugin_startup", + ), + ]; + + // when / then + for (scenario, expected) in &cases { + assert_eq!(scenario.to_string(), *expected); + } + } + + #[test] + fn multi_step_success_reports_correct_steps_taken() { + // given — StaleBranch has 2 steps, no simulated failure + let mut ctx = RecoveryContext::new(); + let scenario = FailureScenario::StaleBranch; + + // when + let result = attempt_recovery(&scenario, &mut ctx); + + // then + assert_eq!(result, RecoveryResult::Recovered { steps_taken: 2 }); + } + + #[test] + fn mcp_handshake_recipe_uses_abort_escalation_policy() { + // given + let recipe = recipe_for(&FailureScenario::McpHandshakeFailure); + + // then + assert_eq!(recipe.escalation_policy, EscalationPolicy::Abort); + assert_eq!(recipe.max_attempts, 1); + } + + #[test] + fn worker_failure_kind_maps_to_failure_scenario() { + // given / when / then — verify the bridge is correct + assert_eq!( + FailureScenario::from_worker_failure_kind(WorkerFailureKind::TrustGate), + FailureScenario::TrustPromptUnresolved, + ); + assert_eq!( + FailureScenario::from_worker_failure_kind(WorkerFailureKind::PromptDelivery), + FailureScenario::PromptMisdelivery, + ); + assert_eq!( + FailureScenario::from_worker_failure_kind(WorkerFailureKind::Protocol), + FailureScenario::McpHandshakeFailure, + ); + assert_eq!( + FailureScenario::from_worker_failure_kind(WorkerFailureKind::Provider), + FailureScenario::ProviderFailure, + ); + } + + #[test] + fn provider_failure_recipe_uses_restart_worker_step() { + // given + let recipe = recipe_for(&FailureScenario::ProviderFailure); + + // then + assert_eq!(recipe.scenario, FailureScenario::ProviderFailure); + assert!(recipe.steps.contains(&RecoveryStep::RestartWorker)); + assert_eq!(recipe.escalation_policy, EscalationPolicy::AlertHuman); + assert_eq!(recipe.max_attempts, 1); + } + + #[test] + fn provider_failure_recovery_attempt_succeeds_then_escalates() { + // given + let mut ctx = RecoveryContext::new(); + let scenario = FailureScenario::ProviderFailure; + + // when — first attempt + let first = attempt_recovery(&scenario, &mut ctx); + assert!(matches!(first, RecoveryResult::Recovered { .. })); + + // when — second attempt should escalate (max_attempts=1) + let second = attempt_recovery(&scenario, &mut ctx); + assert!(matches!(second, RecoveryResult::EscalationRequired { .. })); + assert!(ctx + .events() + .iter() + .any(|e| matches!(e, RecoveryEvent::Escalated))); + } +} diff --git a/crates/runtime/src/remote.rs b/crates/runtime/src/remote.rs index 5fe59a0..24ee780 100644 --- a/crates/runtime/src/remote.rs +++ b/crates/runtime/src/remote.rs @@ -72,9 +72,9 @@ impl RemoteSessionContext { #[must_use] pub fn from_env_map(env_map: &BTreeMap) -> Self { Self { - enabled: env_truthy(env_map.get("CLAW_CODE_REMOTE")), + enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")), session_id: env_map - .get("CLAW_CODE_REMOTE_SESSION_ID") + .get("CLAUDE_CODE_REMOTE_SESSION_ID") .filter(|value| !value.is_empty()) .cloned(), base_url: env_map @@ -272,9 +272,9 @@ mod tests { #[test] fn remote_context_reads_env_state() { let env = BTreeMap::from([ - ("CLAW_CODE_REMOTE".to_string(), "true".to_string()), + ("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()), ( - "CLAW_CODE_REMOTE_SESSION_ID".to_string(), + "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), "session-123".to_string(), ), ( @@ -291,7 +291,7 @@ mod tests { #[test] fn bootstrap_fails_open_when_token_or_session_is_missing() { let env = BTreeMap::from([ - ("CLAW_CODE_REMOTE".to_string(), "1".to_string()), + ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), ]); let bootstrap = UpstreamProxyBootstrap::from_env_map(&env); @@ -307,10 +307,10 @@ mod tests { fs::write(&token_path, "secret-token\n").expect("write token"); let env = BTreeMap::from([ - ("CLAW_CODE_REMOTE".to_string(), "1".to_string()), + ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), ( - "CLAW_CODE_REMOTE_SESSION_ID".to_string(), + "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), "session-123".to_string(), ), ( diff --git a/crates/runtime/src/sandbox.rs b/crates/runtime/src/sandbox.rs index d0054ba..45f118a 100644 --- a/crates/runtime/src/sandbox.rs +++ b/crates/runtime/src/sandbox.rs @@ -107,23 +107,11 @@ impl SandboxConfig { #[must_use] pub fn detect_container_environment() -> ContainerEnvironment { - let proc_1_cgroup = if cfg!(target_os = "linux") { - fs::read_to_string("/proc/1/cgroup").ok() - } else { - None - }; + let proc_1_cgroup = fs::read_to_string("/proc/1/cgroup").ok(); detect_container_environment_from(SandboxDetectionInputs { env_pairs: env::vars().collect(), - dockerenv_exists: if cfg!(target_os = "linux") { - Path::new("/.dockerenv").exists() - } else { - false - }, - containerenv_exists: if cfg!(target_os = "linux") { - Path::new("/run/.containerenv").exists() - } else { - false - }, + dockerenv_exists: Path::new("/.dockerenv").exists(), + containerenv_exists: Path::new("/run/.containerenv").exists(), proc_1_cgroup: proc_1_cgroup.as_deref(), }) } @@ -173,7 +161,7 @@ pub fn resolve_sandbox_status(config: &SandboxConfig, cwd: &Path) -> SandboxStat #[must_use] pub fn resolve_sandbox_status_for_request(request: &SandboxRequest, cwd: &Path) -> SandboxStatus { let container = detect_container_environment(); - let namespace_supported = cfg!(target_os = "linux") && command_exists("unshare"); + let namespace_supported = cfg!(target_os = "linux") && unshare_user_namespace_works(); let network_supported = namespace_supported; let filesystem_active = request.enabled && request.filesystem_mode != FilesystemIsolationMode::Off; @@ -294,6 +282,27 @@ fn command_exists(command: &str) -> bool { .is_some_and(|paths| env::split_paths(&paths).any(|path| path.join(command).exists())) } +/// Check whether `unshare --user` actually works on this system. +/// On some CI environments (e.g. GitHub Actions), the binary exists but +/// user namespaces are restricted, causing silent failures. +fn unshare_user_namespace_works() -> bool { + use std::sync::OnceLock; + static RESULT: OnceLock = OnceLock::new(); + *RESULT.get_or_init(|| { + if !command_exists("unshare") { + return false; + } + std::process::Command::new("unshare") + .args(["--user", "--map-root-user", "true"]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .status() + .map(|s| s.success()) + .unwrap_or(false) + }) +} + #[cfg(test)] mod tests { use super::{ diff --git a/crates/runtime/src/session.rs b/crates/runtime/src/session.rs index ad3e119..3eb0a64 100644 --- a/crates/runtime/src/session.rs +++ b/crates/runtime/src/session.rs @@ -1,15 +1,21 @@ use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; -use std::fs; -use std::path::Path; - -use serde::{Deserialize, Serialize}; +use std::fs::{self, OpenOptions}; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; use crate::json::{JsonError, JsonValue}; use crate::usage::TokenUsage; -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] +const SESSION_VERSION: u32 = 1; +const ROTATE_AFTER_BYTES: u64 = 256 * 1024; +const MAX_ROTATED_FILES: usize = 3; +static SESSION_ID_COUNTER: AtomicU64 = AtomicU64::new(0); + +/// Speaker role associated with a persisted conversation message. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum MessageRole { System, User, @@ -17,17 +23,9 @@ pub enum MessageRole { Tool, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(tag = "type", rename_all = "snake_case")] +/// Structured message content stored inside a [`Session`]. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum ContentBlock { - Thinking { - thinking: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - signature: Option, - }, - RedactedThinking { - data: JsonValue, - }, Text { text: String, }, @@ -44,19 +42,83 @@ pub enum ContentBlock { }, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +/// One conversation message with optional token-usage metadata. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct ConversationMessage { pub role: MessageRole, pub blocks: Vec, pub usage: Option, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct Session { - pub version: u32, - pub messages: Vec, +/// Metadata describing the latest compaction that summarized a session. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionCompaction { + pub count: u32, + pub removed_message_count: usize, + pub summary: String, } +/// Provenance recorded when a session is forked from another session. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionFork { + pub parent_session_id: String, + pub branch_name: Option, +} + +/// A single user prompt recorded with a timestamp for history tracking. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionPromptEntry { + pub timestamp_ms: u64, + pub text: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SessionPersistence { + path: PathBuf, +} + +/// Persisted conversational state for the runtime and CLI session manager. +/// +/// `workspace_root` binds the session to the worktree it was created in. The +/// global session store under `~/.local/share/opencode` is shared across every +/// `opencode serve` instance, so without an explicit workspace root parallel +/// lanes can race and report success while writes land in the wrong CWD. See +/// ROADMAP.md item 41 (Phantom completions root cause) for the full +/// background. +#[derive(Debug, Clone)] +pub struct Session { + pub version: u32, + pub session_id: String, + pub created_at_ms: u64, + pub updated_at_ms: u64, + pub messages: Vec, + pub compaction: Option, + pub fork: Option, + pub workspace_root: Option, + pub prompt_history: Vec, + /// The model used in this session, persisted so resumed sessions can + /// report which model was originally used. + pub model: Option, + persistence: Option, +} + +impl PartialEq for Session { + fn eq(&self, other: &Self) -> bool { + self.version == other.version + && self.session_id == other.session_id + && self.created_at_ms == other.created_at_ms + && self.updated_at_ms == other.updated_at_ms + && self.messages == other.messages + && self.compaction == other.compaction + && self.fork == other.fork + && self.workspace_root == other.workspace_root + && self.prompt_history == other.prompt_history + } +} + +impl Eq for Session {} + +/// Errors raised while loading, parsing, or saving sessions. #[derive(Debug)] pub enum SessionError { Io(std::io::Error), @@ -91,29 +153,143 @@ impl From for SessionError { impl Session { #[must_use] pub fn new() -> Self { + let now = current_time_millis(); Self { - version: 1, + version: SESSION_VERSION, + session_id: generate_session_id(), + created_at_ms: now, + updated_at_ms: now, messages: Vec::new(), + compaction: None, + fork: None, + workspace_root: None, + prompt_history: Vec::new(), + model: None, + persistence: None, } } + #[must_use] + pub fn with_persistence_path(mut self, path: impl Into) -> Self { + self.persistence = Some(SessionPersistence { path: path.into() }); + self + } + + /// Bind this session to the workspace root it was created in. + /// + /// This is the per-worktree counterpart to the global session store and + /// lets downstream tooling reject writes that drift to the wrong CWD when + /// multiple `opencode serve` instances share `~/.local/share/opencode`. + #[must_use] + pub fn with_workspace_root(mut self, workspace_root: impl Into) -> Self { + self.workspace_root = Some(workspace_root.into()); + self + } + + #[must_use] + pub fn workspace_root(&self) -> Option<&Path> { + self.workspace_root.as_deref() + } + + #[must_use] + pub fn persistence_path(&self) -> Option<&Path> { + self.persistence.as_ref().map(|value| value.path.as_path()) + } + pub fn save_to_path(&self, path: impl AsRef) -> Result<(), SessionError> { - fs::write(path, self.to_json().render())?; + let path = path.as_ref(); + let snapshot = self.render_jsonl_snapshot()?; + rotate_session_file_if_needed(path)?; + write_atomic(path, &snapshot)?; + cleanup_rotated_logs(path)?; Ok(()) } pub fn load_from_path(path: impl AsRef) -> Result { + let path = path.as_ref(); let contents = fs::read_to_string(path)?; - Self::from_json(&JsonValue::parse(&contents)?) + let session = match JsonValue::parse(&contents) { + Ok(value) + if value + .as_object() + .is_some_and(|object| object.contains_key("messages")) => + { + Self::from_json(&value)? + } + Err(_) | Ok(_) => Self::from_jsonl(&contents)?, + }; + Ok(session.with_persistence_path(path.to_path_buf())) + } + + pub fn push_message(&mut self, message: ConversationMessage) -> Result<(), SessionError> { + self.touch(); + self.messages.push(message); + let persist_result = { + let message_ref = self.messages.last().ok_or_else(|| { + SessionError::Format("message was just pushed but missing".to_string()) + })?; + self.append_persisted_message(message_ref) + }; + if let Err(error) = persist_result { + self.messages.pop(); + return Err(error); + } + Ok(()) + } + + pub fn push_user_text(&mut self, text: impl Into) -> Result<(), SessionError> { + self.push_message(ConversationMessage::user_text(text)) + } + + pub fn record_compaction(&mut self, summary: impl Into, removed_message_count: usize) { + self.touch(); + let count = self.compaction.as_ref().map_or(1, |value| value.count + 1); + self.compaction = Some(SessionCompaction { + count, + removed_message_count, + summary: summary.into(), + }); } #[must_use] - pub fn to_json(&self) -> JsonValue { + pub fn fork(&self, branch_name: Option) -> Self { + let now = current_time_millis(); + Self { + version: self.version, + session_id: generate_session_id(), + created_at_ms: now, + updated_at_ms: now, + messages: self.messages.clone(), + compaction: self.compaction.clone(), + fork: Some(SessionFork { + parent_session_id: self.session_id.clone(), + branch_name: normalize_optional_string(branch_name), + }), + workspace_root: self.workspace_root.clone(), + prompt_history: self.prompt_history.clone(), + model: self.model.clone(), + persistence: None, + } + } + + pub fn to_json(&self) -> Result { let mut object = BTreeMap::new(); object.insert( "version".to_string(), JsonValue::Number(i64::from(self.version)), ); + object.insert( + "session_id".to_string(), + JsonValue::String(self.session_id.clone()), + ); + object.insert( + "created_at_ms".to_string(), + JsonValue::Number(i64_from_u64(self.created_at_ms, "created_at_ms")?), + ); + object.insert( + "updated_at_ms".to_string(), + JsonValue::Number(i64_from_u64(self.updated_at_ms, "updated_at_ms")?), + ); object.insert( "messages".to_string(), JsonValue::Array( @@ -123,7 +299,30 @@ impl Session { .collect(), ), ); - JsonValue::Object(object) + if let Some(compaction) = &self.compaction { + object.insert("compaction".to_string(), compaction.to_json()?); + } + if let Some(fork) = &self.fork { + object.insert("fork".to_string(), fork.to_json()); + } + if let Some(workspace_root) = &self.workspace_root { + object.insert( + "workspace_root".to_string(), + JsonValue::String(workspace_root_to_string(workspace_root)?), + ); + } + if !self.prompt_history.is_empty() { + object.insert( + "prompt_history".to_string(), + JsonValue::Array( + self.prompt_history + .iter() + .map(SessionPromptEntry::to_jsonl_record) + .collect(), + ), + ); + } + Ok(JsonValue::Object(object)) } pub fn from_json(value: &JsonValue) -> Result { @@ -143,7 +342,268 @@ impl Session { .iter() .map(ConversationMessage::from_json) .collect::, _>>()?; - Ok(Self { version, messages }) + let now = current_time_millis(); + let session_id = object + .get("session_id") + .and_then(JsonValue::as_str) + .map_or_else(generate_session_id, ToOwned::to_owned); + let created_at_ms = object + .get("created_at_ms") + .map(|value| required_u64_from_value(value, "created_at_ms")) + .transpose()? + .unwrap_or(now); + let updated_at_ms = object + .get("updated_at_ms") + .map(|value| required_u64_from_value(value, "updated_at_ms")) + .transpose()? + .unwrap_or(created_at_ms); + let compaction = object + .get("compaction") + .map(SessionCompaction::from_json) + .transpose()?; + let fork = object.get("fork").map(SessionFork::from_json).transpose()?; + let workspace_root = object + .get("workspace_root") + .and_then(JsonValue::as_str) + .map(PathBuf::from); + let prompt_history = object + .get("prompt_history") + .and_then(JsonValue::as_array) + .map(|entries| { + entries + .iter() + .filter_map(SessionPromptEntry::from_json_opt) + .collect() + }) + .unwrap_or_default(); + let model = object + .get("model") + .and_then(JsonValue::as_str) + .map(String::from); + Ok(Self { + version, + session_id, + created_at_ms, + updated_at_ms, + messages, + compaction, + fork, + workspace_root, + prompt_history, + model, + persistence: None, + }) + } + + fn from_jsonl(contents: &str) -> Result { + let mut version = SESSION_VERSION; + let mut session_id = None; + let mut created_at_ms = None; + let mut updated_at_ms = None; + let mut messages = Vec::new(); + let mut compaction = None; + let mut fork = None; + let mut workspace_root = None; + let mut model = None; + let mut prompt_history = Vec::new(); + + for (line_number, raw_line) in contents.lines().enumerate() { + let line = raw_line.trim(); + if line.is_empty() { + continue; + } + let value = JsonValue::parse(line).map_err(|error| { + SessionError::Format(format!( + "invalid JSONL record at line {}: {}", + line_number + 1, + error + )) + })?; + let object = value.as_object().ok_or_else(|| { + SessionError::Format(format!( + "JSONL record at line {} must be an object", + line_number + 1 + )) + })?; + match object + .get("type") + .and_then(JsonValue::as_str) + .ok_or_else(|| { + SessionError::Format(format!( + "JSONL record at line {} missing type", + line_number + 1 + )) + })? { + "session_meta" => { + version = required_u32(object, "version")?; + session_id = Some(required_string(object, "session_id")?); + created_at_ms = Some(required_u64(object, "created_at_ms")?); + updated_at_ms = Some(required_u64(object, "updated_at_ms")?); + fork = object.get("fork").map(SessionFork::from_json).transpose()?; + workspace_root = object + .get("workspace_root") + .and_then(JsonValue::as_str) + .map(PathBuf::from); + model = object + .get("model") + .and_then(JsonValue::as_str) + .map(String::from); + } + "message" => { + let message_value = object.get("message").ok_or_else(|| { + SessionError::Format(format!( + "JSONL record at line {} missing message", + line_number + 1 + )) + })?; + messages.push(ConversationMessage::from_json(message_value)?); + } + "compaction" => { + compaction = Some(SessionCompaction::from_json(&JsonValue::Object( + object.clone(), + ))?); + } + "prompt_history" => { + if let Some(entry) = + SessionPromptEntry::from_json_opt(&JsonValue::Object(object.clone())) + { + prompt_history.push(entry); + } + } + other => { + return Err(SessionError::Format(format!( + "unsupported JSONL record type at line {}: {other}", + line_number + 1 + ))) + } + } + } + + let now = current_time_millis(); + Ok(Self { + version, + session_id: session_id.unwrap_or_else(generate_session_id), + created_at_ms: created_at_ms.unwrap_or(now), + updated_at_ms: updated_at_ms.unwrap_or(created_at_ms.unwrap_or(now)), + messages, + compaction, + fork, + workspace_root, + prompt_history, + model, + persistence: None, + }) + } + + /// Record a user prompt with the current wall-clock timestamp. + /// + /// The entry is appended to the in-memory history and, when a persistence + /// path is configured, incrementally written to the JSONL session file. + pub fn push_prompt_entry(&mut self, text: impl Into) -> Result<(), SessionError> { + let timestamp_ms = current_time_millis(); + let entry = SessionPromptEntry { + timestamp_ms, + text: text.into(), + }; + self.prompt_history.push(entry); + let entry_ref = self.prompt_history.last().expect("entry was just pushed"); + self.append_persisted_prompt_entry(entry_ref) + } + + fn render_jsonl_snapshot(&self) -> Result { + let mut lines = vec![self.meta_record()?.render()]; + if let Some(compaction) = &self.compaction { + lines.push(compaction.to_jsonl_record()?.render()); + } + lines.extend( + self.prompt_history + .iter() + .map(|entry| entry.to_jsonl_record().render()), + ); + lines.extend( + self.messages + .iter() + .map(|message| message_record(message).render()), + ); + let mut rendered = lines.join("\n"); + rendered.push('\n'); + Ok(rendered) + } + + fn append_persisted_message(&self, message: &ConversationMessage) -> Result<(), SessionError> { + let Some(path) = self.persistence_path() else { + return Ok(()); + }; + + let needs_bootstrap = !path.exists() || fs::metadata(path)?.len() == 0; + if needs_bootstrap { + self.save_to_path(path)?; + return Ok(()); + } + + let mut file = OpenOptions::new().append(true).open(path)?; + writeln!(file, "{}", message_record(message).render())?; + Ok(()) + } + + fn append_persisted_prompt_entry( + &self, + entry: &SessionPromptEntry, + ) -> Result<(), SessionError> { + let Some(path) = self.persistence_path() else { + return Ok(()); + }; + + let needs_bootstrap = !path.exists() || fs::metadata(path)?.len() == 0; + if needs_bootstrap { + self.save_to_path(path)?; + return Ok(()); + } + + let mut file = OpenOptions::new().append(true).open(path)?; + writeln!(file, "{}", entry.to_jsonl_record().render())?; + Ok(()) + } + + fn meta_record(&self) -> Result { + let mut object = BTreeMap::new(); + object.insert( + "type".to_string(), + JsonValue::String("session_meta".to_string()), + ); + object.insert( + "version".to_string(), + JsonValue::Number(i64::from(self.version)), + ); + object.insert( + "session_id".to_string(), + JsonValue::String(self.session_id.clone()), + ); + object.insert( + "created_at_ms".to_string(), + JsonValue::Number(i64_from_u64(self.created_at_ms, "created_at_ms")?), + ); + object.insert( + "updated_at_ms".to_string(), + JsonValue::Number(i64_from_u64(self.updated_at_ms, "updated_at_ms")?), + ); + if let Some(fork) = &self.fork { + object.insert("fork".to_string(), fork.to_json()); + } + if let Some(workspace_root) = &self.workspace_root { + object.insert( + "workspace_root".to_string(), + JsonValue::String(workspace_root_to_string(workspace_root)?), + ); + } + if let Some(model) = &self.model { + object.insert("model".to_string(), JsonValue::String(model.clone())); + } + Ok(JsonValue::Object(object)) + } + + fn touch(&mut self) { + self.updated_at_ms = current_time_millis(); } } @@ -269,26 +729,6 @@ impl ContentBlock { object.insert("type".to_string(), JsonValue::String("text".to_string())); object.insert("text".to_string(), JsonValue::String(text.clone())); } - Self::Thinking { thinking, signature } => { - object.insert("type".to_string(), JsonValue::String("thinking".to_string())); - object.insert( - "thinking".to_string(), - JsonValue::String(thinking.clone()), - ); - if let Some(signature) = signature { - object.insert( - "signature".to_string(), - JsonValue::String(signature.clone()), - ); - } - } - Self::RedactedThinking { data } => { - object.insert( - "type".to_string(), - JsonValue::String("redacted_thinking".to_string()), - ); - object.insert("data".to_string(), data.clone()); - } Self::ToolUse { id, name, input } => { object.insert( "type".to_string(), @@ -340,13 +780,6 @@ impl ContentBlock { name: required_string(object, "name")?, input: required_string(object, "input")?, }), - "thinking" => Ok(Self::Thinking { - thinking: required_string(object, "thinking")?, - signature: object.get("signature").and_then(JsonValue::as_str).map(ToOwned::to_owned), - }), - "redacted_thinking" => Ok(Self::RedactedThinking { - data: object.get("data").cloned().ok_or_else(|| SessionError::Format("missing data".to_string()))?, - }), "tool_result" => Ok(Self::ToolResult { tool_use_id: required_string(object, "tool_use_id")?, tool_name: required_string(object, "tool_name")?, @@ -363,6 +796,128 @@ impl ContentBlock { } } +impl SessionCompaction { + pub fn to_json(&self) -> Result { + let mut object = BTreeMap::new(); + object.insert( + "count".to_string(), + JsonValue::Number(i64::from(self.count)), + ); + object.insert( + "removed_message_count".to_string(), + JsonValue::Number(i64_from_usize( + self.removed_message_count, + "removed_message_count", + )?), + ); + object.insert( + "summary".to_string(), + JsonValue::String(self.summary.clone()), + ); + Ok(JsonValue::Object(object)) + } + + pub fn to_jsonl_record(&self) -> Result { + let mut object = BTreeMap::new(); + object.insert( + "type".to_string(), + JsonValue::String("compaction".to_string()), + ); + object.insert( + "count".to_string(), + JsonValue::Number(i64::from(self.count)), + ); + object.insert( + "removed_message_count".to_string(), + JsonValue::Number(i64_from_usize( + self.removed_message_count, + "removed_message_count", + )?), + ); + object.insert( + "summary".to_string(), + JsonValue::String(self.summary.clone()), + ); + Ok(JsonValue::Object(object)) + } + + fn from_json(value: &JsonValue) -> Result { + let object = value + .as_object() + .ok_or_else(|| SessionError::Format("compaction must be an object".to_string()))?; + Ok(Self { + count: required_u32(object, "count")?, + removed_message_count: required_usize(object, "removed_message_count")?, + summary: required_string(object, "summary")?, + }) + } +} + +impl SessionFork { + #[must_use] + pub fn to_json(&self) -> JsonValue { + let mut object = BTreeMap::new(); + object.insert( + "parent_session_id".to_string(), + JsonValue::String(self.parent_session_id.clone()), + ); + if let Some(branch_name) = &self.branch_name { + object.insert( + "branch_name".to_string(), + JsonValue::String(branch_name.clone()), + ); + } + JsonValue::Object(object) + } + + fn from_json(value: &JsonValue) -> Result { + let object = value + .as_object() + .ok_or_else(|| SessionError::Format("fork metadata must be an object".to_string()))?; + Ok(Self { + parent_session_id: required_string(object, "parent_session_id")?, + branch_name: object + .get("branch_name") + .and_then(JsonValue::as_str) + .map(ToOwned::to_owned), + }) + } +} + +impl SessionPromptEntry { + #[must_use] + pub fn to_jsonl_record(&self) -> JsonValue { + let mut object = BTreeMap::new(); + object.insert( + "type".to_string(), + JsonValue::String("prompt_history".to_string()), + ); + object.insert( + "timestamp_ms".to_string(), + JsonValue::Number(i64::try_from(self.timestamp_ms).unwrap_or(i64::MAX)), + ); + object.insert("text".to_string(), JsonValue::String(self.text.clone())); + JsonValue::Object(object) + } + + fn from_json_opt(value: &JsonValue) -> Option { + let object = value.as_object()?; + let timestamp_ms = object + .get("timestamp_ms") + .and_then(JsonValue::as_i64) + .and_then(|value| u64::try_from(value).ok())?; + let text = object.get("text").and_then(JsonValue::as_str)?.to_string(); + Some(Self { timestamp_ms, text }) + } +} + +fn message_record(message: &ConversationMessage) -> JsonValue { + let mut object = BTreeMap::new(); + object.insert("type".to_string(), JsonValue::String("message".to_string())); + object.insert("message".to_string(), message.to_json()); + JsonValue::Object(object) +} + fn usage_to_json(usage: TokenUsage) -> JsonValue { let mut object = BTreeMap::new(); object.insert( @@ -415,22 +970,171 @@ fn required_u32(object: &BTreeMap, key: &str) -> Result, key: &str) -> Result { + let value = object + .get(key) + .ok_or_else(|| SessionError::Format(format!("missing {key}")))?; + required_u64_from_value(value, key) +} + +fn required_u64_from_value(value: &JsonValue, key: &str) -> Result { + let value = value + .as_i64() + .ok_or_else(|| SessionError::Format(format!("missing {key}")))?; + u64::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range"))) +} + +fn required_usize(object: &BTreeMap, key: &str) -> Result { + let value = object + .get(key) + .and_then(JsonValue::as_i64) + .ok_or_else(|| SessionError::Format(format!("missing {key}")))?; + usize::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range"))) +} + +fn i64_from_u64(value: u64, key: &str) -> Result { + i64::try_from(value) + .map_err(|_| SessionError::Format(format!("{key} out of range for JSON number"))) +} + +fn i64_from_usize(value: usize, key: &str) -> Result { + i64::try_from(value) + .map_err(|_| SessionError::Format(format!("{key} out of range for JSON number"))) +} + +fn workspace_root_to_string(path: &Path) -> Result { + path.to_str().map(ToOwned::to_owned).ok_or_else(|| { + SessionError::Format(format!( + "workspace_root is not valid UTF-8: {}", + path.display() + )) + }) +} + +fn normalize_optional_string(value: Option) -> Option { + value.and_then(|value| { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } + }) +} + +fn current_time_millis() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)) + .unwrap_or_default() +} + +fn generate_session_id() -> String { + let millis = current_time_millis(); + let counter = SESSION_ID_COUNTER.fetch_add(1, Ordering::Relaxed); + format!("session-{millis}-{counter}") +} + +fn write_atomic(path: &Path, contents: &str) -> Result<(), SessionError> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + let temp_path = temporary_path_for(path); + fs::write(&temp_path, contents)?; + fs::rename(temp_path, path)?; + Ok(()) +} + +fn temporary_path_for(path: &Path) -> PathBuf { + let file_name = path + .file_name() + .and_then(|value| value.to_str()) + .unwrap_or("session"); + path.with_file_name(format!( + "{file_name}.tmp-{}-{}", + current_time_millis(), + SESSION_ID_COUNTER.fetch_add(1, Ordering::Relaxed) + )) +} + +fn rotate_session_file_if_needed(path: &Path) -> Result<(), SessionError> { + let Ok(metadata) = fs::metadata(path) else { + return Ok(()); + }; + if metadata.len() < ROTATE_AFTER_BYTES { + return Ok(()); + } + let rotated_path = rotated_log_path(path); + fs::rename(path, rotated_path)?; + Ok(()) +} + +fn rotated_log_path(path: &Path) -> PathBuf { + let stem = path + .file_stem() + .and_then(|value| value.to_str()) + .unwrap_or("session"); + path.with_file_name(format!("{stem}.rot-{}.jsonl", current_time_millis())) +} + +fn cleanup_rotated_logs(path: &Path) -> Result<(), SessionError> { + let Some(parent) = path.parent() else { + return Ok(()); + }; + let stem = path + .file_stem() + .and_then(|value| value.to_str()) + .unwrap_or("session"); + let prefix = format!("{stem}.rot-"); + let mut rotated_paths = fs::read_dir(parent)? + .filter_map(Result::ok) + .map(|entry| entry.path()) + .filter(|entry_path| { + entry_path + .file_name() + .and_then(|value| value.to_str()) + .is_some_and(|name| { + name.starts_with(&prefix) + && Path::new(name) + .extension() + .is_some_and(|ext| ext.eq_ignore_ascii_case("jsonl")) + }) + }) + .collect::>(); + + rotated_paths.sort_by_key(|entry_path| { + fs::metadata(entry_path) + .and_then(|metadata| metadata.modified()) + .unwrap_or(UNIX_EPOCH) + }); + + let remove_count = rotated_paths.len().saturating_sub(MAX_ROTATED_FILES); + for stale_path in rotated_paths.into_iter().take(remove_count) { + fs::remove_file(stale_path)?; + } + Ok(()) +} + #[cfg(test)] mod tests { - use super::{ContentBlock, ConversationMessage, MessageRole, Session}; + use super::{ + cleanup_rotated_logs, rotate_session_file_if_needed, ContentBlock, ConversationMessage, + MessageRole, Session, SessionFork, + }; + use crate::json::JsonValue; use crate::usage::TokenUsage; use std::fs; + use std::path::{Path, PathBuf}; use std::time::{SystemTime, UNIX_EPOCH}; #[test] - fn persists_and_restores_session_json() { + fn persists_and_restores_session_jsonl() { let mut session = Session::new(); session - .messages - .push(ConversationMessage::user_text("hello")); + .push_user_text("hello") + .expect("user message should append"); session - .messages - .push(ConversationMessage::assistant_with_usage( + .push_message(ConversationMessage::assistant_with_usage( vec![ ContentBlock::Text { text: "thinking".to_string(), @@ -447,16 +1151,15 @@ mod tests { cache_creation_input_tokens: 1, cache_read_input_tokens: 2, }), - )); - session.messages.push(ConversationMessage::tool_result( - "tool-1", "bash", "hi", false, - )); + )) + .expect("assistant message should append"); + session + .push_message(ConversationMessage::tool_result( + "tool-1", "bash", "hi", false, + )) + .expect("tool result should append"); - let nanos = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("system time should be after epoch") - .as_nanos(); - let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json")); + let path = temp_session_path("jsonl"); session.save_to_path(&path).expect("session should save"); let restored = Session::load_from_path(&path).expect("session should load"); fs::remove_file(&path).expect("temp file should be removable"); @@ -467,5 +1170,346 @@ mod tests { restored.messages[1].usage.expect("usage").total_tokens(), 17 ); + assert_eq!(restored.session_id, session.session_id); + } + + #[test] + fn loads_legacy_session_json_object() { + let path = temp_session_path("legacy"); + let legacy = JsonValue::Object( + [ + ("version".to_string(), JsonValue::Number(1)), + ( + "messages".to_string(), + JsonValue::Array(vec![ConversationMessage::user_text("legacy").to_json()]), + ), + ] + .into_iter() + .collect(), + ); + fs::write(&path, legacy.render()).expect("legacy file should write"); + + let restored = Session::load_from_path(&path).expect("legacy session should load"); + fs::remove_file(&path).expect("temp file should be removable"); + + assert_eq!(restored.messages.len(), 1); + assert_eq!( + restored.messages[0], + ConversationMessage::user_text("legacy") + ); + assert!(!restored.session_id.is_empty()); + } + + #[test] + fn appends_messages_to_persisted_jsonl_session() { + let path = temp_session_path("append"); + let mut session = Session::new().with_persistence_path(path.clone()); + session + .save_to_path(&path) + .expect("initial save should succeed"); + session + .push_user_text("hi") + .expect("user append should succeed"); + session + .push_message(ConversationMessage::assistant(vec![ContentBlock::Text { + text: "hello".to_string(), + }])) + .expect("assistant append should succeed"); + + let restored = Session::load_from_path(&path).expect("session should replay from jsonl"); + fs::remove_file(&path).expect("temp file should be removable"); + + assert_eq!(restored.messages.len(), 2); + assert_eq!(restored.messages[0], ConversationMessage::user_text("hi")); + } + + #[test] + fn persists_compaction_metadata() { + let path = temp_session_path("compaction"); + let mut session = Session::new(); + session + .push_user_text("before") + .expect("message should append"); + session.record_compaction("summarized earlier work", 4); + session.save_to_path(&path).expect("session should save"); + + let restored = Session::load_from_path(&path).expect("session should load"); + fs::remove_file(&path).expect("temp file should be removable"); + + let compaction = restored.compaction.expect("compaction metadata"); + assert_eq!(compaction.count, 1); + assert_eq!(compaction.removed_message_count, 4); + assert!(compaction.summary.contains("summarized")); + } + + #[test] + fn forks_sessions_with_branch_metadata_and_persists_it() { + let path = temp_session_path("fork"); + let mut session = Session::new(); + session + .push_user_text("before fork") + .expect("message should append"); + + let forked = session + .fork(Some("investigation".to_string())) + .with_persistence_path(path.clone()); + forked + .save_to_path(&path) + .expect("forked session should save"); + + let restored = Session::load_from_path(&path).expect("forked session should load"); + fs::remove_file(&path).expect("temp file should be removable"); + + assert_ne!(restored.session_id, session.session_id); + assert_eq!( + restored.fork, + Some(SessionFork { + parent_session_id: session.session_id, + branch_name: Some("investigation".to_string()), + }) + ); + assert_eq!(restored.messages, forked.messages); + } + + #[test] + fn rotates_and_cleans_up_large_session_logs() { + // given + let path = temp_session_path("rotation"); + let oversized_length = + usize::try_from(super::ROTATE_AFTER_BYTES + 10).expect("rotate threshold should fit"); + fs::write(&path, "x".repeat(oversized_length)).expect("oversized file should write"); + + // when + rotate_session_file_if_needed(&path).expect("rotation should succeed"); + + // then + assert!( + !path.exists(), + "original path should be rotated away before rewrite" + ); + + for _ in 0..5 { + let rotated = super::rotated_log_path(&path); + fs::write(&rotated, "old").expect("rotated file should write"); + } + cleanup_rotated_logs(&path).expect("cleanup should succeed"); + + let rotated_count = rotation_files(&path).len(); + assert!(rotated_count <= super::MAX_ROTATED_FILES); + for rotated in rotation_files(&path) { + fs::remove_file(rotated).expect("rotated file should be removable"); + } + } + + #[test] + fn rejects_jsonl_record_without_type() { + // given + let path = write_temp_session_file( + "missing-type", + r#"{"message":{"role":"user","blocks":[{"type":"text","text":"hello"}]}}"#, + ); + + // when + let error = Session::load_from_path(&path) + .expect_err("session should reject JSONL records without a type"); + + // then + assert!(error.to_string().contains("missing type")); + fs::remove_file(path).expect("temp file should be removable"); + } + + #[test] + fn rejects_jsonl_message_record_without_message_payload() { + // given + let path = write_temp_session_file("missing-message", r#"{"type":"message"}"#); + + // when + let error = Session::load_from_path(&path) + .expect_err("session should reject JSONL message records without message payload"); + + // then + assert!(error.to_string().contains("missing message")); + fs::remove_file(path).expect("temp file should be removable"); + } + + #[test] + fn rejects_jsonl_record_with_unknown_type() { + // given + let path = write_temp_session_file("unknown-type", r#"{"type":"mystery"}"#); + + // when + let error = Session::load_from_path(&path) + .expect_err("session should reject unknown JSONL record types"); + + // then + assert!(error.to_string().contains("unsupported JSONL record type")); + fs::remove_file(path).expect("temp file should be removable"); + } + + #[test] + fn rejects_legacy_session_json_without_messages() { + // given + let session = JsonValue::Object( + [("version".to_string(), JsonValue::Number(1))] + .into_iter() + .collect(), + ); + + // when + let error = Session::from_json(&session) + .expect_err("legacy session objects should require messages"); + + // then + assert!(error.to_string().contains("missing messages")); + } + + #[test] + fn normalizes_blank_fork_branch_name_to_none() { + // given + let session = Session::new(); + + // when + let forked = session.fork(Some(" ".to_string())); + + // then + assert_eq!(forked.fork.expect("fork metadata").branch_name, None); + } + + #[test] + fn rejects_unknown_content_block_type() { + // given + let block = JsonValue::Object( + [("type".to_string(), JsonValue::String("unknown".to_string()))] + .into_iter() + .collect(), + ); + + // when + let error = ContentBlock::from_json(&block) + .expect_err("content blocks should reject unknown types"); + + // then + assert!(error.to_string().contains("unsupported block type")); + } + + #[test] + fn persists_workspace_root_round_trip_and_forks_inherit_it() { + // given + let path = temp_session_path("workspace-root"); + let workspace_root = PathBuf::from("/tmp/b4-phantom-diag"); + let mut session = Session::new().with_workspace_root(workspace_root.clone()); + session + .push_user_text("write to the right cwd") + .expect("user message should append"); + + // when + session + .save_to_path(&path) + .expect("workspace-bound session should save"); + let restored = Session::load_from_path(&path).expect("session should load"); + let forked = restored.fork(Some("phantom-diag".to_string())); + fs::remove_file(&path).expect("temp file should be removable"); + + // then + assert_eq!(restored.workspace_root(), Some(workspace_root.as_path())); + assert_eq!(forked.workspace_root(), Some(workspace_root.as_path())); + } + + fn temp_session_path(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-session-{label}-{nanos}.json")) + } + + fn write_temp_session_file(label: &str, contents: &str) -> PathBuf { + let path = temp_session_path(label); + fs::write(&path, format!("{contents}\n")).expect("temp session file should write"); + path + } + + fn rotation_files(path: &Path) -> Vec { + let stem = path + .file_stem() + .and_then(|value| value.to_str()) + .expect("temp path should have file stem") + .to_string(); + fs::read_dir(path.parent().expect("temp path should have parent")) + .expect("temp dir should read") + .filter_map(Result::ok) + .map(|entry| entry.path()) + .filter(|entry_path| { + entry_path + .file_name() + .and_then(|value| value.to_str()) + .is_some_and(|name| { + name.starts_with(&format!("{stem}.rot-")) + && Path::new(name) + .extension() + .is_some_and(|ext| ext.eq_ignore_ascii_case("jsonl")) + }) + }) + .collect() + } +} + +/// Per-worktree session isolation: returns a session directory namespaced +/// by the workspace fingerprint of the given working directory. +/// This prevents parallel `opencode serve` instances from colliding. +/// Called by external consumers (e.g. clawhip) to enumerate sessions for a CWD. +#[allow(dead_code)] +pub fn workspace_sessions_dir(cwd: &std::path::Path) -> Result { + let store = crate::session_control::SessionStore::from_cwd(cwd).map_err(|e| { + SessionError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + )) + })?; + Ok(store.sessions_dir().to_path_buf()) +} + +#[cfg(test)] +mod workspace_sessions_dir_tests { + use super::*; + use std::fs; + + #[test] + fn workspace_sessions_dir_returns_fingerprinted_path_for_valid_cwd() { + let tmp = std::env::temp_dir().join("claw-session-dir-test"); + fs::create_dir_all(&tmp).expect("create temp dir"); + + let result = workspace_sessions_dir(&tmp); + assert!( + result.is_ok(), + "workspace_sessions_dir should succeed for a valid CWD, got: {:?}", + result + ); + let dir = result.unwrap(); + // The returned path should be non-empty and end with a hash component + assert!(!dir.as_os_str().is_empty()); + // Two calls with the same CWD should produce identical paths (deterministic) + let result2 = workspace_sessions_dir(&tmp).unwrap(); + assert_eq!(dir, result2, "workspace_sessions_dir must be deterministic"); + + fs::remove_dir_all(&tmp).ok(); + } + + #[test] + fn workspace_sessions_dir_differs_for_different_cwds() { + let tmp_a = std::env::temp_dir().join("claw-session-dir-a"); + let tmp_b = std::env::temp_dir().join("claw-session-dir-b"); + fs::create_dir_all(&tmp_a).expect("create dir a"); + fs::create_dir_all(&tmp_b).expect("create dir b"); + + let dir_a = workspace_sessions_dir(&tmp_a).expect("dir a"); + let dir_b = workspace_sessions_dir(&tmp_b).expect("dir b"); + assert_ne!( + dir_a, dir_b, + "different CWDs must produce different session dirs" + ); + + fs::remove_dir_all(&tmp_a).ok(); + fs::remove_dir_all(&tmp_b).ok(); } } diff --git a/crates/runtime/src/session_control.rs b/crates/runtime/src/session_control.rs new file mode 100644 index 0000000..0524519 --- /dev/null +++ b/crates/runtime/src/session_control.rs @@ -0,0 +1,873 @@ +#![allow(dead_code)] +use std::env; +use std::fmt::{Display, Formatter}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::UNIX_EPOCH; + +use crate::session::{Session, SessionError}; + +/// Per-worktree session store that namespaces on-disk session files by +/// workspace fingerprint so that parallel `opencode serve` instances never +/// collide. +/// +/// Create via [`SessionStore::from_cwd`] (derives the store path from the +/// server's working directory) or [`SessionStore::from_data_dir`] (honours an +/// explicit `--data-dir` flag). Both constructors produce a directory layout +/// of `/sessions//` where `` is a +/// stable hex digest of the canonical workspace root. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionStore { + /// Resolved root of the session namespace, e.g. + /// `/home/user/project/.claw/sessions/a1b2c3d4e5f60718/`. + sessions_root: PathBuf, + /// The canonical workspace path that was fingerprinted. + workspace_root: PathBuf, +} + +impl SessionStore { + /// Build a store from the server's current working directory. + /// + /// The on-disk layout becomes `/.claw/sessions//`. + pub fn from_cwd(cwd: impl AsRef) -> Result { + let cwd = cwd.as_ref(); + let sessions_root = cwd + .join(".claw") + .join("sessions") + .join(workspace_fingerprint(cwd)); + fs::create_dir_all(&sessions_root)?; + Ok(Self { + sessions_root, + workspace_root: cwd.to_path_buf(), + }) + } + + /// Build a store from an explicit `--data-dir` flag. + /// + /// The on-disk layout becomes `/sessions//` + /// where `` is derived from `workspace_root`. + pub fn from_data_dir( + data_dir: impl AsRef, + workspace_root: impl AsRef, + ) -> Result { + let workspace_root = workspace_root.as_ref(); + let sessions_root = data_dir + .as_ref() + .join("sessions") + .join(workspace_fingerprint(workspace_root)); + fs::create_dir_all(&sessions_root)?; + Ok(Self { + sessions_root, + workspace_root: workspace_root.to_path_buf(), + }) + } + + /// The fully resolved sessions directory for this namespace. + #[must_use] + pub fn sessions_dir(&self) -> &Path { + &self.sessions_root + } + + /// The workspace root this store is bound to. + #[must_use] + pub fn workspace_root(&self) -> &Path { + &self.workspace_root + } + + pub fn create_handle(&self, session_id: &str) -> SessionHandle { + let id = session_id.to_string(); + let path = self + .sessions_root + .join(format!("{id}.{PRIMARY_SESSION_EXTENSION}")); + SessionHandle { id, path } + } + + pub fn resolve_reference(&self, reference: &str) -> Result { + if is_session_reference_alias(reference) { + let latest = self.latest_session()?; + return Ok(SessionHandle { + id: latest.id, + path: latest.path, + }); + } + + let direct = PathBuf::from(reference); + let candidate = if direct.is_absolute() { + direct.clone() + } else { + self.workspace_root.join(&direct) + }; + let looks_like_path = direct.extension().is_some() || direct.components().count() > 1; + let path = if candidate.exists() { + candidate + } else if looks_like_path { + return Err(SessionControlError::Format( + format_missing_session_reference(reference), + )); + } else { + self.resolve_managed_path(reference)? + }; + + Ok(SessionHandle { + id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()), + path, + }) + } + + pub fn resolve_managed_path(&self, session_id: &str) -> Result { + for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] { + let path = self.sessions_root.join(format!("{session_id}.{extension}")); + if path.exists() { + return Ok(path); + } + } + Err(SessionControlError::Format( + format_missing_session_reference(session_id), + )) + } + + pub fn list_sessions(&self) -> Result, SessionControlError> { + let mut sessions = Vec::new(); + let read_result = fs::read_dir(&self.sessions_root); + let entries = match read_result { + Ok(entries) => entries, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(sessions), + Err(err) => return Err(err.into()), + }; + for entry in entries { + let entry = entry?; + let path = entry.path(); + if !is_managed_session_file(&path) { + continue; + } + let metadata = entry.metadata()?; + let modified_epoch_millis = metadata + .modified() + .ok() + .and_then(|time| time.duration_since(UNIX_EPOCH).ok()) + .map(|duration| duration.as_millis()) + .unwrap_or_default(); + let (id, message_count, parent_session_id, branch_name) = + match Session::load_from_path(&path) { + Ok(session) => { + let parent_session_id = session + .fork + .as_ref() + .map(|fork| fork.parent_session_id.clone()); + let branch_name = session + .fork + .as_ref() + .and_then(|fork| fork.branch_name.clone()); + ( + session.session_id, + session.messages.len(), + parent_session_id, + branch_name, + ) + } + Err(_) => ( + path.file_stem() + .and_then(|value| value.to_str()) + .unwrap_or("unknown") + .to_string(), + 0, + None, + None, + ), + }; + sessions.push(ManagedSessionSummary { + id, + path, + modified_epoch_millis, + message_count, + parent_session_id, + branch_name, + }); + } + sessions.sort_by(|left, right| { + right + .modified_epoch_millis + .cmp(&left.modified_epoch_millis) + .then_with(|| right.id.cmp(&left.id)) + }); + Ok(sessions) + } + + pub fn latest_session(&self) -> Result { + self.list_sessions()? + .into_iter() + .next() + .ok_or_else(|| SessionControlError::Format(format_no_managed_sessions())) + } + + pub fn load_session( + &self, + reference: &str, + ) -> Result { + let handle = self.resolve_reference(reference)?; + let session = Session::load_from_path(&handle.path)?; + Ok(LoadedManagedSession { + handle: SessionHandle { + id: session.session_id.clone(), + path: handle.path, + }, + session, + }) + } + + pub fn fork_session( + &self, + session: &Session, + branch_name: Option, + ) -> Result { + let parent_session_id = session.session_id.clone(); + let forked = session.fork(branch_name); + let handle = self.create_handle(&forked.session_id); + let branch_name = forked + .fork + .as_ref() + .and_then(|fork| fork.branch_name.clone()); + let forked = forked.with_persistence_path(handle.path.clone()); + forked.save_to_path(&handle.path)?; + Ok(ForkedManagedSession { + parent_session_id, + handle, + session: forked, + branch_name, + }) + } +} + +/// Stable hex fingerprint of a workspace path. +/// +/// Uses FNV-1a (64-bit) to produce a 16-char hex string that partitions the +/// on-disk session directory per workspace root. +#[must_use] +pub fn workspace_fingerprint(workspace_root: &Path) -> String { + let input = workspace_root.to_string_lossy(); + let mut hash = 0xcbf2_9ce4_8422_2325_u64; + for byte in input.as_bytes() { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(0x0100_0000_01b3); + } + format!("{hash:016x}") +} + +pub const PRIMARY_SESSION_EXTENSION: &str = "jsonl"; +pub const LEGACY_SESSION_EXTENSION: &str = "json"; +pub const LATEST_SESSION_REFERENCE: &str = "latest"; + +const SESSION_REFERENCE_ALIASES: &[&str] = &[LATEST_SESSION_REFERENCE, "last", "recent"]; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionHandle { + pub id: String, + pub path: PathBuf, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ManagedSessionSummary { + pub id: String, + pub path: PathBuf, + pub modified_epoch_millis: u128, + pub message_count: usize, + pub parent_session_id: Option, + pub branch_name: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LoadedManagedSession { + pub handle: SessionHandle, + pub session: Session, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ForkedManagedSession { + pub parent_session_id: String, + pub handle: SessionHandle, + pub session: Session, + pub branch_name: Option, +} + +#[derive(Debug)] +pub enum SessionControlError { + Io(std::io::Error), + Session(SessionError), + Format(String), +} + +impl Display for SessionControlError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Io(error) => write!(f, "{error}"), + Self::Session(error) => write!(f, "{error}"), + Self::Format(error) => write!(f, "{error}"), + } + } +} + +impl std::error::Error for SessionControlError {} + +impl From for SessionControlError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From for SessionControlError { + fn from(value: SessionError) -> Self { + Self::Session(value) + } +} + +pub fn sessions_dir() -> Result { + managed_sessions_dir_for(env::current_dir()?) +} + +pub fn managed_sessions_dir_for( + base_dir: impl AsRef, +) -> Result { + let path = base_dir.as_ref().join(".claw").join("sessions"); + fs::create_dir_all(&path)?; + Ok(path) +} + +pub fn create_managed_session_handle( + session_id: &str, +) -> Result { + create_managed_session_handle_for(env::current_dir()?, session_id) +} + +pub fn create_managed_session_handle_for( + base_dir: impl AsRef, + session_id: &str, +) -> Result { + let id = session_id.to_string(); + let path = + managed_sessions_dir_for(base_dir)?.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}")); + Ok(SessionHandle { id, path }) +} + +pub fn resolve_session_reference(reference: &str) -> Result { + resolve_session_reference_for(env::current_dir()?, reference) +} + +pub fn resolve_session_reference_for( + base_dir: impl AsRef, + reference: &str, +) -> Result { + let base_dir = base_dir.as_ref(); + if is_session_reference_alias(reference) { + let latest = latest_managed_session_for(base_dir)?; + return Ok(SessionHandle { + id: latest.id, + path: latest.path, + }); + } + + let direct = PathBuf::from(reference); + let candidate = if direct.is_absolute() { + direct.clone() + } else { + base_dir.join(&direct) + }; + let looks_like_path = direct.extension().is_some() || direct.components().count() > 1; + let path = if candidate.exists() { + candidate + } else if looks_like_path { + return Err(SessionControlError::Format( + format_missing_session_reference(reference), + )); + } else { + resolve_managed_session_path_for(base_dir, reference)? + }; + + Ok(SessionHandle { + id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()), + path, + }) +} + +pub fn resolve_managed_session_path(session_id: &str) -> Result { + resolve_managed_session_path_for(env::current_dir()?, session_id) +} + +pub fn resolve_managed_session_path_for( + base_dir: impl AsRef, + session_id: &str, +) -> Result { + let directory = managed_sessions_dir_for(base_dir)?; + for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] { + let path = directory.join(format!("{session_id}.{extension}")); + if path.exists() { + return Ok(path); + } + } + Err(SessionControlError::Format( + format_missing_session_reference(session_id), + )) +} + +#[must_use] +pub fn is_managed_session_file(path: &Path) -> bool { + path.extension() + .and_then(|ext| ext.to_str()) + .is_some_and(|extension| { + extension == PRIMARY_SESSION_EXTENSION || extension == LEGACY_SESSION_EXTENSION + }) +} + +pub fn list_managed_sessions() -> Result, SessionControlError> { + list_managed_sessions_for(env::current_dir()?) +} + +pub fn list_managed_sessions_for( + base_dir: impl AsRef, +) -> Result, SessionControlError> { + let mut sessions = Vec::new(); + for entry in fs::read_dir(managed_sessions_dir_for(base_dir)?)? { + let entry = entry?; + let path = entry.path(); + if !is_managed_session_file(&path) { + continue; + } + let metadata = entry.metadata()?; + let modified_epoch_millis = metadata + .modified() + .ok() + .and_then(|time| time.duration_since(UNIX_EPOCH).ok()) + .map(|duration| duration.as_millis()) + .unwrap_or_default(); + let (id, message_count, parent_session_id, branch_name) = + match Session::load_from_path(&path) { + Ok(session) => { + let parent_session_id = session + .fork + .as_ref() + .map(|fork| fork.parent_session_id.clone()); + let branch_name = session + .fork + .as_ref() + .and_then(|fork| fork.branch_name.clone()); + ( + session.session_id, + session.messages.len(), + parent_session_id, + branch_name, + ) + } + Err(_) => ( + path.file_stem() + .and_then(|value| value.to_str()) + .unwrap_or("unknown") + .to_string(), + 0, + None, + None, + ), + }; + sessions.push(ManagedSessionSummary { + id, + path, + modified_epoch_millis, + message_count, + parent_session_id, + branch_name, + }); + } + sessions.sort_by(|left, right| { + right + .modified_epoch_millis + .cmp(&left.modified_epoch_millis) + .then_with(|| right.id.cmp(&left.id)) + }); + Ok(sessions) +} + +pub fn latest_managed_session() -> Result { + latest_managed_session_for(env::current_dir()?) +} + +pub fn latest_managed_session_for( + base_dir: impl AsRef, +) -> Result { + list_managed_sessions_for(base_dir)? + .into_iter() + .next() + .ok_or_else(|| SessionControlError::Format(format_no_managed_sessions())) +} + +pub fn load_managed_session(reference: &str) -> Result { + load_managed_session_for(env::current_dir()?, reference) +} + +pub fn load_managed_session_for( + base_dir: impl AsRef, + reference: &str, +) -> Result { + let handle = resolve_session_reference_for(base_dir, reference)?; + let session = Session::load_from_path(&handle.path)?; + Ok(LoadedManagedSession { + handle: SessionHandle { + id: session.session_id.clone(), + path: handle.path, + }, + session, + }) +} + +pub fn fork_managed_session( + session: &Session, + branch_name: Option, +) -> Result { + fork_managed_session_for(env::current_dir()?, session, branch_name) +} + +pub fn fork_managed_session_for( + base_dir: impl AsRef, + session: &Session, + branch_name: Option, +) -> Result { + let parent_session_id = session.session_id.clone(); + let forked = session.fork(branch_name); + let handle = create_managed_session_handle_for(base_dir, &forked.session_id)?; + let branch_name = forked + .fork + .as_ref() + .and_then(|fork| fork.branch_name.clone()); + let forked = forked.with_persistence_path(handle.path.clone()); + forked.save_to_path(&handle.path)?; + Ok(ForkedManagedSession { + parent_session_id, + handle, + session: forked, + branch_name, + }) +} + +#[must_use] +pub fn is_session_reference_alias(reference: &str) -> bool { + SESSION_REFERENCE_ALIASES + .iter() + .any(|alias| reference.eq_ignore_ascii_case(alias)) +} + +fn session_id_from_path(path: &Path) -> Option { + path.file_name() + .and_then(|value| value.to_str()) + .and_then(|name| { + name.strip_suffix(&format!(".{PRIMARY_SESSION_EXTENSION}")) + .or_else(|| name.strip_suffix(&format!(".{LEGACY_SESSION_EXTENSION}"))) + }) + .map(ToOwned::to_owned) +} + +fn format_missing_session_reference(reference: &str) -> String { + format!( + "session not found: {reference}\nHint: managed sessions live in .claw/sessions/. Try `{LATEST_SESSION_REFERENCE}` for the most recent session or `/session list` in the REPL." + ) +} + +fn format_no_managed_sessions() -> String { + format!( + "no managed sessions found in .claw/sessions/\nStart `claw` to create a session, then rerun with `--resume {LATEST_SESSION_REFERENCE}`." + ) +} + +#[cfg(test)] +mod tests { + use super::{ + create_managed_session_handle_for, fork_managed_session_for, is_session_reference_alias, + list_managed_sessions_for, load_managed_session_for, resolve_session_reference_for, + workspace_fingerprint, ManagedSessionSummary, SessionStore, LATEST_SESSION_REFERENCE, + }; + use crate::session::Session; + use std::fs; + use std::path::{Path, PathBuf}; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir() -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-session-control-{nanos}")) + } + + fn persist_session(root: &Path, text: &str) -> Session { + let mut session = Session::new(); + session + .push_user_text(text) + .expect("session message should save"); + let handle = create_managed_session_handle_for(root, &session.session_id) + .expect("managed session handle should build"); + let session = session.with_persistence_path(handle.path.clone()); + session + .save_to_path(&handle.path) + .expect("session should persist"); + session + } + + fn wait_for_next_millisecond() { + let start = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_millis(); + while SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_millis() + <= start + {} + } + + fn summary_by_id<'a>( + summaries: &'a [ManagedSessionSummary], + id: &str, + ) -> &'a ManagedSessionSummary { + summaries + .iter() + .find(|summary| summary.id == id) + .expect("session summary should exist") + } + + #[test] + fn creates_and_lists_managed_sessions() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir should exist"); + let older = persist_session(&root, "older session"); + wait_for_next_millisecond(); + let newer = persist_session(&root, "newer session"); + + // when + let sessions = list_managed_sessions_for(&root).expect("managed sessions should list"); + + // then + assert_eq!(sessions.len(), 2); + assert_eq!(sessions[0].id, newer.session_id); + assert_eq!(summary_by_id(&sessions, &older.session_id).message_count, 1); + assert_eq!(summary_by_id(&sessions, &newer.session_id).message_count, 1); + fs::remove_dir_all(root).expect("temp dir should clean up"); + } + + #[test] + fn resolves_latest_alias_and_loads_session_from_workspace_root() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir should exist"); + let older = persist_session(&root, "older session"); + wait_for_next_millisecond(); + let newer = persist_session(&root, "newer session"); + + // when + let handle = resolve_session_reference_for(&root, LATEST_SESSION_REFERENCE) + .expect("latest alias should resolve"); + let loaded = load_managed_session_for(&root, "recent") + .expect("recent alias should load the latest session"); + + // then + assert_eq!(handle.id, newer.session_id); + assert_eq!(loaded.handle.id, newer.session_id); + assert_eq!(loaded.session.messages.len(), 1); + assert_ne!(loaded.handle.id, older.session_id); + assert!(is_session_reference_alias("last")); + fs::remove_dir_all(root).expect("temp dir should clean up"); + } + + #[test] + fn forks_session_into_managed_storage_with_lineage() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir should exist"); + let source = persist_session(&root, "parent session"); + + // when + let forked = fork_managed_session_for(&root, &source, Some("incident-review".to_string())) + .expect("session should fork"); + let sessions = list_managed_sessions_for(&root).expect("managed sessions should list"); + let summary = summary_by_id(&sessions, &forked.handle.id); + + // then + assert_eq!(forked.parent_session_id, source.session_id); + assert_eq!(forked.branch_name.as_deref(), Some("incident-review")); + assert_eq!( + summary.parent_session_id.as_deref(), + Some(source.session_id.as_str()) + ); + assert_eq!(summary.branch_name.as_deref(), Some("incident-review")); + assert_eq!( + forked.session.persistence_path(), + Some(forked.handle.path.as_path()) + ); + fs::remove_dir_all(root).expect("temp dir should clean up"); + } + + // ------------------------------------------------------------------ + // Per-worktree session isolation (SessionStore) tests + // ------------------------------------------------------------------ + + fn persist_session_via_store(store: &SessionStore, text: &str) -> Session { + let mut session = Session::new(); + session + .push_user_text(text) + .expect("session message should save"); + let handle = store.create_handle(&session.session_id); + let session = session.with_persistence_path(handle.path.clone()); + session + .save_to_path(&handle.path) + .expect("session should persist"); + session + } + + #[test] + fn workspace_fingerprint_is_deterministic_and_differs_per_path() { + // given + let path_a = Path::new("/tmp/worktree-alpha"); + let path_b = Path::new("/tmp/worktree-beta"); + + // when + let fp_a1 = workspace_fingerprint(path_a); + let fp_a2 = workspace_fingerprint(path_a); + let fp_b = workspace_fingerprint(path_b); + + // then + assert_eq!(fp_a1, fp_a2, "same path must produce the same fingerprint"); + assert_ne!( + fp_a1, fp_b, + "different paths must produce different fingerprints" + ); + assert_eq!(fp_a1.len(), 16, "fingerprint must be a 16-char hex string"); + } + + #[test] + fn session_store_from_cwd_isolates_sessions_by_workspace() { + // given + let base = temp_dir(); + let workspace_a = base.join("repo-alpha"); + let workspace_b = base.join("repo-beta"); + fs::create_dir_all(&workspace_a).expect("workspace a should exist"); + fs::create_dir_all(&workspace_b).expect("workspace b should exist"); + + let store_a = SessionStore::from_cwd(&workspace_a).expect("store a should build"); + let store_b = SessionStore::from_cwd(&workspace_b).expect("store b should build"); + + // when + let session_a = persist_session_via_store(&store_a, "alpha work"); + let _session_b = persist_session_via_store(&store_b, "beta work"); + + // then — each store only sees its own sessions + let list_a = store_a.list_sessions().expect("list a"); + let list_b = store_b.list_sessions().expect("list b"); + assert_eq!(list_a.len(), 1, "store a should see exactly one session"); + assert_eq!(list_b.len(), 1, "store b should see exactly one session"); + assert_eq!(list_a[0].id, session_a.session_id); + assert_ne!( + store_a.sessions_dir(), + store_b.sessions_dir(), + "session directories must differ across workspaces" + ); + fs::remove_dir_all(base).expect("temp dir should clean up"); + } + + #[test] + fn session_store_from_data_dir_namespaces_by_workspace() { + // given + let base = temp_dir(); + let data_dir = base.join("global-data"); + let workspace_a = PathBuf::from("/tmp/project-one"); + let workspace_b = PathBuf::from("/tmp/project-two"); + fs::create_dir_all(&data_dir).expect("data dir should exist"); + + let store_a = + SessionStore::from_data_dir(&data_dir, &workspace_a).expect("store a should build"); + let store_b = + SessionStore::from_data_dir(&data_dir, &workspace_b).expect("store b should build"); + + // when + persist_session_via_store(&store_a, "work in project-one"); + persist_session_via_store(&store_b, "work in project-two"); + + // then + assert_ne!( + store_a.sessions_dir(), + store_b.sessions_dir(), + "data-dir stores must namespace by workspace" + ); + assert_eq!(store_a.list_sessions().expect("list a").len(), 1); + assert_eq!(store_b.list_sessions().expect("list b").len(), 1); + assert_eq!(store_a.workspace_root(), workspace_a.as_path()); + assert_eq!(store_b.workspace_root(), workspace_b.as_path()); + fs::remove_dir_all(base).expect("temp dir should clean up"); + } + + #[test] + fn session_store_create_and_load_round_trip() { + // given + let base = temp_dir(); + fs::create_dir_all(&base).expect("base dir should exist"); + let store = SessionStore::from_cwd(&base).expect("store should build"); + let session = persist_session_via_store(&store, "round-trip message"); + + // when + let loaded = store + .load_session(&session.session_id) + .expect("session should load via store"); + + // then + assert_eq!(loaded.handle.id, session.session_id); + assert_eq!(loaded.session.messages.len(), 1); + fs::remove_dir_all(base).expect("temp dir should clean up"); + } + + #[test] + fn session_store_latest_and_resolve_reference() { + // given + let base = temp_dir(); + fs::create_dir_all(&base).expect("base dir should exist"); + let store = SessionStore::from_cwd(&base).expect("store should build"); + let _older = persist_session_via_store(&store, "older"); + wait_for_next_millisecond(); + let newer = persist_session_via_store(&store, "newer"); + + // when + let latest = store.latest_session().expect("latest should resolve"); + let handle = store + .resolve_reference("latest") + .expect("latest alias should resolve"); + + // then + assert_eq!(latest.id, newer.session_id); + assert_eq!(handle.id, newer.session_id); + fs::remove_dir_all(base).expect("temp dir should clean up"); + } + + #[test] + fn session_store_fork_stays_in_same_namespace() { + // given + let base = temp_dir(); + fs::create_dir_all(&base).expect("base dir should exist"); + let store = SessionStore::from_cwd(&base).expect("store should build"); + let source = persist_session_via_store(&store, "parent work"); + + // when + let forked = store + .fork_session(&source, Some("bugfix".to_string())) + .expect("fork should succeed"); + let sessions = store.list_sessions().expect("list sessions"); + + // then + assert_eq!( + sessions.len(), + 2, + "forked session must land in the same namespace" + ); + assert_eq!(forked.parent_session_id, source.session_id); + assert_eq!(forked.branch_name.as_deref(), Some("bugfix")); + assert!( + forked.handle.path.starts_with(store.sessions_dir()), + "forked session path must be inside the store namespace" + ); + fs::remove_dir_all(base).expect("temp dir should clean up"); + } +} diff --git a/crates/runtime/src/sse.rs b/crates/runtime/src/sse.rs index 331ae50..3c0cbee 100644 --- a/crates/runtime/src/sse.rs +++ b/crates/runtime/src/sse.rs @@ -80,7 +80,11 @@ impl IncrementalSseParser { } fn take_event(&mut self) -> Option { - if self.data_lines.is_empty() && self.event_name.is_none() && self.id.is_none() && self.retry.is_none() { + if self.data_lines.is_empty() + && self.event_name.is_none() + && self.id.is_none() + && self.retry.is_none() + { return None; } @@ -102,8 +106,13 @@ mod tests { #[test] fn parses_streaming_events() { + // given let mut parser = IncrementalSseParser::new(); + + // when let first = parser.push_chunk("event: message\ndata: hel"); + + // then assert!(first.is_empty()); let second = parser.push_chunk("lo\n\nid: 1\ndata: world\n\n"); @@ -125,4 +134,25 @@ mod tests { ] ); } + + #[test] + fn finish_flushes_a_trailing_event_without_separator() { + // given + let mut parser = IncrementalSseParser::new(); + parser.push_chunk("event: message\ndata: trailing"); + + // when + let events = parser.finish(); + + // then + assert_eq!( + events, + vec![SseEvent { + event: Some("message".to_string()), + data: "trailing".to_string(), + id: None, + retry: None, + }] + ); + } } diff --git a/crates/runtime/src/stale_base.rs b/crates/runtime/src/stale_base.rs new file mode 100644 index 0000000..b432d30 --- /dev/null +++ b/crates/runtime/src/stale_base.rs @@ -0,0 +1,429 @@ +#![allow(clippy::must_use_candidate)] +use std::path::Path; +use std::process::Command; + +/// Outcome of comparing the worktree HEAD against the expected base commit. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BaseCommitState { + /// HEAD matches the expected base commit. + Matches, + /// HEAD has diverged from the expected base. + Diverged { expected: String, actual: String }, + /// No expected base was supplied (neither flag nor file). + NoExpectedBase, + /// The working directory is not inside a git repository. + NotAGitRepo, +} + +/// Where the expected base commit originated from. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BaseCommitSource { + Flag(String), + File(String), +} + +/// Read the `.claw-base` file from the given directory and return the trimmed +/// commit hash, or `None` when the file is absent or empty. +pub fn read_claw_base_file(cwd: &Path) -> Option { + let path = cwd.join(".claw-base"); + let content = std::fs::read_to_string(path).ok()?; + let trimmed = content.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } +} + +/// Resolve the expected base commit: prefer the `--base-commit` flag value, +/// fall back to reading `.claw-base` from `cwd`. +pub fn resolve_expected_base(flag_value: Option<&str>, cwd: &Path) -> Option { + if let Some(value) = flag_value { + let trimmed = value.trim(); + if !trimmed.is_empty() { + return Some(BaseCommitSource::Flag(trimmed.to_string())); + } + } + read_claw_base_file(cwd).map(BaseCommitSource::File) +} + +/// Verify that the worktree HEAD matches `expected_base`. +/// +/// Returns [`BaseCommitState::NoExpectedBase`] when no expected commit is +/// provided (the check is effectively a no-op in that case). +pub fn check_base_commit(cwd: &Path, expected_base: Option<&BaseCommitSource>) -> BaseCommitState { + let Some(source) = expected_base else { + return BaseCommitState::NoExpectedBase; + }; + let expected_raw = match source { + BaseCommitSource::Flag(value) | BaseCommitSource::File(value) => value.as_str(), + }; + + let Some(head_sha) = resolve_head_sha(cwd) else { + return BaseCommitState::NotAGitRepo; + }; + + let Some(expected_sha) = resolve_rev(cwd, expected_raw) else { + // If the expected ref cannot be resolved, compare raw strings as a + // best-effort fallback (e.g. partial SHA provided by the caller). + return if head_sha.starts_with(expected_raw) || expected_raw.starts_with(&head_sha) { + BaseCommitState::Matches + } else { + BaseCommitState::Diverged { + expected: expected_raw.to_string(), + actual: head_sha, + } + }; + }; + + if head_sha == expected_sha { + BaseCommitState::Matches + } else { + BaseCommitState::Diverged { + expected: expected_sha, + actual: head_sha, + } + } +} + +/// Format a human-readable warning when the base commit has diverged. +/// +/// Returns `None` for non-warning states (`Matches`, `NoExpectedBase`). +pub fn format_stale_base_warning(state: &BaseCommitState) -> Option { + match state { + BaseCommitState::Diverged { expected, actual } => Some(format!( + "warning: worktree HEAD ({actual}) does not match expected base commit ({expected}). \ + Session may run against a stale codebase." + )), + BaseCommitState::NotAGitRepo => { + Some("warning: stale-base check skipped — not inside a git repository.".to_string()) + } + BaseCommitState::Matches | BaseCommitState::NoExpectedBase => None, + } +} + +fn resolve_head_sha(cwd: &Path) -> Option { + resolve_rev(cwd, "HEAD") +} + +fn resolve_rev(cwd: &Path, rev: &str) -> Option { + let output = Command::new("git") + .args(["rev-parse", rev]) + .current_dir(cwd) + .output() + .ok()?; + if !output.status.success() { + return None; + } + let sha = String::from_utf8(output.stdout).ok()?; + let trimmed = sha.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::process::Command; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir() -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-stale-base-{nanos}")) + } + + fn init_repo(path: &std::path::Path) { + fs::create_dir_all(path).expect("create repo dir"); + run(path, &["init", "--quiet", "-b", "main"]); + run(path, &["config", "user.email", "tests@example.com"]); + run(path, &["config", "user.name", "Stale Base Tests"]); + fs::write(path.join("init.txt"), "initial\n").expect("write init file"); + run(path, &["add", "."]); + run(path, &["commit", "-m", "initial commit", "--quiet"]); + } + + fn run(cwd: &std::path::Path, args: &[&str]) { + let status = Command::new("git") + .args(args) + .current_dir(cwd) + .status() + .unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" "))); + assert!( + status.success(), + "git {} exited with {status}", + args.join(" ") + ); + } + + fn commit_file(repo: &std::path::Path, name: &str, msg: &str) { + fs::write(repo.join(name), format!("{msg}\n")).expect("write file"); + run(repo, &["add", name]); + run(repo, &["commit", "-m", msg, "--quiet"]); + } + + fn head_sha(repo: &std::path::Path) -> String { + let output = Command::new("git") + .args(["rev-parse", "HEAD"]) + .current_dir(repo) + .output() + .expect("git rev-parse HEAD"); + String::from_utf8(output.stdout) + .expect("valid utf8") + .trim() + .to_string() + } + + #[test] + fn matches_when_head_equals_expected_base() { + // given + let root = temp_dir(); + init_repo(&root); + let sha = head_sha(&root); + let source = BaseCommitSource::Flag(sha); + + // when + let state = check_base_commit(&root, Some(&source)); + + // then + assert_eq!(state, BaseCommitState::Matches); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn diverged_when_head_moved_past_expected_base() { + // given + let root = temp_dir(); + init_repo(&root); + let old_sha = head_sha(&root); + commit_file(&root, "extra.txt", "move head forward"); + let new_sha = head_sha(&root); + let source = BaseCommitSource::Flag(old_sha.clone()); + + // when + let state = check_base_commit(&root, Some(&source)); + + // then + assert_eq!( + state, + BaseCommitState::Diverged { + expected: old_sha, + actual: new_sha, + } + ); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn no_expected_base_when_source_is_none() { + // given + let root = temp_dir(); + init_repo(&root); + + // when + let state = check_base_commit(&root, None); + + // then + assert_eq!(state, BaseCommitState::NoExpectedBase); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn not_a_git_repo_when_outside_repo() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + let source = BaseCommitSource::Flag("abc1234".to_string()); + + // when + let state = check_base_commit(&root, Some(&source)); + + // then + assert_eq!(state, BaseCommitState::NotAGitRepo); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn reads_claw_base_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), "abc1234def5678\n").expect("write .claw-base"); + + // when + let value = read_claw_base_file(&root); + + // then + assert_eq!(value, Some("abc1234def5678".to_string())); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn returns_none_for_missing_claw_base_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + + // when + let value = read_claw_base_file(&root); + + // then + assert!(value.is_none()); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn returns_none_for_empty_claw_base_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), " \n").expect("write empty .claw-base"); + + // when + let value = read_claw_base_file(&root); + + // then + assert!(value.is_none()); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn resolve_expected_base_prefers_flag_over_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base"); + + // when + let source = resolve_expected_base(Some("from_flag"), &root); + + // then + assert_eq!( + source, + Some(BaseCommitSource::Flag("from_flag".to_string())) + ); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn resolve_expected_base_falls_back_to_file() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base"); + + // when + let source = resolve_expected_base(None, &root); + + // then + assert_eq!( + source, + Some(BaseCommitSource::File("from_file".to_string())) + ); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn resolve_expected_base_returns_none_when_nothing_available() { + // given + let root = temp_dir(); + fs::create_dir_all(&root).expect("create dir"); + + // when + let source = resolve_expected_base(None, &root); + + // then + assert!(source.is_none()); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn format_warning_returns_message_for_diverged() { + // given + let state = BaseCommitState::Diverged { + expected: "abc1234".to_string(), + actual: "def5678".to_string(), + }; + + // when + let warning = format_stale_base_warning(&state); + + // then + let message = warning.expect("should produce warning"); + assert!(message.contains("abc1234")); + assert!(message.contains("def5678")); + assert!(message.contains("stale codebase")); + } + + #[test] + fn format_warning_returns_none_for_matches() { + // given + let state = BaseCommitState::Matches; + + // when + let warning = format_stale_base_warning(&state); + + // then + assert!(warning.is_none()); + } + + #[test] + fn format_warning_returns_none_for_no_expected_base() { + // given + let state = BaseCommitState::NoExpectedBase; + + // when + let warning = format_stale_base_warning(&state); + + // then + assert!(warning.is_none()); + } + + #[test] + fn matches_with_claw_base_file_in_real_repo() { + // given + let root = temp_dir(); + init_repo(&root); + let sha = head_sha(&root); + fs::write(root.join(".claw-base"), format!("{sha}\n")).expect("write .claw-base"); + let source = resolve_expected_base(None, &root); + + // when + let state = check_base_commit(&root, source.as_ref()); + + // then + assert_eq!(state, BaseCommitState::Matches); + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn diverged_with_claw_base_file_after_new_commit() { + // given + let root = temp_dir(); + init_repo(&root); + let old_sha = head_sha(&root); + fs::write(root.join(".claw-base"), format!("{old_sha}\n")).expect("write .claw-base"); + commit_file(&root, "new.txt", "advance head"); + let new_sha = head_sha(&root); + let source = resolve_expected_base(None, &root); + + // when + let state = check_base_commit(&root, source.as_ref()); + + // then + assert_eq!( + state, + BaseCommitState::Diverged { + expected: old_sha, + actual: new_sha, + } + ); + fs::remove_dir_all(&root).expect("cleanup"); + } +} diff --git a/crates/runtime/src/stale_branch.rs b/crates/runtime/src/stale_branch.rs new file mode 100644 index 0000000..ccdd3f5 --- /dev/null +++ b/crates/runtime/src/stale_branch.rs @@ -0,0 +1,417 @@ +#![allow(clippy::must_use_candidate)] +use std::path::Path; +use std::process::Command; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BranchFreshness { + Fresh, + Stale { + commits_behind: usize, + missing_fixes: Vec, + }, + Diverged { + ahead: usize, + behind: usize, + missing_fixes: Vec, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StaleBranchPolicy { + AutoRebase, + AutoMergeForward, + WarnOnly, + Block, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StaleBranchEvent { + BranchStaleAgainstMain { + branch: String, + commits_behind: usize, + missing_fixes: Vec, + }, + RebaseAttempted { + branch: String, + result: String, + }, + MergeForwardAttempted { + branch: String, + result: String, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StaleBranchAction { + Noop, + Warn { message: String }, + Block { message: String }, + Rebase, + MergeForward, +} + +pub fn check_freshness(branch: &str, main_ref: &str) -> BranchFreshness { + check_freshness_in(branch, main_ref, Path::new(".")) +} + +pub fn apply_policy(freshness: &BranchFreshness, policy: StaleBranchPolicy) -> StaleBranchAction { + match freshness { + BranchFreshness::Fresh => StaleBranchAction::Noop, + BranchFreshness::Stale { + commits_behind, + missing_fixes, + } => match policy { + StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn { + message: format!( + "Branch is {commits_behind} commit(s) behind main. Missing fixes: {}", + if missing_fixes.is_empty() { + "(none)".to_string() + } else { + missing_fixes.join("; ") + } + ), + }, + StaleBranchPolicy::Block => StaleBranchAction::Block { + message: format!( + "Branch is {commits_behind} commit(s) behind main and must be updated before proceeding." + ), + }, + StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase, + StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward, + }, + BranchFreshness::Diverged { + ahead, + behind, + missing_fixes, + } => match policy { + StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn { + message: format!( + "Branch has diverged: {ahead} commit(s) ahead, {behind} commit(s) behind main. Missing fixes: {}", + format_missing_fixes(missing_fixes) + ), + }, + StaleBranchPolicy::Block => StaleBranchAction::Block { + message: format!( + "Branch has diverged ({ahead} ahead, {behind} behind) and must be reconciled before proceeding. Missing fixes: {}", + format_missing_fixes(missing_fixes) + ), + }, + StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase, + StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward, + }, + } +} + +pub(crate) fn check_freshness_in( + branch: &str, + main_ref: &str, + repo_path: &Path, +) -> BranchFreshness { + let behind = rev_list_count(main_ref, branch, repo_path); + let ahead = rev_list_count(branch, main_ref, repo_path); + + if behind == 0 { + return BranchFreshness::Fresh; + } + + if ahead > 0 { + return BranchFreshness::Diverged { + ahead, + behind, + missing_fixes: missing_fix_subjects(main_ref, branch, repo_path), + }; + } + + let missing_fixes = missing_fix_subjects(main_ref, branch, repo_path); + BranchFreshness::Stale { + commits_behind: behind, + missing_fixes, + } +} + +fn format_missing_fixes(missing_fixes: &[String]) -> String { + if missing_fixes.is_empty() { + "(none)".to_string() + } else { + missing_fixes.join("; ") + } +} + +fn rev_list_count(a: &str, b: &str, repo_path: &Path) -> usize { + let output = Command::new("git") + .args(["rev-list", "--count", &format!("{b}..{a}")]) + .current_dir(repo_path) + .output(); + match output { + Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout) + .trim() + .parse::() + .unwrap_or(0), + _ => 0, + } +} + +fn missing_fix_subjects(a: &str, b: &str, repo_path: &Path) -> Vec { + let output = Command::new("git") + .args(["log", "--format=%s", &format!("{b}..{a}")]) + .current_dir(repo_path) + .output(); + match output { + Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout) + .lines() + .filter(|l| !l.is_empty()) + .map(String::from) + .collect(), + _ => Vec::new(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir() -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("runtime-stale-branch-{nanos}")) + } + + fn init_repo(path: &Path) { + fs::create_dir_all(path).expect("create repo dir"); + run(path, &["init", "--quiet", "-b", "main"]); + run(path, &["config", "user.email", "tests@example.com"]); + run(path, &["config", "user.name", "Stale Branch Tests"]); + fs::write(path.join("init.txt"), "initial\n").expect("write init file"); + run(path, &["add", "."]); + run(path, &["commit", "-m", "initial commit", "--quiet"]); + } + + fn run(cwd: &Path, args: &[&str]) { + let status = Command::new("git") + .args(args) + .current_dir(cwd) + .status() + .unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" "))); + assert!( + status.success(), + "git {} exited with {status}", + args.join(" ") + ); + } + + fn commit_file(repo: &Path, name: &str, msg: &str) { + fs::write(repo.join(name), format!("{msg}\n")).expect("write file"); + run(repo, &["add", name]); + run(repo, &["commit", "-m", msg, "--quiet"]); + } + + #[test] + fn fresh_branch_passes() { + let root = temp_dir(); + init_repo(&root); + + // given + run(&root, &["checkout", "-b", "topic"]); + + // when + let freshness = check_freshness_in("topic", "main", &root); + + // then + assert_eq!(freshness, BranchFreshness::Fresh); + + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn fresh_branch_ahead_of_main_still_fresh() { + let root = temp_dir(); + init_repo(&root); + + // given + run(&root, &["checkout", "-b", "topic"]); + commit_file(&root, "feature.txt", "add feature"); + + // when + let freshness = check_freshness_in("topic", "main", &root); + + // then + assert_eq!(freshness, BranchFreshness::Fresh); + + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn stale_branch_detected_with_correct_behind_count_and_missing_fixes() { + let root = temp_dir(); + init_repo(&root); + + // given + run(&root, &["checkout", "-b", "topic"]); + run(&root, &["checkout", "main"]); + commit_file(&root, "fix1.txt", "fix: resolve timeout"); + commit_file(&root, "fix2.txt", "fix: handle null pointer"); + + // when + let freshness = check_freshness_in("topic", "main", &root); + + // then + match freshness { + BranchFreshness::Stale { + commits_behind, + missing_fixes, + } => { + assert_eq!(commits_behind, 2); + assert_eq!(missing_fixes.len(), 2); + assert_eq!(missing_fixes[0], "fix: handle null pointer"); + assert_eq!(missing_fixes[1], "fix: resolve timeout"); + } + other => panic!("expected Stale, got {other:?}"), + } + + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn diverged_branch_detection() { + let root = temp_dir(); + init_repo(&root); + + // given + run(&root, &["checkout", "-b", "topic"]); + commit_file(&root, "topic_work.txt", "topic work"); + run(&root, &["checkout", "main"]); + commit_file(&root, "main_fix.txt", "main fix"); + + // when + let freshness = check_freshness_in("topic", "main", &root); + + // then + match freshness { + BranchFreshness::Diverged { + ahead, + behind, + missing_fixes, + } => { + assert_eq!(ahead, 1); + assert_eq!(behind, 1); + assert_eq!(missing_fixes, vec!["main fix".to_string()]); + } + other => panic!("expected Diverged, got {other:?}"), + } + + fs::remove_dir_all(&root).expect("cleanup"); + } + + #[test] + fn policy_noop_for_fresh_branch() { + // given + let freshness = BranchFreshness::Fresh; + + // when + let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly); + + // then + assert_eq!(action, StaleBranchAction::Noop); + } + + #[test] + fn policy_warn_for_stale_branch() { + // given + let freshness = BranchFreshness::Stale { + commits_behind: 3, + missing_fixes: vec!["fix: timeout".into(), "fix: null ptr".into()], + }; + + // when + let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly); + + // then + match action { + StaleBranchAction::Warn { message } => { + assert!(message.contains("3 commit(s) behind")); + assert!(message.contains("fix: timeout")); + assert!(message.contains("fix: null ptr")); + } + other => panic!("expected Warn, got {other:?}"), + } + } + + #[test] + fn policy_block_for_stale_branch() { + // given + let freshness = BranchFreshness::Stale { + commits_behind: 1, + missing_fixes: vec!["hotfix".into()], + }; + + // when + let action = apply_policy(&freshness, StaleBranchPolicy::Block); + + // then + match action { + StaleBranchAction::Block { message } => { + assert!(message.contains("1 commit(s) behind")); + } + other => panic!("expected Block, got {other:?}"), + } + } + + #[test] + fn policy_auto_rebase_for_stale_branch() { + // given + let freshness = BranchFreshness::Stale { + commits_behind: 2, + missing_fixes: vec![], + }; + + // when + let action = apply_policy(&freshness, StaleBranchPolicy::AutoRebase); + + // then + assert_eq!(action, StaleBranchAction::Rebase); + } + + #[test] + fn policy_auto_merge_forward_for_diverged_branch() { + // given + let freshness = BranchFreshness::Diverged { + ahead: 5, + behind: 2, + missing_fixes: vec!["fix: merge main".into()], + }; + + // when + let action = apply_policy(&freshness, StaleBranchPolicy::AutoMergeForward); + + // then + assert_eq!(action, StaleBranchAction::MergeForward); + } + + #[test] + fn policy_warn_for_diverged_branch() { + // given + let freshness = BranchFreshness::Diverged { + ahead: 3, + behind: 1, + missing_fixes: vec!["main hotfix".into()], + }; + + // when + let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly); + + // then + match action { + StaleBranchAction::Warn { message } => { + assert!(message.contains("diverged")); + assert!(message.contains("3 commit(s) ahead")); + assert!(message.contains("1 commit(s) behind")); + assert!(message.contains("main hotfix")); + } + other => panic!("expected Warn, got {other:?}"), + } + } +} diff --git a/crates/runtime/src/summary_compression.rs b/crates/runtime/src/summary_compression.rs new file mode 100644 index 0000000..30ae276 --- /dev/null +++ b/crates/runtime/src/summary_compression.rs @@ -0,0 +1,300 @@ +use std::collections::BTreeSet; + +const DEFAULT_MAX_CHARS: usize = 1_200; +const DEFAULT_MAX_LINES: usize = 24; +const DEFAULT_MAX_LINE_CHARS: usize = 160; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SummaryCompressionBudget { + pub max_chars: usize, + pub max_lines: usize, + pub max_line_chars: usize, +} + +impl Default for SummaryCompressionBudget { + fn default() -> Self { + Self { + max_chars: DEFAULT_MAX_CHARS, + max_lines: DEFAULT_MAX_LINES, + max_line_chars: DEFAULT_MAX_LINE_CHARS, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SummaryCompressionResult { + pub summary: String, + pub original_chars: usize, + pub compressed_chars: usize, + pub original_lines: usize, + pub compressed_lines: usize, + pub removed_duplicate_lines: usize, + pub omitted_lines: usize, + pub truncated: bool, +} + +#[must_use] +pub fn compress_summary( + summary: &str, + budget: SummaryCompressionBudget, +) -> SummaryCompressionResult { + let original_chars = summary.chars().count(); + let original_lines = summary.lines().count(); + + let normalized = normalize_lines(summary, budget.max_line_chars); + if normalized.lines.is_empty() || budget.max_chars == 0 || budget.max_lines == 0 { + return SummaryCompressionResult { + summary: String::new(), + original_chars, + compressed_chars: 0, + original_lines, + compressed_lines: 0, + removed_duplicate_lines: normalized.removed_duplicate_lines, + omitted_lines: normalized.lines.len(), + truncated: original_chars > 0, + }; + } + + let selected = select_line_indexes(&normalized.lines, budget); + let mut compressed_lines = selected + .iter() + .map(|index| normalized.lines[*index].clone()) + .collect::>(); + if compressed_lines.is_empty() { + compressed_lines.push(truncate_line(&normalized.lines[0], budget.max_chars)); + } + let omitted_lines = normalized + .lines + .len() + .saturating_sub(compressed_lines.len()); + + if omitted_lines > 0 { + let omission_notice = omission_notice(omitted_lines); + push_line_with_budget(&mut compressed_lines, omission_notice, budget); + } + + let compressed_summary = compressed_lines.join("\n"); + + SummaryCompressionResult { + compressed_chars: compressed_summary.chars().count(), + compressed_lines: compressed_lines.len(), + removed_duplicate_lines: normalized.removed_duplicate_lines, + omitted_lines, + truncated: compressed_summary != summary.trim(), + summary: compressed_summary, + original_chars, + original_lines, + } +} + +#[must_use] +pub fn compress_summary_text(summary: &str) -> String { + compress_summary(summary, SummaryCompressionBudget::default()).summary +} + +#[derive(Debug, Default)] +struct NormalizedSummary { + lines: Vec, + removed_duplicate_lines: usize, +} + +fn normalize_lines(summary: &str, max_line_chars: usize) -> NormalizedSummary { + let mut seen = BTreeSet::new(); + let mut lines = Vec::new(); + let mut removed_duplicate_lines = 0; + + for raw_line in summary.lines() { + let normalized = collapse_inline_whitespace(raw_line); + if normalized.is_empty() { + continue; + } + + let truncated = truncate_line(&normalized, max_line_chars); + let dedupe_key = dedupe_key(&truncated); + if !seen.insert(dedupe_key) { + removed_duplicate_lines += 1; + continue; + } + + lines.push(truncated); + } + + NormalizedSummary { + lines, + removed_duplicate_lines, + } +} + +fn select_line_indexes(lines: &[String], budget: SummaryCompressionBudget) -> Vec { + let mut selected = BTreeSet::::new(); + + for priority in 0..=3 { + for (index, line) in lines.iter().enumerate() { + if selected.contains(&index) || line_priority(line) != priority { + continue; + } + + let candidate = selected + .iter() + .map(|selected_index| lines[*selected_index].as_str()) + .chain(std::iter::once(line.as_str())) + .collect::>(); + + if candidate.len() > budget.max_lines { + continue; + } + + if joined_char_count(&candidate) > budget.max_chars { + continue; + } + + selected.insert(index); + } + } + + selected.into_iter().collect() +} + +fn push_line_with_budget(lines: &mut Vec, line: String, budget: SummaryCompressionBudget) { + let candidate = lines + .iter() + .map(String::as_str) + .chain(std::iter::once(line.as_str())) + .collect::>(); + + if candidate.len() <= budget.max_lines && joined_char_count(&candidate) <= budget.max_chars { + lines.push(line); + } +} + +fn joined_char_count(lines: &[&str]) -> usize { + lines.iter().map(|line| line.chars().count()).sum::() + lines.len().saturating_sub(1) +} + +fn line_priority(line: &str) -> usize { + if line == "Summary:" || line == "Conversation summary:" || is_core_detail(line) { + 0 + } else if is_section_header(line) { + 1 + } else if line.starts_with("- ") || line.starts_with(" - ") { + 2 + } else { + 3 + } +} + +fn is_core_detail(line: &str) -> bool { + [ + "- Scope:", + "- Current work:", + "- Pending work:", + "- Key files referenced:", + "- Tools mentioned:", + "- Recent user requests:", + "- Previously compacted context:", + "- Newly compacted context:", + ] + .iter() + .any(|prefix| line.starts_with(prefix)) +} + +fn is_section_header(line: &str) -> bool { + line.ends_with(':') +} + +fn omission_notice(omitted_lines: usize) -> String { + format!("- … {omitted_lines} additional line(s) omitted.") +} + +fn collapse_inline_whitespace(line: &str) -> String { + line.split_whitespace().collect::>().join(" ") +} + +fn truncate_line(line: &str, max_chars: usize) -> String { + if max_chars == 0 || line.chars().count() <= max_chars { + return line.to_string(); + } + + if max_chars == 1 { + return "…".to_string(); + } + + let mut truncated = line + .chars() + .take(max_chars.saturating_sub(1)) + .collect::(); + truncated.push('…'); + truncated +} + +fn dedupe_key(line: &str) -> String { + line.to_ascii_lowercase() +} + +#[cfg(test)] +mod tests { + use super::{compress_summary, compress_summary_text, SummaryCompressionBudget}; + + #[test] + fn collapses_whitespace_and_duplicate_lines() { + // given + let summary = "Conversation summary:\n\n- Scope: compact earlier messages.\n- Scope: compact earlier messages.\n- Current work: update runtime module.\n"; + + // when + let result = compress_summary(summary, SummaryCompressionBudget::default()); + + // then + assert_eq!(result.removed_duplicate_lines, 1); + assert!(result + .summary + .contains("- Scope: compact earlier messages.")); + assert!(!result.summary.contains(" compact earlier")); + } + + #[test] + fn keeps_core_lines_when_budget_is_tight() { + // given + let summary = [ + "Conversation summary:", + "- Scope: 18 earlier messages compacted.", + "- Current work: finish summary compression.", + "- Key timeline:", + " - user: asked for a working implementation.", + " - assistant: inspected runtime compaction flow.", + " - tool: cargo check succeeded.", + ] + .join("\n"); + + // when + let result = compress_summary( + &summary, + SummaryCompressionBudget { + max_chars: 120, + max_lines: 3, + max_line_chars: 80, + }, + ); + + // then + assert!(result.summary.contains("Conversation summary:")); + assert!(result + .summary + .contains("- Scope: 18 earlier messages compacted.")); + assert!(result + .summary + .contains("- Current work: finish summary compression.")); + assert!(result.omitted_lines > 0); + } + + #[test] + fn provides_a_default_text_only_helper() { + // given + let summary = "Summary:\n\nA short line."; + + // when + let compressed = compress_summary_text(summary); + + // then + assert_eq!(compressed, "Summary:\nA short line."); + } +} diff --git a/crates/runtime/src/task_packet.rs b/crates/runtime/src/task_packet.rs new file mode 100644 index 0000000..86d1c6c --- /dev/null +++ b/crates/runtime/src/task_packet.rs @@ -0,0 +1,158 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::{Display, Formatter}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TaskPacket { + pub objective: String, + pub scope: String, + pub repo: String, + pub branch_policy: String, + pub acceptance_tests: Vec, + pub commit_policy: String, + pub reporting_contract: String, + pub escalation_policy: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TaskPacketValidationError { + errors: Vec, +} + +impl TaskPacketValidationError { + #[must_use] + pub fn new(errors: Vec) -> Self { + Self { errors } + } + + #[must_use] + pub fn errors(&self) -> &[String] { + &self.errors + } +} + +impl Display for TaskPacketValidationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.errors.join("; ")) + } +} + +impl std::error::Error for TaskPacketValidationError {} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedPacket(TaskPacket); + +impl ValidatedPacket { + #[must_use] + pub fn packet(&self) -> &TaskPacket { + &self.0 + } + + #[must_use] + pub fn into_inner(self) -> TaskPacket { + self.0 + } +} + +pub fn validate_packet(packet: TaskPacket) -> Result { + let mut errors = Vec::new(); + + validate_required("objective", &packet.objective, &mut errors); + validate_required("scope", &packet.scope, &mut errors); + validate_required("repo", &packet.repo, &mut errors); + validate_required("branch_policy", &packet.branch_policy, &mut errors); + validate_required("commit_policy", &packet.commit_policy, &mut errors); + validate_required( + "reporting_contract", + &packet.reporting_contract, + &mut errors, + ); + validate_required("escalation_policy", &packet.escalation_policy, &mut errors); + + for (index, test) in packet.acceptance_tests.iter().enumerate() { + if test.trim().is_empty() { + errors.push(format!( + "acceptance_tests contains an empty value at index {index}" + )); + } + } + + if errors.is_empty() { + Ok(ValidatedPacket(packet)) + } else { + Err(TaskPacketValidationError::new(errors)) + } +} + +fn validate_required(field: &str, value: &str, errors: &mut Vec) { + if value.trim().is_empty() { + errors.push(format!("{field} must not be empty")); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_packet() -> TaskPacket { + TaskPacket { + objective: "Implement typed task packet format".to_string(), + scope: "runtime/task system".to_string(), + repo: "claw-code-parity".to_string(), + branch_policy: "origin/main only".to_string(), + acceptance_tests: vec![ + "cargo build --workspace".to_string(), + "cargo test --workspace".to_string(), + ], + commit_policy: "single verified commit".to_string(), + reporting_contract: "print build result, test result, commit sha".to_string(), + escalation_policy: "stop only on destructive ambiguity".to_string(), + } + } + + #[test] + fn valid_packet_passes_validation() { + let packet = sample_packet(); + let validated = validate_packet(packet.clone()).expect("packet should validate"); + assert_eq!(validated.packet(), &packet); + assert_eq!(validated.into_inner(), packet); + } + + #[test] + fn invalid_packet_accumulates_errors() { + let packet = TaskPacket { + objective: " ".to_string(), + scope: String::new(), + repo: String::new(), + branch_policy: "\t".to_string(), + acceptance_tests: vec!["ok".to_string(), " ".to_string()], + commit_policy: String::new(), + reporting_contract: String::new(), + escalation_policy: String::new(), + }; + + let error = validate_packet(packet).expect_err("packet should be rejected"); + + assert!(error.errors().len() >= 7); + assert!(error + .errors() + .contains(&"objective must not be empty".to_string())); + assert!(error + .errors() + .contains(&"scope must not be empty".to_string())); + assert!(error + .errors() + .contains(&"repo must not be empty".to_string())); + assert!(error + .errors() + .contains(&"acceptance_tests contains an empty value at index 1".to_string())); + } + + #[test] + fn serialization_roundtrip_preserves_packet() { + let packet = sample_packet(); + let serialized = serde_json::to_string(&packet).expect("packet should serialize"); + let deserialized: TaskPacket = + serde_json::from_str(&serialized).expect("packet should deserialize"); + assert_eq!(deserialized, packet); + } +} diff --git a/crates/runtime/src/task_registry.rs b/crates/runtime/src/task_registry.rs new file mode 100644 index 0000000..7487115 --- /dev/null +++ b/crates/runtime/src/task_registry.rs @@ -0,0 +1,503 @@ +#![allow(clippy::must_use_candidate, clippy::unnecessary_map_or)] +//! In-memory task registry for sub-agent task lifecycle management. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::{validate_packet, TaskPacket, TaskPacketValidationError}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TaskStatus { + Created, + Running, + Completed, + Failed, + Stopped, +} + +impl std::fmt::Display for TaskStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Created => write!(f, "created"), + Self::Running => write!(f, "running"), + Self::Completed => write!(f, "completed"), + Self::Failed => write!(f, "failed"), + Self::Stopped => write!(f, "stopped"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Task { + pub task_id: String, + pub prompt: String, + pub description: Option, + pub task_packet: Option, + pub status: TaskStatus, + pub created_at: u64, + pub updated_at: u64, + pub messages: Vec, + pub output: String, + pub team_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskMessage { + pub role: String, + pub content: String, + pub timestamp: u64, +} + +#[derive(Debug, Clone, Default)] +pub struct TaskRegistry { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct RegistryInner { + tasks: HashMap, + counter: u64, +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +impl TaskRegistry { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn create(&self, prompt: &str, description: Option<&str>) -> Task { + self.create_task(prompt.to_owned(), description.map(str::to_owned), None) + } + + pub fn create_from_packet( + &self, + packet: TaskPacket, + ) -> Result { + let packet = validate_packet(packet)?.into_inner(); + Ok(self.create_task( + packet.objective.clone(), + Some(packet.scope.clone()), + Some(packet), + )) + } + + fn create_task( + &self, + prompt: String, + description: Option, + task_packet: Option, + ) -> Task { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + inner.counter += 1; + let ts = now_secs(); + let task_id = format!("task_{:08x}_{}", ts, inner.counter); + let task = Task { + task_id: task_id.clone(), + prompt, + description, + task_packet, + status: TaskStatus::Created, + created_at: ts, + updated_at: ts, + messages: Vec::new(), + output: String::new(), + team_id: None, + }; + inner.tasks.insert(task_id, task.clone()); + task + } + + pub fn get(&self, task_id: &str) -> Option { + let inner = self.inner.lock().expect("registry lock poisoned"); + inner.tasks.get(task_id).cloned() + } + + pub fn list(&self, status_filter: Option) -> Vec { + let inner = self.inner.lock().expect("registry lock poisoned"); + inner + .tasks + .values() + .filter(|t| status_filter.map_or(true, |s| t.status == s)) + .cloned() + .collect() + } + + pub fn stop(&self, task_id: &str) -> Result { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + + match task.status { + TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => { + return Err(format!( + "task {task_id} is already in terminal state: {}", + task.status + )); + } + _ => {} + } + + task.status = TaskStatus::Stopped; + task.updated_at = now_secs(); + Ok(task.clone()) + } + + pub fn update(&self, task_id: &str, message: &str) -> Result { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + + task.messages.push(TaskMessage { + role: String::from("user"), + content: message.to_owned(), + timestamp: now_secs(), + }); + task.updated_at = now_secs(); + Ok(task.clone()) + } + + pub fn output(&self, task_id: &str) -> Result { + let inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + Ok(task.output.clone()) + } + + pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + task.output.push_str(output); + task.updated_at = now_secs(); + Ok(()) + } + + pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + task.status = status; + task.updated_at = now_secs(); + Ok(()) + } + + pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + let task = inner + .tasks + .get_mut(task_id) + .ok_or_else(|| format!("task not found: {task_id}"))?; + task.team_id = Some(team_id.to_owned()); + task.updated_at = now_secs(); + Ok(()) + } + + pub fn remove(&self, task_id: &str) -> Option { + let mut inner = self.inner.lock().expect("registry lock poisoned"); + inner.tasks.remove(task_id) + } + + #[must_use] + pub fn len(&self) -> usize { + let inner = self.inner.lock().expect("registry lock poisoned"); + inner.tasks.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn creates_and_retrieves_tasks() { + let registry = TaskRegistry::new(); + let task = registry.create("Do something", Some("A test task")); + assert_eq!(task.status, TaskStatus::Created); + assert_eq!(task.prompt, "Do something"); + assert_eq!(task.description.as_deref(), Some("A test task")); + assert_eq!(task.task_packet, None); + + let fetched = registry.get(&task.task_id).expect("task should exist"); + assert_eq!(fetched.task_id, task.task_id); + } + + #[test] + fn creates_task_from_packet() { + let registry = TaskRegistry::new(); + let packet = TaskPacket { + objective: "Ship task packet support".to_string(), + scope: "runtime/task system".to_string(), + repo: "claw-code-parity".to_string(), + branch_policy: "origin/main only".to_string(), + acceptance_tests: vec!["cargo test --workspace".to_string()], + commit_policy: "single commit".to_string(), + reporting_contract: "print commit sha".to_string(), + escalation_policy: "manual escalation".to_string(), + }; + + let task = registry + .create_from_packet(packet.clone()) + .expect("packet-backed task should be created"); + + assert_eq!(task.prompt, packet.objective); + assert_eq!(task.description.as_deref(), Some("runtime/task system")); + assert_eq!(task.task_packet, Some(packet.clone())); + + let fetched = registry.get(&task.task_id).expect("task should exist"); + assert_eq!(fetched.task_packet, Some(packet)); + } + + #[test] + fn lists_tasks_with_optional_filter() { + let registry = TaskRegistry::new(); + registry.create("Task A", None); + let task_b = registry.create("Task B", None); + registry + .set_status(&task_b.task_id, TaskStatus::Running) + .expect("set status should succeed"); + + let all = registry.list(None); + assert_eq!(all.len(), 2); + + let running = registry.list(Some(TaskStatus::Running)); + assert_eq!(running.len(), 1); + assert_eq!(running[0].task_id, task_b.task_id); + + let created = registry.list(Some(TaskStatus::Created)); + assert_eq!(created.len(), 1); + } + + #[test] + fn stops_running_task() { + let registry = TaskRegistry::new(); + let task = registry.create("Stoppable", None); + registry + .set_status(&task.task_id, TaskStatus::Running) + .unwrap(); + + let stopped = registry.stop(&task.task_id).expect("stop should succeed"); + assert_eq!(stopped.status, TaskStatus::Stopped); + + // Stopping again should fail + let result = registry.stop(&task.task_id); + assert!(result.is_err()); + } + + #[test] + fn updates_task_with_messages() { + let registry = TaskRegistry::new(); + let task = registry.create("Messageable", None); + let updated = registry + .update(&task.task_id, "Here's more context") + .expect("update should succeed"); + assert_eq!(updated.messages.len(), 1); + assert_eq!(updated.messages[0].content, "Here's more context"); + assert_eq!(updated.messages[0].role, "user"); + } + + #[test] + fn appends_and_retrieves_output() { + let registry = TaskRegistry::new(); + let task = registry.create("Output task", None); + registry + .append_output(&task.task_id, "line 1\n") + .expect("append should succeed"); + registry + .append_output(&task.task_id, "line 2\n") + .expect("append should succeed"); + + let output = registry.output(&task.task_id).expect("output should exist"); + assert_eq!(output, "line 1\nline 2\n"); + } + + #[test] + fn assigns_team_and_removes_task() { + let registry = TaskRegistry::new(); + let task = registry.create("Team task", None); + registry + .assign_team(&task.task_id, "team_abc") + .expect("assign should succeed"); + + let fetched = registry.get(&task.task_id).unwrap(); + assert_eq!(fetched.team_id.as_deref(), Some("team_abc")); + + let removed = registry.remove(&task.task_id); + assert!(removed.is_some()); + assert!(registry.get(&task.task_id).is_none()); + assert!(registry.is_empty()); + } + + #[test] + fn rejects_operations_on_missing_task() { + let registry = TaskRegistry::new(); + assert!(registry.stop("nonexistent").is_err()); + assert!(registry.update("nonexistent", "msg").is_err()); + assert!(registry.output("nonexistent").is_err()); + assert!(registry.append_output("nonexistent", "data").is_err()); + assert!(registry + .set_status("nonexistent", TaskStatus::Running) + .is_err()); + } + + #[test] + fn task_status_display_all_variants() { + // given + let cases = [ + (TaskStatus::Created, "created"), + (TaskStatus::Running, "running"), + (TaskStatus::Completed, "completed"), + (TaskStatus::Failed, "failed"), + (TaskStatus::Stopped, "stopped"), + ]; + + // when + let rendered: Vec<_> = cases + .into_iter() + .map(|(status, expected)| (status.to_string(), expected)) + .collect(); + + // then + assert_eq!( + rendered, + vec![ + ("created".to_string(), "created"), + ("running".to_string(), "running"), + ("completed".to_string(), "completed"), + ("failed".to_string(), "failed"), + ("stopped".to_string(), "stopped"), + ] + ); + } + + #[test] + fn stop_rejects_completed_task() { + // given + let registry = TaskRegistry::new(); + let task = registry.create("done", None); + registry + .set_status(&task.task_id, TaskStatus::Completed) + .expect("set status should succeed"); + + // when + let result = registry.stop(&task.task_id); + + // then + let error = result.expect_err("completed task should be rejected"); + assert!(error.contains("already in terminal state")); + assert!(error.contains("completed")); + } + + #[test] + fn stop_rejects_failed_task() { + // given + let registry = TaskRegistry::new(); + let task = registry.create("failed", None); + registry + .set_status(&task.task_id, TaskStatus::Failed) + .expect("set status should succeed"); + + // when + let result = registry.stop(&task.task_id); + + // then + let error = result.expect_err("failed task should be rejected"); + assert!(error.contains("already in terminal state")); + assert!(error.contains("failed")); + } + + #[test] + fn stop_succeeds_from_created_state() { + // given + let registry = TaskRegistry::new(); + let task = registry.create("created task", None); + + // when + let stopped = registry.stop(&task.task_id).expect("stop should succeed"); + + // then + assert_eq!(stopped.status, TaskStatus::Stopped); + assert!(stopped.updated_at >= task.updated_at); + } + + #[test] + fn new_registry_is_empty() { + // given + let registry = TaskRegistry::new(); + + // when + let all_tasks = registry.list(None); + + // then + assert!(registry.is_empty()); + assert_eq!(registry.len(), 0); + assert!(all_tasks.is_empty()); + } + + #[test] + fn create_without_description() { + // given + let registry = TaskRegistry::new(); + + // when + let task = registry.create("Do the thing", None); + + // then + assert!(task.task_id.starts_with("task_")); + assert_eq!(task.description, None); + assert_eq!(task.task_packet, None); + assert!(task.messages.is_empty()); + assert!(task.output.is_empty()); + assert_eq!(task.team_id, None); + } + + #[test] + fn remove_nonexistent_returns_none() { + // given + let registry = TaskRegistry::new(); + + // when + let removed = registry.remove("missing"); + + // then + assert!(removed.is_none()); + } + + #[test] + fn assign_team_rejects_missing_task() { + // given + let registry = TaskRegistry::new(); + + // when + let result = registry.assign_team("missing", "team_123"); + + // then + let error = result.expect_err("missing task should be rejected"); + assert_eq!(error, "task not found: missing"); + } +} diff --git a/crates/runtime/src/team_cron_registry.rs b/crates/runtime/src/team_cron_registry.rs new file mode 100644 index 0000000..1e1a65f --- /dev/null +++ b/crates/runtime/src/team_cron_registry.rs @@ -0,0 +1,509 @@ +#![allow(clippy::must_use_candidate)] +//! In-memory registries for Team and Cron lifecycle management. +//! +//! Provides TeamCreate/Delete and CronCreate/Delete/List runtime backing +//! to replace the stub implementations in the tools crate. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Team { + pub team_id: String, + pub name: String, + pub task_ids: Vec, + pub status: TeamStatus, + pub created_at: u64, + pub updated_at: u64, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TeamStatus { + Created, + Running, + Completed, + Deleted, +} + +impl std::fmt::Display for TeamStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Created => write!(f, "created"), + Self::Running => write!(f, "running"), + Self::Completed => write!(f, "completed"), + Self::Deleted => write!(f, "deleted"), + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct TeamRegistry { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct TeamInner { + teams: HashMap, + counter: u64, +} + +impl TeamRegistry { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn create(&self, name: &str, task_ids: Vec) -> Team { + let mut inner = self.inner.lock().expect("team registry lock poisoned"); + inner.counter += 1; + let ts = now_secs(); + let team_id = format!("team_{:08x}_{}", ts, inner.counter); + let team = Team { + team_id: team_id.clone(), + name: name.to_owned(), + task_ids, + status: TeamStatus::Created, + created_at: ts, + updated_at: ts, + }; + inner.teams.insert(team_id, team.clone()); + team + } + + pub fn get(&self, team_id: &str) -> Option { + let inner = self.inner.lock().expect("team registry lock poisoned"); + inner.teams.get(team_id).cloned() + } + + pub fn list(&self) -> Vec { + let inner = self.inner.lock().expect("team registry lock poisoned"); + inner.teams.values().cloned().collect() + } + + pub fn delete(&self, team_id: &str) -> Result { + let mut inner = self.inner.lock().expect("team registry lock poisoned"); + let team = inner + .teams + .get_mut(team_id) + .ok_or_else(|| format!("team not found: {team_id}"))?; + team.status = TeamStatus::Deleted; + team.updated_at = now_secs(); + Ok(team.clone()) + } + + pub fn remove(&self, team_id: &str) -> Option { + let mut inner = self.inner.lock().expect("team registry lock poisoned"); + inner.teams.remove(team_id) + } + + #[must_use] + pub fn len(&self) -> usize { + let inner = self.inner.lock().expect("team registry lock poisoned"); + inner.teams.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CronEntry { + pub cron_id: String, + pub schedule: String, + pub prompt: String, + pub description: Option, + pub enabled: bool, + pub created_at: u64, + pub updated_at: u64, + pub last_run_at: Option, + pub run_count: u64, +} + +#[derive(Debug, Clone, Default)] +pub struct CronRegistry { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct CronInner { + entries: HashMap, + counter: u64, +} + +impl CronRegistry { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn create(&self, schedule: &str, prompt: &str, description: Option<&str>) -> CronEntry { + let mut inner = self.inner.lock().expect("cron registry lock poisoned"); + inner.counter += 1; + let ts = now_secs(); + let cron_id = format!("cron_{:08x}_{}", ts, inner.counter); + let entry = CronEntry { + cron_id: cron_id.clone(), + schedule: schedule.to_owned(), + prompt: prompt.to_owned(), + description: description.map(str::to_owned), + enabled: true, + created_at: ts, + updated_at: ts, + last_run_at: None, + run_count: 0, + }; + inner.entries.insert(cron_id, entry.clone()); + entry + } + + pub fn get(&self, cron_id: &str) -> Option { + let inner = self.inner.lock().expect("cron registry lock poisoned"); + inner.entries.get(cron_id).cloned() + } + + pub fn list(&self, enabled_only: bool) -> Vec { + let inner = self.inner.lock().expect("cron registry lock poisoned"); + inner + .entries + .values() + .filter(|e| !enabled_only || e.enabled) + .cloned() + .collect() + } + + pub fn delete(&self, cron_id: &str) -> Result { + let mut inner = self.inner.lock().expect("cron registry lock poisoned"); + inner + .entries + .remove(cron_id) + .ok_or_else(|| format!("cron not found: {cron_id}")) + } + + /// Disable a cron entry without removing it. + pub fn disable(&self, cron_id: &str) -> Result<(), String> { + let mut inner = self.inner.lock().expect("cron registry lock poisoned"); + let entry = inner + .entries + .get_mut(cron_id) + .ok_or_else(|| format!("cron not found: {cron_id}"))?; + entry.enabled = false; + entry.updated_at = now_secs(); + Ok(()) + } + + /// Record a cron run. + pub fn record_run(&self, cron_id: &str) -> Result<(), String> { + let mut inner = self.inner.lock().expect("cron registry lock poisoned"); + let entry = inner + .entries + .get_mut(cron_id) + .ok_or_else(|| format!("cron not found: {cron_id}"))?; + entry.last_run_at = Some(now_secs()); + entry.run_count += 1; + entry.updated_at = now_secs(); + Ok(()) + } + + #[must_use] + pub fn len(&self) -> usize { + let inner = self.inner.lock().expect("cron registry lock poisoned"); + inner.entries.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── Team tests ────────────────────────────────────── + + #[test] + fn creates_and_retrieves_team() { + let registry = TeamRegistry::new(); + let team = registry.create("Alpha Squad", vec!["task_001".into(), "task_002".into()]); + assert_eq!(team.name, "Alpha Squad"); + assert_eq!(team.task_ids.len(), 2); + assert_eq!(team.status, TeamStatus::Created); + + let fetched = registry.get(&team.team_id).expect("team should exist"); + assert_eq!(fetched.team_id, team.team_id); + } + + #[test] + fn lists_and_deletes_teams() { + let registry = TeamRegistry::new(); + let t1 = registry.create("Team A", vec![]); + let t2 = registry.create("Team B", vec![]); + + let all = registry.list(); + assert_eq!(all.len(), 2); + + let deleted = registry.delete(&t1.team_id).expect("delete should succeed"); + assert_eq!(deleted.status, TeamStatus::Deleted); + + // Team is still listable (soft delete) + let still_there = registry.get(&t1.team_id).unwrap(); + assert_eq!(still_there.status, TeamStatus::Deleted); + + // Hard remove + registry.remove(&t2.team_id); + assert_eq!(registry.len(), 1); + } + + #[test] + fn rejects_missing_team_operations() { + let registry = TeamRegistry::new(); + assert!(registry.delete("nonexistent").is_err()); + assert!(registry.get("nonexistent").is_none()); + } + + // ── Cron tests ────────────────────────────────────── + + #[test] + fn creates_and_retrieves_cron() { + let registry = CronRegistry::new(); + let entry = registry.create("0 * * * *", "Check status", Some("hourly check")); + assert_eq!(entry.schedule, "0 * * * *"); + assert_eq!(entry.prompt, "Check status"); + assert!(entry.enabled); + assert_eq!(entry.run_count, 0); + assert!(entry.last_run_at.is_none()); + + let fetched = registry.get(&entry.cron_id).expect("cron should exist"); + assert_eq!(fetched.cron_id, entry.cron_id); + } + + #[test] + fn lists_with_enabled_filter() { + let registry = CronRegistry::new(); + let c1 = registry.create("* * * * *", "Task 1", None); + let c2 = registry.create("0 * * * *", "Task 2", None); + registry + .disable(&c1.cron_id) + .expect("disable should succeed"); + + let all = registry.list(false); + assert_eq!(all.len(), 2); + + let enabled_only = registry.list(true); + assert_eq!(enabled_only.len(), 1); + assert_eq!(enabled_only[0].cron_id, c2.cron_id); + } + + #[test] + fn deletes_cron_entry() { + let registry = CronRegistry::new(); + let entry = registry.create("* * * * *", "To delete", None); + let deleted = registry + .delete(&entry.cron_id) + .expect("delete should succeed"); + assert_eq!(deleted.cron_id, entry.cron_id); + assert!(registry.get(&entry.cron_id).is_none()); + assert!(registry.is_empty()); + } + + #[test] + fn records_cron_runs() { + let registry = CronRegistry::new(); + let entry = registry.create("*/5 * * * *", "Recurring", None); + registry.record_run(&entry.cron_id).unwrap(); + registry.record_run(&entry.cron_id).unwrap(); + + let fetched = registry.get(&entry.cron_id).unwrap(); + assert_eq!(fetched.run_count, 2); + assert!(fetched.last_run_at.is_some()); + } + + #[test] + fn rejects_missing_cron_operations() { + let registry = CronRegistry::new(); + assert!(registry.delete("nonexistent").is_err()); + assert!(registry.disable("nonexistent").is_err()); + assert!(registry.record_run("nonexistent").is_err()); + assert!(registry.get("nonexistent").is_none()); + } + + #[test] + fn team_status_display_all_variants() { + // given + let cases = [ + (TeamStatus::Created, "created"), + (TeamStatus::Running, "running"), + (TeamStatus::Completed, "completed"), + (TeamStatus::Deleted, "deleted"), + ]; + + // when + let rendered: Vec<_> = cases + .into_iter() + .map(|(status, expected)| (status.to_string(), expected)) + .collect(); + + // then + assert_eq!( + rendered, + vec![ + ("created".to_string(), "created"), + ("running".to_string(), "running"), + ("completed".to_string(), "completed"), + ("deleted".to_string(), "deleted"), + ] + ); + } + + #[test] + fn new_team_registry_is_empty() { + // given + let registry = TeamRegistry::new(); + + // when + let teams = registry.list(); + + // then + assert!(registry.is_empty()); + assert_eq!(registry.len(), 0); + assert!(teams.is_empty()); + } + + #[test] + fn team_remove_nonexistent_returns_none() { + // given + let registry = TeamRegistry::new(); + + // when + let removed = registry.remove("missing"); + + // then + assert!(removed.is_none()); + } + + #[test] + fn team_len_transitions() { + // given + let registry = TeamRegistry::new(); + + // when + let alpha = registry.create("Alpha", vec![]); + let beta = registry.create("Beta", vec![]); + let after_create = registry.len(); + registry.remove(&alpha.team_id); + let after_first_remove = registry.len(); + registry.remove(&beta.team_id); + + // then + assert_eq!(after_create, 2); + assert_eq!(after_first_remove, 1); + assert_eq!(registry.len(), 0); + assert!(registry.is_empty()); + } + + #[test] + fn cron_list_all_disabled_returns_empty_for_enabled_only() { + // given + let registry = CronRegistry::new(); + let first = registry.create("* * * * *", "Task 1", None); + let second = registry.create("0 * * * *", "Task 2", None); + registry + .disable(&first.cron_id) + .expect("disable should succeed"); + registry + .disable(&second.cron_id) + .expect("disable should succeed"); + + // when + let enabled_only = registry.list(true); + let all_entries = registry.list(false); + + // then + assert!(enabled_only.is_empty()); + assert_eq!(all_entries.len(), 2); + } + + #[test] + fn cron_create_without_description() { + // given + let registry = CronRegistry::new(); + + // when + let entry = registry.create("*/15 * * * *", "Check health", None); + + // then + assert!(entry.cron_id.starts_with("cron_")); + assert_eq!(entry.description, None); + assert!(entry.enabled); + assert_eq!(entry.run_count, 0); + assert_eq!(entry.last_run_at, None); + } + + #[test] + fn new_cron_registry_is_empty() { + // given + let registry = CronRegistry::new(); + + // when + let enabled_only = registry.list(true); + let all_entries = registry.list(false); + + // then + assert!(registry.is_empty()); + assert_eq!(registry.len(), 0); + assert!(enabled_only.is_empty()); + assert!(all_entries.is_empty()); + } + + #[test] + fn cron_record_run_updates_timestamp_and_counter() { + // given + let registry = CronRegistry::new(); + let entry = registry.create("*/5 * * * *", "Recurring", None); + + // when + registry + .record_run(&entry.cron_id) + .expect("first run should succeed"); + registry + .record_run(&entry.cron_id) + .expect("second run should succeed"); + let fetched = registry.get(&entry.cron_id).expect("entry should exist"); + + // then + assert_eq!(fetched.run_count, 2); + assert!(fetched.last_run_at.is_some()); + assert!(fetched.updated_at >= entry.updated_at); + } + + #[test] + fn cron_disable_updates_timestamp() { + // given + let registry = CronRegistry::new(); + let entry = registry.create("0 0 * * *", "Nightly", None); + + // when + registry + .disable(&entry.cron_id) + .expect("disable should succeed"); + let fetched = registry.get(&entry.cron_id).expect("entry should exist"); + + // then + assert!(!fetched.enabled); + assert!(fetched.updated_at >= entry.updated_at); + } +} diff --git a/crates/runtime/src/trust_resolver.rs b/crates/runtime/src/trust_resolver.rs new file mode 100644 index 0000000..52d46dc --- /dev/null +++ b/crates/runtime/src/trust_resolver.rs @@ -0,0 +1,299 @@ +use std::path::{Path, PathBuf}; + +const TRUST_PROMPT_CUES: &[&str] = &[ + "do you trust the files in this folder", + "trust the files in this folder", + "trust this folder", + "allow and continue", + "yes, proceed", +]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TrustPolicy { + AutoTrust, + RequireApproval, + Deny, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TrustEvent { + TrustRequired { cwd: String }, + TrustResolved { cwd: String, policy: TrustPolicy }, + TrustDenied { cwd: String, reason: String }, +} + +#[derive(Debug, Clone, Default)] +pub struct TrustConfig { + allowlisted: Vec, + denied: Vec, +} + +impl TrustConfig { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn with_allowlisted(mut self, path: impl Into) -> Self { + self.allowlisted.push(path.into()); + self + } + + #[must_use] + pub fn with_denied(mut self, path: impl Into) -> Self { + self.denied.push(path.into()); + self + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TrustDecision { + NotRequired, + Required { + policy: TrustPolicy, + events: Vec, + }, +} + +impl TrustDecision { + #[must_use] + pub fn policy(&self) -> Option { + match self { + Self::NotRequired => None, + Self::Required { policy, .. } => Some(*policy), + } + } + + #[must_use] + pub fn events(&self) -> &[TrustEvent] { + match self { + Self::NotRequired => &[], + Self::Required { events, .. } => events, + } + } +} + +#[derive(Debug, Clone)] +pub struct TrustResolver { + config: TrustConfig, +} + +impl TrustResolver { + #[must_use] + pub fn new(config: TrustConfig) -> Self { + Self { config } + } + + #[must_use] + pub fn resolve(&self, cwd: &str, screen_text: &str) -> TrustDecision { + if !detect_trust_prompt(screen_text) { + return TrustDecision::NotRequired; + } + + let mut events = vec![TrustEvent::TrustRequired { + cwd: cwd.to_owned(), + }]; + + if let Some(matched_root) = self + .config + .denied + .iter() + .find(|root| path_matches(cwd, root)) + { + let reason = format!("cwd matches denied trust root: {}", matched_root.display()); + events.push(TrustEvent::TrustDenied { + cwd: cwd.to_owned(), + reason, + }); + return TrustDecision::Required { + policy: TrustPolicy::Deny, + events, + }; + } + + if self + .config + .allowlisted + .iter() + .any(|root| path_matches(cwd, root)) + { + events.push(TrustEvent::TrustResolved { + cwd: cwd.to_owned(), + policy: TrustPolicy::AutoTrust, + }); + return TrustDecision::Required { + policy: TrustPolicy::AutoTrust, + events, + }; + } + + TrustDecision::Required { + policy: TrustPolicy::RequireApproval, + events, + } + } + + #[must_use] + pub fn trusts(&self, cwd: &str) -> bool { + !self + .config + .denied + .iter() + .any(|root| path_matches(cwd, root)) + && self + .config + .allowlisted + .iter() + .any(|root| path_matches(cwd, root)) + } +} + +#[must_use] +pub fn detect_trust_prompt(screen_text: &str) -> bool { + let lowered = screen_text.to_ascii_lowercase(); + TRUST_PROMPT_CUES + .iter() + .any(|needle| lowered.contains(needle)) +} + +#[must_use] +pub fn path_matches_trusted_root(cwd: &str, trusted_root: &str) -> bool { + path_matches(cwd, &normalize_path(Path::new(trusted_root))) +} + +fn path_matches(candidate: &str, root: &Path) -> bool { + let candidate = normalize_path(Path::new(candidate)); + let root = normalize_path(root); + candidate == root || candidate.starts_with(&root) +} + +fn normalize_path(path: &Path) -> PathBuf { + std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf()) +} + +#[cfg(test)] +mod tests { + use super::{ + detect_trust_prompt, path_matches_trusted_root, TrustConfig, TrustDecision, TrustEvent, + TrustPolicy, TrustResolver, + }; + + #[test] + fn detects_known_trust_prompt_copy() { + // given + let screen_text = "Do you trust the files in this folder?\n1. Yes, proceed\n2. No"; + + // when + let detected = detect_trust_prompt(screen_text); + + // then + assert!(detected); + } + + #[test] + fn does_not_emit_events_when_prompt_is_absent() { + // given + let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees")); + + // when + let decision = resolver.resolve("/tmp/worktrees/repo-a", "Ready for your input\n>"); + + // then + assert_eq!(decision, TrustDecision::NotRequired); + assert_eq!(decision.events(), &[]); + assert_eq!(decision.policy(), None); + } + + #[test] + fn auto_trusts_allowlisted_cwd_after_prompt_detection() { + // given + let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees")); + + // when + let decision = resolver.resolve( + "/tmp/worktrees/repo-a", + "Do you trust the files in this folder?\n1. Yes, proceed\n2. No", + ); + + // then + assert_eq!(decision.policy(), Some(TrustPolicy::AutoTrust)); + assert_eq!( + decision.events(), + &[ + TrustEvent::TrustRequired { + cwd: "/tmp/worktrees/repo-a".to_string(), + }, + TrustEvent::TrustResolved { + cwd: "/tmp/worktrees/repo-a".to_string(), + policy: TrustPolicy::AutoTrust, + }, + ] + ); + } + + #[test] + fn requires_approval_for_unknown_cwd_after_prompt_detection() { + // given + let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees")); + + // when + let decision = resolver.resolve( + "/tmp/other/repo-b", + "Do you trust the files in this folder?\n1. Yes, proceed\n2. No", + ); + + // then + assert_eq!(decision.policy(), Some(TrustPolicy::RequireApproval)); + assert_eq!( + decision.events(), + &[TrustEvent::TrustRequired { + cwd: "/tmp/other/repo-b".to_string(), + }] + ); + } + + #[test] + fn denied_root_takes_precedence_over_allowlist() { + // given + let resolver = TrustResolver::new( + TrustConfig::new() + .with_allowlisted("/tmp/worktrees") + .with_denied("/tmp/worktrees/repo-c"), + ); + + // when + let decision = resolver.resolve( + "/tmp/worktrees/repo-c", + "Do you trust the files in this folder?\n1. Yes, proceed\n2. No", + ); + + // then + assert_eq!(decision.policy(), Some(TrustPolicy::Deny)); + assert_eq!( + decision.events(), + &[ + TrustEvent::TrustRequired { + cwd: "/tmp/worktrees/repo-c".to_string(), + }, + TrustEvent::TrustDenied { + cwd: "/tmp/worktrees/repo-c".to_string(), + reason: "cwd matches denied trust root: /tmp/worktrees/repo-c".to_string(), + }, + ] + ); + } + + #[test] + fn sibling_prefix_does_not_match_trusted_root() { + // given + let trusted_root = "/tmp/worktrees"; + let sibling_path = "/tmp/worktrees-other/repo-d"; + + // when + let matched = path_matches_trusted_root(sibling_path, trusted_root); + + // then + assert!(!matched); + } +} diff --git a/crates/runtime/src/usage.rs b/crates/runtime/src/usage.rs index 0570bc1..9741165 100644 --- a/crates/runtime/src/usage.rs +++ b/crates/runtime/src/usage.rs @@ -1,11 +1,11 @@ use crate::session::Session; -use serde::{Deserialize, Serialize}; const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0; const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0; const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75; const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5; +/// Per-million-token pricing used for cost estimation. #[derive(Debug, Clone, Copy, PartialEq)] pub struct ModelPricing { pub input_cost_per_million: f64, @@ -26,7 +26,8 @@ impl ModelPricing { } } -#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)] +/// Token counters accumulated for a conversation turn or session. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct TokenUsage { pub input_tokens: u32, pub output_tokens: u32, @@ -34,6 +35,7 @@ pub struct TokenUsage { pub cache_read_input_tokens: u32, } +/// Estimated dollar cost derived from a [`TokenUsage`] sample. #[derive(Debug, Clone, Copy, PartialEq)] pub struct UsageCostEstimate { pub input_cost_usd: f64, @@ -52,6 +54,7 @@ impl UsageCostEstimate { } } +/// Returns pricing metadata for a known model alias or family. #[must_use] pub fn pricing_for_model(model: &str) -> Option { let normalized = model.to_ascii_lowercase(); @@ -156,10 +159,12 @@ fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 { } #[must_use] +/// Formats a dollar-denominated value for CLI display. pub fn format_usd(amount: f64) -> String { format!("${amount:.4}") } +/// Aggregates token usage across a running session. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct UsageTracker { latest_turn: TokenUsage, @@ -250,9 +255,9 @@ mod tests { let cost = usage.estimate_cost_usd(); assert_eq!(format_usd(cost.input_cost_usd), "$15.0000"); assert_eq!(format_usd(cost.output_cost_usd), "$37.5000"); - let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6")); + let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514")); assert!(lines[0].contains("estimated_cost=$54.6750")); - assert!(lines[0].contains("model=claude-sonnet-4-6")); + assert!(lines[0].contains("model=claude-sonnet-4-20250514")); assert!(lines[1].contains("cache_read=$0.3000")); } @@ -265,7 +270,7 @@ mod tests { cache_read_input_tokens: 0, }; - let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing"); + let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing"); let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing"); let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku); let opus_cost = usage.estimate_cost_usd_with_pricing(opus); @@ -287,21 +292,19 @@ mod tests { #[test] fn reconstructs_usage_from_session_messages() { - let session = Session { - version: 1, - messages: vec![ConversationMessage { - role: MessageRole::Assistant, - blocks: vec![ContentBlock::Text { - text: "done".to_string(), - }], - usage: Some(TokenUsage { - input_tokens: 5, - output_tokens: 2, - cache_creation_input_tokens: 1, - cache_read_input_tokens: 0, - }), + let mut session = Session::new(); + session.messages = vec![ConversationMessage { + role: MessageRole::Assistant, + blocks: vec![ContentBlock::Text { + text: "done".to_string(), }], - }; + usage: Some(TokenUsage { + input_tokens: 5, + output_tokens: 2, + cache_creation_input_tokens: 1, + cache_read_input_tokens: 0, + }), + }]; let tracker = UsageTracker::from_session(&session); assert_eq!(tracker.turns(), 1); diff --git a/crates/runtime/src/worker_boot.rs b/crates/runtime/src/worker_boot.rs new file mode 100644 index 0000000..6909768 --- /dev/null +++ b/crates/runtime/src/worker_boot.rs @@ -0,0 +1,1180 @@ +#![allow( + clippy::struct_excessive_bools, + clippy::too_many_lines, + clippy::question_mark, + clippy::redundant_closure, + clippy::map_unwrap_or +)] +//! In-memory worker-boot state machine and control registry. +//! +//! This provides a foundational control plane for reliable worker startup: +//! trust-gate detection, ready-for-prompt handshakes, and prompt-misdelivery +//! detection/recovery all live above raw terminal transport. + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkerStatus { + Spawning, + TrustRequired, + ReadyForPrompt, + Running, + Finished, + Failed, +} + +impl std::fmt::Display for WorkerStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Spawning => write!(f, "spawning"), + Self::TrustRequired => write!(f, "trust_required"), + Self::ReadyForPrompt => write!(f, "ready_for_prompt"), + Self::Running => write!(f, "running"), + Self::Finished => write!(f, "finished"), + Self::Failed => write!(f, "failed"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkerFailureKind { + TrustGate, + PromptDelivery, + Protocol, + Provider, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct WorkerFailure { + pub kind: WorkerFailureKind, + pub message: String, + pub created_at: u64, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkerEventKind { + Spawning, + TrustRequired, + TrustResolved, + ReadyForPrompt, + PromptMisdelivery, + PromptReplayArmed, + Running, + Restarted, + Finished, + Failed, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkerTrustResolution { + AutoAllowlisted, + ManualApproval, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkerPromptTarget { + Shell, + WrongTarget, + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum WorkerEventPayload { + TrustPrompt { + cwd: String, + #[serde(skip_serializing_if = "Option::is_none")] + resolution: Option, + }, + PromptDelivery { + prompt_preview: String, + observed_target: WorkerPromptTarget, + #[serde(skip_serializing_if = "Option::is_none")] + observed_cwd: Option, + recovery_armed: bool, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct WorkerEvent { + pub seq: u64, + pub kind: WorkerEventKind, + pub status: WorkerStatus, + pub detail: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub payload: Option, + pub timestamp: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct Worker { + pub worker_id: String, + pub cwd: String, + pub status: WorkerStatus, + pub trust_auto_resolve: bool, + pub trust_gate_cleared: bool, + pub auto_recover_prompt_misdelivery: bool, + pub prompt_delivery_attempts: u32, + pub prompt_in_flight: bool, + pub last_prompt: Option, + pub replay_prompt: Option, + pub last_error: Option, + pub created_at: u64, + pub updated_at: u64, + pub events: Vec, +} + +#[derive(Debug, Clone, Default)] +pub struct WorkerRegistry { + inner: Arc>, +} + +#[derive(Debug, Default)] +struct WorkerRegistryInner { + workers: HashMap, + counter: u64, +} + +impl WorkerRegistry { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn create( + &self, + cwd: &str, + trusted_roots: &[String], + auto_recover_prompt_misdelivery: bool, + ) -> Worker { + let mut inner = self.inner.lock().expect("worker registry lock poisoned"); + inner.counter += 1; + let ts = now_secs(); + let worker_id = format!("worker_{:08x}_{}", ts, inner.counter); + let trust_auto_resolve = trusted_roots + .iter() + .any(|root| path_matches_allowlist(cwd, root)); + let mut worker = Worker { + worker_id: worker_id.clone(), + cwd: cwd.to_owned(), + status: WorkerStatus::Spawning, + trust_auto_resolve, + trust_gate_cleared: false, + auto_recover_prompt_misdelivery, + prompt_delivery_attempts: 0, + prompt_in_flight: false, + last_prompt: None, + replay_prompt: None, + last_error: None, + created_at: ts, + updated_at: ts, + events: Vec::new(), + }; + push_event( + &mut worker, + WorkerEventKind::Spawning, + WorkerStatus::Spawning, + Some("worker created".to_string()), + None, + ); + inner.workers.insert(worker_id, worker.clone()); + worker + } + + #[must_use] + pub fn get(&self, worker_id: &str) -> Option { + let inner = self.inner.lock().expect("worker registry lock poisoned"); + inner.workers.get(worker_id).cloned() + } + + pub fn observe(&self, worker_id: &str, screen_text: &str) -> Result { + let mut inner = self.inner.lock().expect("worker registry lock poisoned"); + let worker = inner + .workers + .get_mut(worker_id) + .ok_or_else(|| format!("worker not found: {worker_id}"))?; + let lowered = screen_text.to_ascii_lowercase(); + + if !worker.trust_gate_cleared && detect_trust_prompt(&lowered) { + worker.status = WorkerStatus::TrustRequired; + worker.last_error = Some(WorkerFailure { + kind: WorkerFailureKind::TrustGate, + message: "worker boot blocked on trust prompt".to_string(), + created_at: now_secs(), + }); + push_event( + worker, + WorkerEventKind::TrustRequired, + WorkerStatus::TrustRequired, + Some("trust prompt detected".to_string()), + Some(WorkerEventPayload::TrustPrompt { + cwd: worker.cwd.clone(), + resolution: None, + }), + ); + + if worker.trust_auto_resolve { + worker.trust_gate_cleared = true; + worker.last_error = None; + worker.status = WorkerStatus::Spawning; + push_event( + worker, + WorkerEventKind::TrustResolved, + WorkerStatus::Spawning, + Some("allowlisted repo auto-resolved trust prompt".to_string()), + Some(WorkerEventPayload::TrustPrompt { + cwd: worker.cwd.clone(), + resolution: Some(WorkerTrustResolution::AutoAllowlisted), + }), + ); + } else { + return Ok(worker.clone()); + } + } + + if let Some(observation) = prompt_misdelivery_is_relevant(worker) + .then(|| { + detect_prompt_misdelivery( + screen_text, + &lowered, + worker.last_prompt.as_deref(), + &worker.cwd, + ) + }) + .flatten() + { + let prompt_preview = prompt_preview(worker.last_prompt.as_deref().unwrap_or_default()); + let message = match observation.target { + WorkerPromptTarget::Shell => { + format!( + "worker prompt landed in shell instead of coding agent: {prompt_preview}" + ) + } + WorkerPromptTarget::WrongTarget => format!( + "worker prompt landed in the wrong target instead of {}: {}", + worker.cwd, prompt_preview + ), + WorkerPromptTarget::Unknown => format!( + "worker prompt delivery failed before reaching coding agent: {prompt_preview}" + ), + }; + worker.last_error = Some(WorkerFailure { + kind: WorkerFailureKind::PromptDelivery, + message, + created_at: now_secs(), + }); + worker.prompt_in_flight = false; + push_event( + worker, + WorkerEventKind::PromptMisdelivery, + WorkerStatus::Failed, + Some(prompt_misdelivery_detail(&observation).to_string()), + Some(WorkerEventPayload::PromptDelivery { + prompt_preview: prompt_preview.clone(), + observed_target: observation.target, + observed_cwd: observation.observed_cwd.clone(), + recovery_armed: false, + }), + ); + if worker.auto_recover_prompt_misdelivery { + worker.replay_prompt = worker.last_prompt.clone(); + worker.status = WorkerStatus::ReadyForPrompt; + push_event( + worker, + WorkerEventKind::PromptReplayArmed, + WorkerStatus::ReadyForPrompt, + Some("prompt replay armed after prompt misdelivery".to_string()), + Some(WorkerEventPayload::PromptDelivery { + prompt_preview, + observed_target: observation.target, + observed_cwd: observation.observed_cwd, + recovery_armed: true, + }), + ); + } else { + worker.status = WorkerStatus::Failed; + } + return Ok(worker.clone()); + } + + if detect_running_cue(&lowered) && worker.prompt_in_flight { + worker.prompt_in_flight = false; + worker.status = WorkerStatus::Running; + worker.last_error = None; + } + + if detect_ready_for_prompt(screen_text, &lowered) + && worker.status != WorkerStatus::ReadyForPrompt + { + worker.status = WorkerStatus::ReadyForPrompt; + worker.prompt_in_flight = false; + if matches!( + worker.last_error.as_ref().map(|failure| failure.kind), + Some(WorkerFailureKind::TrustGate) + ) { + worker.last_error = None; + } + push_event( + worker, + WorkerEventKind::ReadyForPrompt, + WorkerStatus::ReadyForPrompt, + Some("worker is ready for prompt delivery".to_string()), + None, + ); + } + + Ok(worker.clone()) + } + + pub fn resolve_trust(&self, worker_id: &str) -> Result { + let mut inner = self.inner.lock().expect("worker registry lock poisoned"); + let worker = inner + .workers + .get_mut(worker_id) + .ok_or_else(|| format!("worker not found: {worker_id}"))?; + + if worker.status != WorkerStatus::TrustRequired { + return Err(format!( + "worker {worker_id} is not waiting on trust; current status: {}", + worker.status + )); + } + + worker.trust_gate_cleared = true; + worker.last_error = None; + worker.status = WorkerStatus::Spawning; + push_event( + worker, + WorkerEventKind::TrustResolved, + WorkerStatus::Spawning, + Some("trust prompt resolved manually".to_string()), + Some(WorkerEventPayload::TrustPrompt { + cwd: worker.cwd.clone(), + resolution: Some(WorkerTrustResolution::ManualApproval), + }), + ); + Ok(worker.clone()) + } + + pub fn send_prompt(&self, worker_id: &str, prompt: Option<&str>) -> Result { + let mut inner = self.inner.lock().expect("worker registry lock poisoned"); + let worker = inner + .workers + .get_mut(worker_id) + .ok_or_else(|| format!("worker not found: {worker_id}"))?; + + if worker.status != WorkerStatus::ReadyForPrompt { + return Err(format!( + "worker {worker_id} is not ready for prompt delivery; current status: {}", + worker.status + )); + } + + let next_prompt = prompt + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_owned) + .or_else(|| worker.replay_prompt.clone()) + .ok_or_else(|| format!("worker {worker_id} has no prompt to send or replay"))?; + + worker.prompt_delivery_attempts += 1; + worker.prompt_in_flight = true; + worker.last_prompt = Some(next_prompt.clone()); + worker.replay_prompt = None; + worker.last_error = None; + worker.status = WorkerStatus::Running; + push_event( + worker, + WorkerEventKind::Running, + WorkerStatus::Running, + Some(format!( + "prompt dispatched to worker: {}", + prompt_preview(&next_prompt) + )), + None, + ); + Ok(worker.clone()) + } + + pub fn await_ready(&self, worker_id: &str) -> Result { + let worker = self + .get(worker_id) + .ok_or_else(|| format!("worker not found: {worker_id}"))?; + + Ok(WorkerReadySnapshot { + worker_id: worker.worker_id.clone(), + status: worker.status, + ready: worker.status == WorkerStatus::ReadyForPrompt, + blocked: matches!( + worker.status, + WorkerStatus::TrustRequired | WorkerStatus::Failed + ), + replay_prompt_ready: worker.replay_prompt.is_some(), + last_error: worker.last_error.clone(), + }) + } + + pub fn restart(&self, worker_id: &str) -> Result { + let mut inner = self.inner.lock().expect("worker registry lock poisoned"); + let worker = inner + .workers + .get_mut(worker_id) + .ok_or_else(|| format!("worker not found: {worker_id}"))?; + worker.status = WorkerStatus::Spawning; + worker.trust_gate_cleared = false; + worker.last_prompt = None; + worker.replay_prompt = None; + worker.last_error = None; + worker.prompt_delivery_attempts = 0; + worker.prompt_in_flight = false; + push_event( + worker, + WorkerEventKind::Restarted, + WorkerStatus::Spawning, + Some("worker restarted".to_string()), + None, + ); + Ok(worker.clone()) + } + + pub fn terminate(&self, worker_id: &str) -> Result { + let mut inner = self.inner.lock().expect("worker registry lock poisoned"); + let worker = inner + .workers + .get_mut(worker_id) + .ok_or_else(|| format!("worker not found: {worker_id}"))?; + worker.status = WorkerStatus::Finished; + worker.prompt_in_flight = false; + push_event( + worker, + WorkerEventKind::Finished, + WorkerStatus::Finished, + Some("worker terminated by control plane".to_string()), + None, + ); + Ok(worker.clone()) + } + + /// Classify session completion and transition worker to appropriate terminal state. + /// Detects degraded completions (finish="unknown" with zero tokens) as provider failures. + pub fn observe_completion( + &self, + worker_id: &str, + finish_reason: &str, + tokens_output: u64, + ) -> Result { + let mut inner = self.inner.lock().expect("worker registry lock poisoned"); + let worker = inner + .workers + .get_mut(worker_id) + .ok_or_else(|| format!("worker not found: {worker_id}"))?; + + let is_provider_failure = + (finish_reason == "unknown" && tokens_output == 0) || finish_reason == "error"; + + if is_provider_failure { + let message = if finish_reason == "unknown" && tokens_output == 0 { + "session completed with finish='unknown' and zero output — provider degraded or context exhausted".to_string() + } else { + format!("session failed with finish='{finish_reason}' — provider error") + }; + + worker.last_error = Some(WorkerFailure { + kind: WorkerFailureKind::Provider, + message, + created_at: now_secs(), + }); + worker.status = WorkerStatus::Failed; + worker.prompt_in_flight = false; + push_event( + worker, + WorkerEventKind::Failed, + WorkerStatus::Failed, + Some("provider failure classified".to_string()), + None, + ); + } else { + worker.status = WorkerStatus::Finished; + worker.prompt_in_flight = false; + worker.last_error = None; + push_event( + worker, + WorkerEventKind::Finished, + WorkerStatus::Finished, + Some(format!( + "session completed: finish='{finish_reason}', tokens={tokens_output}" + )), + None, + ); + } + + Ok(worker.clone()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct WorkerReadySnapshot { + pub worker_id: String, + pub status: WorkerStatus, + pub ready: bool, + pub blocked: bool, + pub replay_prompt_ready: bool, + pub last_error: Option, +} + +fn prompt_misdelivery_is_relevant(worker: &Worker) -> bool { + worker.prompt_in_flight && worker.last_prompt.is_some() +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct PromptDeliveryObservation { + target: WorkerPromptTarget, + observed_cwd: Option, +} + +fn push_event( + worker: &mut Worker, + kind: WorkerEventKind, + status: WorkerStatus, + detail: Option, + payload: Option, +) { + let timestamp = now_secs(); + let seq = worker.events.len() as u64 + 1; + worker.updated_at = timestamp; + worker.status = status; + worker.events.push(WorkerEvent { + seq, + kind, + status, + detail, + payload, + timestamp, + }); + emit_state_file(worker); +} + +/// Write current worker state to `.claw/worker-state.json` under the worker's cwd. +/// This is the file-based observability surface: external observers (clawhip, orchestrators) +/// poll this file instead of requiring an HTTP route on the opencode binary. +fn emit_state_file(worker: &Worker) { + let state_dir = std::path::Path::new(&worker.cwd).join(".claw"); + if let Err(_) = std::fs::create_dir_all(&state_dir) { + return; + } + let state_path = state_dir.join("worker-state.json"); + let tmp_path = state_dir.join("worker-state.json.tmp"); + + #[derive(serde::Serialize)] + struct StateSnapshot<'a> { + worker_id: &'a str, + status: WorkerStatus, + is_ready: bool, + trust_gate_cleared: bool, + prompt_in_flight: bool, + last_event: Option<&'a WorkerEvent>, + updated_at: u64, + /// Seconds since last state transition. Clawhip uses this to detect + /// stalled workers without computing epoch deltas. + seconds_since_update: u64, + } + + let now = now_secs(); + let snapshot = StateSnapshot { + worker_id: &worker.worker_id, + status: worker.status, + is_ready: worker.status == WorkerStatus::ReadyForPrompt, + trust_gate_cleared: worker.trust_gate_cleared, + prompt_in_flight: worker.prompt_in_flight, + last_event: worker.events.last(), + updated_at: worker.updated_at, + seconds_since_update: now.saturating_sub(worker.updated_at), + }; + + if let Ok(json) = serde_json::to_string_pretty(&snapshot) { + let _ = std::fs::write(&tmp_path, json); + let _ = std::fs::rename(&tmp_path, &state_path); + } +} + +fn path_matches_allowlist(cwd: &str, trusted_root: &str) -> bool { + let cwd = normalize_path(cwd); + let trusted_root = normalize_path(trusted_root); + cwd == trusted_root || cwd.starts_with(&trusted_root) +} + +fn normalize_path(path: &str) -> PathBuf { + std::fs::canonicalize(path).unwrap_or_else(|_| Path::new(path).to_path_buf()) +} + +fn detect_trust_prompt(lowered: &str) -> bool { + [ + "do you trust the files in this folder", + "trust the files in this folder", + "trust this folder", + "allow and continue", + "yes, proceed", + ] + .iter() + .any(|needle| lowered.contains(needle)) +} + +fn detect_ready_for_prompt(screen_text: &str, lowered: &str) -> bool { + if [ + "ready for input", + "ready for your input", + "ready for prompt", + "send a message", + ] + .iter() + .any(|needle| lowered.contains(needle)) + { + return true; + } + + let Some(last_non_empty) = screen_text + .lines() + .rev() + .find(|line| !line.trim().is_empty()) + else { + return false; + }; + let trimmed = last_non_empty.trim(); + if is_shell_prompt(trimmed) { + return false; + } + + trimmed == ">" + || trimmed == "›" + || trimmed == "❯" + || trimmed.starts_with("> ") + || trimmed.starts_with("› ") + || trimmed.starts_with("❯ ") + || trimmed.contains("│ >") + || trimmed.contains("│ ›") + || trimmed.contains("│ ❯") +} + +fn detect_running_cue(lowered: &str) -> bool { + [ + "thinking", + "working", + "running tests", + "inspecting", + "analyzing", + ] + .iter() + .any(|needle| lowered.contains(needle)) +} + +fn is_shell_prompt(trimmed: &str) -> bool { + trimmed.ends_with('$') + || trimmed.ends_with('%') + || trimmed.ends_with('#') + || trimmed.starts_with('$') + || trimmed.starts_with('%') + || trimmed.starts_with('#') +} + +fn detect_prompt_misdelivery( + screen_text: &str, + lowered: &str, + prompt: Option<&str>, + expected_cwd: &str, +) -> Option { + let Some(prompt) = prompt else { + return None; + }; + + let prompt_snippet = prompt + .lines() + .find(|line| !line.trim().is_empty()) + .map(|line| line.trim().to_ascii_lowercase()) + .unwrap_or_default(); + if prompt_snippet.is_empty() { + return None; + } + let prompt_visible = lowered.contains(&prompt_snippet); + + if let Some(observed_cwd) = detect_observed_shell_cwd(screen_text) { + if prompt_visible && !cwd_matches_observed_target(expected_cwd, &observed_cwd) { + return Some(PromptDeliveryObservation { + target: WorkerPromptTarget::WrongTarget, + observed_cwd: Some(observed_cwd), + }); + } + } + + let shell_error = [ + "command not found", + "syntax error near unexpected token", + "parse error near", + "no such file or directory", + "unknown command", + ] + .iter() + .any(|needle| lowered.contains(needle)); + + (shell_error && prompt_visible).then_some(PromptDeliveryObservation { + target: WorkerPromptTarget::Shell, + observed_cwd: None, + }) +} + +fn prompt_preview(prompt: &str) -> String { + let trimmed = prompt.trim(); + if trimmed.chars().count() <= 48 { + return trimmed.to_string(); + } + let preview = trimmed.chars().take(48).collect::(); + format!("{}…", preview.trim_end()) +} + +fn prompt_misdelivery_detail(observation: &PromptDeliveryObservation) -> &'static str { + match observation.target { + WorkerPromptTarget::Shell => "shell misdelivery detected", + WorkerPromptTarget::WrongTarget => "prompt landed in wrong target", + WorkerPromptTarget::Unknown => "prompt delivery failure detected", + } +} + +fn detect_observed_shell_cwd(screen_text: &str) -> Option { + screen_text.lines().find_map(|line| { + let tokens = line.split_whitespace().collect::>(); + tokens + .iter() + .position(|token| is_shell_prompt_token(token)) + .and_then(|index| index.checked_sub(1).map(|cwd_index| tokens[cwd_index])) + .filter(|candidate| looks_like_cwd_label(candidate)) + .map(ToOwned::to_owned) + }) +} + +fn is_shell_prompt_token(token: &&str) -> bool { + matches!(*token, "$" | "%" | "#" | ">" | "›" | "❯") +} + +fn looks_like_cwd_label(candidate: &str) -> bool { + candidate.starts_with('/') + || candidate.starts_with('~') + || candidate.starts_with('.') + || candidate.contains('/') +} + +fn cwd_matches_observed_target(expected_cwd: &str, observed_cwd: &str) -> bool { + let expected = normalize_path(expected_cwd); + let expected_base = expected + .file_name() + .map(|segment| segment.to_string_lossy().into_owned()) + .unwrap_or_else(|| expected.to_string_lossy().into_owned()); + let observed_base = Path::new(observed_cwd) + .file_name() + .map(|segment| segment.to_string_lossy().into_owned()) + .unwrap_or_else(|| observed_cwd.trim_matches(':').to_string()); + + expected.to_string_lossy().ends_with(observed_cwd) + || observed_cwd.ends_with(expected.to_string_lossy().as_ref()) + || expected_base == observed_base +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn allowlisted_trust_prompt_auto_resolves_then_reaches_ready_state() { + let registry = WorkerRegistry::new(); + let worker = registry.create( + "/tmp/worktrees/repo-a", + &["/tmp/worktrees".to_string()], + true, + ); + + let after_trust = registry + .observe( + &worker.worker_id, + "Do you trust the files in this folder?\n1. Yes, proceed\n2. No", + ) + .expect("trust observe should succeed"); + assert_eq!(after_trust.status, WorkerStatus::Spawning); + assert!(after_trust.trust_gate_cleared); + let trust_required = after_trust + .events + .iter() + .find(|event| event.kind == WorkerEventKind::TrustRequired) + .expect("trust required event should exist"); + assert_eq!( + trust_required.payload, + Some(WorkerEventPayload::TrustPrompt { + cwd: "/tmp/worktrees/repo-a".to_string(), + resolution: None, + }) + ); + let trust_resolved = after_trust + .events + .iter() + .find(|event| event.kind == WorkerEventKind::TrustResolved) + .expect("trust resolved event should exist"); + assert_eq!( + trust_resolved.payload, + Some(WorkerEventPayload::TrustPrompt { + cwd: "/tmp/worktrees/repo-a".to_string(), + resolution: Some(WorkerTrustResolution::AutoAllowlisted), + }) + ); + + let ready = registry + .observe(&worker.worker_id, "Ready for your input\n>") + .expect("ready observe should succeed"); + assert_eq!(ready.status, WorkerStatus::ReadyForPrompt); + assert!(ready.last_error.is_none()); + } + + #[test] + fn trust_prompt_blocks_non_allowlisted_worker_until_resolved() { + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-b", &[], true); + + let blocked = registry + .observe( + &worker.worker_id, + "Do you trust the files in this folder?\n1. Yes, proceed\n2. No", + ) + .expect("trust observe should succeed"); + assert_eq!(blocked.status, WorkerStatus::TrustRequired); + assert_eq!( + blocked.last_error.expect("trust error should exist").kind, + WorkerFailureKind::TrustGate + ); + + let send_before_resolve = registry.send_prompt(&worker.worker_id, Some("ship it")); + assert!(send_before_resolve + .expect_err("prompt delivery should be gated") + .contains("not ready for prompt delivery")); + + let resolved = registry + .resolve_trust(&worker.worker_id) + .expect("manual trust resolution should succeed"); + assert_eq!(resolved.status, WorkerStatus::Spawning); + assert!(resolved.trust_gate_cleared); + let trust_resolved = resolved + .events + .iter() + .find(|event| event.kind == WorkerEventKind::TrustResolved) + .expect("manual trust resolve event should exist"); + assert_eq!( + trust_resolved.payload, + Some(WorkerEventPayload::TrustPrompt { + cwd: "/tmp/repo-b".to_string(), + resolution: Some(WorkerTrustResolution::ManualApproval), + }) + ); + } + + #[test] + fn ready_detection_ignores_plain_shell_prompts() { + assert!(!detect_ready_for_prompt("bellman@host %", "bellman@host %")); + assert!(!detect_ready_for_prompt("/tmp/repo $", "/tmp/repo $")); + assert!(detect_ready_for_prompt("│ >", "│ >")); + } + + #[test] + fn prompt_misdelivery_is_detected_and_replay_can_be_rearmed() { + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-c", &[], true); + registry + .observe(&worker.worker_id, "Ready for input\n>") + .expect("ready observe should succeed"); + + let running = registry + .send_prompt(&worker.worker_id, Some("Implement worker handshake")) + .expect("prompt send should succeed"); + assert_eq!(running.status, WorkerStatus::Running); + assert_eq!(running.prompt_delivery_attempts, 1); + assert!(running.prompt_in_flight); + + let recovered = registry + .observe( + &worker.worker_id, + "% Implement worker handshake\nzsh: command not found: Implement", + ) + .expect("misdelivery observe should succeed"); + assert_eq!(recovered.status, WorkerStatus::ReadyForPrompt); + assert_eq!( + recovered + .last_error + .expect("misdelivery error should exist") + .kind, + WorkerFailureKind::PromptDelivery + ); + assert_eq!( + recovered.replay_prompt.as_deref(), + Some("Implement worker handshake") + ); + let misdelivery = recovered + .events + .iter() + .find(|event| event.kind == WorkerEventKind::PromptMisdelivery) + .expect("misdelivery event should exist"); + assert_eq!(misdelivery.status, WorkerStatus::Failed); + assert_eq!( + misdelivery.payload, + Some(WorkerEventPayload::PromptDelivery { + prompt_preview: "Implement worker handshake".to_string(), + observed_target: WorkerPromptTarget::Shell, + observed_cwd: None, + recovery_armed: false, + }) + ); + let replay = recovered + .events + .iter() + .find(|event| event.kind == WorkerEventKind::PromptReplayArmed) + .expect("replay event should exist"); + assert_eq!(replay.status, WorkerStatus::ReadyForPrompt); + assert_eq!( + replay.payload, + Some(WorkerEventPayload::PromptDelivery { + prompt_preview: "Implement worker handshake".to_string(), + observed_target: WorkerPromptTarget::Shell, + observed_cwd: None, + recovery_armed: true, + }) + ); + + let replayed = registry + .send_prompt(&worker.worker_id, None) + .expect("replay send should succeed"); + assert_eq!(replayed.status, WorkerStatus::Running); + assert!(replayed.replay_prompt.is_none()); + assert_eq!(replayed.prompt_delivery_attempts, 2); + } + + #[test] + fn prompt_delivery_detects_wrong_target_and_replays_to_expected_worker() { + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-target-a", &[], true); + registry + .observe(&worker.worker_id, "Ready for input\n>") + .expect("ready observe should succeed"); + registry + .send_prompt(&worker.worker_id, Some("Run the worker bootstrap tests")) + .expect("prompt send should succeed"); + + let recovered = registry + .observe( + &worker.worker_id, + "/tmp/repo-target-b % Run the worker bootstrap tests\nzsh: command not found: Run", + ) + .expect("wrong target should be detected"); + + assert_eq!(recovered.status, WorkerStatus::ReadyForPrompt); + assert_eq!( + recovered.replay_prompt.as_deref(), + Some("Run the worker bootstrap tests") + ); + assert!(recovered + .last_error + .expect("wrong target error should exist") + .message + .contains("wrong target")); + let misdelivery = recovered + .events + .iter() + .find(|event| event.kind == WorkerEventKind::PromptMisdelivery) + .expect("wrong-target event should exist"); + assert_eq!( + misdelivery.payload, + Some(WorkerEventPayload::PromptDelivery { + prompt_preview: "Run the worker bootstrap tests".to_string(), + observed_target: WorkerPromptTarget::WrongTarget, + observed_cwd: Some("/tmp/repo-target-b".to_string()), + recovery_armed: false, + }) + ); + } + + #[test] + fn await_ready_surfaces_blocked_or_ready_worker_state() { + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-d", &[], false); + + let initial = registry + .await_ready(&worker.worker_id) + .expect("await should succeed"); + assert!(!initial.ready); + assert!(!initial.blocked); + + registry + .observe( + &worker.worker_id, + "Do you trust the files in this folder?\n1. Yes, proceed\n2. No", + ) + .expect("trust observe should succeed"); + let blocked = registry + .await_ready(&worker.worker_id) + .expect("await should succeed"); + assert!(!blocked.ready); + assert!(blocked.blocked); + + registry + .resolve_trust(&worker.worker_id) + .expect("manual trust resolution should succeed"); + registry + .observe(&worker.worker_id, "Ready for your input\n>") + .expect("ready observe should succeed"); + let ready = registry + .await_ready(&worker.worker_id) + .expect("await should succeed"); + assert!(ready.ready); + assert!(!ready.blocked); + assert!(ready.last_error.is_none()); + } + + #[test] + fn restart_and_terminate_reset_or_finish_worker() { + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-e", &[], true); + registry + .observe(&worker.worker_id, "Ready for input\n>") + .expect("ready observe should succeed"); + registry + .send_prompt(&worker.worker_id, Some("Run tests")) + .expect("prompt send should succeed"); + + let restarted = registry + .restart(&worker.worker_id) + .expect("restart should succeed"); + assert_eq!(restarted.status, WorkerStatus::Spawning); + assert_eq!(restarted.prompt_delivery_attempts, 0); + assert!(restarted.last_prompt.is_none()); + assert!(!restarted.prompt_in_flight); + + let finished = registry + .terminate(&worker.worker_id) + .expect("terminate should succeed"); + assert_eq!(finished.status, WorkerStatus::Finished); + assert!(finished + .events + .iter() + .any(|event| event.kind == WorkerEventKind::Finished)); + } + + #[test] + fn observe_completion_classifies_provider_failure_on_unknown_finish_zero_tokens() { + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-f", &[], true); + registry + .observe(&worker.worker_id, "Ready for input\n>") + .expect("ready observe should succeed"); + registry + .send_prompt(&worker.worker_id, Some("Run tests")) + .expect("prompt send should succeed"); + + let failed = registry + .observe_completion(&worker.worker_id, "unknown", 0) + .expect("completion observe should succeed"); + + assert_eq!(failed.status, WorkerStatus::Failed); + let error = failed.last_error.expect("provider error should exist"); + assert_eq!(error.kind, WorkerFailureKind::Provider); + assert!(error.message.contains("provider degraded")); + assert!(failed + .events + .iter() + .any(|event| event.kind == WorkerEventKind::Failed)); + } + + #[test] + fn emit_state_file_writes_worker_status_on_transition() { + let cwd_path = std::env::temp_dir().join(format!( + "claw-state-test-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() + )); + std::fs::create_dir_all(&cwd_path).expect("test dir should create"); + let cwd = cwd_path.to_str().expect("test path should be utf8"); + let registry = WorkerRegistry::new(); + let worker = registry.create(cwd, &[], true); + + // After create the worker is Spawning — state file should exist + let state_path = cwd_path.join(".claw").join("worker-state.json"); + assert!( + state_path.exists(), + "state file should exist after worker creation" + ); + + let raw = std::fs::read_to_string(&state_path).expect("state file should be readable"); + let value: serde_json::Value = + serde_json::from_str(&raw).expect("state file should be valid JSON"); + assert_eq!( + value["status"].as_str(), + Some("spawning"), + "initial status should be spawning" + ); + assert_eq!(value["is_ready"].as_bool(), Some(false)); + + // Transition to ReadyForPrompt by observing trust-cleared text + registry + .observe(&worker.worker_id, "Ready for input\n>") + .expect("observe ready should succeed"); + + let raw = std::fs::read_to_string(&state_path) + .expect("state file should be readable after observe"); + let value: serde_json::Value = + serde_json::from_str(&raw).expect("state file should be valid JSON after observe"); + assert_eq!( + value["status"].as_str(), + Some("ready_for_prompt"), + "status should be ready_for_prompt after observe" + ); + assert_eq!( + value["is_ready"].as_bool(), + Some(true), + "is_ready should be true when ReadyForPrompt" + ); + } + + #[test] + fn observe_completion_accepts_normal_finish_with_tokens() { + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-g", &[], true); + registry + .observe(&worker.worker_id, "Ready for input\n>") + .expect("ready observe should succeed"); + registry + .send_prompt(&worker.worker_id, Some("Run tests")) + .expect("prompt send should succeed"); + + let finished = registry + .observe_completion(&worker.worker_id, "stop", 150) + .expect("completion observe should succeed"); + + assert_eq!(finished.status, WorkerStatus::Finished); + assert!(finished.last_error.is_none()); + assert!(finished + .events + .iter() + .any(|event| event.kind == WorkerEventKind::Finished)); + } +} diff --git a/crates/runtime/tests/integration_tests.rs b/crates/runtime/tests/integration_tests.rs new file mode 100644 index 0000000..49c6636 --- /dev/null +++ b/crates/runtime/tests/integration_tests.rs @@ -0,0 +1,386 @@ +#![allow(clippy::doc_markdown, clippy::uninlined_format_args, unused_imports)] +//! Integration tests for cross-module wiring. +//! +//! These tests verify that adjacent modules in the runtime crate actually +//! connect correctly — catching wiring gaps that unit tests miss. + +use std::time::Duration; + +use runtime::green_contract::{GreenContract, GreenContractOutcome, GreenLevel}; +use runtime::{ + apply_policy, BranchFreshness, DiffScope, LaneBlocker, LaneContext, PolicyAction, + PolicyCondition, PolicyEngine, PolicyRule, ReconcileReason, ReviewStatus, StaleBranchAction, + StaleBranchPolicy, +}; + +/// stale_branch + policy_engine integration: +/// When a branch is detected stale, does it correctly flow through +/// PolicyCondition::StaleBranch to generate the expected action? +#[test] +fn stale_branch_detection_flows_into_policy_engine() { + // given — a stale branch context (2 hours behind main, threshold is 1 hour) + let stale_context = LaneContext::new( + "stale-lane", + 0, + Duration::from_secs(2 * 60 * 60), // 2 hours stale + LaneBlocker::None, + ReviewStatus::Pending, + DiffScope::Full, + false, + ); + + let engine = PolicyEngine::new(vec![PolicyRule::new( + "stale-merge-forward", + PolicyCondition::StaleBranch, + PolicyAction::MergeForward, + 10, + )]); + + // when + let actions = engine.evaluate(&stale_context); + + // then + assert_eq!(actions, vec![PolicyAction::MergeForward]); +} + +/// stale_branch + policy_engine: Fresh branch does NOT trigger stale rules +#[test] +fn fresh_branch_does_not_trigger_stale_policy() { + let fresh_context = LaneContext::new( + "fresh-lane", + 0, + Duration::from_secs(30 * 60), // 30 min stale — under 1 hour threshold + LaneBlocker::None, + ReviewStatus::Pending, + DiffScope::Full, + false, + ); + + let engine = PolicyEngine::new(vec![PolicyRule::new( + "stale-merge-forward", + PolicyCondition::StaleBranch, + PolicyAction::MergeForward, + 10, + )]); + + let actions = engine.evaluate(&fresh_context); + assert!(actions.is_empty()); +} + +/// green_contract + policy_engine integration: +/// A lane that meets its green contract should be mergeable +#[test] +fn green_contract_satisfied_allows_merge() { + let contract = GreenContract::new(GreenLevel::Workspace); + let satisfied = contract.is_satisfied_by(GreenLevel::Workspace); + assert!(satisfied); + + let exceeded = contract.is_satisfied_by(GreenLevel::MergeReady); + assert!(exceeded); + + let insufficient = contract.is_satisfied_by(GreenLevel::Package); + assert!(!insufficient); +} + +/// green_contract + policy_engine: +/// Lane with green level below contract requirement gets blocked +#[test] +fn green_contract_unsatisfied_blocks_merge() { + let context = LaneContext::new( + "partial-green-lane", + 1, // GreenLevel::Package as u8 + Duration::from_secs(0), + LaneBlocker::None, + ReviewStatus::Pending, + DiffScope::Full, + false, + ); + + // This is a conceptual test — we need a way to express "requires workspace green" + // Currently LaneContext has raw green_level: u8, not a contract + // For now we just verify the policy condition works + let engine = PolicyEngine::new(vec![PolicyRule::new( + "workspace-green-required", + PolicyCondition::GreenAt { level: 3 }, // GreenLevel::Workspace + PolicyAction::MergeToDev, + 10, + )]); + + let actions = engine.evaluate(&context); + assert!(actions.is_empty()); // level 1 < 3, so no merge +} + +/// reconciliation + policy_engine integration: +/// A reconciled lane should be handled by reconcile rules, not generic closeout +#[test] +fn reconciled_lane_matches_reconcile_condition() { + let context = LaneContext::reconciled("reconciled-lane"); + + let engine = PolicyEngine::new(vec![ + PolicyRule::new( + "reconcile-first", + PolicyCondition::LaneReconciled, + PolicyAction::Reconcile { + reason: ReconcileReason::AlreadyMerged, + }, + 5, + ), + PolicyRule::new( + "generic-closeout", + PolicyCondition::LaneCompleted, + PolicyAction::CloseoutLane, + 30, + ), + ]); + + let actions = engine.evaluate(&context); + + // Both rules fire — reconcile (priority 5) first, then closeout (priority 30) + assert_eq!( + actions, + vec![ + PolicyAction::Reconcile { + reason: ReconcileReason::AlreadyMerged, + }, + PolicyAction::CloseoutLane, + ] + ); +} + +/// stale_branch module: apply_policy generates correct actions +#[test] +fn stale_branch_apply_policy_produces_rebase_action() { + let stale = BranchFreshness::Stale { + commits_behind: 5, + missing_fixes: vec!["fix-123".to_string()], + }; + + let action = apply_policy(&stale, StaleBranchPolicy::AutoRebase); + assert_eq!(action, StaleBranchAction::Rebase); +} + +#[test] +fn stale_branch_apply_policy_produces_merge_forward_action() { + let stale = BranchFreshness::Stale { + commits_behind: 3, + missing_fixes: vec![], + }; + + let action = apply_policy(&stale, StaleBranchPolicy::AutoMergeForward); + assert_eq!(action, StaleBranchAction::MergeForward); +} + +#[test] +fn stale_branch_apply_policy_warn_only() { + let stale = BranchFreshness::Stale { + commits_behind: 2, + missing_fixes: vec!["fix-456".to_string()], + }; + + let action = apply_policy(&stale, StaleBranchPolicy::WarnOnly); + match action { + StaleBranchAction::Warn { message } => { + assert!(message.contains("2 commit(s) behind main")); + assert!(message.contains("fix-456")); + } + _ => panic!("expected Warn action, got {:?}", action), + } +} + +#[test] +fn stale_branch_fresh_produces_noop() { + let fresh = BranchFreshness::Fresh; + let action = apply_policy(&fresh, StaleBranchPolicy::AutoRebase); + assert_eq!(action, StaleBranchAction::Noop); +} + +/// Combined flow: stale detection + policy + action +#[test] +fn end_to_end_stale_lane_gets_merge_forward_action() { + // Simulating what a harness would do: + // 1. Detect branch freshness + // 2. Build lane context from freshness + other signals + // 3. Run policy engine + // 4. Return actions + + // given: detected stale state + let _freshness = BranchFreshness::Stale { + commits_behind: 5, + missing_fixes: vec!["fix-123".to_string()], + }; + + // when: build context and evaluate policy + let context = LaneContext::new( + "lane-9411", + 3, // Workspace green + Duration::from_secs(5 * 60 * 60), // 5 hours stale, definitely over threshold + LaneBlocker::None, + ReviewStatus::Approved, + DiffScope::Scoped, + false, + ); + + let engine = PolicyEngine::new(vec![ + // Priority 5: Check if stale first + PolicyRule::new( + "auto-merge-forward-if-stale-and-approved", + PolicyCondition::And(vec![ + PolicyCondition::StaleBranch, + PolicyCondition::ReviewPassed, + ]), + PolicyAction::MergeForward, + 5, + ), + // Priority 10: Normal stale handling + PolicyRule::new( + "stale-warning", + PolicyCondition::StaleBranch, + PolicyAction::Notify { + channel: "#build-status".to_string(), + }, + 10, + ), + ]); + + let actions = engine.evaluate(&context); + + // then: both rules should fire (stale + approved matches both) + assert_eq!( + actions, + vec![ + PolicyAction::MergeForward, + PolicyAction::Notify { + channel: "#build-status".to_string(), + }, + ] + ); +} + +/// Fresh branch with approved review should merge (not stale-blocked) +#[test] +fn fresh_approved_lane_gets_merge_action() { + let context = LaneContext::new( + "fresh-approved-lane", + 3, // Workspace green + Duration::from_secs(30 * 60), // 30 min — under 1 hour threshold = fresh + LaneBlocker::None, + ReviewStatus::Approved, + DiffScope::Scoped, + false, + ); + + let engine = PolicyEngine::new(vec![PolicyRule::new( + "merge-if-green-approved-not-stale", + PolicyCondition::And(vec![ + PolicyCondition::GreenAt { level: 3 }, + PolicyCondition::ReviewPassed, + // NOT PolicyCondition::StaleBranch — fresh lanes bypass this + ]), + PolicyAction::MergeToDev, + 5, + )]); + + let actions = engine.evaluate(&context); + assert_eq!(actions, vec![PolicyAction::MergeToDev]); +} + +/// worker_boot + recovery_recipes + policy_engine integration: +/// When a session completes with a provider failure, does the worker +/// status transition trigger the correct recovery recipe, and does +/// the resulting recovery state feed into policy decisions? +#[test] +fn worker_provider_failure_flows_through_recovery_to_policy() { + use runtime::recovery_recipes::{ + attempt_recovery, FailureScenario, RecoveryContext, RecoveryResult, RecoveryStep, + }; + use runtime::worker_boot::{WorkerFailureKind, WorkerRegistry, WorkerStatus}; + + // given — a worker that encounters a provider failure during session completion + let registry = WorkerRegistry::new(); + let worker = registry.create("/tmp/repo-recovery-test", &[], true); + + // Worker reaches ready state + registry + .observe(&worker.worker_id, "Ready for your input\n>") + .expect("ready observe should succeed"); + registry + .send_prompt(&worker.worker_id, Some("Run analysis")) + .expect("prompt send should succeed"); + + // Session completes with provider failure (finish="unknown", tokens=0) + let failed_worker = registry + .observe_completion(&worker.worker_id, "unknown", 0) + .expect("completion observe should succeed"); + assert_eq!(failed_worker.status, WorkerStatus::Failed); + let failure = failed_worker + .last_error + .expect("worker should have recorded error"); + assert_eq!(failure.kind, WorkerFailureKind::Provider); + + // Bridge: WorkerFailureKind -> FailureScenario + let scenario = FailureScenario::from_worker_failure_kind(failure.kind); + assert_eq!(scenario, FailureScenario::ProviderFailure); + + // Recovery recipe lookup and execution + let mut ctx = RecoveryContext::new(); + let result = attempt_recovery(&scenario, &mut ctx); + + // then — recovery should recommend RestartWorker step + assert!( + matches!(result, RecoveryResult::Recovered { steps_taken: 1 }), + "provider failure should recover via single RestartWorker step, got: {result:?}" + ); + assert!( + ctx.events().iter().any(|e| { + matches!( + e, + runtime::recovery_recipes::RecoveryEvent::RecoveryAttempted { + result: RecoveryResult::Recovered { steps_taken: 1 }, + .. + } + ) + }), + "recovery should emit structured attempt event" + ); + + // Policy integration: recovery success + green status = merge-ready + // (Simulating the policy check that would happen after successful recovery) + let recovery_success = matches!(result, RecoveryResult::Recovered { .. }); + let green_level = 3; // Workspace green + let not_stale = Duration::from_secs(30 * 60); // 30 min — fresh + + let post_recovery_context = LaneContext::new( + "recovered-lane", + green_level, + not_stale, + LaneBlocker::None, + ReviewStatus::Approved, + DiffScope::Scoped, + false, + ); + + let policy_engine = PolicyEngine::new(vec![ + // Rule: if recovered from failure + green + approved -> merge + PolicyRule::new( + "merge-after-successful-recovery", + PolicyCondition::And(vec![ + PolicyCondition::GreenAt { level: 3 }, + PolicyCondition::ReviewPassed, + ]), + PolicyAction::MergeToDev, + 10, + ), + ]); + + // Recovery success is a pre-condition; policy evaluates post-recovery context + assert!( + recovery_success, + "recovery must succeed for lane to proceed" + ); + let actions = policy_engine.evaluate(&post_recovery_context); + assert_eq!( + actions, + vec![PolicyAction::MergeToDev], + "post-recovery green+approved lane should be merge-ready" + ); +} diff --git a/crates/rusty-claude-cli/Cargo.toml b/crates/rusty-claude-cli/Cargo.toml index 5e5186d..635fdb3 100644 --- a/crates/rusty-claude-cli/Cargo.toml +++ b/crates/rusty-claude-cli/Cargo.toml @@ -15,12 +15,20 @@ commands = { path = "../commands" } compat-harness = { path = "../compat-harness" } crossterm = "0.28" pulldown-cmark = "0.13" -plugins = { path = "../plugins" } +rustyline = "15" runtime = { path = "../runtime" } -serde_json = "1" +plugins = { path = "../plugins" } +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true syntect = "5" -tokio = { version = "1", features = ["rt-multi-thread", "time"] } +tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] } tools = { path = "../tools" } [lints] workspace = true + +[dev-dependencies] +mock-anthropic-service = { path = "../mock-anthropic-service" } +serde_json.workspace = true +tokio = { version = "1", features = ["rt-multi-thread"] } + diff --git a/crates/rusty-claude-cli/build.rs b/crates/rusty-claude-cli/build.rs new file mode 100644 index 0000000..032948d --- /dev/null +++ b/crates/rusty-claude-cli/build.rs @@ -0,0 +1,38 @@ +use std::env; +use std::process::Command; + +fn main() { + // Get git SHA (short hash) + let git_sha = Command::new("git") + .args(["rev-parse", "--short", "HEAD"]) + .output() + .ok() + .and_then(|output| { + if output.status.success() { + String::from_utf8(output.stdout).ok() + } else { + None + } + }) + .map_or_else(|| "unknown".to_string(), |s| s.trim().to_string()); + + println!("cargo:rustc-env=GIT_SHA={git_sha}"); + + // TARGET is always set by Cargo during build + let target = env::var("TARGET").unwrap_or_else(|_| "unknown".to_string()); + println!("cargo:rustc-env=TARGET={target}"); + + // Build date from SOURCE_DATE_EPOCH (reproducible builds) or current date. + let build_date = std::env::var("SOURCE_DATE_EPOCH") + .ok() + .and_then(|epoch| epoch.parse::().ok()) + .map_or_else( + || "unknown".to_string(), + |_| std::env::var("BUILD_DATE").unwrap_or_else(|_| "unknown".to_string()), + ); + println!("cargo:rustc-env=BUILD_DATE={build_date}"); + + // Rerun if git state changes + println!("cargo:rerun-if-changed=.git/HEAD"); + println!("cargo:rerun-if-changed=.git/refs"); +} diff --git a/crates/rusty-claude-cli/src/app.rs b/crates/rusty-claude-cli/src/app.rs deleted file mode 100644 index b2864a3..0000000 --- a/crates/rusty-claude-cli/src/app.rs +++ /dev/null @@ -1,398 +0,0 @@ -use std::io::{self, Write}; -use std::path::PathBuf; - -use crate::args::{OutputFormat, PermissionMode}; -use crate::input::{LineEditor, ReadOutcome}; -use crate::render::{Spinner, TerminalRenderer}; -use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionConfig { - pub model: String, - pub permission_mode: PermissionMode, - pub config: Option, - pub output_format: OutputFormat, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct SessionState { - pub turns: usize, - pub compacted_messages: usize, - pub last_model: String, - pub last_usage: UsageSummary, -} - -impl SessionState { - #[must_use] - pub fn new(model: impl Into) -> Self { - Self { - turns: 0, - compacted_messages: 0, - last_model: model.into(), - last_usage: UsageSummary::default(), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum CommandResult { - Continue, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum SlashCommand { - Help, - Status, - Compact, - Unknown(String), -} - -impl SlashCommand { - #[must_use] - pub fn parse(input: &str) -> Option { - let trimmed = input.trim(); - if !trimmed.starts_with('/') { - return None; - } - - let command = trimmed - .trim_start_matches('/') - .split_whitespace() - .next() - .unwrap_or_default(); - Some(match command { - "help" => Self::Help, - "status" => Self::Status, - "compact" => Self::Compact, - other => Self::Unknown(other.to_string()), - }) - } -} - -struct SlashCommandHandler { - command: SlashCommand, - summary: &'static str, -} - -const SLASH_COMMAND_HANDLERS: &[SlashCommandHandler] = &[ - SlashCommandHandler { - command: SlashCommand::Help, - summary: "Show command help", - }, - SlashCommandHandler { - command: SlashCommand::Status, - summary: "Show current session status", - }, - SlashCommandHandler { - command: SlashCommand::Compact, - summary: "Compact local session history", - }, -]; - -pub struct CliApp { - config: SessionConfig, - renderer: TerminalRenderer, - state: SessionState, - conversation_client: ConversationClient, - conversation_history: Vec, -} - -impl CliApp { - pub fn new(config: SessionConfig) -> Result { - let state = SessionState::new(config.model.clone()); - let conversation_client = ConversationClient::from_env(config.model.clone())?; - Ok(Self { - config, - renderer: TerminalRenderer::new(), - state, - conversation_client, - conversation_history: Vec::new(), - }) - } - - pub fn run_repl(&mut self) -> io::Result<()> { - let mut editor = LineEditor::new("› ", Vec::new()); - println!("Rusty Claude CLI interactive mode"); - println!("Type /help for commands. Shift+Enter or Ctrl+J inserts a newline."); - - loop { - match editor.read_line()? { - ReadOutcome::Submit(input) => { - if input.trim().is_empty() { - continue; - } - self.handle_submission(&input, &mut io::stdout())?; - } - ReadOutcome::Cancel => continue, - ReadOutcome::Exit => break, - } - } - - Ok(()) - } - - pub fn run_prompt(&mut self, prompt: &str, out: &mut impl Write) -> io::Result<()> { - self.render_response(prompt, out) - } - - pub fn handle_submission( - &mut self, - input: &str, - out: &mut impl Write, - ) -> io::Result { - if let Some(command) = SlashCommand::parse(input) { - return self.dispatch_slash_command(command, out); - } - - self.state.turns += 1; - self.render_response(input, out)?; - Ok(CommandResult::Continue) - } - - fn dispatch_slash_command( - &mut self, - command: SlashCommand, - out: &mut impl Write, - ) -> io::Result { - match command { - SlashCommand::Help => Self::handle_help(out), - SlashCommand::Status => self.handle_status(out), - SlashCommand::Compact => self.handle_compact(out), - SlashCommand::Unknown(name) => { - writeln!(out, "Unknown slash command: /{name}")?; - Ok(CommandResult::Continue) - } - } - } - - fn handle_help(out: &mut impl Write) -> io::Result { - writeln!(out, "Available commands:")?; - for handler in SLASH_COMMAND_HANDLERS { - let name = match handler.command { - SlashCommand::Help => "/help", - SlashCommand::Status => "/status", - SlashCommand::Compact => "/compact", - SlashCommand::Unknown(_) => continue, - }; - writeln!(out, " {name:<9} {}", handler.summary)?; - } - Ok(CommandResult::Continue) - } - - fn handle_status(&mut self, out: &mut impl Write) -> io::Result { - writeln!( - out, - "status: turns={} model={} permission-mode={:?} output-format={:?} last-usage={} in/{} out config={}", - self.state.turns, - self.state.last_model, - self.config.permission_mode, - self.config.output_format, - self.state.last_usage.input_tokens, - self.state.last_usage.output_tokens, - self.config - .config - .as_ref() - .map_or_else(|| String::from(""), |path| path.display().to_string()) - )?; - Ok(CommandResult::Continue) - } - - fn handle_compact(&mut self, out: &mut impl Write) -> io::Result { - self.state.compacted_messages += self.state.turns; - self.state.turns = 0; - self.conversation_history.clear(); - writeln!( - out, - "Compacted session history into a local summary ({} messages total compacted).", - self.state.compacted_messages - )?; - Ok(CommandResult::Continue) - } - - fn handle_stream_event( - renderer: &TerminalRenderer, - event: StreamEvent, - stream_spinner: &mut Spinner, - tool_spinner: &mut Spinner, - saw_text: &mut bool, - turn_usage: &mut UsageSummary, - out: &mut impl Write, - ) { - match event { - StreamEvent::TextDelta(delta) => { - if !*saw_text { - let _ = - stream_spinner.finish("Streaming response", renderer.color_theme(), out); - *saw_text = true; - } - let _ = write!(out, "{delta}"); - let _ = out.flush(); - } - StreamEvent::ToolCallStart { name, input } => { - if *saw_text { - let _ = writeln!(out); - } - let _ = tool_spinner.tick( - &format!("Running tool `{name}` with {input}"), - renderer.color_theme(), - out, - ); - } - StreamEvent::ToolCallResult { - name, - output, - is_error, - } => { - let label = if is_error { - format!("Tool `{name}` failed") - } else { - format!("Tool `{name}` completed") - }; - let _ = tool_spinner.finish(&label, renderer.color_theme(), out); - let rendered_output = format!("### Tool `{name}`\n\n```text\n{output}\n```\n"); - let _ = renderer.stream_markdown(&rendered_output, out); - } - StreamEvent::Usage(usage) => { - *turn_usage = usage; - } - } - } - - fn write_turn_output( - &self, - summary: &runtime::TurnSummary, - out: &mut impl Write, - ) -> io::Result<()> { - match self.config.output_format { - OutputFormat::Text => { - writeln!( - out, - "\nToken usage: {} input / {} output", - self.state.last_usage.input_tokens, self.state.last_usage.output_tokens - )?; - } - OutputFormat::Json => { - writeln!( - out, - "{}", - serde_json::json!({ - "message": summary.assistant_text, - "usage": { - "input_tokens": self.state.last_usage.input_tokens, - "output_tokens": self.state.last_usage.output_tokens, - } - }) - )?; - } - OutputFormat::Ndjson => { - writeln!( - out, - "{}", - serde_json::json!({ - "type": "message", - "text": summary.assistant_text, - "usage": { - "input_tokens": self.state.last_usage.input_tokens, - "output_tokens": self.state.last_usage.output_tokens, - } - }) - )?; - } - } - Ok(()) - } - - fn render_response(&mut self, input: &str, out: &mut impl Write) -> io::Result<()> { - let mut stream_spinner = Spinner::new(); - stream_spinner.tick( - "Opening conversation stream", - self.renderer.color_theme(), - out, - )?; - - let mut turn_usage = UsageSummary::default(); - let mut tool_spinner = Spinner::new(); - let mut saw_text = false; - let renderer = &self.renderer; - - let result = - self.conversation_client - .run_turn(&mut self.conversation_history, input, |event| { - Self::handle_stream_event( - renderer, - event, - &mut stream_spinner, - &mut tool_spinner, - &mut saw_text, - &mut turn_usage, - out, - ); - }); - - let summary = match result { - Ok(summary) => summary, - Err(error) => { - stream_spinner.fail( - "Streaming response failed", - self.renderer.color_theme(), - out, - )?; - return Err(io::Error::other(error)); - } - }; - self.state.last_usage = summary.usage.clone(); - if saw_text { - writeln!(out)?; - } else { - stream_spinner.finish("Streaming response", self.renderer.color_theme(), out)?; - } - - self.write_turn_output(&summary, out)?; - let _ = turn_usage; - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use crate::args::{OutputFormat, PermissionMode}; - - use super::{CommandResult, SessionConfig, SlashCommand}; - - #[test] - fn parses_required_slash_commands() { - assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help)); - assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status)); - assert_eq!( - SlashCommand::parse("/compact now"), - Some(SlashCommand::Compact) - ); - } - - #[test] - fn help_output_lists_commands() { - let mut out = Vec::new(); - let result = super::CliApp::handle_help(&mut out).expect("help succeeds"); - assert_eq!(result, CommandResult::Continue); - let output = String::from_utf8_lossy(&out); - assert!(output.contains("/help")); - assert!(output.contains("/status")); - assert!(output.contains("/compact")); - } - - #[test] - fn session_state_tracks_config_values() { - let config = SessionConfig { - model: "claude".into(), - permission_mode: PermissionMode::WorkspaceWrite, - config: Some(PathBuf::from("settings.toml")), - output_format: OutputFormat::Text, - }; - - assert_eq!(config.model, "claude"); - assert_eq!(config.permission_mode, PermissionMode::WorkspaceWrite); - assert_eq!(config.config, Some(PathBuf::from("settings.toml"))); - } -} diff --git a/crates/rusty-claude-cli/src/args.rs b/crates/rusty-claude-cli/src/args.rs deleted file mode 100644 index 990beb4..0000000 --- a/crates/rusty-claude-cli/src/args.rs +++ /dev/null @@ -1,102 +0,0 @@ -use std::path::PathBuf; - -use clap::{Parser, Subcommand, ValueEnum}; - -#[derive(Debug, Clone, Parser, PartialEq, Eq)] -#[command( - name = "rusty-claude-cli", - version, - about = "Rust Claude CLI prototype" -)] -pub struct Cli { - #[arg(long, default_value = "claude-opus-4-6")] - pub model: String, - - #[arg(long, value_enum, default_value_t = PermissionMode::WorkspaceWrite)] - pub permission_mode: PermissionMode, - - #[arg(long)] - pub config: Option, - - #[arg(long, value_enum, default_value_t = OutputFormat::Text)] - pub output_format: OutputFormat, - - #[command(subcommand)] - pub command: Option, -} - -#[derive(Debug, Clone, Subcommand, PartialEq, Eq)] -pub enum Command { - /// Read upstream TS sources and print extracted counts - DumpManifests, - /// Print the current bootstrap phase skeleton - BootstrapPlan, - /// Start the OAuth login flow - Login, - /// Clear saved OAuth credentials - Logout, - /// Run a non-interactive prompt and exit - Prompt { prompt: Vec }, -} - -#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] -pub enum PermissionMode { - ReadOnly, - WorkspaceWrite, - DangerFullAccess, -} - -#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] -pub enum OutputFormat { - Text, - Json, - Ndjson, -} - -#[cfg(test)] -mod tests { - use clap::Parser; - - use super::{Cli, Command, OutputFormat, PermissionMode}; - - #[test] - fn parses_requested_flags() { - let cli = Cli::parse_from([ - "rusty-claude-cli", - "--model", - "claude-3-5-haiku", - "--permission-mode", - "read-only", - "--config", - "/tmp/config.toml", - "--output-format", - "ndjson", - "prompt", - "hello", - "world", - ]); - - assert_eq!(cli.model, "claude-3-5-haiku"); - assert_eq!(cli.permission_mode, PermissionMode::ReadOnly); - assert_eq!( - cli.config.as_deref(), - Some(std::path::Path::new("/tmp/config.toml")) - ); - assert_eq!(cli.output_format, OutputFormat::Ndjson); - assert_eq!( - cli.command, - Some(Command::Prompt { - prompt: vec!["hello".into(), "world".into()] - }) - ); - } - - #[test] - fn parses_login_and_logout_commands() { - let login = Cli::parse_from(["rusty-claude-cli", "login"]); - assert_eq!(login.command, Some(Command::Login)); - - let logout = Cli::parse_from(["rusty-claude-cli", "logout"]); - assert_eq!(logout.command, Some(Command::Logout)); - } -} diff --git a/crates/rusty-claude-cli/src/init.rs b/crates/rusty-claude-cli/src/init.rs index 4847c0a..a5f0ff7 100644 --- a/crates/rusty-claude-cli/src/init.rs +++ b/crates/rusty-claude-cli/src/init.rs @@ -1,15 +1,15 @@ use std::fs; use std::path::{Path, PathBuf}; -const STARTER_CLAUDE_JSON: &str = concat!( +const STARTER_CLAW_JSON: &str = concat!( "{\n", " \"permissions\": {\n", - " \"defaultMode\": \"acceptEdits\"\n", + " \"defaultMode\": \"dontAsk\"\n", " }\n", "}\n", ); -const GITIGNORE_COMMENT: &str = "# Claude Code local artifacts"; -const GITIGNORE_ENTRIES: [&str; 2] = [".claude/settings.local.json", ".claude/sessions/"]; +const GITIGNORE_COMMENT: &str = "# Claw Code local artifacts"; +const GITIGNORE_ENTRIES: [&str; 2] = [".claw/settings.local.json", ".claw/sessions/"]; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum InitStatus { @@ -80,16 +80,16 @@ struct RepoDetection { pub(crate) fn initialize_repo(cwd: &Path) -> Result> { let mut artifacts = Vec::new(); - let claude_dir = cwd.join(".claude"); + let claw_dir = cwd.join(".claw"); artifacts.push(InitArtifact { - name: ".claude/", - status: ensure_dir(&claude_dir)?, + name: ".claw/", + status: ensure_dir(&claw_dir)?, }); - let claude_json = cwd.join(".claude.json"); + let claw_json = cwd.join(".claw.json"); artifacts.push(InitArtifact { - name: ".claude.json", - status: write_file_if_missing(&claude_json, STARTER_CLAUDE_JSON)?, + name: ".claw.json", + status: write_file_if_missing(&claw_json, STARTER_CLAW_JSON)?, }); let gitignore = cwd.join(".gitignore"); @@ -164,7 +164,7 @@ pub(crate) fn render_init_claude_md(cwd: &Path) -> String { let mut lines = vec![ "# CLAUDE.md".to_string(), String::new(), - "This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.".to_string(), + "This file provides guidance to Claw Code (clawcode.dev) when working with code in this repository.".to_string(), String::new(), ]; @@ -209,7 +209,7 @@ pub(crate) fn render_init_claude_md(cwd: &Path) -> String { lines.push("## Working agreement".to_string()); lines.push("- Prefer small, reviewable changes and keep generated bootstrap files aligned with actual repo workflows.".to_string()); - lines.push("- Keep shared defaults in `.claude.json`; reserve `.claude/settings.local.json` for machine-local overrides.".to_string()); + lines.push("- Keep shared defaults in `.claw.json`; reserve `.claw/settings.local.json` for machine-local overrides.".to_string()); lines.push("- Do not overwrite existing `CLAUDE.md` content automatically; update it intentionally when repo workflows change.".to_string()); lines.push(String::new()); @@ -354,26 +354,27 @@ mod tests { let report = initialize_repo(&root).expect("init should succeed"); let rendered = report.render(); - assert!(rendered.contains(".claude/ created")); - assert!(rendered.contains(".claude.json created")); + assert!(rendered.contains(".claw/")); + assert!(rendered.contains(".claw.json")); + assert!(rendered.contains("created")); assert!(rendered.contains(".gitignore created")); assert!(rendered.contains("CLAUDE.md created")); - assert!(root.join(".claude").is_dir()); - assert!(root.join(".claude.json").is_file()); + assert!(root.join(".claw").is_dir()); + assert!(root.join(".claw.json").is_file()); assert!(root.join("CLAUDE.md").is_file()); assert_eq!( - fs::read_to_string(root.join(".claude.json")).expect("read claude json"), + fs::read_to_string(root.join(".claw.json")).expect("read claw json"), concat!( "{\n", " \"permissions\": {\n", - " \"defaultMode\": \"acceptEdits\"\n", + " \"defaultMode\": \"dontAsk\"\n", " }\n", "}\n", ) ); let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore"); - assert!(gitignore.contains(".claude/settings.local.json")); - assert!(gitignore.contains(".claude/sessions/")); + assert!(gitignore.contains(".claw/settings.local.json")); + assert!(gitignore.contains(".claw/sessions/")); let claude_md = fs::read_to_string(root.join("CLAUDE.md")).expect("read claude md"); assert!(claude_md.contains("Languages: Rust.")); assert!(claude_md.contains("cargo clippy --workspace --all-targets -- -D warnings")); @@ -386,8 +387,7 @@ mod tests { let root = temp_dir(); fs::create_dir_all(&root).expect("create root"); fs::write(root.join("CLAUDE.md"), "custom guidance\n").expect("write existing claude md"); - fs::write(root.join(".gitignore"), ".claude/settings.local.json\n") - .expect("write gitignore"); + fs::write(root.join(".gitignore"), ".claw/settings.local.json\n").expect("write gitignore"); let first = initialize_repo(&root).expect("first init should succeed"); assert!(first @@ -395,8 +395,9 @@ mod tests { .contains("CLAUDE.md skipped (already exists)")); let second = initialize_repo(&root).expect("second init should succeed"); let second_rendered = second.render(); - assert!(second_rendered.contains(".claude/ skipped (already exists)")); - assert!(second_rendered.contains(".claude.json skipped (already exists)")); + assert!(second_rendered.contains(".claw/")); + assert!(second_rendered.contains(".claw.json")); + assert!(second_rendered.contains("skipped (already exists)")); assert!(second_rendered.contains(".gitignore skipped (already exists)")); assert!(second_rendered.contains("CLAUDE.md skipped (already exists)")); assert_eq!( @@ -404,8 +405,8 @@ mod tests { "custom guidance\n" ); let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore"); - assert_eq!(gitignore.matches(".claude/settings.local.json").count(), 1); - assert_eq!(gitignore.matches(".claude/sessions/").count(), 1); + assert_eq!(gitignore.matches(".claw/settings.local.json").count(), 1); + assert_eq!(gitignore.matches(".claw/sessions/").count(), 1); fs::remove_dir_all(root).expect("cleanup temp dir"); } diff --git a/crates/rusty-claude-cli/src/input.rs b/crates/rusty-claude-cli/src/input.rs index bca3791..b0664da 100644 --- a/crates/rusty-claude-cli/src/input.rs +++ b/crates/rusty-claude-cli/src/input.rs @@ -1,166 +1,17 @@ +use std::borrow::Cow; +use std::cell::RefCell; +use std::collections::BTreeSet; use std::io::{self, IsTerminal, Write}; -use crossterm::cursor::{MoveDown, MoveToColumn, MoveUp}; -use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers}; -use crossterm::queue; -use crossterm::terminal::{disable_raw_mode, enable_raw_mode, Clear, ClearType}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct InputBuffer { - buffer: String, - cursor: usize, -} - -impl InputBuffer { - #[must_use] - pub fn new() -> Self { - Self { - buffer: String::new(), - cursor: 0, - } - } - - pub fn insert(&mut self, ch: char) { - self.buffer.insert(self.cursor, ch); - self.cursor += ch.len_utf8(); - } - - pub fn insert_newline(&mut self) { - self.insert('\n'); - } - - pub fn backspace(&mut self) { - if self.cursor == 0 { - return; - } - - let previous = self.buffer[..self.cursor] - .char_indices() - .last() - .map_or(0, |(idx, _)| idx); - self.buffer.drain(previous..self.cursor); - self.cursor = previous; - } - - pub fn move_left(&mut self) { - if self.cursor == 0 { - return; - } - self.cursor = self.buffer[..self.cursor] - .char_indices() - .last() - .map_or(0, |(idx, _)| idx); - } - - pub fn move_right(&mut self) { - if self.cursor >= self.buffer.len() { - return; - } - if let Some(next) = self.buffer[self.cursor..].chars().next() { - self.cursor += next.len_utf8(); - } - } - - pub fn move_home(&mut self) { - self.cursor = 0; - } - - pub fn move_end(&mut self) { - self.cursor = self.buffer.len(); - } - - #[must_use] - pub fn as_str(&self) -> &str { - &self.buffer - } - - #[cfg(test)] - #[must_use] - pub fn cursor(&self) -> usize { - self.cursor - } - - pub fn clear(&mut self) { - self.buffer.clear(); - self.cursor = 0; - } - - pub fn replace(&mut self, value: impl Into) { - self.buffer = value.into(); - self.cursor = self.buffer.len(); - } - - #[must_use] - fn current_command_prefix(&self) -> Option<&str> { - if self.cursor != self.buffer.len() { - return None; - } - let prefix = &self.buffer[..self.cursor]; - if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') { - return None; - } - Some(prefix) - } - - pub fn complete_slash_command(&mut self, candidates: &[String]) -> bool { - let Some(prefix) = self.current_command_prefix() else { - return false; - }; - - let matches = candidates - .iter() - .filter(|candidate| candidate.starts_with(prefix)) - .map(String::as_str) - .collect::>(); - if matches.is_empty() { - return false; - } - - let replacement = longest_common_prefix(&matches); - if replacement == prefix { - return false; - } - - self.replace(replacement); - true - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct RenderedBuffer { - lines: Vec, - cursor_row: u16, - cursor_col: u16, -} - -impl RenderedBuffer { - #[must_use] - pub fn line_count(&self) -> usize { - self.lines.len() - } - - fn write(&self, out: &mut impl Write) -> io::Result<()> { - for (index, line) in self.lines.iter().enumerate() { - if index > 0 { - writeln!(out)?; - } - write!(out, "{line}")?; - } - Ok(()) - } - - #[cfg(test)] - #[must_use] - pub fn lines(&self) -> &[String] { - &self.lines - } - - #[cfg(test)] - #[must_use] - pub fn cursor_position(&self) -> (u16, u16) { - (self.cursor_row, self.cursor_col) - } -} +use rustyline::completion::{Completer, Pair}; +use rustyline::error::ReadlineError; +use rustyline::highlight::{CmdKind, Highlighter}; +use rustyline::hint::Hinter; +use rustyline::history::DefaultHistory; +use rustyline::validate::Validator; +use rustyline::{ + Cmd, CompletionType, Config, Context, EditMode, Editor, Helper, KeyCode, KeyEvent, Modifiers, +}; #[derive(Debug, Clone, PartialEq, Eq)] pub enum ReadOutcome { @@ -169,25 +20,105 @@ pub enum ReadOutcome { Exit, } +struct SlashCommandHelper { + completions: Vec, + current_line: RefCell, +} + +impl SlashCommandHelper { + fn new(completions: Vec) -> Self { + Self { + completions: normalize_completions(completions), + current_line: RefCell::new(String::new()), + } + } + + fn reset_current_line(&self) { + self.current_line.borrow_mut().clear(); + } + + fn current_line(&self) -> String { + self.current_line.borrow().clone() + } + + fn set_current_line(&self, line: &str) { + let mut current = self.current_line.borrow_mut(); + current.clear(); + current.push_str(line); + } + + fn set_completions(&mut self, completions: Vec) { + self.completions = normalize_completions(completions); + } +} + +impl Completer for SlashCommandHelper { + type Candidate = Pair; + + fn complete( + &self, + line: &str, + pos: usize, + _ctx: &Context<'_>, + ) -> rustyline::Result<(usize, Vec)> { + let Some(prefix) = slash_command_prefix(line, pos) else { + return Ok((0, Vec::new())); + }; + + let matches = self + .completions + .iter() + .filter(|candidate| candidate.starts_with(prefix)) + .map(|candidate| Pair { + display: candidate.clone(), + replacement: candidate.clone(), + }) + .collect(); + + Ok((0, matches)) + } +} + +impl Hinter for SlashCommandHelper { + type Hint = String; +} + +impl Highlighter for SlashCommandHelper { + fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { + self.set_current_line(line); + Cow::Borrowed(line) + } + + fn highlight_char(&self, line: &str, _pos: usize, _kind: CmdKind) -> bool { + self.set_current_line(line); + false + } +} + +impl Validator for SlashCommandHelper {} +impl Helper for SlashCommandHelper {} + pub struct LineEditor { prompt: String, - continuation_prompt: String, - history: Vec, - history_index: Option, - draft: Option, - completions: Vec, + editor: Editor, } impl LineEditor { #[must_use] pub fn new(prompt: impl Into, completions: Vec) -> Self { + let config = Config::builder() + .completion_type(CompletionType::List) + .edit_mode(EditMode::Emacs) + .build(); + let mut editor = Editor::::with_config(config) + .expect("rustyline editor should initialize"); + editor.set_helper(Some(SlashCommandHelper::new(completions))); + editor.bind_sequence(KeyEvent(KeyCode::Char('J'), Modifiers::CTRL), Cmd::Newline); + editor.bind_sequence(KeyEvent(KeyCode::Enter, Modifiers::SHIFT), Cmd::Newline); + Self { prompt: prompt.into(), - continuation_prompt: String::from("> "), - history: Vec::new(), - history_index: None, - draft: None, - completions, + editor, } } @@ -196,9 +127,14 @@ impl LineEditor { if entry.trim().is_empty() { return; } - self.history.push(entry); - self.history_index = None; - self.draft = None; + + let _ = self.editor.add_history_entry(entry); + } + + pub fn set_completions(&mut self, completions: Vec) { + if let Some(helper) = self.editor.helper_mut() { + helper.set_completions(completions); + } } pub fn read_line(&mut self) -> io::Result { @@ -206,45 +142,43 @@ impl LineEditor { return self.read_line_fallback(); } - enable_raw_mode()?; - let mut stdout = io::stdout(); - let mut input = InputBuffer::new(); - let mut rendered_lines = 1usize; - self.redraw(&mut stdout, &input, rendered_lines)?; + if let Some(helper) = self.editor.helper_mut() { + helper.reset_current_line(); + } - loop { - let event = event::read()?; - if let Event::Key(key) = event { - match self.handle_key(key, &mut input) { - EditorAction::Continue => { - rendered_lines = self.redraw(&mut stdout, &input, rendered_lines)?; - } - EditorAction::Submit => { - disable_raw_mode()?; - writeln!(stdout)?; - self.history_index = None; - self.draft = None; - return Ok(ReadOutcome::Submit(input.as_str().to_owned())); - } - EditorAction::Cancel => { - disable_raw_mode()?; - writeln!(stdout)?; - self.history_index = None; - self.draft = None; - return Ok(ReadOutcome::Cancel); - } - EditorAction::Exit => { - disable_raw_mode()?; - writeln!(stdout)?; - self.history_index = None; - self.draft = None; - return Ok(ReadOutcome::Exit); - } + match self.editor.readline(&self.prompt) { + Ok(line) => Ok(ReadOutcome::Submit(line)), + Err(ReadlineError::Interrupted) => { + let has_input = !self.current_line().is_empty(); + self.finish_interrupted_read()?; + if has_input { + Ok(ReadOutcome::Cancel) + } else { + Ok(ReadOutcome::Exit) } } + Err(ReadlineError::Eof) => { + self.finish_interrupted_read()?; + Ok(ReadOutcome::Exit) + } + Err(error) => Err(io::Error::other(error)), } } + fn current_line(&self) -> String { + self.editor + .helper() + .map_or_else(String::new, SlashCommandHelper::current_line) + } + + fn finish_interrupted_read(&mut self) -> io::Result<()> { + if let Some(helper) = self.editor.helper_mut() { + helper.reset_current_line(); + } + let mut stdout = io::stdout(); + writeln!(stdout) + } + fn read_line_fallback(&self) -> io::Result { let mut stdout = io::stdout(); write!(stdout, "{}", self.prompt)?; @@ -261,388 +195,136 @@ impl LineEditor { } Ok(ReadOutcome::Submit(buffer)) } - - #[allow(clippy::too_many_lines)] - fn handle_key(&mut self, key: KeyEvent, input: &mut InputBuffer) -> EditorAction { - match key { - KeyEvent { - code: KeyCode::Char('c'), - modifiers, - .. - } if modifiers.contains(KeyModifiers::CONTROL) => { - if input.as_str().is_empty() { - EditorAction::Exit - } else { - input.clear(); - self.history_index = None; - self.draft = None; - EditorAction::Cancel - } - } - KeyEvent { - code: KeyCode::Char('j'), - modifiers, - .. - } if modifiers.contains(KeyModifiers::CONTROL) => { - input.insert_newline(); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Enter, - modifiers, - .. - } if modifiers.contains(KeyModifiers::SHIFT) => { - input.insert_newline(); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Enter, - .. - } => EditorAction::Submit, - KeyEvent { - code: KeyCode::Backspace, - .. - } => { - input.backspace(); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Left, - .. - } => { - input.move_left(); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Right, - .. - } => { - input.move_right(); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Up, .. - } => { - self.navigate_history_up(input); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Down, - .. - } => { - self.navigate_history_down(input); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Tab, .. - } => { - input.complete_slash_command(&self.completions); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Home, - .. - } => { - input.move_home(); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::End, .. - } => { - input.move_end(); - EditorAction::Continue - } - KeyEvent { - code: KeyCode::Esc, .. - } => { - input.clear(); - self.history_index = None; - self.draft = None; - EditorAction::Cancel - } - KeyEvent { - code: KeyCode::Char(ch), - modifiers, - .. - } if modifiers.is_empty() || modifiers == KeyModifiers::SHIFT => { - input.insert(ch); - self.history_index = None; - self.draft = None; - EditorAction::Continue - } - _ => EditorAction::Continue, - } - } - - fn navigate_history_up(&mut self, input: &mut InputBuffer) { - if self.history.is_empty() { - return; - } - - match self.history_index { - Some(0) => {} - Some(index) => { - let next_index = index - 1; - input.replace(self.history[next_index].clone()); - self.history_index = Some(next_index); - } - None => { - self.draft = Some(input.as_str().to_owned()); - let next_index = self.history.len() - 1; - input.replace(self.history[next_index].clone()); - self.history_index = Some(next_index); - } - } - } - - fn navigate_history_down(&mut self, input: &mut InputBuffer) { - let Some(index) = self.history_index else { - return; - }; - - if index + 1 < self.history.len() { - let next_index = index + 1; - input.replace(self.history[next_index].clone()); - self.history_index = Some(next_index); - return; - } - - input.replace(self.draft.take().unwrap_or_default()); - self.history_index = None; - } - - fn redraw( - &self, - out: &mut impl Write, - input: &InputBuffer, - previous_line_count: usize, - ) -> io::Result { - let rendered = render_buffer(&self.prompt, &self.continuation_prompt, input); - if previous_line_count > 1 { - queue!(out, MoveUp(saturating_u16(previous_line_count - 1)))?; - } - queue!(out, MoveToColumn(0), Clear(ClearType::FromCursorDown),)?; - rendered.write(out)?; - queue!( - out, - MoveUp(saturating_u16(rendered.line_count().saturating_sub(1))), - MoveToColumn(0), - )?; - if rendered.cursor_row > 0 { - queue!(out, MoveDown(rendered.cursor_row))?; - } - queue!(out, MoveToColumn(rendered.cursor_col))?; - out.flush()?; - Ok(rendered.line_count()) - } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum EditorAction { - Continue, - Submit, - Cancel, - Exit, +fn slash_command_prefix(line: &str, pos: usize) -> Option<&str> { + if pos != line.len() { + return None; + } + + let prefix = &line[..pos]; + if !prefix.starts_with('/') { + return None; + } + + Some(prefix) } -#[must_use] -pub fn render_buffer( - prompt: &str, - continuation_prompt: &str, - input: &InputBuffer, -) -> RenderedBuffer { - let before_cursor = &input.as_str()[..input.cursor]; - let cursor_row = saturating_u16(before_cursor.chars().filter(|ch| *ch == '\n').count()); - let cursor_line = before_cursor.rsplit('\n').next().unwrap_or_default(); - let cursor_prompt = if cursor_row == 0 { - prompt - } else { - continuation_prompt - }; - let cursor_col = saturating_u16(cursor_prompt.chars().count() + cursor_line.chars().count()); - - let mut lines = Vec::new(); - for (index, line) in input.as_str().split('\n').enumerate() { - let prefix = if index == 0 { - prompt - } else { - continuation_prompt - }; - lines.push(format!("{prefix}{line}")); - } - if lines.is_empty() { - lines.push(prompt.to_string()); - } - - RenderedBuffer { - lines, - cursor_row, - cursor_col, - } -} - -#[must_use] -fn longest_common_prefix(values: &[&str]) -> String { - let Some(first) = values.first() else { - return String::new(); - }; - - let mut prefix = (*first).to_string(); - for value in values.iter().skip(1) { - while !value.starts_with(&prefix) { - prefix.pop(); - if prefix.is_empty() { - break; - } - } - } - prefix -} - -#[must_use] -fn saturating_u16(value: usize) -> u16 { - u16::try_from(value).unwrap_or(u16::MAX) +fn normalize_completions(completions: Vec) -> Vec { + let mut seen = BTreeSet::new(); + completions + .into_iter() + .filter(|candidate| candidate.starts_with('/')) + .filter(|candidate| seen.insert(candidate.clone())) + .collect() } #[cfg(test)] mod tests { - use super::{render_buffer, InputBuffer, LineEditor}; - use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + use super::{slash_command_prefix, LineEditor, SlashCommandHelper}; + use rustyline::completion::Completer; + use rustyline::highlight::Highlighter; + use rustyline::history::{DefaultHistory, History}; + use rustyline::Context; - fn key(code: KeyCode) -> KeyEvent { - KeyEvent::new(code, KeyModifiers::NONE) + #[test] + fn extracts_terminal_slash_command_prefixes_with_arguments() { + assert_eq!(slash_command_prefix("/he", 3), Some("/he")); + assert_eq!(slash_command_prefix("/help me", 8), Some("/help me")); + assert_eq!( + slash_command_prefix("/session switch ses", 19), + Some("/session switch ses") + ); + assert_eq!(slash_command_prefix("hello", 5), None); + assert_eq!(slash_command_prefix("/help", 2), None); } #[test] - fn supports_basic_line_editing() { - let mut input = InputBuffer::new(); - input.insert('h'); - input.insert('i'); - input.move_end(); - input.insert_newline(); - input.insert('x'); - - assert_eq!(input.as_str(), "hi\nx"); - assert_eq!(input.cursor(), 4); - - input.move_left(); - input.backspace(); - assert_eq!(input.as_str(), "hix"); - assert_eq!(input.cursor(), 2); - } - - #[test] - fn completes_unique_slash_command() { - let mut input = InputBuffer::new(); - for ch in "/he".chars() { - input.insert(ch); - } - - assert!(input.complete_slash_command(&[ + fn completes_matching_slash_commands() { + let helper = SlashCommandHelper::new(vec![ "/help".to_string(), "/hello".to_string(), "/status".to_string(), - ])); - assert_eq!(input.as_str(), "/hel"); + ]); + let history = DefaultHistory::new(); + let ctx = Context::new(&history); + let (start, matches) = helper + .complete("/he", 3, &ctx) + .expect("completion should work"); - assert!(input.complete_slash_command(&["/help".to_string(), "/status".to_string()])); - assert_eq!(input.as_str(), "/help"); - } - - #[test] - fn ignores_completion_when_prefix_is_not_a_slash_command() { - let mut input = InputBuffer::new(); - for ch in "hello".chars() { - input.insert(ch); - } - - assert!(!input.complete_slash_command(&["/help".to_string()])); - assert_eq!(input.as_str(), "hello"); - } - - #[test] - fn history_navigation_restores_current_draft() { - let mut editor = LineEditor::new("› ", vec![]); - editor.push_history("/help"); - editor.push_history("status report"); - - let mut input = InputBuffer::new(); - for ch in "draft".chars() { - input.insert(ch); - } - - let _ = editor.handle_key(key(KeyCode::Up), &mut input); - assert_eq!(input.as_str(), "status report"); - - let _ = editor.handle_key(key(KeyCode::Up), &mut input); - assert_eq!(input.as_str(), "/help"); - - let _ = editor.handle_key(key(KeyCode::Down), &mut input); - assert_eq!(input.as_str(), "status report"); - - let _ = editor.handle_key(key(KeyCode::Down), &mut input); - assert_eq!(input.as_str(), "draft"); - } - - #[test] - fn tab_key_completes_from_editor_candidates() { - let mut editor = LineEditor::new( - "› ", - vec![ - "/help".to_string(), - "/status".to_string(), - "/session".to_string(), - ], - ); - let mut input = InputBuffer::new(); - for ch in "/st".chars() { - input.insert(ch); - } - - let _ = editor.handle_key(key(KeyCode::Tab), &mut input); - assert_eq!(input.as_str(), "/status"); - } - - #[test] - fn renders_multiline_buffers_with_continuation_prompt() { - let mut input = InputBuffer::new(); - for ch in "hello\nworld".chars() { - if ch == '\n' { - input.insert_newline(); - } else { - input.insert(ch); - } - } - - let rendered = render_buffer("› ", "> ", &input); + assert_eq!(start, 0); assert_eq!( - rendered.lines(), - &["› hello".to_string(), "> world".to_string()] + matches + .into_iter() + .map(|candidate| candidate.replacement) + .collect::>(), + vec!["/help".to_string(), "/hello".to_string()] ); - assert_eq!(rendered.cursor_position(), (1, 7)); } #[test] - fn ctrl_c_exits_only_when_buffer_is_empty() { - let mut editor = LineEditor::new("› ", vec![]); - let mut empty = InputBuffer::new(); - assert!(matches!( - editor.handle_key( - KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL), - &mut empty, - ), - super::EditorAction::Exit - )); + fn completes_matching_slash_command_arguments() { + let helper = SlashCommandHelper::new(vec![ + "/model".to_string(), + "/model opus".to_string(), + "/model sonnet".to_string(), + "/session switch alpha".to_string(), + ]); + let history = DefaultHistory::new(); + let ctx = Context::new(&history); + let (start, matches) = helper + .complete("/model o", 8, &ctx) + .expect("completion should work"); - let mut filled = InputBuffer::new(); - filled.insert('x'); - assert!(matches!( - editor.handle_key( - KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL), - &mut filled, - ), - super::EditorAction::Cancel - )); - assert!(filled.as_str().is_empty()); + assert_eq!(start, 0); + assert_eq!( + matches + .into_iter() + .map(|candidate| candidate.replacement) + .collect::>(), + vec!["/model opus".to_string()] + ); + } + + #[test] + fn ignores_non_slash_command_completion_requests() { + let helper = SlashCommandHelper::new(vec!["/help".to_string()]); + let history = DefaultHistory::new(); + let ctx = Context::new(&history); + let (_, matches) = helper + .complete("hello", 5, &ctx) + .expect("completion should work"); + + assert!(matches.is_empty()); + } + + #[test] + fn tracks_current_buffer_through_highlighter() { + let helper = SlashCommandHelper::new(Vec::new()); + let _ = helper.highlight("draft", 5); + + assert_eq!(helper.current_line(), "draft"); + } + + #[test] + fn push_history_ignores_blank_entries() { + let mut editor = LineEditor::new("> ", vec!["/help".to_string()]); + editor.push_history(" "); + editor.push_history("/help"); + + assert_eq!(editor.editor.history().len(), 1); + } + + #[test] + fn set_completions_replaces_and_normalizes_candidates() { + let mut editor = LineEditor::new("> ", vec!["/help".to_string()]); + editor.set_completions(vec![ + "/model opus".to_string(), + "/model opus".to_string(), + "status".to_string(), + ]); + + let helper = editor.editor.helper().expect("helper should exist"); + assert_eq!(helper.completions, vec!["/model opus".to_string()]); } } diff --git a/crates/rusty-claude-cli/src/main.rs b/crates/rusty-claude-cli/src/main.rs index bf16026..7caef9d 100644 --- a/crates/rusty-claude-cli/src/main.rs +++ b/crates/rusty-claude-cli/src/main.rs @@ -1,107 +1,325 @@ +#![allow( + dead_code, + unused_imports, + unused_variables, + clippy::unneeded_struct_pattern, + clippy::unnecessary_wraps, + clippy::unused_self +)] mod init; mod input; mod render; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::BTreeSet; use std::env; use std::fs; -use std::io::{self, Read, Write}; +use std::io::{self, IsTerminal, Read, Write}; use std::net::TcpListener; +use std::ops::{Deref, DerefMut}; use std::path::{Path, PathBuf}; use std::process::Command; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender}; +use std::sync::{Arc, Mutex}; +use std::thread::{self, JoinHandle}; +use std::time::{Duration, Instant, UNIX_EPOCH}; use api::{ - resolve_startup_auth_source, AuthSource, ClawApiClient, ContentBlockDelta, InputContentBlock, - InputMessage, MessageRequest, MessageResponse, OutputContentBlock, - StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, + detect_provider_kind, oauth_token_is_expired, resolve_startup_auth_source, AnthropicClient, + AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, + MessageResponse, OutputContentBlock, PromptCache, ProviderClient as ApiProviderClient, + ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, + ToolResultContentBlock, }; use commands::{ - handle_agents_slash_command, handle_branch_slash_command, handle_commit_slash_command, - handle_plugins_slash_command, handle_skills_slash_command, handle_worktree_slash_command, - render_slash_command_help, resume_supported_slash_commands, slash_command_specs, SlashCommand, + classify_skills_slash_command, handle_agents_slash_command, handle_agents_slash_command_json, + handle_mcp_slash_command, handle_mcp_slash_command_json, handle_plugins_slash_command, + handle_skills_slash_command, handle_skills_slash_command_json, render_slash_command_help, + render_slash_command_help_filtered, resolve_skill_invocation, resume_supported_slash_commands, + slash_command_specs, validate_slash_command_input, SkillSlashDispatch, SlashCommand, }; use compat_harness::{extract_manifest, UpstreamPaths}; use init::initialize_repo; -use render::{Spinner, TerminalRenderer}; +use plugins::{PluginHooks, PluginManager, PluginManagerConfig, PluginRegistry}; +use render::{MarkdownStreamState, Spinner, TerminalRenderer}; use runtime::{ - clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt, - parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, - AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, - ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, - OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, - Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, + check_base_commit, clear_oauth_credentials, format_stale_base_warning, format_usd, + generate_pkce_pair, generate_state, load_oauth_credentials, load_system_prompt, + parse_oauth_callback_request_target, pricing_for_model, resolve_expected_base, + resolve_sandbox_status, save_oauth_credentials, ApiClient, ApiRequest, AssistantEvent, + CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, ConversationMessage, + ConversationRuntime, McpServer, McpServerManager, McpServerSpec, McpTool, MessageRole, + ModelPricing, OAuthAuthorizationRequest, OAuthConfig, OAuthTokenExchangeRequest, + PermissionMode, PermissionPolicy, ProjectContext, PromptCacheEvent, ResolvedPermissionMode, + RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, +}; +use serde::Deserialize; +use serde_json::{json, Map, Value}; +use tools::{ + execute_tool, mvp_tool_specs, GlobalToolRegistry, RuntimeToolDefinition, ToolSearchOutput, }; -use serde_json::json; -use tools::{execute_tool, mvp_tool_specs, ToolSpec}; -use plugins::{self}; const DEFAULT_MODEL: &str = "claude-opus-4-6"; -const DEFAULT_MAX_TOKENS: u32 = 32; -const DEFAULT_DATE: &str = "2026-03-31"; +fn max_tokens_for_model(model: &str) -> u32 { + if model.contains("opus") { + 32_000 + } else { + 64_000 + } +} +// Build-time constants injected by build.rs (fall back to static values when +// build.rs hasn't run, e.g. in doc-test or unusual toolchain environments). +const DEFAULT_DATE: &str = match option_env!("BUILD_DATE") { + Some(d) => d, + None => "unknown", +}; const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; const VERSION: &str = env!("CARGO_PKG_VERSION"); const BUILD_TARGET: Option<&str> = option_env!("TARGET"); const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); +const INTERNAL_PROGRESS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3); +const POST_TOOL_STALL_TIMEOUT: Duration = Duration::from_secs(10); +const PRIMARY_SESSION_EXTENSION: &str = "jsonl"; +const LEGACY_SESSION_EXTENSION: &str = "json"; +const LATEST_SESSION_REFERENCE: &str = "latest"; +const SESSION_REFERENCE_ALIASES: &[&str] = &[LATEST_SESSION_REFERENCE, "last", "recent"]; +const CLI_OPTION_SUGGESTIONS: &[&str] = &[ + "--help", + "-h", + "--version", + "-V", + "--model", + "--output-format", + "--permission-mode", + "--dangerously-skip-permissions", + "--allowedTools", + "--allowed-tools", + "--resume", + "--print", + "--compact", + "--base-commit", + "-p", +]; type AllowedToolSet = BTreeSet; +type RuntimePluginStateBuildOutput = ( + Option>>, + Vec, +); fn main() { if let Err(error) = run() { - eprintln!( - "error: {error} + let message = error.to_string(); + // When --output-format json is active, emit errors as JSON so downstream + // tools can parse failures the same way they parse successes (ROADMAP #42). + let argv: Vec = std::env::args().collect(); + let json_output = argv + .windows(2) + .any(|w| w[0] == "--output-format" && w[1] == "json") + || argv.iter().any(|a| a == "--output-format=json"); + if json_output { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": message, + }) + ); + } else if message.contains("`claw --help`") { + eprintln!("error: {message}"); + } else { + eprintln!( + "error: {message} Run `claw --help` for usage." - ); + ); + } std::process::exit(1); } } +/// Read piped stdin content when stdin is not a terminal. +/// +/// Returns `None` when stdin is attached to a terminal (interactive REPL use), +/// when reading fails, or when the piped content is empty after trimming. +/// Returns `Some(raw_content)` when a pipe delivered non-empty content. +fn read_piped_stdin() -> Option { + if io::stdin().is_terminal() { + return None; + } + let mut buffer = String::new(); + if io::stdin().read_to_string(&mut buffer).is_err() { + return None; + } + if buffer.trim().is_empty() { + return None; + } + Some(buffer) +} + +/// Merge a piped stdin payload into a prompt argument. +/// +/// When `stdin_content` is `None` or empty after trimming, the prompt is +/// returned unchanged. Otherwise the trimmed stdin content is appended to the +/// prompt separated by a blank line so the model sees the prompt first and the +/// piped context immediately after it. +fn merge_prompt_with_stdin(prompt: &str, stdin_content: Option<&str>) -> String { + let Some(raw) = stdin_content else { + return prompt.to_string(); + }; + let trimmed = raw.trim(); + if trimmed.is_empty() { + return prompt.to_string(); + } + if prompt.is_empty() { + return trimmed.to_string(); + } + format!("{prompt}\n\n{trimmed}") +} + fn run() -> Result<(), Box> { let args: Vec = env::args().skip(1).collect(); match parse_args(&args)? { - CliAction::DumpManifests => dump_manifests(), - CliAction::BootstrapPlan => print_bootstrap_plan(), - CliAction::PrintSystemPrompt { cwd, date } => print_system_prompt(cwd, date), - CliAction::Version => print_version(), + CliAction::DumpManifests { output_format } => dump_manifests(output_format)?, + CliAction::BootstrapPlan { output_format } => print_bootstrap_plan(output_format)?, + CliAction::Agents { + args, + output_format, + } => LiveCli::print_agents(args.as_deref(), output_format)?, + CliAction::Mcp { + args, + output_format, + } => LiveCli::print_mcp(args.as_deref(), output_format)?, + CliAction::Skills { + args, + output_format, + } => LiveCli::print_skills(args.as_deref(), output_format)?, + CliAction::Plugins { + action, + target, + output_format, + } => LiveCli::print_plugins(action.as_deref(), target.as_deref(), output_format)?, + CliAction::PrintSystemPrompt { + cwd, + date, + output_format, + } => print_system_prompt(cwd, date, output_format)?, + CliAction::Version { output_format } => print_version(output_format)?, CliAction::ResumeSession { session_path, commands, - } => resume_session(&session_path, &commands), + output_format, + } => resume_session(&session_path, &commands, output_format), + CliAction::Status { + model, + permission_mode, + output_format, + } => print_status_snapshot(&model, permission_mode, output_format)?, + CliAction::Sandbox { output_format } => print_sandbox_status_snapshot(output_format)?, CliAction::Prompt { prompt, model, output_format, allowed_tools, permission_mode, - } => LiveCli::new(model, false, allowed_tools, permission_mode)? - .run_turn_with_output(&prompt, output_format)?, - CliAction::Login => run_login()?, - CliAction::Logout => run_logout()?, - CliAction::Init => run_init()?, + compact, + base_commit, + reasoning_effort, + allow_broad_cwd, + } => { + enforce_broad_cwd_policy(allow_broad_cwd, output_format)?; + run_stale_base_preflight(base_commit.as_deref()); + // Only consume piped stdin as prompt context when the permission + // mode is fully unattended. In modes where the permission + // prompter may invoke CliPermissionPrompter::decide(), stdin + // must remain available for interactive approval; otherwise the + // prompter's read_line() would hit EOF and deny every request. + let stdin_context = if matches!(permission_mode, PermissionMode::DangerFullAccess) { + read_piped_stdin() + } else { + None + }; + let effective_prompt = merge_prompt_with_stdin(&prompt, stdin_context.as_deref()); + let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; + cli.set_reasoning_effort(reasoning_effort); + cli.run_turn_with_output(&effective_prompt, output_format, compact)?; + } + CliAction::Login { output_format } => run_login(output_format)?, + CliAction::Logout { output_format } => run_logout(output_format)?, + CliAction::Doctor { output_format } => run_doctor(output_format)?, + CliAction::State { output_format } => run_worker_state(output_format)?, + CliAction::Init { output_format } => run_init(output_format)?, + CliAction::Export { + session_reference, + output_path, + output_format, + } => run_export(&session_reference, output_path.as_deref(), output_format)?, CliAction::Repl { model, allowed_tools, permission_mode, - } => run_repl(model, allowed_tools, permission_mode)?, - CliAction::Help => print_help(), + base_commit, + reasoning_effort, + allow_broad_cwd, + } => run_repl( + model, + allowed_tools, + permission_mode, + base_commit, + reasoning_effort, + allow_broad_cwd, + )?, + CliAction::HelpTopic(topic) => print_help_topic(topic), + CliAction::Help { output_format } => print_help(output_format)?, } Ok(()) } #[derive(Debug, Clone, PartialEq, Eq)] enum CliAction { - DumpManifests, - BootstrapPlan, + DumpManifests { + output_format: CliOutputFormat, + }, + BootstrapPlan { + output_format: CliOutputFormat, + }, + Agents { + args: Option, + output_format: CliOutputFormat, + }, + Mcp { + args: Option, + output_format: CliOutputFormat, + }, + Skills { + args: Option, + output_format: CliOutputFormat, + }, + Plugins { + action: Option, + target: Option, + output_format: CliOutputFormat, + }, PrintSystemPrompt { cwd: PathBuf, date: String, + output_format: CliOutputFormat, + }, + Version { + output_format: CliOutputFormat, }, - Version, ResumeSession { session_path: PathBuf, commands: Vec, + output_format: CliOutputFormat, + }, + Status { + model: String, + permission_mode: PermissionMode, + output_format: CliOutputFormat, + }, + Sandbox { + output_format: CliOutputFormat, }, Prompt { prompt: String, @@ -109,17 +327,51 @@ enum CliAction { output_format: CliOutputFormat, allowed_tools: Option, permission_mode: PermissionMode, + compact: bool, + base_commit: Option, + reasoning_effort: Option, + allow_broad_cwd: bool, + }, + Login { + output_format: CliOutputFormat, + }, + Logout { + output_format: CliOutputFormat, + }, + Doctor { + output_format: CliOutputFormat, + }, + State { + output_format: CliOutputFormat, + }, + Init { + output_format: CliOutputFormat, + }, + Export { + session_reference: String, + output_path: Option, + output_format: CliOutputFormat, }, - Login, - Logout, - Init, Repl { model: String, allowed_tools: Option, permission_mode: PermissionMode, + base_commit: Option, + reasoning_effort: Option, + allow_broad_cwd: bool, }, + HelpTopic(LocalHelpTopic), // prompt-mode formatting is only supported for non-interactive runs - Help, + Help { + output_format: CliOutputFormat, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum LocalHelpTopic { + Status, + Sandbox, + Doctor, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -144,14 +396,48 @@ impl CliOutputFormat { fn parse_args(args: &[String]) -> Result { let mut model = DEFAULT_MODEL.to_string(); let mut output_format = CliOutputFormat::Text; - let mut permission_mode = default_permission_mode(); + let mut permission_mode_override = None; + let mut wants_help = false; let mut wants_version = false; let mut allowed_tool_values = Vec::new(); - let mut rest = Vec::new(); + let mut compact = false; + let mut base_commit: Option = None; + let mut reasoning_effort: Option = None; + let mut allow_broad_cwd = false; + let mut rest: Vec = Vec::new(); let mut index = 0; while index < args.len() { match args[index].as_str() { + "--help" | "-h" if rest.is_empty() => { + wants_help = true; + index += 1; + } + "--help" | "-h" + if !rest.is_empty() + && matches!( + rest[0].as_str(), + "prompt" + | "login" + | "logout" + | "version" + | "state" + | "init" + | "export" + | "commit" + | "pr" + | "issue" + ) => + { + // `--help` following a subcommand that would otherwise forward + // the arg to the API (e.g. `claw prompt --help`) should show + // top-level help instead. Subcommands that consume their own + // args (agents, mcp, plugins, skills) and local help-topic + // subcommands (status, sandbox, doctor) must NOT be intercepted + // here — they handle --help in their own dispatch paths. + wants_help = true; + index += 1; + } "--version" | "-V" => { wants_version = true; index += 1; @@ -160,11 +446,11 @@ fn parse_args(args: &[String]) -> Result { let value = args .get(index + 1) .ok_or_else(|| "missing value for --model".to_string())?; - model.clone_from(value); + model = resolve_model_alias_with_config(value); index += 2; } flag if flag.starts_with("--model=") => { - model = flag[8..].to_string(); + model = resolve_model_alias_with_config(&flag[8..]); index += 1; } "--output-format" => { @@ -178,7 +464,7 @@ fn parse_args(args: &[String]) -> Result { let value = args .get(index + 1) .ok_or_else(|| "missing value for --permission-mode".to_string())?; - permission_mode = parse_permission_mode_arg(value)?; + permission_mode_override = Some(parse_permission_mode_arg(value)?); index += 2; } flag if flag.starts_with("--output-format=") => { @@ -186,7 +472,85 @@ fn parse_args(args: &[String]) -> Result { index += 1; } flag if flag.starts_with("--permission-mode=") => { - permission_mode = parse_permission_mode_arg(&flag[18..])?; + permission_mode_override = Some(parse_permission_mode_arg(&flag[18..])?); + index += 1; + } + "--dangerously-skip-permissions" => { + permission_mode_override = Some(PermissionMode::DangerFullAccess); + index += 1; + } + "--compact" => { + compact = true; + index += 1; + } + "--base-commit" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --base-commit".to_string())?; + base_commit = Some(value.clone()); + index += 2; + } + flag if flag.starts_with("--base-commit=") => { + base_commit = Some(flag[14..].to_string()); + index += 1; + } + "--reasoning-effort" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --reasoning-effort".to_string())?; + if !matches!(value.as_str(), "low" | "medium" | "high") { + return Err(format!( + "invalid value for --reasoning-effort: '{value}'; must be low, medium, or high" + )); + } + reasoning_effort = Some(value.clone()); + index += 2; + } + flag if flag.starts_with("--reasoning-effort=") => { + let value = &flag[19..]; + if !matches!(value, "low" | "medium" | "high") { + return Err(format!( + "invalid value for --reasoning-effort: '{value}'; must be low, medium, or high" + )); + } + reasoning_effort = Some(value.to_string()); + index += 1; + } + "--allow-broad-cwd" => { + allow_broad_cwd = true; + index += 1; + } + "-p" => { + // Claw Code compat: -p "prompt" = one-shot prompt + let prompt = args[index + 1..].join(" "); + if prompt.trim().is_empty() { + return Err("-p requires a prompt string".to_string()); + } + return Ok(CliAction::Prompt { + prompt, + model: resolve_model_alias_with_config(&model), + output_format, + allowed_tools: normalize_allowed_tools(&allowed_tool_values)?, + permission_mode: permission_mode_override + .unwrap_or_else(default_permission_mode), + compact, + base_commit: base_commit.clone(), + reasoning_effort: reasoning_effort.clone(), + allow_broad_cwd, + }); + } + "--print" => { + // Claw Code compat: --print makes output non-interactive + output_format = CliOutputFormat::Text; + index += 1; + } + "--resume" if rest.is_empty() => { + rest.push("--resume".to_string()); + index += 1; + } + flag if rest.is_empty() && flag.starts_with("--resume=") => { + rest.push("--resume".to_string()); + rest.push(flag[9..].to_string()); index += 1; } "--allowedTools" | "--allowed-tools" => { @@ -204,6 +568,9 @@ fn parse_args(args: &[String]) -> Result { allowed_tool_values.push(flag[16..].to_string()); index += 1; } + other if rest.is_empty() && other.starts_with('-') => { + return Err(format_unknown_option(other)) + } other => { rest.push(other.to_string()); index += 1; @@ -211,33 +578,99 @@ fn parse_args(args: &[String]) -> Result { } } + if wants_help { + return Ok(CliAction::Help { output_format }); + } + if wants_version { - return Ok(CliAction::Version); + return Ok(CliAction::Version { output_format }); } let allowed_tools = normalize_allowed_tools(&allowed_tool_values)?; if rest.is_empty() { + let permission_mode = permission_mode_override.unwrap_or_else(default_permission_mode); + // When stdin is not a terminal (pipe/redirect) and no prompt is given on the + // command line, read stdin as the prompt and dispatch as a one-shot Prompt + // rather than starting the interactive REPL (which would consume the pipe and + // print the startup banner, then exit without sending anything to the API). + if !std::io::stdin().is_terminal() { + let mut buf = String::new(); + let _ = std::io::Read::read_to_string(&mut std::io::stdin(), &mut buf); + let piped = buf.trim().to_string(); + if !piped.is_empty() { + return Ok(CliAction::Prompt { + model, + prompt: piped, + allowed_tools, + permission_mode, + output_format, + compact: false, + base_commit, + reasoning_effort, + allow_broad_cwd, + }); + } + } return Ok(CliAction::Repl { model, allowed_tools, permission_mode, + base_commit, + reasoning_effort: reasoning_effort.clone(), + allow_broad_cwd, }); } - if matches!(rest.first().map(String::as_str), Some("--help" | "-h")) { - return Ok(CliAction::Help); - } if rest.first().map(String::as_str) == Some("--resume") { - return parse_resume_args(&rest[1..]); + return parse_resume_args(&rest[1..], output_format); + } + if let Some(action) = parse_local_help_action(&rest) { + return action; + } + if let Some(action) = + parse_single_word_command_alias(&rest, &model, permission_mode_override, output_format) + { + return action; } + let permission_mode = permission_mode_override.unwrap_or_else(default_permission_mode); + match rest[0].as_str() { - "dump-manifests" => Ok(CliAction::DumpManifests), - "bootstrap-plan" => Ok(CliAction::BootstrapPlan), - "system-prompt" => parse_system_prompt_args(&rest[1..]), - "login" => Ok(CliAction::Login), - "logout" => Ok(CliAction::Logout), - "init" => Ok(CliAction::Init), + "dump-manifests" => Ok(CliAction::DumpManifests { output_format }), + "bootstrap-plan" => Ok(CliAction::BootstrapPlan { output_format }), + "agents" => Ok(CliAction::Agents { + args: join_optional_args(&rest[1..]), + output_format, + }), + "mcp" => Ok(CliAction::Mcp { + args: join_optional_args(&rest[1..]), + output_format, + }), + "skills" => { + let args = join_optional_args(&rest[1..]); + match classify_skills_slash_command(args.as_deref()) { + SkillSlashDispatch::Invoke(prompt) => Ok(CliAction::Prompt { + prompt, + model, + output_format, + allowed_tools, + permission_mode, + compact, + base_commit, + reasoning_effort: reasoning_effort.clone(), + allow_broad_cwd, + }), + SkillSlashDispatch::Local => Ok(CliAction::Skills { + args, + output_format, + }), + } + } + "system-prompt" => parse_system_prompt_args(&rest[1..], output_format), + "login" => Ok(CliAction::Login { output_format }), + "logout" => Ok(CliAction::Logout { output_format }), + "init" => Ok(CliAction::Init { output_format }), + "export" => parse_export_args(&rest[1..], output_format), "prompt" => { let prompt = rest[1..].join(" "); if prompt.trim().is_empty() { @@ -249,65 +682,354 @@ fn parse_args(args: &[String]) -> Result { output_format, allowed_tools, permission_mode, + compact, + base_commit: base_commit.clone(), + reasoning_effort: reasoning_effort.clone(), + allow_broad_cwd, }) } - other if !other.starts_with('/') => Ok(CliAction::Prompt { + other if other.starts_with('/') => parse_direct_slash_cli_action( + &rest, + model, + output_format, + allowed_tools, + permission_mode, + compact, + base_commit, + reasoning_effort, + allow_broad_cwd, + ), + _other => Ok(CliAction::Prompt { prompt: rest.join(" "), model, output_format, allowed_tools, permission_mode, + compact, + base_commit, + reasoning_effort: reasoning_effort.clone(), + allow_broad_cwd, }), - other => Err(format!("unknown subcommand: {other}")), } } +fn parse_local_help_action(rest: &[String]) -> Option> { + if rest.len() != 2 || !is_help_flag(&rest[1]) { + return None; + } + + let topic = match rest[0].as_str() { + "status" => LocalHelpTopic::Status, + "sandbox" => LocalHelpTopic::Sandbox, + "doctor" => LocalHelpTopic::Doctor, + _ => return None, + }; + Some(Ok(CliAction::HelpTopic(topic))) +} + +fn is_help_flag(value: &str) -> bool { + matches!(value, "--help" | "-h") +} + +fn parse_single_word_command_alias( + rest: &[String], + model: &str, + permission_mode_override: Option, + output_format: CliOutputFormat, +) -> Option> { + if rest.len() != 1 { + return None; + } + + match rest[0].as_str() { + "help" => Some(Ok(CliAction::Help { output_format })), + "version" => Some(Ok(CliAction::Version { output_format })), + "status" => Some(Ok(CliAction::Status { + model: model.to_string(), + permission_mode: permission_mode_override.unwrap_or_else(default_permission_mode), + output_format, + })), + "sandbox" => Some(Ok(CliAction::Sandbox { output_format })), + "doctor" => Some(Ok(CliAction::Doctor { output_format })), + "state" => Some(Ok(CliAction::State { output_format })), + other => bare_slash_command_guidance(other).map(Err), + } +} + +fn bare_slash_command_guidance(command_name: &str) -> Option { + if matches!( + command_name, + "dump-manifests" + | "bootstrap-plan" + | "agents" + | "mcp" + | "skills" + | "system-prompt" + | "login" + | "logout" + | "init" + | "prompt" + | "export" + ) { + return None; + } + let slash_command = slash_command_specs() + .iter() + .find(|spec| spec.name == command_name)?; + let guidance = if slash_command.resume_supported { + format!( + "`claw {command_name}` is a slash command. Use `claw --resume SESSION.jsonl /{command_name}` or start `claw` and run `/{command_name}`." + ) + } else { + format!( + "`claw {command_name}` is a slash command. Start `claw` and run `/{command_name}` inside the REPL." + ) + }; + Some(guidance) +} + +fn join_optional_args(args: &[String]) -> Option { + let joined = args.join(" "); + let trimmed = joined.trim(); + (!trimmed.is_empty()).then(|| trimmed.to_string()) +} + +fn parse_direct_slash_cli_action( + rest: &[String], + model: String, + output_format: CliOutputFormat, + allowed_tools: Option, + permission_mode: PermissionMode, + compact: bool, + base_commit: Option, + reasoning_effort: Option, + allow_broad_cwd: bool, +) -> Result { + let raw = rest.join(" "); + match SlashCommand::parse(&raw) { + Ok(Some(SlashCommand::Help)) => Ok(CliAction::Help { output_format }), + Ok(Some(SlashCommand::Agents { args })) => Ok(CliAction::Agents { + args, + output_format, + }), + Ok(Some(SlashCommand::Mcp { action, target })) => Ok(CliAction::Mcp { + args: match (action, target) { + (None, None) => None, + (Some(action), None) => Some(action), + (Some(action), Some(target)) => Some(format!("{action} {target}")), + (None, Some(target)) => Some(target), + }, + output_format, + }), + Ok(Some(SlashCommand::Skills { args })) => { + match classify_skills_slash_command(args.as_deref()) { + SkillSlashDispatch::Invoke(prompt) => Ok(CliAction::Prompt { + prompt, + model, + output_format, + allowed_tools, + permission_mode, + compact, + base_commit, + reasoning_effort: reasoning_effort.clone(), + allow_broad_cwd, + }), + SkillSlashDispatch::Local => Ok(CliAction::Skills { + args, + output_format, + }), + } + } + Ok(Some(SlashCommand::Unknown(name))) => Err(format_unknown_direct_slash_command(&name)), + Ok(Some(command)) => Err({ + let _ = command; + format!( + "slash command {command_name} is interactive-only. Start `claw` and run it there, or use `claw --resume SESSION.jsonl {command_name}` / `claw --resume {latest} {command_name}` when the command is marked [resume] in /help.", + command_name = rest[0], + latest = LATEST_SESSION_REFERENCE, + ) + }), + Ok(None) => Err(format!("unknown subcommand: {}", rest[0])), + Err(error) => Err(error.to_string()), + } +} + +fn format_unknown_option(option: &str) -> String { + let mut message = format!("unknown option: {option}"); + if let Some(suggestion) = suggest_closest_term(option, CLI_OPTION_SUGGESTIONS) { + message.push_str("\nDid you mean "); + message.push_str(suggestion); + message.push('?'); + } + message.push_str("\nRun `claw --help` for usage."); + message +} + +fn format_unknown_direct_slash_command(name: &str) -> String { + let mut message = format!("unknown slash command outside the REPL: /{name}"); + if let Some(suggestions) = render_suggestion_line("Did you mean", &suggest_slash_commands(name)) + { + message.push('\n'); + message.push_str(&suggestions); + } + if let Some(note) = omc_compatibility_note_for_unknown_slash_command(name) { + message.push('\n'); + message.push_str(note); + } + message.push_str("\nRun `claw --help` for CLI usage, or start `claw` and use /help."); + message +} + +fn format_unknown_slash_command(name: &str) -> String { + let mut message = format!("Unknown slash command: /{name}"); + if let Some(suggestions) = render_suggestion_line("Did you mean", &suggest_slash_commands(name)) + { + message.push('\n'); + message.push_str(&suggestions); + } + if let Some(note) = omc_compatibility_note_for_unknown_slash_command(name) { + message.push('\n'); + message.push_str(note); + } + message.push_str("\n Help /help lists available slash commands"); + message +} + +fn omc_compatibility_note_for_unknown_slash_command(name: &str) -> Option<&'static str> { + name.starts_with("oh-my-claudecode:") + .then_some( + "Compatibility note: `/oh-my-claudecode:*` is a Claude Code/OMC plugin command. `claw` does not yet load plugin slash commands, Claude statusline stdin, or OMC session hooks.", + ) +} + +fn render_suggestion_line(label: &str, suggestions: &[String]) -> Option { + (!suggestions.is_empty()).then(|| format!(" {label:<16} {}", suggestions.join(", "),)) +} + +fn suggest_slash_commands(input: &str) -> Vec { + let mut candidates = slash_command_specs() + .iter() + .flat_map(|spec| { + std::iter::once(spec.name) + .chain(spec.aliases.iter().copied()) + .map(|name| format!("/{name}")) + .collect::>() + }) + .collect::>(); + candidates.sort(); + candidates.dedup(); + let candidate_refs = candidates.iter().map(String::as_str).collect::>(); + ranked_suggestions(input.trim_start_matches('/'), &candidate_refs) + .into_iter() + .map(str::to_string) + .collect() +} + +fn suggest_closest_term<'a>(input: &str, candidates: &'a [&'a str]) -> Option<&'a str> { + ranked_suggestions(input, candidates).into_iter().next() +} + +fn ranked_suggestions<'a>(input: &str, candidates: &'a [&'a str]) -> Vec<&'a str> { + let normalized_input = input.trim_start_matches('/').to_ascii_lowercase(); + let mut ranked = candidates + .iter() + .filter_map(|candidate| { + let normalized_candidate = candidate.trim_start_matches('/').to_ascii_lowercase(); + let distance = levenshtein_distance(&normalized_input, &normalized_candidate); + let prefix_bonus = usize::from( + !(normalized_candidate.starts_with(&normalized_input) + || normalized_input.starts_with(&normalized_candidate)), + ); + let score = distance + prefix_bonus; + (score <= 4).then_some((score, *candidate)) + }) + .collect::>(); + ranked.sort_by(|left, right| left.cmp(right).then_with(|| left.1.cmp(right.1))); + ranked + .into_iter() + .map(|(_, candidate)| candidate) + .take(3) + .collect() +} + +fn levenshtein_distance(left: &str, right: &str) -> usize { + if left.is_empty() { + return right.chars().count(); + } + if right.is_empty() { + return left.chars().count(); + } + + let right_chars = right.chars().collect::>(); + let mut previous = (0..=right_chars.len()).collect::>(); + let mut current = vec![0; right_chars.len() + 1]; + + for (left_index, left_char) in left.chars().enumerate() { + current[0] = left_index + 1; + for (right_index, right_char) in right_chars.iter().enumerate() { + let substitution_cost = usize::from(left_char != *right_char); + current[right_index + 1] = (previous[right_index + 1] + 1) + .min(current[right_index] + 1) + .min(previous[right_index] + substitution_cost); + } + previous.clone_from(¤t); + } + + previous[right_chars.len()] +} + +fn resolve_model_alias(model: &str) -> &str { + match model { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => model, + } +} + +/// Resolve a model name through user-defined config aliases first, then fall +/// back to the built-in alias table. This is the entry point used wherever a +/// user-supplied model string is about to be dispatched to a provider. +fn resolve_model_alias_with_config(model: &str) -> String { + let trimmed = model.trim(); + if let Some(resolved) = config_alias_for_current_dir(trimmed) { + return resolve_model_alias(&resolved).to_string(); + } + resolve_model_alias(trimmed).to_string() +} + +fn config_alias_for_current_dir(alias: &str) -> Option { + if alias.is_empty() { + return None; + } + let cwd = env::current_dir().ok()?; + let loader = ConfigLoader::default_for(&cwd); + let config = loader.load().ok()?; + config.aliases().get(alias).cloned() +} + fn normalize_allowed_tools(values: &[String]) -> Result, String> { if values.is_empty() { return Ok(None); } - - let canonical_names = mvp_tool_specs() - .into_iter() - .map(|spec| spec.name.to_string()) - .collect::>(); - let mut name_map = canonical_names - .iter() - .map(|name| (normalize_tool_name(name), name.clone())) - .collect::>(); - - for (alias, canonical) in [ - ("read", "read_file"), - ("write", "write_file"), - ("edit", "edit_file"), - ("glob", "glob_search"), - ("grep", "grep_search"), - ] { - name_map.insert(alias.to_string(), canonical.to_string()); - } - - let mut allowed = AllowedToolSet::new(); - for value in values { - for token in value - .split(|ch: char| ch == ',' || ch.is_whitespace()) - .filter(|token| !token.is_empty()) - { - let normalized = normalize_tool_name(token); - let canonical = name_map.get(&normalized).ok_or_else(|| { - format!( - "unsupported tool in --allowedTools: {token} (expected one of: {})", - canonical_names.join(", ") - ) - })?; - allowed.insert(canonical.clone()); - } - } - - Ok(Some(allowed)) + current_tool_registry()?.normalize_allowed_tools(values) } -fn normalize_tool_name(value: &str) -> String { - value.trim().replace('-', "_").to_ascii_lowercase() +fn current_tool_registry() -> Result { + let cwd = env::current_dir().map_err(|error| error.to_string())?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load().map_err(|error| error.to_string())?; + let state = build_runtime_plugin_state_with_loader(&cwd, &loader, &runtime_config) + .map_err(|error| error.to_string())?; + let registry = state.tool_registry.clone(); + if let Some(mcp_state) = state.mcp_state { + mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .shutdown() + .map_err(|error| error.to_string())?; + } + Ok(registry) } fn parse_permission_mode_arg(value: &str) -> Result { @@ -329,22 +1051,81 @@ fn permission_mode_from_label(mode: &str) -> PermissionMode { } } +fn permission_mode_from_resolved(mode: ResolvedPermissionMode) -> PermissionMode { + match mode { + ResolvedPermissionMode::ReadOnly => PermissionMode::ReadOnly, + ResolvedPermissionMode::WorkspaceWrite => PermissionMode::WorkspaceWrite, + ResolvedPermissionMode::DangerFullAccess => PermissionMode::DangerFullAccess, + } +} + fn default_permission_mode() -> PermissionMode { env::var("RUSTY_CLAUDE_PERMISSION_MODE") .ok() .as_deref() .and_then(normalize_permission_mode) - .map_or(PermissionMode::WorkspaceWrite, permission_mode_from_label) + .map(permission_mode_from_label) + .or_else(config_permission_mode_for_current_dir) + .unwrap_or(PermissionMode::DangerFullAccess) } -fn filter_tool_specs(allowed_tools: Option<&AllowedToolSet>) -> Vec { - mvp_tool_specs() - .into_iter() - .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) - .collect() +fn config_permission_mode_for_current_dir() -> Option { + let cwd = env::current_dir().ok()?; + let loader = ConfigLoader::default_for(&cwd); + loader + .load() + .ok()? + .permission_mode() + .map(permission_mode_from_resolved) } -fn parse_system_prompt_args(args: &[String]) -> Result { +fn config_model_for_current_dir() -> Option { + let cwd = env::current_dir().ok()?; + let loader = ConfigLoader::default_for(&cwd); + loader.load().ok()?.model().map(ToOwned::to_owned) +} + +fn resolve_repl_model(cli_model: String) -> String { + if cli_model != DEFAULT_MODEL { + return cli_model; + } + if let Some(env_model) = env::var("ANTHROPIC_MODEL") + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + { + return resolve_model_alias_with_config(&env_model); + } + if let Some(config_model) = config_model_for_current_dir() { + return resolve_model_alias_with_config(&config_model); + } + cli_model +} + +fn provider_label(kind: ProviderKind) -> &'static str { + match kind { + ProviderKind::Anthropic => "anthropic", + ProviderKind::Xai => "xai", + ProviderKind::OpenAi => "openai", + } +} + +fn format_connected_line(model: &str) -> String { + let provider = provider_label(detect_provider_kind(model)); + format!("Connected: {model} via {provider}") +} + +fn filter_tool_specs( + tool_registry: &GlobalToolRegistry, + allowed_tools: Option<&AllowedToolSet>, +) -> Vec { + tool_registry.definitions(allowed_tools) +} + +fn parse_system_prompt_args( + args: &[String], + output_format: CliOutputFormat, +) -> Result { let mut cwd = env::current_dir().map_err(|error| error.to_string())?; let mut date = DEFAULT_DATE.to_string(); let mut index = 0; @@ -369,58 +1150,869 @@ fn parse_system_prompt_args(args: &[String]) -> Result { } } - Ok(CliAction::PrintSystemPrompt { cwd, date }) -} - -fn parse_resume_args(args: &[String]) -> Result { - let session_path = args - .first() - .ok_or_else(|| "missing session path for --resume".to_string()) - .map(PathBuf::from)?; - let commands = args[1..].to_vec(); - if commands - .iter() - .any(|command| !command.trim_start().starts_with('/')) - { - return Err("--resume trailing arguments must be slash commands".to_string()); - } - Ok(CliAction::ResumeSession { - session_path, - commands, + Ok(CliAction::PrintSystemPrompt { + cwd, + date, + output_format, }) } -fn dump_manifests() { +fn parse_export_args(args: &[String], output_format: CliOutputFormat) -> Result { + let mut session_reference = LATEST_SESSION_REFERENCE.to_string(); + let mut output_path: Option = None; + let mut index = 0; + + while index < args.len() { + match args[index].as_str() { + "--session" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --session".to_string())?; + session_reference = value.clone(); + index += 2; + } + flag if flag.starts_with("--session=") => { + session_reference = flag[10..].to_string(); + index += 1; + } + "--output" | "-o" => { + let value = args + .get(index + 1) + .ok_or_else(|| format!("missing value for {}", args[index]))?; + output_path = Some(PathBuf::from(value)); + index += 2; + } + flag if flag.starts_with("--output=") => { + output_path = Some(PathBuf::from(&flag[9..])); + index += 1; + } + other if other.starts_with('-') => { + return Err(format!("unknown export option: {other}")); + } + other if output_path.is_none() => { + output_path = Some(PathBuf::from(other)); + index += 1; + } + other => { + return Err(format!("unexpected export argument: {other}")); + } + } + } + + Ok(CliAction::Export { + session_reference, + output_path, + output_format, + }) +} + +fn parse_resume_args(args: &[String], output_format: CliOutputFormat) -> Result { + let (session_path, command_tokens): (PathBuf, &[String]) = match args.first() { + None => (PathBuf::from(LATEST_SESSION_REFERENCE), &[]), + Some(first) if looks_like_slash_command_token(first) => { + (PathBuf::from(LATEST_SESSION_REFERENCE), args) + } + Some(first) => (PathBuf::from(first), &args[1..]), + }; + let mut commands = Vec::new(); + let mut current_command = String::new(); + + for token in command_tokens { + if token.trim_start().starts_with('/') { + if resume_command_can_absorb_token(¤t_command, token) { + current_command.push(' '); + current_command.push_str(token); + continue; + } + if !current_command.is_empty() { + commands.push(current_command); + } + current_command = String::from(token.as_str()); + continue; + } + + if current_command.is_empty() { + return Err("--resume trailing arguments must be slash commands".to_string()); + } + + current_command.push(' '); + current_command.push_str(token); + } + + if !current_command.is_empty() { + commands.push(current_command); + } + + Ok(CliAction::ResumeSession { + session_path, + commands, + output_format, + }) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum DiagnosticLevel { + Ok, + Warn, + Fail, +} + +impl DiagnosticLevel { + fn label(self) -> &'static str { + match self { + Self::Ok => "ok", + Self::Warn => "warn", + Self::Fail => "fail", + } + } + + fn is_failure(self) -> bool { + matches!(self, Self::Fail) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct DiagnosticCheck { + name: &'static str, + level: DiagnosticLevel, + summary: String, + details: Vec, + data: Map, +} + +impl DiagnosticCheck { + fn new(name: &'static str, level: DiagnosticLevel, summary: impl Into) -> Self { + Self { + name, + level, + summary: summary.into(), + details: Vec::new(), + data: Map::new(), + } + } + + fn with_details(mut self, details: Vec) -> Self { + self.details = details; + self + } + + fn with_data(mut self, data: Map) -> Self { + self.data = data; + self + } + + fn json_value(&self) -> Value { + let mut value = Map::from_iter([ + ( + "name".to_string(), + Value::String(self.name.to_ascii_lowercase()), + ), + ( + "status".to_string(), + Value::String(self.level.label().to_string()), + ), + ("summary".to_string(), Value::String(self.summary.clone())), + ( + "details".to_string(), + Value::Array( + self.details + .iter() + .cloned() + .map(Value::String) + .collect::>(), + ), + ), + ]); + value.extend(self.data.clone()); + Value::Object(value) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct DoctorReport { + checks: Vec, +} + +impl DoctorReport { + fn counts(&self) -> (usize, usize, usize) { + ( + self.checks + .iter() + .filter(|check| check.level == DiagnosticLevel::Ok) + .count(), + self.checks + .iter() + .filter(|check| check.level == DiagnosticLevel::Warn) + .count(), + self.checks + .iter() + .filter(|check| check.level == DiagnosticLevel::Fail) + .count(), + ) + } + + fn has_failures(&self) -> bool { + self.checks.iter().any(|check| check.level.is_failure()) + } + + fn render(&self) -> String { + let (ok_count, warn_count, fail_count) = self.counts(); + let mut lines = vec![ + "Doctor".to_string(), + format!( + "Summary\n OK {ok_count}\n Warnings {warn_count}\n Failures {fail_count}" + ), + ]; + lines.extend(self.checks.iter().map(render_diagnostic_check)); + lines.join("\n\n") + } + + fn json_value(&self) -> Value { + let report = self.render(); + let (ok_count, warn_count, fail_count) = self.counts(); + json!({ + "kind": "doctor", + "message": report, + "report": report, + "has_failures": self.has_failures(), + "summary": { + "total": self.checks.len(), + "ok": ok_count, + "warnings": warn_count, + "failures": fail_count, + }, + "checks": self + .checks + .iter() + .map(DiagnosticCheck::json_value) + .collect::>(), + }) + } +} + +fn render_diagnostic_check(check: &DiagnosticCheck) -> String { + let mut lines = vec![format!( + "{}\n Status {}\n Summary {}", + check.name, + check.level.label(), + check.summary + )]; + if !check.details.is_empty() { + lines.push(" Details".to_string()); + lines.extend(check.details.iter().map(|detail| format!(" - {detail}"))); + } + lines.join("\n") +} + +fn render_doctor_report() -> Result> { + let cwd = env::current_dir()?; + let config_loader = ConfigLoader::default_for(&cwd); + let config = config_loader.load(); + let discovered_config = config_loader.discover(); + let project_context = ProjectContext::discover_with_git(&cwd, DEFAULT_DATE)?; + let (project_root, git_branch) = + parse_git_status_metadata(project_context.git_status.as_deref()); + let git_summary = parse_git_workspace_summary(project_context.git_status.as_deref()); + let empty_config = runtime::RuntimeConfig::empty(); + let sandbox_config = config.as_ref().ok().unwrap_or(&empty_config); + let context = StatusContext { + cwd: cwd.clone(), + session_path: None, + loaded_config_files: config + .as_ref() + .ok() + .map_or(0, |runtime_config| runtime_config.loaded_entries().len()), + discovered_config_files: discovered_config.len(), + memory_file_count: project_context.instruction_files.len(), + project_root, + git_branch, + git_summary, + sandbox_status: resolve_sandbox_status(sandbox_config.sandbox(), &cwd), + }; + Ok(DoctorReport { + checks: vec![ + check_auth_health(), + check_config_health(&config_loader, config.as_ref()), + check_workspace_health(&context), + check_sandbox_health(&context.sandbox_status), + check_system_health(&cwd, config.as_ref().ok()), + ], + }) +} + +fn run_doctor(output_format: CliOutputFormat) -> Result<(), Box> { + let report = render_doctor_report()?; + let message = report.render(); + match output_format { + CliOutputFormat::Text => println!("{message}"), + CliOutputFormat::Json => { + println!("{}", serde_json::to_string_pretty(&report.json_value())?); + } + } + if report.has_failures() { + return Err("doctor found failing checks".into()); + } + Ok(()) +} + +/// Starts a minimal Model Context Protocol server that exposes claw's +/// built-in tools over stdio. +/// +/// Tool descriptors come from [`tools::mvp_tool_specs`] and calls are +/// dispatched through [`tools::execute_tool`], so this server exposes exactly +/// Read `.claw/worker-state.json` from the current working directory and print it. +/// This is the file-based worker observability surface: `push_event()` in `worker_boot.rs` +/// atomically writes state transitions here so external observers (clawhip, orchestrators) +/// can poll current `WorkerStatus` without needing an HTTP route on the opencode binary. +fn run_worker_state(output_format: CliOutputFormat) -> Result<(), Box> { + let cwd = env::current_dir()?; + let state_path = cwd.join(".claw").join("worker-state.json"); + if !state_path.exists() { + // Emit a structured error, then return Err so the process exits 1. + // Callers (scripts, CI) need a non-zero exit to detect "no state" without + // parsing prose output. + // Let the error propagate to main() which will format it correctly + // (prose for text mode, JSON envelope for --output-format json). + return Err(format!( + "no worker state file found at {} — run a worker first", + state_path.display() + ) + .into()); + } + let raw = std::fs::read_to_string(&state_path)?; + match output_format { + CliOutputFormat::Text => println!("{raw}"), + CliOutputFormat::Json => { + // Validate it parses as JSON before re-emitting + let _: serde_json::Value = serde_json::from_str(&raw)?; + println!("{raw}"); + } + } + Ok(()) +} + +/// the same surface the in-process agent loop uses. +fn run_mcp_serve() -> Result<(), Box> { + let tools = mvp_tool_specs() + .into_iter() + .map(|spec| McpTool { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: Some(spec.input_schema), + annotations: None, + meta: None, + }) + .collect(); + + let spec = McpServerSpec { + server_name: "claw".to_string(), + server_version: VERSION.to_string(), + tools, + tool_handler: Box::new(execute_tool), + }; + + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + runtime.block_on(async move { + let mut server = McpServer::new(spec); + server.run().await + })?; + Ok(()) +} + +#[allow(clippy::too_many_lines)] +fn check_auth_health() -> DiagnosticCheck { + let api_key_present = env::var("ANTHROPIC_API_KEY") + .ok() + .is_some_and(|value| !value.trim().is_empty()); + let auth_token_present = env::var("ANTHROPIC_AUTH_TOKEN") + .ok() + .is_some_and(|value| !value.trim().is_empty()); + + match load_oauth_credentials() { + Ok(Some(token_set)) => { + let expired = oauth_token_is_expired(&api::OAuthTokenSet { + access_token: token_set.access_token.clone(), + refresh_token: token_set.refresh_token.clone(), + expires_at: token_set.expires_at, + scopes: token_set.scopes.clone(), + }); + let mut details = vec![ + format!( + "Environment api_key={} auth_token={}", + if api_key_present { "present" } else { "absent" }, + if auth_token_present { + "present" + } else { + "absent" + } + ), + format!( + "Saved OAuth expires_at={} refresh_token={} scopes={}", + token_set + .expires_at + .map_or_else(|| "".to_string(), |value| value.to_string()), + if token_set.refresh_token.is_some() { + "present" + } else { + "absent" + }, + if token_set.scopes.is_empty() { + "".to_string() + } else { + token_set.scopes.join(",") + } + ), + ]; + if expired { + details.push( + "Suggested action claw login to refresh local OAuth credentials".to_string(), + ); + } + DiagnosticCheck::new( + "Auth", + if expired { + DiagnosticLevel::Warn + } else { + DiagnosticLevel::Ok + }, + if expired { + "saved OAuth credentials are present but expired" + } else if api_key_present || auth_token_present { + "environment and saved credentials are available" + } else { + "saved OAuth credentials are available" + }, + ) + .with_details(details) + .with_data(Map::from_iter([ + ("api_key_present".to_string(), json!(api_key_present)), + ("auth_token_present".to_string(), json!(auth_token_present)), + ("saved_oauth_present".to_string(), json!(true)), + ("saved_oauth_expired".to_string(), json!(expired)), + ( + "saved_oauth_expires_at".to_string(), + json!(token_set.expires_at), + ), + ( + "refresh_token_present".to_string(), + json!(token_set.refresh_token.is_some()), + ), + ("scopes".to_string(), json!(token_set.scopes)), + ])) + } + Ok(None) => DiagnosticCheck::new( + "Auth", + if api_key_present || auth_token_present { + DiagnosticLevel::Ok + } else { + DiagnosticLevel::Warn + }, + if api_key_present || auth_token_present { + "environment credentials are configured" + } else { + "no API key or saved OAuth credentials were found" + }, + ) + .with_details(vec![format!( + "Environment api_key={} auth_token={}", + if api_key_present { "present" } else { "absent" }, + if auth_token_present { + "present" + } else { + "absent" + } + )]) + .with_data(Map::from_iter([ + ("api_key_present".to_string(), json!(api_key_present)), + ("auth_token_present".to_string(), json!(auth_token_present)), + ("saved_oauth_present".to_string(), json!(false)), + ("saved_oauth_expired".to_string(), json!(false)), + ("saved_oauth_expires_at".to_string(), Value::Null), + ("refresh_token_present".to_string(), json!(false)), + ("scopes".to_string(), json!(Vec::::new())), + ])), + Err(error) => DiagnosticCheck::new( + "Auth", + DiagnosticLevel::Fail, + format!("failed to inspect saved credentials: {error}"), + ) + .with_data(Map::from_iter([ + ("api_key_present".to_string(), json!(api_key_present)), + ("auth_token_present".to_string(), json!(auth_token_present)), + ("saved_oauth_present".to_string(), Value::Null), + ("saved_oauth_expired".to_string(), Value::Null), + ("saved_oauth_expires_at".to_string(), Value::Null), + ("refresh_token_present".to_string(), Value::Null), + ("scopes".to_string(), Value::Null), + ("saved_oauth_error".to_string(), json!(error.to_string())), + ])), + } +} + +fn check_config_health( + config_loader: &ConfigLoader, + config: Result<&runtime::RuntimeConfig, &runtime::ConfigError>, +) -> DiagnosticCheck { + let discovered = config_loader.discover(); + let discovered_count = discovered.len(); + // Separate candidate paths that actually exist from those that don't. + // Showing non-existent paths as "Discovered file" implies they loaded + // but something went wrong, which is confusing. We only surface paths + // that exist on disk as discovered; non-existent ones are silently + // omitted from the display (they are just the standard search locations). + let present_paths: Vec = discovered + .iter() + .filter(|e| e.path.exists()) + .map(|e| e.path.display().to_string()) + .collect(); + let discovered_paths = discovered + .iter() + .map(|entry| entry.path.display().to_string()) + .collect::>(); + match config { + Ok(runtime_config) => { + let loaded_entries = runtime_config.loaded_entries(); + let loaded_count = loaded_entries.len(); + let present_count = present_paths.len(); + let mut details = vec![format!( + "Config files loaded {}/{}", + loaded_count, present_count + )]; + if let Some(model) = runtime_config.model() { + details.push(format!("Resolved model {model}")); + } + details.push(format!( + "MCP servers {}", + runtime_config.mcp().servers().len() + )); + if present_paths.is_empty() { + details.push("Discovered files (defaults active)".to_string()); + } else { + details.extend( + present_paths + .iter() + .map(|path| format!("Discovered file {path}")), + ); + } + DiagnosticCheck::new( + "Config", + DiagnosticLevel::Ok, + if present_count == 0 { + "no config files present; defaults are active" + } else { + "runtime config loaded successfully" + }, + ) + .with_details(details) + .with_data(Map::from_iter([ + ("discovered_files".to_string(), json!(present_paths)), + ("discovered_files_count".to_string(), json!(present_count)), + ("loaded_config_files".to_string(), json!(loaded_count)), + ("resolved_model".to_string(), json!(runtime_config.model())), + ( + "mcp_servers".to_string(), + json!(runtime_config.mcp().servers().len()), + ), + ])) + } + Err(error) => DiagnosticCheck::new( + "Config", + DiagnosticLevel::Fail, + format!("runtime config failed to load: {error}"), + ) + .with_details(if discovered_paths.is_empty() { + vec!["Discovered files ".to_string()] + } else { + discovered_paths + .iter() + .map(|path| format!("Discovered file {path}")) + .collect() + }) + .with_data(Map::from_iter([ + ("discovered_files".to_string(), json!(discovered_paths)), + ( + "discovered_files_count".to_string(), + json!(discovered_count), + ), + ("loaded_config_files".to_string(), json!(0)), + ("resolved_model".to_string(), Value::Null), + ("mcp_servers".to_string(), Value::Null), + ("load_error".to_string(), json!(error.to_string())), + ])), + } +} + +fn check_workspace_health(context: &StatusContext) -> DiagnosticCheck { + let in_repo = context.project_root.is_some(); + DiagnosticCheck::new( + "Workspace", + if in_repo { + DiagnosticLevel::Ok + } else { + DiagnosticLevel::Warn + }, + if in_repo { + format!( + "project root detected on branch {}", + context.git_branch.as_deref().unwrap_or("unknown") + ) + } else { + "current directory is not inside a git project".to_string() + }, + ) + .with_details(vec![ + format!("Cwd {}", context.cwd.display()), + format!( + "Project root {}", + context + .project_root + .as_ref() + .map_or_else(|| "".to_string(), |path| path.display().to_string()) + ), + format!( + "Git branch {}", + context.git_branch.as_deref().unwrap_or("unknown") + ), + format!("Git state {}", context.git_summary.headline()), + format!("Changed files {}", context.git_summary.changed_files), + format!( + "Memory files {} · config files loaded {}/{}", + context.memory_file_count, context.loaded_config_files, context.discovered_config_files + ), + ]) + .with_data(Map::from_iter([ + ("cwd".to_string(), json!(context.cwd.display().to_string())), + ( + "project_root".to_string(), + json!(context + .project_root + .as_ref() + .map(|path| path.display().to_string())), + ), + ("in_git_repo".to_string(), json!(in_repo)), + ("git_branch".to_string(), json!(context.git_branch)), + ( + "git_state".to_string(), + json!(context.git_summary.headline()), + ), + ( + "changed_files".to_string(), + json!(context.git_summary.changed_files), + ), + ( + "memory_file_count".to_string(), + json!(context.memory_file_count), + ), + ( + "loaded_config_files".to_string(), + json!(context.loaded_config_files), + ), + ( + "discovered_config_files".to_string(), + json!(context.discovered_config_files), + ), + ])) +} + +fn check_sandbox_health(status: &runtime::SandboxStatus) -> DiagnosticCheck { + let degraded = status.enabled && !status.active; + let mut details = vec![ + format!("Enabled {}", status.enabled), + format!("Active {}", status.active), + format!("Supported {}", status.supported), + format!("Filesystem mode {}", status.filesystem_mode.as_str()), + format!("Filesystem live {}", status.filesystem_active), + ]; + if let Some(reason) = &status.fallback_reason { + details.push(format!("Fallback reason {reason}")); + } + DiagnosticCheck::new( + "Sandbox", + if degraded { + DiagnosticLevel::Warn + } else { + DiagnosticLevel::Ok + }, + if degraded { + "sandbox was requested but is not currently active" + } else if status.active { + "sandbox protections are active" + } else { + "sandbox is not active for this session" + }, + ) + .with_details(details) + .with_data(Map::from_iter([ + ("enabled".to_string(), json!(status.enabled)), + ("active".to_string(), json!(status.active)), + ("supported".to_string(), json!(status.supported)), + ( + "namespace_supported".to_string(), + json!(status.namespace_supported), + ), + ( + "namespace_active".to_string(), + json!(status.namespace_active), + ), + ( + "network_supported".to_string(), + json!(status.network_supported), + ), + ("network_active".to_string(), json!(status.network_active)), + ( + "filesystem_mode".to_string(), + json!(status.filesystem_mode.as_str()), + ), + ( + "filesystem_active".to_string(), + json!(status.filesystem_active), + ), + ("allowed_mounts".to_string(), json!(status.allowed_mounts)), + ("in_container".to_string(), json!(status.in_container)), + ( + "container_markers".to_string(), + json!(status.container_markers), + ), + ("fallback_reason".to_string(), json!(status.fallback_reason)), + ])) +} + +fn check_system_health(cwd: &Path, config: Option<&runtime::RuntimeConfig>) -> DiagnosticCheck { + let default_model = config.and_then(runtime::RuntimeConfig::model); + let mut details = vec![ + format!("OS {} {}", env::consts::OS, env::consts::ARCH), + format!("Working dir {}", cwd.display()), + format!("Version {}", VERSION), + format!("Build target {}", BUILD_TARGET.unwrap_or("")), + format!("Git SHA {}", GIT_SHA.unwrap_or("")), + ]; + if let Some(model) = default_model { + details.push(format!("Default model {model}")); + } + DiagnosticCheck::new( + "System", + DiagnosticLevel::Ok, + "captured local runtime metadata", + ) + .with_details(details) + .with_data(Map::from_iter([ + ("os".to_string(), json!(env::consts::OS)), + ("arch".to_string(), json!(env::consts::ARCH)), + ("working_dir".to_string(), json!(cwd.display().to_string())), + ("version".to_string(), json!(VERSION)), + ("build_target".to_string(), json!(BUILD_TARGET)), + ("git_sha".to_string(), json!(GIT_SHA)), + ("default_model".to_string(), json!(default_model)), + ])) +} + +fn resume_command_can_absorb_token(current_command: &str, token: &str) -> bool { + matches!( + SlashCommand::parse(current_command), + Ok(Some(SlashCommand::Export { path: None })) + ) && !looks_like_slash_command_token(token) +} + +fn looks_like_slash_command_token(token: &str) -> bool { + let trimmed = token.trim_start(); + let Some(name) = trimmed.strip_prefix('/').and_then(|value| { + value + .split_whitespace() + .next() + .map(str::trim) + .filter(|value| !value.is_empty()) + }) else { + return false; + }; + + slash_command_specs() + .iter() + .any(|spec| spec.name == name || spec.aliases.contains(&name)) +} + +fn dump_manifests(output_format: CliOutputFormat) -> Result<(), Box> { let workspace_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../.."); + // Surface the resolved path in the error so users can diagnose missing + // manifest files without guessing what path the binary expected. + // ROADMAP #45: this path is only correct when running from the build tree; + // a proper fix would ship manifests alongside the binary. + let resolved = workspace_dir + .canonicalize() + .unwrap_or_else(|_| workspace_dir.clone()); let paths = UpstreamPaths::from_workspace_dir(&workspace_dir); match extract_manifest(&paths) { Ok(manifest) => { - println!("commands: {}", manifest.commands.entries().len()); - println!("tools: {}", manifest.tools.entries().len()); - println!("bootstrap phases: {}", manifest.bootstrap.phases().len()); - } - Err(error) => { - eprintln!("failed to extract manifests: {error}"); - std::process::exit(1); + match output_format { + CliOutputFormat::Text => { + println!("commands: {}", manifest.commands.entries().len()); + println!("tools: {}", manifest.tools.entries().len()); + println!("bootstrap phases: {}", manifest.bootstrap.phases().len()); + } + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "dump-manifests", + "commands": manifest.commands.entries().len(), + "tools": manifest.tools.entries().len(), + "bootstrap_phases": manifest.bootstrap.phases().len(), + }))? + ), + } + Ok(()) } + Err(error) => Err(format!( + "failed to extract manifests: {error}\n looked in: {}", + resolved.display() + ) + .into()), } } -fn print_bootstrap_plan() { - for phase in runtime::BootstrapPlan::claw_default().phases() { - println!("- {phase:?}"); +fn print_bootstrap_plan(output_format: CliOutputFormat) -> Result<(), Box> { + let phases = runtime::BootstrapPlan::claude_code_default() + .phases() + .iter() + .map(|phase| format!("{phase:?}")) + .collect::>(); + match output_format { + CliOutputFormat::Text => { + for phase in &phases { + println!("- {phase}"); + } + } + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "bootstrap-plan", + "phases": phases, + }))? + ), + } + Ok(()) +} + +fn default_oauth_config() -> OAuthConfig { + OAuthConfig { + client_id: String::from("9d1c250a-e61b-44d9-88ed-5944d1962f5e"), + authorize_url: String::from("https://platform.claude.com/oauth/authorize"), + token_url: String::from("https://platform.claude.com/v1/oauth/token"), + callback_port: None, + manual_redirect_url: None, + scopes: vec![ + String::from("user:profile"), + String::from("user:inference"), + String::from("user:sessions:claude_code"), + ], } } -fn run_login() -> Result<(), Box> { +fn run_login(output_format: CliOutputFormat) -> Result<(), Box> { let cwd = env::current_dir()?; let config = ConfigLoader::default_for(&cwd).load()?; - let oauth = config.oauth().ok_or_else(|| { - io::Error::new( - io::ErrorKind::NotFound, - "OAuth config is missing. Add settings.oauth.clientId/authorizeUrl/tokenUrl first.", - ) - })?; + let default_oauth = default_oauth_config(); + let oauth = config.oauth().unwrap_or(&default_oauth); let callback_port = oauth.callback_port.unwrap_or(DEFAULT_OAUTH_CALLBACK_PORT); let redirect_uri = runtime::loopback_redirect_uri(callback_port); let pkce = generate_pkce_pair()?; @@ -429,11 +2021,18 @@ fn run_login() -> Result<(), Box> { OAuthAuthorizationRequest::from_config(oauth, redirect_uri.clone(), state.clone(), &pkce) .build_url(); - println!("Starting Claude OAuth login..."); - println!("Listening for callback on {redirect_uri}"); + if output_format == CliOutputFormat::Text { + println!("Starting Claude OAuth login..."); + println!("Listening for callback on {redirect_uri}"); + } if let Err(error) = open_browser(&authorize_url) { - eprintln!("warning: failed to open browser automatically: {error}"); - println!("Open this URL manually:\n{authorize_url}"); + emit_login_browser_open_failure( + output_format, + &authorize_url, + &error, + &mut io::stdout(), + &mut io::stderr(), + )?; } let callback = wait_for_oauth_callback(callback_port)?; @@ -453,9 +2052,14 @@ fn run_login() -> Result<(), Box> { return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into()); } - let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); - let exchange_request = - OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri); + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); + let exchange_request = OAuthTokenExchangeRequest::from_config( + oauth, + code, + state, + pkce.verifier, + redirect_uri.clone(), + ); let runtime = tokio::runtime::Runtime::new()?; let token_set = runtime.block_on(client.exchange_oauth_code(oauth, &exchange_request))?; save_oauth_credentials(&runtime::OAuthTokenSet { @@ -464,13 +2068,50 @@ fn run_login() -> Result<(), Box> { expires_at: token_set.expires_at, scopes: token_set.scopes, })?; - println!("Claude OAuth login complete."); + match output_format { + CliOutputFormat::Text => println!("Claude OAuth login complete."), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "login", + "callback_port": callback_port, + "redirect_uri": redirect_uri, + "message": "Claude OAuth login complete.", + }))? + ), + } Ok(()) } -fn run_logout() -> Result<(), Box> { +fn emit_login_browser_open_failure( + output_format: CliOutputFormat, + authorize_url: &str, + error: &io::Error, + stdout: &mut impl Write, + stderr: &mut impl Write, +) -> io::Result<()> { + writeln!( + stderr, + "warning: failed to open browser automatically: {error}" + )?; + match output_format { + CliOutputFormat::Text => writeln!(stdout, "Open this URL manually:\n{authorize_url}"), + CliOutputFormat::Json => writeln!(stderr, "Open this URL manually:\n{authorize_url}"), + } +} + +fn run_logout(output_format: CliOutputFormat) -> Result<(), Box> { clear_oauth_credentials()?; - println!("Claude OAuth credentials cleared."); + match output_format { + CliOutputFormat::Text => println!("Claude OAuth credentials cleared."), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "logout", + "message": "Claude OAuth credentials cleared.", + }))? + ), + } Ok(()) } @@ -528,56 +2169,209 @@ fn wait_for_oauth_callback( Ok(callback) } -fn print_system_prompt(cwd: PathBuf, date: String) { - match load_system_prompt(cwd, date, env::consts::OS, "unknown") { - Ok(sections) => println!("{}", sections.join("\n\n")), - Err(error) => { - eprintln!("failed to build system prompt: {error}"); - std::process::exit(1); +fn print_system_prompt( + cwd: PathBuf, + date: String, + output_format: CliOutputFormat, +) -> Result<(), Box> { + let sections = load_system_prompt(cwd, date, env::consts::OS, "unknown")?; + let message = sections.join( + " + +", + ); + match output_format { + CliOutputFormat::Text => println!("{message}"), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "system-prompt", + "message": message, + "sections": sections, + }))? + ), + } + Ok(()) +} + +fn print_version(output_format: CliOutputFormat) -> Result<(), Box> { + match output_format { + CliOutputFormat::Text => println!("{}", render_version_report()), + CliOutputFormat::Json => { + println!("{}", serde_json::to_string_pretty(&version_json_value())?); } } + Ok(()) } -fn print_version() { - println!("{}", render_version_report()); +fn version_json_value() -> serde_json::Value { + json!({ + "kind": "version", + "message": render_version_report(), + "version": VERSION, + "git_sha": GIT_SHA, + "target": BUILD_TARGET, + }) } -fn resume_session(session_path: &Path, commands: &[String]) { - let session = match Session::load_from_path(session_path) { +fn resume_session(session_path: &Path, commands: &[String], output_format: CliOutputFormat) { + let resolved_path = if session_path.exists() { + session_path.to_path_buf() + } else { + match resolve_session_reference(&session_path.display().to_string()) { + Ok(handle) => handle.path, + Err(error) => { + if output_format == CliOutputFormat::Json { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": format!("failed to restore session: {error}"), + }) + ); + } else { + eprintln!("failed to restore session: {error}"); + } + std::process::exit(1); + } + } + }; + + let session = match Session::load_from_path(&resolved_path) { Ok(session) => session, Err(error) => { - eprintln!("failed to restore session: {error}"); + if output_format == CliOutputFormat::Json { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": format!("failed to restore session: {error}"), + }) + ); + } else { + eprintln!("failed to restore session: {error}"); + } std::process::exit(1); } }; if commands.is_empty() { - println!( - "Restored session from {} ({} messages).", - session_path.display(), - session.messages.len() - ); + if output_format == CliOutputFormat::Json { + println!( + "{}", + serde_json::json!({ + "kind": "restored", + "session_id": session.session_id, + "path": resolved_path.display().to_string(), + "message_count": session.messages.len(), + }) + ); + } else { + println!( + "Restored session from {} ({} messages).", + resolved_path.display(), + session.messages.len() + ); + } return; } let mut session = session; for raw_command in commands { - let Some(command) = SlashCommand::parse(raw_command) else { - eprintln!("unsupported resumed command: {raw_command}"); - std::process::exit(2); + // Intercept spec commands that have no parse arm before calling + // SlashCommand::parse — they return Err(SlashCommandParseError) which + // formats as the confusing circular "Did you mean /X?" message. + // STUB_COMMANDS covers both completions-filtered stubs and parse-less + // spec entries; treat both as unsupported in resume mode. + { + let cmd_root = raw_command + .trim_start_matches('/') + .split_whitespace() + .next() + .unwrap_or(""); + if STUB_COMMANDS.contains(&cmd_root) { + if output_format == CliOutputFormat::Json { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": format!("/{cmd_root} is not yet implemented in this build"), + "command": raw_command, + }) + ); + } else { + eprintln!("/{cmd_root} is not yet implemented in this build"); + } + std::process::exit(2); + } + } + let command = match SlashCommand::parse(raw_command) { + Ok(Some(command)) => command, + Ok(None) => { + if output_format == CliOutputFormat::Json { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": format!("unsupported resumed command: {raw_command}"), + "command": raw_command, + }) + ); + } else { + eprintln!("unsupported resumed command: {raw_command}"); + } + std::process::exit(2); + } + Err(error) => { + if output_format == CliOutputFormat::Json { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": error.to_string(), + "command": raw_command, + }) + ); + } else { + eprintln!("{error}"); + } + std::process::exit(2); + } }; - match run_resume_command(session_path, &session, &command) { + match run_resume_command(&resolved_path, &session, &command) { Ok(ResumeCommandOutcome { session: next_session, message, + json, }) => { session = next_session; - if let Some(message) = message { + if output_format == CliOutputFormat::Json { + if let Some(value) = json { + println!( + "{}", + serde_json::to_string_pretty(&value) + .expect("resume command json output") + ); + } else if let Some(message) = message { + println!("{message}"); + } + } else if let Some(message) = message { println!("{message}"); } } Err(error) => { - eprintln!("{error}"); + if output_format == CliOutputFormat::Json { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": error.to_string(), + "command": raw_command, + }) + ); + } else { + eprintln!("{error}"); + } std::process::exit(2); } } @@ -588,6 +2382,7 @@ fn resume_session(session_path: &Path, commands: &[String]) { struct ResumeCommandOutcome { session: Session, message: Option, + json: Option, } #[derive(Debug, Clone)] @@ -599,6 +2394,8 @@ struct StatusContext { memory_file_count: usize, project_root: Option, git_branch: Option, + git_summary: GitWorkspaceSummary, + sandbox_status: runtime::SandboxStatus, } #[derive(Debug, Clone, Copy)] @@ -610,6 +2407,64 @@ struct StatusUsage { estimated_tokens: usize, } +#[allow(clippy::struct_field_names)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +struct GitWorkspaceSummary { + changed_files: usize, + staged_files: usize, + unstaged_files: usize, + untracked_files: usize, + conflicted_files: usize, +} + +impl GitWorkspaceSummary { + fn is_clean(self) -> bool { + self.changed_files == 0 + } + + fn headline(self) -> String { + if self.is_clean() { + "clean".to_string() + } else { + let mut details = Vec::new(); + if self.staged_files > 0 { + details.push(format!("{} staged", self.staged_files)); + } + if self.unstaged_files > 0 { + details.push(format!("{} unstaged", self.unstaged_files)); + } + if self.untracked_files > 0 { + details.push(format!("{} untracked", self.untracked_files)); + } + if self.conflicted_files > 0 { + details.push(format!("{} conflicted", self.conflicted_files)); + } + format!( + "dirty · {} files · {}", + self.changed_files, + details.join(", ") + ) + } + } +} + +#[cfg(test)] +fn format_unknown_slash_command_message(name: &str) -> String { + let suggestions = suggest_slash_commands(name); + let mut message = format!("unknown slash command: /{name}."); + if !suggestions.is_empty() { + message.push_str(" Did you mean "); + message.push_str(&suggestions.join(", ")); + message.push('?'); + } + if let Some(note) = omc_compatibility_note_for_unknown_slash_command(name) { + message.push(' '); + message.push_str(note); + } + message.push_str(" Use /help to list available commands."); + message +} + fn format_model_report(model: &str, message_count: usize, turns: u32) -> String { format!( "Model @@ -711,6 +2566,15 @@ fn format_resume_report(session_path: &str, message_count: usize, turns: u32) -> ) } +fn render_resume_usage() -> String { + format!( + "Resume + Usage /resume + Auto-save .claw/sessions/.{PRIMARY_SESSION_EXTENSION} + Tip use /session list to inspect saved sessions" + ) +} + fn format_compact_report(removed: usize, resulting_messages: usize, skipped: bool) -> String { if skipped { format!( @@ -729,28 +2593,104 @@ fn format_compact_report(removed: usize, resulting_messages: usize, skipped: boo } } -fn parse_git_status_metadata(status: Option<&str>) -> (Option, Option) { - let Some(status) = status else { - return (None, None); - }; - let branch = status.lines().next().and_then(|line| { - line.strip_prefix("## ") - .map(|line| { - line.split(['.', ' ']) - .next() - .unwrap_or_default() - .to_string() - }) - .filter(|value| !value.is_empty()) - }); - let project_root = find_git_root().ok(); - (project_root, branch) +fn format_auto_compaction_notice(removed: usize) -> String { + format!("[auto-compacted: removed {removed} messages]") } -fn find_git_root() -> Result> { +fn parse_git_status_metadata(status: Option<&str>) -> (Option, Option) { + parse_git_status_metadata_for( + &env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + status, + ) +} + +fn parse_git_status_branch(status: Option<&str>) -> Option { + let status = status?; + let first_line = status.lines().next()?; + let line = first_line.strip_prefix("## ")?; + if line.starts_with("HEAD") { + return Some("detached HEAD".to_string()); + } + let branch = line.split(['.', ' ']).next().unwrap_or_default().trim(); + if branch.is_empty() { + None + } else { + Some(branch.to_string()) + } +} + +fn parse_git_workspace_summary(status: Option<&str>) -> GitWorkspaceSummary { + let mut summary = GitWorkspaceSummary::default(); + let Some(status) = status else { + return summary; + }; + + for line in status.lines() { + if line.starts_with("## ") || line.trim().is_empty() { + continue; + } + + summary.changed_files += 1; + let mut chars = line.chars(); + let index_status = chars.next().unwrap_or(' '); + let worktree_status = chars.next().unwrap_or(' '); + + if index_status == '?' && worktree_status == '?' { + summary.untracked_files += 1; + continue; + } + + if index_status != ' ' { + summary.staged_files += 1; + } + if worktree_status != ' ' { + summary.unstaged_files += 1; + } + if (matches!(index_status, 'U' | 'A') && matches!(worktree_status, 'U' | 'A')) + || index_status == 'U' + || worktree_status == 'U' + { + summary.conflicted_files += 1; + } + } + + summary +} + +fn resolve_git_branch_for(cwd: &Path) -> Option { + let branch = run_git_capture_in(cwd, &["branch", "--show-current"])?; + let branch = branch.trim(); + if !branch.is_empty() { + return Some(branch.to_string()); + } + + let fallback = run_git_capture_in(cwd, &["rev-parse", "--abbrev-ref", "HEAD"])?; + let fallback = fallback.trim(); + if fallback.is_empty() { + None + } else if fallback == "HEAD" { + Some("detached HEAD".to_string()) + } else { + Some(fallback.to_string()) + } +} + +fn run_git_capture_in(cwd: &Path, args: &[&str]) -> Option { + let output = std::process::Command::new("git") + .args(args) + .current_dir(cwd) + .output() + .ok()?; + if !output.status.success() { + return None; + } + String::from_utf8(output.stdout).ok() +} + +fn find_git_root_in(cwd: &Path) -> Result> { let output = std::process::Command::new("git") .args(["rev-parse", "--show-toplevel"]) - .current_dir(env::current_dir()?) + .current_dir(cwd) .output()?; if !output.status.success() { return Err("not a git repository".into()); @@ -762,6 +2702,15 @@ fn find_git_root() -> Result> { Ok(PathBuf::from(path)) } +fn parse_git_status_metadata_for( + cwd: &Path, + status: Option<&str>, +) -> (Option, Option) { + let branch = resolve_git_branch_for(cwd).or_else(|| parse_git_status_branch(status)); + let project_root = find_git_root_in(cwd).ok(); + (project_root, branch) +} + #[allow(clippy::too_many_lines)] fn run_resume_command( session_path: &Path, @@ -772,6 +2721,7 @@ fn run_resume_command( SlashCommand::Help => Ok(ResumeCommandOutcome { session: session.clone(), message: Some(render_repl_help()), + json: Some(serde_json::json!({ "kind": "help", "text": render_repl_help() })), }), SlashCommand::Compact => { let result = runtime::compact_session( @@ -788,6 +2738,12 @@ fn run_resume_command( Ok(ResumeCommandOutcome { session: result.compacted_session, message: Some(format_compact_report(removed, kept, skipped)), + json: Some(serde_json::json!({ + "kind": "compact", + "skipped": skipped, + "removed_messages": removed, + "kept_messages": kept, + })), }) } SlashCommand::Clear { confirm } => { @@ -797,25 +2753,43 @@ fn run_resume_command( message: Some( "clear: confirmation required; rerun with /clear --confirm".to_string(), ), + json: Some(serde_json::json!({ + "kind": "error", + "error": "confirmation required", + "hint": "rerun with /clear --confirm", + })), }); } + let backup_path = write_session_clear_backup(session, session_path)?; + let previous_session_id = session.session_id.clone(); let cleared = Session::new(); + let new_session_id = cleared.session_id.clone(); cleared.save_to_path(session_path)?; Ok(ResumeCommandOutcome { session: cleared, message: Some(format!( - "Cleared resumed session file {}.", + "Session cleared\n Mode resumed session reset\n Previous session {previous_session_id}\n Backup {}\n Resume previous claw --resume {}\n New session {new_session_id}\n Session file {}", + backup_path.display(), + backup_path.display(), session_path.display() )), + json: Some(serde_json::json!({ + "kind": "clear", + "previous_session_id": previous_session_id, + "new_session_id": new_session_id, + "backup": backup_path.display().to_string(), + "session_file": session_path.display().to_string(), + })), }) } SlashCommand::Status => { let tracker = UsageTracker::from_session(session); let usage = tracker.cumulative_usage(); + let context = status_context(Some(session_path))?; Ok(ResumeCommandOutcome { session: session.clone(), message: Some(format_status_report( - "restored-session", + session.model.as_deref().unwrap_or("restored-session"), StatusUsage { message_count: session.messages.len(), turns: tracker.turns(), @@ -824,8 +2798,31 @@ fn run_resume_command( estimated_tokens: 0, }, default_permission_mode().as_str(), - &status_context(Some(session_path))?, + &context, )), + json: Some(status_json_value( + session.model.as_deref(), + StatusUsage { + message_count: session.messages.len(), + turns: tracker.turns(), + latest: tracker.current_turn_usage(), + cumulative: usage, + estimated_tokens: 0, + }, + default_permission_mode().as_str(), + &context, + )), + }) + } + SlashCommand::Sandbox => { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + let status = resolve_sandbox_status(runtime_config.sandbox(), &cwd); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format_sandbox_report(&status)), + json: Some(sandbox_json_value(&status)), }) } SlashCommand::Cost => { @@ -833,70 +2830,313 @@ fn run_resume_command( Ok(ResumeCommandOutcome { session: session.clone(), message: Some(format_cost_report(usage)), + json: Some(serde_json::json!({ + "kind": "cost", + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "cache_creation_input_tokens": usage.cache_creation_input_tokens, + "cache_read_input_tokens": usage.cache_read_input_tokens, + "total_tokens": usage.total_tokens(), + })), + }) + } + SlashCommand::Config { section } => { + let message = render_config_report(section.as_deref())?; + let json = render_config_json(section.as_deref())?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(message), + json: Some(json), + }) + } + SlashCommand::Mcp { action, target } => { + let cwd = env::current_dir()?; + let args = match (action.as_deref(), target.as_deref()) { + (None, None) => None, + (Some(action), None) => Some(action.to_string()), + (Some(action), Some(target)) => Some(format!("{action} {target}")), + (None, Some(target)) => Some(target.to_string()), + }; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(handle_mcp_slash_command(args.as_deref(), &cwd)?), + json: Some(handle_mcp_slash_command_json(args.as_deref(), &cwd)?), }) } - SlashCommand::Config { section } => Ok(ResumeCommandOutcome { - session: session.clone(), - message: Some(render_config_report(section.as_deref())?), - }), SlashCommand::Memory => Ok(ResumeCommandOutcome { session: session.clone(), message: Some(render_memory_report()?), + json: Some(render_memory_json()?), }), - SlashCommand::Init => Ok(ResumeCommandOutcome { - session: session.clone(), - message: Some(init_claude_md()?), - }), - SlashCommand::Diff => Ok(ResumeCommandOutcome { - session: session.clone(), - message: Some(render_diff_report()?), - }), + SlashCommand::Init => { + let message = init_claude_md()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(message.clone()), + json: Some(init_json_value(&message)), + }) + } + SlashCommand::Diff => { + let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")); + let message = render_diff_report_for(&cwd)?; + let json = render_diff_json_for(&cwd)?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(message), + json: Some(json), + }) + } SlashCommand::Version => Ok(ResumeCommandOutcome { session: session.clone(), message: Some(render_version_report()), + json: Some(version_json_value()), }), SlashCommand::Export { path } => { let export_path = resolve_export_path(path.as_deref(), session)?; fs::write(&export_path, render_export_text(session))?; + let msg_count = session.messages.len(); Ok(ResumeCommandOutcome { session: session.clone(), message: Some(format!( "Export\n Result wrote transcript\n File {}\n Messages {}", export_path.display(), - session.messages.len(), + msg_count, )), + json: Some(serde_json::json!({ + "kind": "export", + "file": export_path.display().to_string(), + "message_count": msg_count, + })), }) } - SlashCommand::Agents { args } => Ok(ResumeCommandOutcome { - session: session.clone(), - message: Some( - handle_agents_slash_command(args.as_deref(), &env::current_dir()?) - .map_err(|error| error.to_string())?, - ), - }), - SlashCommand::Skills { args } => Ok(ResumeCommandOutcome { - session: session.clone(), - message: Some( - handle_skills_slash_command(args.as_deref(), &env::current_dir()?) - .map_err(|error| error.to_string())?, - ), - }), - SlashCommand::Branch { .. } - | SlashCommand::Bughunter { .. } - | SlashCommand::Worktree { .. } - | SlashCommand::Commit - | SlashCommand::CommitPushPr { .. } + SlashCommand::Agents { args } => { + let cwd = env::current_dir()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(handle_agents_slash_command(args.as_deref(), &cwd)?), + json: Some(serde_json::json!({ + "kind": "agents", + "text": handle_agents_slash_command(args.as_deref(), &cwd)?, + })), + }) + } + SlashCommand::Skills { args } => { + if let SkillSlashDispatch::Invoke(_) = classify_skills_slash_command(args.as_deref()) { + return Err( + "resumed /skills invocations are interactive-only; start `claw` and run `/skills ` in the REPL".into(), + ); + } + let cwd = env::current_dir()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(handle_skills_slash_command(args.as_deref(), &cwd)?), + json: Some(handle_skills_slash_command_json(args.as_deref(), &cwd)?), + }) + } + SlashCommand::Doctor => { + let report = render_doctor_report()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(report.render()), + json: Some(report.json_value()), + }) + } + SlashCommand::Stats => { + let usage = UsageTracker::from_session(session).cumulative_usage(); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format_cost_report(usage)), + json: Some(serde_json::json!({ + "kind": "stats", + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "cache_creation_input_tokens": usage.cache_creation_input_tokens, + "cache_read_input_tokens": usage.cache_read_input_tokens, + "total_tokens": usage.total_tokens(), + })), + }) + } + SlashCommand::History { count } => { + let limit = parse_history_count(count.as_deref()) + .map_err(|error| -> Box { error.into() })?; + let entries = collect_session_prompt_history(session); + let shown: Vec<_> = entries.iter().rev().take(limit).rev().collect(); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_prompt_history_report(&entries, limit)), + json: Some(serde_json::json!({ + "kind": "history", + "total": entries.len(), + "showing": shown.len(), + "entries": shown.iter().map(|e| serde_json::json!({ + "timestamp_ms": e.timestamp_ms, + "text": e.text, + })).collect::>(), + })), + }) + } + SlashCommand::Unknown(name) => Err(format_unknown_slash_command(name).into()), + // /session list can be served from the sessions directory without a live session. + SlashCommand::Session { + action: Some(ref act), + .. + } if act == "list" => { + let sessions = list_managed_sessions().unwrap_or_default(); + let session_ids: Vec = sessions.iter().map(|s| s.id.clone()).collect(); + let active_id = session.session_id.clone(); + let text = render_session_list(&active_id).unwrap_or_else(|e| format!("error: {e}")); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(text), + json: Some(serde_json::json!({ + "kind": "session_list", + "sessions": session_ids, + "active": active_id, + })), + }) + } + SlashCommand::Bughunter { .. } + | SlashCommand::Commit { .. } | SlashCommand::Pr { .. } | SlashCommand::Issue { .. } | SlashCommand::Ultraplan { .. } | SlashCommand::Teleport { .. } - | SlashCommand::DebugToolCall + | SlashCommand::DebugToolCall { .. } | SlashCommand::Resume { .. } | SlashCommand::Model { .. } | SlashCommand::Permissions { .. } | SlashCommand::Session { .. } | SlashCommand::Plugins { .. } - | SlashCommand::Unknown(_) => Err("unsupported resumed slash command".into()), + | SlashCommand::Login + | SlashCommand::Logout + | SlashCommand::Vim + | SlashCommand::Upgrade + | SlashCommand::Share + | SlashCommand::Feedback + | SlashCommand::Files + | SlashCommand::Fast + | SlashCommand::Exit + | SlashCommand::Summary + | SlashCommand::Desktop + | SlashCommand::Brief + | SlashCommand::Advisor + | SlashCommand::Stickers + | SlashCommand::Insights + | SlashCommand::Thinkback + | SlashCommand::ReleaseNotes + | SlashCommand::SecurityReview + | SlashCommand::Keybindings + | SlashCommand::PrivacySettings + | SlashCommand::Plan { .. } + | SlashCommand::Review { .. } + | SlashCommand::Tasks { .. } + | SlashCommand::Theme { .. } + | SlashCommand::Voice { .. } + | SlashCommand::Usage { .. } + | SlashCommand::Rename { .. } + | SlashCommand::Copy { .. } + | SlashCommand::Hooks { .. } + | SlashCommand::Context { .. } + | SlashCommand::Color { .. } + | SlashCommand::Effort { .. } + | SlashCommand::Branch { .. } + | SlashCommand::Rewind { .. } + | SlashCommand::Ide { .. } + | SlashCommand::Tag { .. } + | SlashCommand::OutputStyle { .. } + | SlashCommand::AddDir { .. } => Err("unsupported resumed slash command".into()), + } +} + +/// Detect if the current working directory is "broad" (home directory or +/// filesystem root). Returns the cwd path if broad, None otherwise. +fn detect_broad_cwd() -> Option { + let Ok(cwd) = env::current_dir() else { + return None; + }; + let is_home = env::var_os("HOME") + .or_else(|| env::var_os("USERPROFILE")) + .map(|h| PathBuf::from(h) == cwd) + .unwrap_or(false); + let is_root = cwd.parent().is_none(); + if is_home || is_root { + Some(cwd) + } else { + None + } +} + +/// Enforce the broad-CWD policy: when running from home or root, either +/// require the --allow-broad-cwd flag, or prompt for confirmation (interactive), +/// or exit with an error (non-interactive). +fn enforce_broad_cwd_policy( + allow_broad_cwd: bool, + output_format: CliOutputFormat, +) -> Result<(), Box> { + if allow_broad_cwd { + return Ok(()); + } + let Some(cwd) = detect_broad_cwd() else { + return Ok(()); + }; + + let is_interactive = io::stdin().is_terminal(); + + if is_interactive { + // Interactive mode: print warning and ask for confirmation + eprintln!( + "Warning: claw is running from a very broad directory ({}).\n\ + The agent can read and search everything under this path.\n\ + Consider running from inside your project: cd /path/to/project && claw", + cwd.display() + ); + eprint!("Continue anyway? [y/N]: "); + io::stderr().flush()?; + + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let trimmed = input.trim().to_lowercase(); + if trimmed != "y" && trimmed != "yes" { + eprintln!("Aborted."); + std::process::exit(0); + } + Ok(()) + } else { + // Non-interactive mode: exit with error (JSON or text) + let message = format!( + "claw is running from a very broad directory ({}). \ + The agent can read and search everything under this path. \ + Use --allow-broad-cwd to proceed anyway, \ + or run from inside your project: cd /path/to/project && claw", + cwd.display() + ); + match output_format { + CliOutputFormat::Json => { + eprintln!( + "{}", + serde_json::json!({ + "type": "error", + "error": message, + }) + ); + } + CliOutputFormat::Text => { + eprintln!("error: {message}"); + } + } + std::process::exit(1); + } +} + +fn run_stale_base_preflight(flag_value: Option<&str>) { + let cwd = match env::current_dir() { + Ok(cwd) => cwd, + Err(_) => return, + }; + let source = resolve_expected_base(flag_value, &cwd); + let state = check_base_commit(&cwd, source.as_ref()); + if let Some(warning) = format_stale_base_warning(&state) { + eprintln!("{warning}"); } } @@ -904,12 +3144,22 @@ fn run_repl( model: String, allowed_tools: Option, permission_mode: PermissionMode, + base_commit: Option, + reasoning_effort: Option, + allow_broad_cwd: bool, ) -> Result<(), Box> { - let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; - let mut editor = input::LineEditor::new("> ", slash_command_completion_candidates()); + enforce_broad_cwd_policy(allow_broad_cwd, CliOutputFormat::Text)?; + run_stale_base_preflight(base_commit.as_deref()); + let resolved_model = resolve_repl_model(model); + let mut cli = LiveCli::new(resolved_model, true, allowed_tools, permission_mode)?; + cli.set_reasoning_effort(reasoning_effort); + let mut editor = + input::LineEditor::new("> ", cli.repl_completion_candidates().unwrap_or_default()); println!("{}", cli.startup_banner()); + println!("{}", format_connected_line(&cli.model)); loop { + editor.set_completions(cli.repl_completion_candidates().unwrap_or_default()); match editor.read_line()? { input::ReadOutcome::Submit(input) => { let trimmed = input.trim().to_string(); @@ -920,13 +3170,41 @@ fn run_repl( cli.persist_session()?; break; } - if let Some(command) = SlashCommand::parse(&trimmed) { - if cli.handle_repl_command(command)? { - cli.persist_session()?; + match SlashCommand::parse(&trimmed) { + Ok(Some(command)) => { + if cli.handle_repl_command(command)? { + cli.persist_session()?; + } + continue; + } + Ok(None) => {} + Err(error) => { + eprintln!("{error}"); + continue; + } + } + // Bare-word skill dispatch: if the first token of the input + // matches a known skill name, invoke it as `/skills ` + // rather than forwarding raw text to the LLM (ROADMAP #36). + let bare_first_token = trimmed.split_whitespace().next().unwrap_or_default(); + let looks_like_skill_name = !bare_first_token.is_empty() + && !bare_first_token.starts_with('/') + && bare_first_token + .chars() + .all(|c| c.is_alphanumeric() || c == '-' || c == '_'); + if looks_like_skill_name { + let cwd = std::env::current_dir().unwrap_or_default(); + if let Ok(SkillSlashDispatch::Invoke(prompt)) = + resolve_skill_invocation(&cwd, Some(&trimmed)) + { + editor.push_history(input); + cli.record_prompt_history(&trimmed); + cli.run_turn(&prompt)?; + continue; } - continue; } editor.push_history(input); + cli.record_prompt_history(&trimmed); cli.run_turn(&trimmed)?; } input::ReadOutcome::Cancel => {} @@ -950,8 +3228,10 @@ struct SessionHandle { struct ManagedSessionSummary { id: String, path: PathBuf, - modified_epoch_secs: u64, + modified_epoch_millis: u128, message_count: usize, + parent_session_id: Option, + branch_name: Option, } struct LiveCli { @@ -959,8 +3239,486 @@ struct LiveCli { allowed_tools: Option, permission_mode: PermissionMode, system_prompt: Vec, - runtime: ConversationRuntime, + runtime: BuiltRuntime, session: SessionHandle, + prompt_history: Vec, +} + +#[derive(Debug, Clone)] +struct PromptHistoryEntry { + timestamp_ms: u64, + text: String, +} + +struct RuntimePluginState { + feature_config: runtime::RuntimeFeatureConfig, + tool_registry: GlobalToolRegistry, + plugin_registry: PluginRegistry, + mcp_state: Option>>, +} + +struct RuntimeMcpState { + runtime: tokio::runtime::Runtime, + manager: McpServerManager, + pending_servers: Vec, + degraded_report: Option, +} + +struct BuiltRuntime { + runtime: Option>, + plugin_registry: PluginRegistry, + plugins_active: bool, + mcp_state: Option>>, + mcp_active: bool, +} + +impl BuiltRuntime { + fn new( + runtime: ConversationRuntime, + plugin_registry: PluginRegistry, + mcp_state: Option>>, + ) -> Self { + Self { + runtime: Some(runtime), + plugin_registry, + plugins_active: true, + mcp_state, + mcp_active: true, + } + } + + fn with_hook_abort_signal(mut self, hook_abort_signal: runtime::HookAbortSignal) -> Self { + let runtime = self + .runtime + .take() + .expect("runtime should exist before installing hook abort signal"); + self.runtime = Some(runtime.with_hook_abort_signal(hook_abort_signal)); + self + } + + fn shutdown_plugins(&mut self) -> Result<(), Box> { + if self.plugins_active { + self.plugin_registry.shutdown()?; + self.plugins_active = false; + } + Ok(()) + } + + fn shutdown_mcp(&mut self) -> Result<(), Box> { + if self.mcp_active { + if let Some(mcp_state) = &self.mcp_state { + mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .shutdown()?; + } + self.mcp_active = false; + } + Ok(()) + } +} + +impl Deref for BuiltRuntime { + type Target = ConversationRuntime; + + fn deref(&self) -> &Self::Target { + self.runtime + .as_ref() + .expect("runtime should exist while built runtime is alive") + } +} + +impl DerefMut for BuiltRuntime { + fn deref_mut(&mut self) -> &mut Self::Target { + self.runtime + .as_mut() + .expect("runtime should exist while built runtime is alive") + } +} + +impl Drop for BuiltRuntime { + fn drop(&mut self) { + let _ = self.shutdown_mcp(); + let _ = self.shutdown_plugins(); + } +} + +#[derive(Debug, Deserialize)] +struct ToolSearchRequest { + query: String, + max_results: Option, +} + +#[derive(Debug, Deserialize)] +struct McpToolRequest { + #[serde(rename = "qualifiedName")] + qualified_name: Option, + tool: Option, + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct ListMcpResourcesRequest { + server: Option, +} + +#[derive(Debug, Deserialize)] +struct ReadMcpResourceRequest { + server: String, + uri: String, +} + +impl RuntimeMcpState { + fn new( + runtime_config: &runtime::RuntimeConfig, + ) -> Result, Box> { + let mut manager = McpServerManager::from_runtime_config(runtime_config); + if manager.server_names().is_empty() && manager.unsupported_servers().is_empty() { + return Ok(None); + } + + let runtime = tokio::runtime::Runtime::new()?; + let discovery = runtime.block_on(manager.discover_tools_best_effort()); + let pending_servers = discovery + .failed_servers + .iter() + .map(|failure| failure.server_name.clone()) + .chain( + discovery + .unsupported_servers + .iter() + .map(|server| server.server_name.clone()), + ) + .collect::>() + .into_iter() + .collect::>(); + let available_tools = discovery + .tools + .iter() + .map(|tool| tool.qualified_name.clone()) + .collect::>(); + let failed_server_names = pending_servers.iter().cloned().collect::>(); + let working_servers = manager + .server_names() + .into_iter() + .filter(|server_name| !failed_server_names.contains(server_name)) + .collect::>(); + let failed_servers = + discovery + .failed_servers + .iter() + .map(|failure| runtime::McpFailedServer { + server_name: failure.server_name.clone(), + phase: runtime::McpLifecyclePhase::ToolDiscovery, + error: runtime::McpErrorSurface::new( + runtime::McpLifecyclePhase::ToolDiscovery, + Some(failure.server_name.clone()), + failure.error.clone(), + std::collections::BTreeMap::new(), + true, + ), + }) + .chain(discovery.unsupported_servers.iter().map(|server| { + runtime::McpFailedServer { + server_name: server.server_name.clone(), + phase: runtime::McpLifecyclePhase::ServerRegistration, + error: runtime::McpErrorSurface::new( + runtime::McpLifecyclePhase::ServerRegistration, + Some(server.server_name.clone()), + server.reason.clone(), + std::collections::BTreeMap::from([( + "transport".to_string(), + format!("{:?}", server.transport).to_ascii_lowercase(), + )]), + false, + ), + } + })) + .collect::>(); + let degraded_report = (!failed_servers.is_empty()).then(|| { + runtime::McpDegradedReport::new( + working_servers, + failed_servers, + available_tools.clone(), + available_tools, + ) + }); + + Ok(Some(( + Self { + runtime, + manager, + pending_servers, + degraded_report, + }, + discovery, + ))) + } + + fn shutdown(&mut self) -> Result<(), Box> { + self.runtime.block_on(self.manager.shutdown())?; + Ok(()) + } + + fn pending_servers(&self) -> Option> { + (!self.pending_servers.is_empty()).then(|| self.pending_servers.clone()) + } + + fn degraded_report(&self) -> Option { + self.degraded_report.clone() + } + + fn server_names(&self) -> Vec { + self.manager.server_names() + } + + fn call_tool( + &mut self, + qualified_tool_name: &str, + arguments: Option, + ) -> Result { + let response = self + .runtime + .block_on(self.manager.call_tool(qualified_tool_name, arguments)) + .map_err(|error| ToolError::new(error.to_string()))?; + if let Some(error) = response.error { + return Err(ToolError::new(format!( + "MCP tool `{qualified_tool_name}` returned JSON-RPC error: {} ({})", + error.message, error.code + ))); + } + + let result = response.result.ok_or_else(|| { + ToolError::new(format!( + "MCP tool `{qualified_tool_name}` returned no result payload" + )) + })?; + serde_json::to_string_pretty(&result).map_err(|error| ToolError::new(error.to_string())) + } + + fn list_resources_for_server(&mut self, server_name: &str) -> Result { + let result = self + .runtime + .block_on(self.manager.list_resources(server_name)) + .map_err(|error| ToolError::new(error.to_string()))?; + serde_json::to_string_pretty(&json!({ + "server": server_name, + "resources": result.resources, + })) + .map_err(|error| ToolError::new(error.to_string())) + } + + fn list_resources_for_all_servers(&mut self) -> Result { + let mut resources = Vec::new(); + let mut failures = Vec::new(); + + for server_name in self.server_names() { + match self + .runtime + .block_on(self.manager.list_resources(&server_name)) + { + Ok(result) => resources.push(json!({ + "server": server_name, + "resources": result.resources, + })), + Err(error) => failures.push(json!({ + "server": server_name, + "error": error.to_string(), + })), + } + } + + if resources.is_empty() && !failures.is_empty() { + let message = failures + .iter() + .filter_map(|failure| failure.get("error").and_then(serde_json::Value::as_str)) + .collect::>() + .join("; "); + return Err(ToolError::new(message)); + } + + serde_json::to_string_pretty(&json!({ + "resources": resources, + "failures": failures, + })) + .map_err(|error| ToolError::new(error.to_string())) + } + + fn read_resource(&mut self, server_name: &str, uri: &str) -> Result { + let result = self + .runtime + .block_on(self.manager.read_resource(server_name, uri)) + .map_err(|error| ToolError::new(error.to_string()))?; + serde_json::to_string_pretty(&json!({ + "server": server_name, + "contents": result.contents, + })) + .map_err(|error| ToolError::new(error.to_string())) + } +} + +fn build_runtime_mcp_state( + runtime_config: &runtime::RuntimeConfig, +) -> Result> { + let Some((mcp_state, discovery)) = RuntimeMcpState::new(runtime_config)? else { + return Ok((None, Vec::new())); + }; + + let mut runtime_tools = discovery + .tools + .iter() + .map(mcp_runtime_tool_definition) + .collect::>(); + if !mcp_state.server_names().is_empty() { + runtime_tools.extend(mcp_wrapper_tool_definitions()); + } + + Ok((Some(Arc::new(Mutex::new(mcp_state))), runtime_tools)) +} + +fn mcp_runtime_tool_definition(tool: &runtime::ManagedMcpTool) -> RuntimeToolDefinition { + RuntimeToolDefinition { + name: tool.qualified_name.clone(), + description: Some( + tool.tool + .description + .clone() + .unwrap_or_else(|| format!("Invoke MCP tool `{}`.", tool.qualified_name)), + ), + input_schema: tool + .tool + .input_schema + .clone() + .unwrap_or_else(|| json!({ "type": "object", "additionalProperties": true })), + required_permission: permission_mode_for_mcp_tool(&tool.tool), + } +} + +fn mcp_wrapper_tool_definitions() -> Vec { + vec![ + RuntimeToolDefinition { + name: "MCPTool".to_string(), + description: Some( + "Call a configured MCP tool by its qualified name and JSON arguments.".to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "qualifiedName": { "type": "string" }, + "arguments": {} + }, + "required": ["qualifiedName"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + RuntimeToolDefinition { + name: "ListMcpResourcesTool".to_string(), + description: Some( + "List MCP resources from one configured server or from every connected server." + .to_string(), + ), + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" } + }, + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + RuntimeToolDefinition { + name: "ReadMcpResourceTool".to_string(), + description: Some("Read a specific MCP resource from a configured server.".to_string()), + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" }, + "uri": { "type": "string" } + }, + "required": ["server", "uri"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ] +} + +fn permission_mode_for_mcp_tool(tool: &McpTool) -> PermissionMode { + let read_only = mcp_annotation_flag(tool, "readOnlyHint"); + let destructive = mcp_annotation_flag(tool, "destructiveHint"); + let open_world = mcp_annotation_flag(tool, "openWorldHint"); + + if read_only && !destructive && !open_world { + PermissionMode::ReadOnly + } else if destructive || open_world { + PermissionMode::DangerFullAccess + } else { + PermissionMode::WorkspaceWrite + } +} + +fn mcp_annotation_flag(tool: &McpTool, key: &str) -> bool { + tool.annotations + .as_ref() + .and_then(|annotations| annotations.get(key)) + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) +} + +struct HookAbortMonitor { + stop_tx: Option>, + join_handle: Option>, +} + +impl HookAbortMonitor { + fn spawn(abort_signal: runtime::HookAbortSignal) -> Self { + Self::spawn_with_waiter(abort_signal, move |stop_rx, abort_signal| { + let Ok(runtime) = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + else { + return; + }; + + runtime.block_on(async move { + let wait_for_stop = tokio::task::spawn_blocking(move || { + let _ = stop_rx.recv(); + }); + + tokio::select! { + result = tokio::signal::ctrl_c() => { + if result.is_ok() { + abort_signal.abort(); + } + } + _ = wait_for_stop => {} + } + }); + }) + } + + fn spawn_with_waiter(abort_signal: runtime::HookAbortSignal, wait_for_interrupt: F) -> Self + where + F: FnOnce(Receiver<()>, runtime::HookAbortSignal) + Send + 'static, + { + let (stop_tx, stop_rx) = mpsc::channel(); + let join_handle = thread::spawn(move || wait_for_interrupt(stop_rx, abort_signal)); + + Self { + stop_tx: Some(stop_tx), + join_handle: Some(join_handle), + } + } + + fn stop(mut self) { + if let Some(stop_tx) = self.stop_tx.take() { + let _ = stop_tx.send(()); + } + if let Some(join_handle) = self.join_handle.take() { + let _ = join_handle.join(); + } + } } impl LiveCli { @@ -971,14 +3729,18 @@ impl LiveCli { permission_mode: PermissionMode, ) -> Result> { let system_prompt = build_system_prompt()?; - let session = create_managed_session_handle()?; + let session_state = Session::new(); + let session = create_managed_session_handle(&session_state.session_id)?; let runtime = build_runtime( - Session::new(), + session_state.with_persistence_path(session.path.clone()), + &session.id, model.clone(), system_prompt.clone(), enable_tools, + true, allowed_tools.clone(), permission_mode, + None, )?; let cli = Self { model, @@ -987,16 +3749,36 @@ impl LiveCli { system_prompt, runtime, session, + prompt_history: Vec::new(), }; cli.persist_session()?; Ok(cli) } + fn set_reasoning_effort(&mut self, effort: Option) { + if let Some(rt) = self.runtime.runtime.as_mut() { + rt.api_client_mut().set_reasoning_effort(effort); + } + } + fn startup_banner(&self) -> String { let cwd = env::current_dir().map_or_else( |_| "".to_string(), |path| path.display().to_string(), ); + let status = status_context(None).ok(); + let git_branch = status + .as_ref() + .and_then(|context| context.git_branch.as_deref()) + .unwrap_or("unknown"); + let workspace = status.as_ref().map_or_else( + || "unknown".to_string(), + |context| context.git_summary.headline(), + ); + let session_path = self.session.path.strip_prefix(Path::new(&cwd)).map_or_else( + |_| self.session.path.display().to_string(), + |path| path.display().to_string(), + ); format!( "\x1b[38;5;196m\ ██████╗██╗ █████╗ ██╗ ██╗\n\ @@ -1007,17 +3789,63 @@ impl LiveCli { ╚═════╝╚══════╝╚═╝ ╚═╝ ╚══╝╚══╝\x1b[0m \x1b[38;5;208mCode\x1b[0m 🦞\n\n\ \x1b[2mModel\x1b[0m {}\n\ \x1b[2mPermissions\x1b[0m {}\n\ + \x1b[2mBranch\x1b[0m {}\n\ + \x1b[2mWorkspace\x1b[0m {}\n\ \x1b[2mDirectory\x1b[0m {}\n\ - \x1b[2mSession\x1b[0m {}\n\n\ - Type \x1b[1m/help\x1b[0m for commands · \x1b[2mShift+Enter\x1b[0m for newline", + \x1b[2mSession\x1b[0m {}\n\ + \x1b[2mAuto-save\x1b[0m {}\n\n\ + Type \x1b[1m/help\x1b[0m for commands · \x1b[1m/status\x1b[0m for live context · \x1b[2m/resume latest\x1b[0m jumps back to the newest session · \x1b[1m/diff\x1b[0m then \x1b[1m/commit\x1b[0m to ship · \x1b[2mTab\x1b[0m for workflow completions · \x1b[2mShift+Enter\x1b[0m for newline", self.model, self.permission_mode.as_str(), + git_branch, + workspace, cwd, self.session.id, + session_path, ) } + fn repl_completion_candidates(&self) -> Result, Box> { + Ok(slash_command_completion_candidates_with_sessions( + &self.model, + Some(&self.session.id), + list_managed_sessions()? + .into_iter() + .map(|session| session.id) + .collect(), + )) + } + + fn prepare_turn_runtime( + &self, + emit_output: bool, + ) -> Result<(BuiltRuntime, HookAbortMonitor), Box> { + let hook_abort_signal = runtime::HookAbortSignal::new(); + let runtime = build_runtime( + self.runtime.session().clone(), + &self.session.id, + self.model.clone(), + self.system_prompt.clone(), + true, + emit_output, + self.allowed_tools.clone(), + self.permission_mode, + None, + )? + .with_hook_abort_signal(hook_abort_signal.clone()); + let hook_abort_monitor = HookAbortMonitor::spawn(hook_abort_signal); + + Ok((runtime, hook_abort_monitor)) + } + + fn replace_runtime(&mut self, runtime: BuiltRuntime) -> Result<(), Box> { + self.runtime.shutdown_plugins()?; + self.runtime = runtime; + Ok(()) + } + fn run_turn(&mut self, input: &str) -> Result<(), Box> { + let (mut runtime, hook_abort_monitor) = self.prepare_turn_runtime(true)?; let mut spinner = Spinner::new(); let mut stdout = io::stdout(); spinner.tick( @@ -1026,19 +3854,28 @@ impl LiveCli { &mut stdout, )?; let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); - let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); + let result = runtime.run_turn(input, Some(&mut permission_prompter)); + hook_abort_monitor.stop(); match result { - Ok(_) => { + Ok(summary) => { + self.replace_runtime(runtime)?; spinner.finish( "✨ Done", TerminalRenderer::new().color_theme(), &mut stdout, )?; println!(); + if let Some(event) = summary.auto_compaction { + println!( + "{}", + format_auto_compaction_notice(event.removed_message_count) + ); + } self.persist_session()?; Ok(()) } Err(error) => { + runtime.shutdown_plugins()?; spinner.fail( "❌ Request failed", TerminalRenderer::new().color_theme(), @@ -1053,58 +3890,67 @@ impl LiveCli { &mut self, input: &str, output_format: CliOutputFormat, + compact: bool, ) -> Result<(), Box> { match output_format { + CliOutputFormat::Text if compact => self.run_prompt_compact(input), CliOutputFormat::Text => self.run_turn(input), CliOutputFormat::Json => self.run_prompt_json(input), } } + fn run_prompt_compact(&mut self, input: &str) -> Result<(), Box> { + let (mut runtime, hook_abort_monitor) = self.prepare_turn_runtime(false)?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let result = runtime.run_turn(input, Some(&mut permission_prompter)); + hook_abort_monitor.stop(); + let summary = result?; + self.replace_runtime(runtime)?; + self.persist_session()?; + let final_text = final_assistant_text(&summary); + println!("{final_text}"); + Ok(()) + } + fn run_prompt_json(&mut self, input: &str) -> Result<(), Box> { - let client = ClawApiClient::from_auth(resolve_cli_auth_source()?).with_base_url(api::read_base_url()); - let request = MessageRequest { - model: self.model.clone(), - max_tokens: DEFAULT_MAX_TOKENS, - messages: vec![InputMessage { - role: "user".to_string(), - content: vec![InputContentBlock::Text { - text: input.to_string(), - }], - }], - system: (!self.system_prompt.is_empty()).then(|| self.system_prompt.join("\n\n")), - tools: None, - tool_choice: None, - stream: false, - }; - let runtime = tokio::runtime::Runtime::new()?; - let response = runtime.block_on(client.send_message(&request))?; - let text = response - .content - .iter() - .filter_map(|block| match block { - OutputContentBlock::Text { text } => Some(text.as_str()), - OutputContentBlock::ToolUse { .. } - | OutputContentBlock::Thinking { .. } - | OutputContentBlock::RedactedThinking { .. } => None, - }) - .collect::>() - .join(""); + let (mut runtime, hook_abort_monitor) = self.prepare_turn_runtime(false)?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let result = runtime.run_turn(input, Some(&mut permission_prompter)); + hook_abort_monitor.stop(); + let summary = result?; + self.replace_runtime(runtime)?; + self.persist_session()?; println!( "{}", json!({ - "message": text, + "message": final_assistant_text(&summary), "model": self.model, + "iterations": summary.iterations, + "auto_compaction": summary.auto_compaction.map(|event| json!({ + "removed_messages": event.removed_message_count, + "notice": format_auto_compaction_notice(event.removed_message_count), + })), + "tool_uses": collect_tool_uses(&summary), + "tool_results": collect_tool_results(&summary), + "prompt_cache_events": collect_prompt_cache_events(&summary), "usage": { - "input_tokens": response.usage.input_tokens, - "output_tokens": response.usage.output_tokens, - "cache_creation_input_tokens": response.usage.cache_creation_input_tokens, - "cache_read_input_tokens": response.usage.cache_read_input_tokens, - } + "input_tokens": summary.usage.input_tokens, + "output_tokens": summary.usage.output_tokens, + "cache_creation_input_tokens": summary.usage.cache_creation_input_tokens, + "cache_read_input_tokens": summary.usage.cache_read_input_tokens, + }, + "estimated_cost": format_usd( + summary.usage.estimate_cost_usd_with_pricing( + pricing_for_model(&self.model) + .unwrap_or_else(runtime::ModelPricing::default_sonnet_tier) + ).total_cost_usd() + ) }) ); Ok(()) } + #[allow(clippy::too_many_lines)] fn handle_repl_command( &mut self, command: SlashCommand, @@ -1118,6 +3964,38 @@ impl LiveCli { self.print_status(); false } + SlashCommand::Bughunter { scope } => { + self.run_bughunter(scope.as_deref())?; + false + } + SlashCommand::Commit => { + self.run_commit(None)?; + false + } + SlashCommand::Pr { context } => { + self.run_pr(context.as_deref())?; + false + } + SlashCommand::Issue { context } => { + self.run_issue(context.as_deref())?; + false + } + SlashCommand::Ultraplan { task } => { + self.run_ultraplan(task.as_deref())?; + false + } + SlashCommand::Teleport { target } => { + Self::run_teleport(target.as_deref())?; + false + } + SlashCommand::DebugToolCall => { + self.run_debug_tool_call(None)?; + false + } + SlashCommand::Sandbox => { + Self::print_sandbox_status(); + false + } SlashCommand::Compact => { self.compact()?; false @@ -1134,12 +4012,22 @@ impl LiveCli { Self::print_config(section.as_deref())?; false } + SlashCommand::Mcp { action, target } => { + let args = match (action.as_deref(), target.as_deref()) { + (None, None) => None, + (Some(action), None) => Some(action.to_string()), + (Some(action), Some(target)) => Some(format!("{action} {target}")), + (None, Some(target)) => Some(target.to_string()), + }; + Self::print_mcp(args.as_deref(), CliOutputFormat::Text)?; + false + } SlashCommand::Memory => { Self::print_memory()?; false } SlashCommand::Init => { - run_init()?; + run_init(CliOutputFormat::Text)?; false } SlashCommand::Diff => { @@ -1147,87 +4035,89 @@ impl LiveCli { false } SlashCommand::Version => { - Self::print_version(); + Self::print_version(CliOutputFormat::Text); false } SlashCommand::Export { path } => { self.export_session(path.as_deref())?; false } - SlashCommand::Branch { action, target } => { - println!( - "{}", - handle_branch_slash_command( - action.as_deref(), - target.as_deref(), - &env::current_dir()? - )? - ); - false - } - SlashCommand::Worktree { - action, - path, - branch, - } => { - println!( - "{}", - handle_worktree_slash_command( - action.as_deref(), - path.as_deref(), - branch.as_deref(), - &env::current_dir()? - )? - ); - false - } - SlashCommand::Commit => { - println!( - "{}", - handle_commit_slash_command("resume commit", &env::current_dir()?)? - ); - false - } - SlashCommand::Agents { args } => { - println!( - "{}", - handle_agents_slash_command(args.as_deref(), &env::current_dir()?)? - ); - false - } - SlashCommand::Skills { args } => { - println!( - "{}", - handle_skills_slash_command(args.as_deref(), &env::current_dir()?)? - ); - false - } - SlashCommand::Plugins { action, target } => { - let config = plugins::PluginManagerConfig::new(env::current_dir()?); - let mut manager = plugins::PluginManager::new(config); - let result = handle_plugins_slash_command( - action.as_deref(), - target.as_deref(), - &mut manager, - )?; - println!("{}", result.message); - result.reload_runtime - } - SlashCommand::Bughunter { .. } - | SlashCommand::CommitPushPr { .. } - | SlashCommand::Pr { .. } - | SlashCommand::Issue { .. } - | SlashCommand::Ultraplan { .. } - | SlashCommand::Teleport { .. } - | SlashCommand::DebugToolCall => { - eprintln!("slash command not yet implemented in REPL: {command:?}"); - false - } SlashCommand::Session { action, target } => { self.handle_session_command(action.as_deref(), target.as_deref())? } + SlashCommand::Plugins { action, target } => { + self.handle_plugins_command(action.as_deref(), target.as_deref())? + } + SlashCommand::Agents { args } => { + Self::print_agents(args.as_deref(), CliOutputFormat::Text)?; + false + } + SlashCommand::Skills { args } => { + match classify_skills_slash_command(args.as_deref()) { + SkillSlashDispatch::Invoke(prompt) => self.run_turn(&prompt)?, + SkillSlashDispatch::Local => { + Self::print_skills(args.as_deref(), CliOutputFormat::Text)?; + } + } + false + } + SlashCommand::Doctor => { + println!("{}", render_doctor_report()?.render()); + false + } + SlashCommand::History { count } => { + self.print_prompt_history(count.as_deref()); + false + } + SlashCommand::Stats => { + let usage = UsageTracker::from_session(self.runtime.session()).cumulative_usage(); + println!("{}", format_cost_report(usage)); + false + } + SlashCommand::Login + | SlashCommand::Logout + | SlashCommand::Vim + | SlashCommand::Upgrade + | SlashCommand::Share + | SlashCommand::Feedback + | SlashCommand::Files + | SlashCommand::Fast + | SlashCommand::Exit + | SlashCommand::Summary + | SlashCommand::Desktop + | SlashCommand::Brief + | SlashCommand::Advisor + | SlashCommand::Stickers + | SlashCommand::Insights + | SlashCommand::Thinkback + | SlashCommand::ReleaseNotes + | SlashCommand::SecurityReview + | SlashCommand::Keybindings + | SlashCommand::PrivacySettings + | SlashCommand::Plan { .. } + | SlashCommand::Review { .. } + | SlashCommand::Tasks { .. } + | SlashCommand::Theme { .. } + | SlashCommand::Voice { .. } + | SlashCommand::Usage { .. } + | SlashCommand::Rename { .. } + | SlashCommand::Copy { .. } + | SlashCommand::Hooks { .. } + | SlashCommand::Context { .. } + | SlashCommand::Color { .. } + | SlashCommand::Effort { .. } + | SlashCommand::Branch { .. } + | SlashCommand::Rewind { .. } + | SlashCommand::Ide { .. } + | SlashCommand::Tag { .. } + | SlashCommand::OutputStyle { .. } + | SlashCommand::AddDir { .. } => { + let cmd_name = command.slash_name(); + eprintln!("{cmd_name} is not yet implemented in this build."); + false + } SlashCommand::Unknown(name) => { - eprintln!("unknown slash command: /{name}"); + eprintln!("{}", format_unknown_slash_command(&name)); false } }) @@ -1258,6 +4148,68 @@ impl LiveCli { ); } + fn record_prompt_history(&mut self, prompt: &str) { + let timestamp_ms = std::time::SystemTime::now() + .duration_since(UNIX_EPOCH) + .ok() + .map_or(self.runtime.session().updated_at_ms, |duration| { + u64::try_from(duration.as_millis()).unwrap_or(u64::MAX) + }); + let entry = PromptHistoryEntry { + timestamp_ms, + text: prompt.to_string(), + }; + self.prompt_history.push(entry); + if let Err(error) = self.runtime.session_mut().push_prompt_entry(prompt) { + eprintln!("warning: failed to persist prompt history: {error}"); + } + } + + fn print_prompt_history(&self, count: Option<&str>) { + let limit = match parse_history_count(count) { + Ok(limit) => limit, + Err(message) => { + eprintln!("{message}"); + return; + } + }; + let session_entries = &self.runtime.session().prompt_history; + let entries = if session_entries.is_empty() { + if self.prompt_history.is_empty() { + collect_session_prompt_history(self.runtime.session()) + } else { + self.prompt_history + .iter() + .map(|entry| PromptHistoryEntry { + timestamp_ms: entry.timestamp_ms, + text: entry.text.clone(), + }) + .collect() + } + } else { + session_entries + .iter() + .map(|entry| PromptHistoryEntry { + timestamp_ms: entry.timestamp_ms, + text: entry.text.clone(), + }) + .collect() + }; + println!("{}", render_prompt_history_report(&entries, limit)); + } + + fn print_sandbox_status() { + let cwd = env::current_dir().expect("current dir"); + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader + .load() + .unwrap_or_else(|_| runtime::RuntimeConfig::empty()); + println!( + "{}", + format_sandbox_report(&resolve_sandbox_status(runtime_config.sandbox(), &cwd)) + ); + } + fn set_model(&mut self, model: Option) -> Result> { let Some(model) = model else { println!( @@ -1271,6 +4223,8 @@ impl LiveCli { return Ok(false); }; + let model = resolve_model_alias_with_config(&model); + if model == self.model { println!( "{}", @@ -1286,14 +4240,18 @@ impl LiveCli { let previous = self.model.clone(); let session = self.runtime.session().clone(); let message_count = session.messages.len(); - self.runtime = build_runtime( + let runtime = build_runtime( session, + &self.session.id, model.clone(), self.system_prompt.clone(), true, + true, self.allowed_tools.clone(), self.permission_mode, + None, )?; + self.replace_runtime(runtime)?; self.model.clone_from(&model); println!( "{}", @@ -1328,14 +4286,18 @@ impl LiveCli { let previous = self.permission_mode.as_str().to_string(); let session = self.runtime.session().clone(); self.permission_mode = permission_mode_from_label(normalized); - self.runtime = build_runtime( + let runtime = build_runtime( session, + &self.session.id, self.model.clone(), self.system_prompt.clone(), true, + true, self.allowed_tools.clone(), self.permission_mode, + None, )?; + self.replace_runtime(runtime)?; println!( "{}", format_permissions_switch_report(&previous, normalized) @@ -1351,20 +4313,29 @@ impl LiveCli { return Ok(false); } - self.session = create_managed_session_handle()?; - self.runtime = build_runtime( - Session::new(), + let previous_session = self.session.clone(); + let session_state = Session::new(); + self.session = create_managed_session_handle(&session_state.session_id)?; + let runtime = build_runtime( + session_state.with_persistence_path(self.session.path.clone()), + &self.session.id, self.model.clone(), self.system_prompt.clone(), true, + true, self.allowed_tools.clone(), self.permission_mode, + None, )?; + self.replace_runtime(runtime)?; println!( - "Session cleared\n Mode fresh session\n Preserved model {}\n Permission mode {}\n Session {}", + "Session cleared\n Mode fresh session\n Previous session {}\n Resume previous /resume {}\n Preserved model {}\n Permission mode {}\n New session {}\n Session file {}", + previous_session.id, + previous_session.id, self.model, self.permission_mode.as_str(), self.session.id, + self.session.path.display(), ); Ok(true) } @@ -1379,22 +4350,30 @@ impl LiveCli { session_path: Option, ) -> Result> { let Some(session_ref) = session_path else { - println!("Usage: /resume "); + println!("{}", render_resume_usage()); return Ok(false); }; let handle = resolve_session_reference(&session_ref)?; let session = Session::load_from_path(&handle.path)?; let message_count = session.messages.len(); - self.runtime = build_runtime( + let session_id = session.session_id.clone(); + let runtime = build_runtime( session, + &handle.id, self.model.clone(), self.system_prompt.clone(), true, + true, self.allowed_tools.clone(), self.permission_mode, + None, )?; - self.session = handle; + self.replace_runtime(runtime)?; + self.session = SessionHandle { + id: session_id, + path: handle.path, + }; println!( "{}", format_resume_report( @@ -1416,13 +4395,90 @@ impl LiveCli { Ok(()) } + fn print_agents( + args: Option<&str>, + output_format: CliOutputFormat, + ) -> Result<(), Box> { + let cwd = env::current_dir()?; + match output_format { + CliOutputFormat::Text => println!("{}", handle_agents_slash_command(args, &cwd)?), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&handle_agents_slash_command_json(args, &cwd)?)? + ), + } + Ok(()) + } + + fn print_mcp( + args: Option<&str>, + output_format: CliOutputFormat, + ) -> Result<(), Box> { + // `claw mcp serve` starts a stdio MCP server exposing claw's built-in + // tools. All other `mcp` subcommands fall through to the existing + // configured-server reporter (`list`, `status`, ...). + if matches!(args.map(str::trim), Some("serve")) { + return run_mcp_serve(); + } + let cwd = env::current_dir()?; + match output_format { + CliOutputFormat::Text => println!("{}", handle_mcp_slash_command(args, &cwd)?), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&handle_mcp_slash_command_json(args, &cwd)?)? + ), + } + Ok(()) + } + + fn print_skills( + args: Option<&str>, + output_format: CliOutputFormat, + ) -> Result<(), Box> { + let cwd = env::current_dir()?; + match output_format { + CliOutputFormat::Text => println!("{}", handle_skills_slash_command(args, &cwd)?), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&handle_skills_slash_command_json(args, &cwd)?)? + ), + } + Ok(()) + } + + fn print_plugins( + action: Option<&str>, + target: Option<&str>, + output_format: CliOutputFormat, + ) -> Result<(), Box> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + let mut manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let result = handle_plugins_slash_command(action, target, &mut manager)?; + match output_format { + CliOutputFormat::Text => println!("{}", result.message), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "plugin", + "action": action.unwrap_or("list"), + "target": target, + "message": result.message, + "reload_runtime": result.reload_runtime, + }))? + ), + } + Ok(()) + } + fn print_diff() -> Result<(), Box> { println!("{}", render_diff_report()?); Ok(()) } - fn print_version() { - println!("{}", render_version_report()); + fn print_version(output_format: CliOutputFormat) { + let _ = crate::print_version(output_format); } fn export_session( @@ -1457,15 +4513,23 @@ impl LiveCli { let handle = resolve_session_reference(target)?; let session = Session::load_from_path(&handle.path)?; let message_count = session.messages.len(); - self.runtime = build_runtime( + let session_id = session.session_id.clone(); + let runtime = build_runtime( session, + &handle.id, self.model.clone(), self.system_prompt.clone(), true, + true, self.allowed_tools.clone(), self.permission_mode, + None, )?; - self.session = handle; + self.replace_runtime(runtime)?; + self.session = SessionHandle { + id: session_id, + path: handle.path, + }; println!( "Session switched\n Active session {}\n File {}\n Messages {}", self.session.id, @@ -1474,105 +4538,443 @@ impl LiveCli { ); Ok(true) } + Some("fork") => { + let forked = self.runtime.fork_session(target.map(ToOwned::to_owned)); + let parent_session_id = self.session.id.clone(); + let handle = create_managed_session_handle(&forked.session_id)?; + let branch_name = forked + .fork + .as_ref() + .and_then(|fork| fork.branch_name.clone()); + let forked = forked.with_persistence_path(handle.path.clone()); + let message_count = forked.messages.len(); + forked.save_to_path(&handle.path)?; + let runtime = build_runtime( + forked, + &handle.id, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.replace_runtime(runtime)?; + self.session = handle; + println!( + "Session forked\n Parent session {}\n Active session {}\n Branch {}\n File {}\n Messages {}", + parent_session_id, + self.session.id, + branch_name.as_deref().unwrap_or("(unnamed)"), + self.session.path.display(), + message_count, + ); + Ok(true) + } + Some("delete") => { + let Some(target) = target else { + println!("Usage: /session delete [--force]"); + return Ok(false); + }; + let handle = resolve_session_reference(target)?; + if handle.id == self.session.id { + println!( + "delete: refusing to delete the active session '{}'.\nSwitch to another session first with /session switch .", + handle.id + ); + return Ok(false); + } + if !confirm_session_deletion(&handle.id) { + println!("delete: cancelled."); + return Ok(false); + } + delete_managed_session(&handle.path)?; + println!( + "Session deleted\n Deleted session {}\n File {}", + handle.id, + handle.path.display(), + ); + Ok(false) + } + Some("delete-force") => { + let Some(target) = target else { + println!("Usage: /session delete [--force]"); + return Ok(false); + }; + let handle = resolve_session_reference(target)?; + if handle.id == self.session.id { + println!( + "delete: refusing to delete the active session '{}'.\nSwitch to another session first with /session switch .", + handle.id + ); + return Ok(false); + } + delete_managed_session(&handle.path)?; + println!( + "Session deleted\n Deleted session {}\n File {}", + handle.id, + handle.path.display(), + ); + Ok(false) + } Some(other) => { - println!("Unknown /session action '{other}'. Use /session list or /session switch ."); + println!( + "Unknown /session action '{other}'. Use /session list, /session switch , /session fork [branch-name], or /session delete [--force]." + ); Ok(false) } } } + fn handle_plugins_command( + &mut self, + action: Option<&str>, + target: Option<&str>, + ) -> Result> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + let mut manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let result = handle_plugins_slash_command(action, target, &mut manager)?; + println!("{}", result.message); + if result.reload_runtime { + self.reload_runtime_features()?; + } + Ok(false) + } + + fn reload_runtime_features(&mut self) -> Result<(), Box> { + let runtime = build_runtime( + self.runtime.session().clone(), + &self.session.id, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.replace_runtime(runtime)?; + self.persist_session() + } + fn compact(&mut self) -> Result<(), Box> { let result = self.runtime.compact(CompactionConfig::default()); let removed = result.removed_message_count; let kept = result.compacted_session.messages.len(); let skipped = removed == 0; - self.runtime = build_runtime( + let runtime = build_runtime( result.compacted_session, + &self.session.id, self.model.clone(), self.system_prompt.clone(), true, + true, self.allowed_tools.clone(), self.permission_mode, + None, )?; + self.replace_runtime(runtime)?; self.persist_session()?; println!("{}", format_compact_report(removed, kept, skipped)); Ok(()) } + + fn run_internal_prompt_text_with_progress( + &self, + prompt: &str, + enable_tools: bool, + progress: Option, + ) -> Result> { + let session = self.runtime.session().clone(); + let mut runtime = build_runtime( + session, + &self.session.id, + self.model.clone(), + self.system_prompt.clone(), + enable_tools, + false, + self.allowed_tools.clone(), + self.permission_mode, + progress, + )?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let summary = runtime.run_turn(prompt, Some(&mut permission_prompter))?; + let text = final_assistant_text(&summary).trim().to_string(); + runtime.shutdown_plugins()?; + Ok(text) + } + + fn run_internal_prompt_text( + &self, + prompt: &str, + enable_tools: bool, + ) -> Result> { + self.run_internal_prompt_text_with_progress(prompt, enable_tools, None) + } + + fn run_bughunter(&self, scope: Option<&str>) -> Result<(), Box> { + println!("{}", format_bughunter_report(scope)); + Ok(()) + } + + fn run_ultraplan(&self, task: Option<&str>) -> Result<(), Box> { + println!("{}", format_ultraplan_report(task)); + Ok(()) + } + + fn run_teleport(target: Option<&str>) -> Result<(), Box> { + let Some(target) = target.map(str::trim).filter(|value| !value.is_empty()) else { + println!("Usage: /teleport "); + return Ok(()); + }; + + println!("{}", render_teleport_report(target)?); + Ok(()) + } + + fn run_debug_tool_call(&self, args: Option<&str>) -> Result<(), Box> { + validate_no_args("/debug-tool-call", args)?; + println!("{}", render_last_tool_debug_report(self.runtime.session())?); + Ok(()) + } + + fn run_commit(&mut self, args: Option<&str>) -> Result<(), Box> { + validate_no_args("/commit", args)?; + let status = git_output(&["status", "--short", "--branch"])?; + let summary = parse_git_workspace_summary(Some(&status)); + let branch = parse_git_status_branch(Some(&status)); + if summary.is_clean() { + println!("{}", format_commit_skipped_report()); + return Ok(()); + } + + println!( + "{}", + format_commit_preflight_report(branch.as_deref(), summary) + ); + Ok(()) + } + + fn run_pr(&self, context: Option<&str>) -> Result<(), Box> { + let branch = + resolve_git_branch_for(&env::current_dir()?).unwrap_or_else(|| "unknown".to_string()); + println!("{}", format_pr_report(&branch, context)); + Ok(()) + } + + fn run_issue(&self, context: Option<&str>) -> Result<(), Box> { + println!("{}", format_issue_report(context)); + Ok(()) + } } fn sessions_dir() -> Result> { let cwd = env::current_dir()?; - let path = cwd.join(".claude").join("sessions"); - fs::create_dir_all(&path)?; - Ok(path) + let store = runtime::SessionStore::from_cwd(&cwd) + .map_err(|e| Box::new(e) as Box)?; + Ok(store.sessions_dir().to_path_buf()) } -fn create_managed_session_handle() -> Result> { - let id = generate_session_id(); - let path = sessions_dir()?.join(format!("{id}.json")); +fn create_managed_session_handle( + session_id: &str, +) -> Result> { + let id = session_id.to_string(); + let path = sessions_dir()?.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}")); Ok(SessionHandle { id, path }) } -fn generate_session_id() -> String { - let millis = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_millis()) - .unwrap_or_default(); - format!("session-{millis}") -} - fn resolve_session_reference(reference: &str) -> Result> { + if SESSION_REFERENCE_ALIASES + .iter() + .any(|alias| reference.eq_ignore_ascii_case(alias)) + { + let latest = latest_managed_session()?; + return Ok(SessionHandle { + id: latest.id, + path: latest.path, + }); + } + let direct = PathBuf::from(reference); + let looks_like_path = direct.extension().is_some() || direct.components().count() > 1; let path = if direct.exists() { direct + } else if looks_like_path { + return Err(format_missing_session_reference(reference).into()); } else { - sessions_dir()?.join(format!("{reference}.json")) + resolve_managed_session_path(reference)? }; - if !path.exists() { - return Err(format!("session not found: {reference}").into()); - } let id = path - .file_stem() + .file_name() .and_then(|value| value.to_str()) + .and_then(|name| { + name.strip_suffix(&format!(".{PRIMARY_SESSION_EXTENSION}")) + .or_else(|| name.strip_suffix(&format!(".{LEGACY_SESSION_EXTENSION}"))) + }) .unwrap_or(reference) .to_string(); Ok(SessionHandle { id, path }) } -fn list_managed_sessions() -> Result, Box> { - let mut sessions = Vec::new(); - for entry in fs::read_dir(sessions_dir()?)? { +fn resolve_managed_session_path(session_id: &str) -> Result> { + let directory = sessions_dir()?; + for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] { + let path = directory.join(format!("{session_id}.{extension}")); + if path.exists() { + return Ok(path); + } + } + // Backward compatibility: pre-isolation sessions were stored at + // `.claw/sessions/.{jsonl,json}` without the per-workspace hash + // subdirectory. Walk up from `directory` to the `.claw/sessions/` root + // and try the flat layout as a fallback so users do not lose access + // to their pre-upgrade managed sessions. + if let Some(legacy_root) = directory + .parent() + .filter(|parent| parent.file_name().is_some_and(|name| name == "sessions")) + { + for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] { + let path = legacy_root.join(format!("{session_id}.{extension}")); + if path.exists() { + return Ok(path); + } + } + } + Err(format_missing_session_reference(session_id).into()) +} + +fn is_managed_session_file(path: &Path) -> bool { + path.extension() + .and_then(|ext| ext.to_str()) + .is_some_and(|extension| { + extension == PRIMARY_SESSION_EXTENSION || extension == LEGACY_SESSION_EXTENSION + }) +} + +fn collect_sessions_from_dir( + directory: &Path, + sessions: &mut Vec, +) -> Result<(), Box> { + if !directory.exists() { + return Ok(()); + } + for entry in fs::read_dir(directory)? { let entry = entry?; let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) != Some("json") { + if !is_managed_session_file(&path) { continue; } let metadata = entry.metadata()?; - let modified_epoch_secs = metadata + let modified_epoch_millis = metadata .modified() .ok() .and_then(|time| time.duration_since(UNIX_EPOCH).ok()) - .map(|duration| duration.as_secs()) + .map(|duration| duration.as_millis()) .unwrap_or_default(); - let message_count = Session::load_from_path(&path) - .map(|session| session.messages.len()) - .unwrap_or_default(); - let id = path - .file_stem() - .and_then(|value| value.to_str()) - .unwrap_or("unknown") - .to_string(); + let (id, message_count, parent_session_id, branch_name) = + match Session::load_from_path(&path) { + Ok(session) => { + let parent_session_id = session + .fork + .as_ref() + .map(|fork| fork.parent_session_id.clone()); + let branch_name = session + .fork + .as_ref() + .and_then(|fork| fork.branch_name.clone()); + ( + session.session_id, + session.messages.len(), + parent_session_id, + branch_name, + ) + } + Err(_) => ( + path.file_stem() + .and_then(|value| value.to_str()) + .unwrap_or("unknown") + .to_string(), + 0, + None, + None, + ), + }; sessions.push(ManagedSessionSummary { id, path, - modified_epoch_secs, + modified_epoch_millis, message_count, + parent_session_id, + branch_name, }); } - sessions.sort_by(|left, right| right.modified_epoch_secs.cmp(&left.modified_epoch_secs)); + Ok(()) +} + +fn list_managed_sessions() -> Result, Box> { + let mut sessions = Vec::new(); + let primary_dir = sessions_dir()?; + collect_sessions_from_dir(&primary_dir, &mut sessions)?; + + // Backward compatibility: include sessions stored in the pre-isolation + // flat `.claw/sessions/` root so users do not lose access to existing + // managed sessions after the workspace-hashed subdirectory rollout. + if let Some(legacy_root) = primary_dir + .parent() + .filter(|parent| parent.file_name().is_some_and(|name| name == "sessions")) + { + collect_sessions_from_dir(legacy_root, &mut sessions)?; + } + + sessions.sort_by(|left, right| { + right + .modified_epoch_millis + .cmp(&left.modified_epoch_millis) + .then_with(|| right.id.cmp(&left.id)) + }); Ok(sessions) } +fn latest_managed_session() -> Result> { + list_managed_sessions()? + .into_iter() + .next() + .ok_or_else(|| format_no_managed_sessions().into()) +} + +fn delete_managed_session(path: &Path) -> Result<(), Box> { + if !path.exists() { + return Err(format!("session file does not exist: {}", path.display()).into()); + } + fs::remove_file(path)?; + Ok(()) +} + +fn confirm_session_deletion(session_id: &str) -> bool { + print!("Delete session '{session_id}'? This cannot be undone. [y/N]: "); + io::stdout().flush().unwrap_or(()); + let mut answer = String::new(); + if io::stdin().read_line(&mut answer).is_err() { + return false; + } + matches!(answer.trim(), "y" | "Y" | "yes" | "Yes" | "YES") +} + +fn format_missing_session_reference(reference: &str) -> String { + format!( + "session not found: {reference}\nHint: managed sessions live in .claw/sessions/. Try `{LATEST_SESSION_REFERENCE}` for the most recent session or `/session list` in the REPL." + ) +} + +fn format_no_managed_sessions() -> String { + format!( + "no managed sessions found in .claw/sessions/\nStart `claw` to create a session, then rerun with `--resume {LATEST_SESSION_REFERENCE}`." + ) +} + fn render_session_list(active_session_id: &str) -> Result> { let sessions = list_managed_sessions()?; let mut lines = vec![ @@ -1589,28 +4991,84 @@ fn render_session_list(active_session_id: &str) -> Result { + format!(" branch={branch_name} from={parent_session_id}") + } + (None, Some(parent_session_id)) => format!(" from={parent_session_id}"), + (Some(branch_name), None) => format!(" branch={branch_name}"), + (None, None) => String::new(), + }; lines.push(format!( - " {id:<20} {marker:<10} msgs={msgs:<4} modified={modified} path={path}", + " {id:<20} {marker:<10} msgs={msgs:<4} modified={modified}{lineage} path={path}", id = session.id, msgs = session.message_count, - modified = session.modified_epoch_secs, + modified = format_session_modified_age(session.modified_epoch_millis), + lineage = lineage, path = session.path.display(), )); } Ok(lines.join("\n")) } +fn format_session_modified_age(modified_epoch_millis: u128) -> String { + let now = std::time::SystemTime::now() + .duration_since(UNIX_EPOCH) + .ok() + .map_or(modified_epoch_millis, |duration| duration.as_millis()); + let delta_seconds = now + .saturating_sub(modified_epoch_millis) + .checked_div(1_000) + .unwrap_or_default(); + match delta_seconds { + 0..=4 => "just-now".to_string(), + 5..=59 => format!("{delta_seconds}s-ago"), + 60..=3_599 => format!("{}m-ago", delta_seconds / 60), + 3_600..=86_399 => format!("{}h-ago", delta_seconds / 3_600), + _ => format!("{}d-ago", delta_seconds / 86_400), + } +} + +fn write_session_clear_backup( + session: &Session, + session_path: &Path, +) -> Result> { + let backup_path = session_clear_backup_path(session_path); + session.save_to_path(&backup_path)?; + Ok(backup_path) +} + +fn session_clear_backup_path(session_path: &Path) -> PathBuf { + let timestamp = std::time::SystemTime::now() + .duration_since(UNIX_EPOCH) + .ok() + .map_or(0, |duration| duration.as_millis()); + let file_name = session_path + .file_name() + .and_then(|value| value.to_str()) + .unwrap_or("session.jsonl"); + session_path.with_file_name(format!("{file_name}.before-clear-{timestamp}.bak")) +} + fn render_repl_help() -> String { [ "REPL".to_string(), " /exit Quit the REPL".to_string(), " /quit Quit the REPL".to_string(), " Up/Down Navigate prompt history".to_string(), - " Tab Complete slash commands".to_string(), + " Ctrl-R Reverse-search prompt history".to_string(), + " Tab Complete commands, modes, and recent sessions".to_string(), " Ctrl-C Clear input (or exit on empty prompt)".to_string(), " Shift+Enter/Ctrl+J Insert a newline".to_string(), + " Auto-save .claw/sessions/.jsonl".to_string(), + " Resume latest /resume latest".to_string(), + " Browse sessions /session list".to_string(), + " Show prompt history /history [count]".to_string(), String::new(), - render_slash_command_help(), + render_slash_command_help_filtered(STUB_COMMANDS), ] .join( " @@ -1618,6 +5076,93 @@ fn render_repl_help() -> String { ) } +fn print_status_snapshot( + model: &str, + permission_mode: PermissionMode, + output_format: CliOutputFormat, +) -> Result<(), Box> { + let usage = StatusUsage { + message_count: 0, + turns: 0, + latest: TokenUsage::default(), + cumulative: TokenUsage::default(), + estimated_tokens: 0, + }; + let context = status_context(None)?; + match output_format { + CliOutputFormat::Text => println!( + "{}", + format_status_report(model, usage, permission_mode.as_str(), &context) + ), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&status_json_value( + Some(model), + usage, + permission_mode.as_str(), + &context, + ))? + ), + } + Ok(()) +} + +fn status_json_value( + model: Option<&str>, + usage: StatusUsage, + permission_mode: &str, + context: &StatusContext, +) -> serde_json::Value { + json!({ + "kind": "status", + "model": model, + "permission_mode": permission_mode, + "usage": { + "messages": usage.message_count, + "turns": usage.turns, + "latest_total": usage.latest.total_tokens(), + "cumulative_input": usage.cumulative.input_tokens, + "cumulative_output": usage.cumulative.output_tokens, + "cumulative_total": usage.cumulative.total_tokens(), + "estimated_tokens": usage.estimated_tokens, + }, + "workspace": { + "cwd": context.cwd, + "project_root": context.project_root, + "git_branch": context.git_branch, + "git_state": context.git_summary.headline(), + "changed_files": context.git_summary.changed_files, + "staged_files": context.git_summary.staged_files, + "unstaged_files": context.git_summary.unstaged_files, + "untracked_files": context.git_summary.untracked_files, + "session": context.session_path.as_ref().map_or_else(|| "live-repl".to_string(), |path| path.display().to_string()), + "session_id": context.session_path.as_ref().and_then(|path| { + // Session files are named .jsonl directly under + // .claw/sessions/. Extract the stem (drop the .jsonl extension). + path.file_stem().map(|n| n.to_string_lossy().into_owned()) + }), + "loaded_config_files": context.loaded_config_files, + "discovered_config_files": context.discovered_config_files, + "memory_file_count": context.memory_file_count, + }, + "sandbox": { + "enabled": context.sandbox_status.enabled, + "active": context.sandbox_status.active, + "supported": context.sandbox_status.supported, + "in_container": context.sandbox_status.in_container, + "requested_namespace": context.sandbox_status.requested.namespace_restrictions, + "active_namespace": context.sandbox_status.namespace_active, + "requested_network": context.sandbox_status.requested.network_isolation, + "active_network": context.sandbox_status.network_active, + "filesystem_mode": context.sandbox_status.filesystem_mode.as_str(), + "filesystem_active": context.sandbox_status.filesystem_active, + "allowed_mounts": context.sandbox_status.allowed_mounts, + "markers": context.sandbox_status.container_markers, + "fallback_reason": context.sandbox_status.fallback_reason, + } + }) +} + fn status_context( session_path: Option<&Path>, ) -> Result> { @@ -1628,6 +5173,8 @@ fn status_context( let project_context = ProjectContext::discover_with_git(&cwd, DEFAULT_DATE)?; let (project_root, git_branch) = parse_git_status_metadata(project_context.git_status.as_deref()); + let git_summary = parse_git_workspace_summary(project_context.git_status.as_deref()); + let sandbox_status = resolve_sandbox_status(runtime_config.sandbox(), &cwd); Ok(StatusContext { cwd, session_path: session_path.map(Path::to_path_buf), @@ -1636,6 +5183,8 @@ fn status_context( memory_file_count: project_context.instruction_files.len(), project_root, git_branch, + git_summary, + sandbox_status, }) } @@ -1671,15 +5220,26 @@ fn format_status_report( Cwd {} Project root {} Git branch {} + Git state {} + Changed files {} + Staged {} + Unstaged {} + Untracked {} Session {} Config files loaded {}/{} - Memory files {}", + Memory files {} + Suggested flow /status → /diff → /commit", context.cwd.display(), context .project_root .as_ref() .map_or_else(|| "unknown".to_string(), |path| path.display().to_string()), context.git_branch.as_deref().unwrap_or("unknown"), + context.git_summary.headline(), + context.git_summary.changed_files, + context.git_summary.staged_files, + context.git_summary.unstaged_files, + context.git_summary.untracked_files, context.session_path.as_ref().map_or_else( || "live-repl".to_string(), |path| path.display().to_string() @@ -1688,6 +5248,7 @@ fn format_status_report( context.discovered_config_files, context.memory_file_count, ), + format_sandbox_report(&context.sandbox_status), ] .join( " @@ -1696,6 +5257,137 @@ fn format_status_report( ) } +fn format_sandbox_report(status: &runtime::SandboxStatus) -> String { + format!( + "Sandbox + Enabled {} + Active {} + Supported {} + In container {} + Requested ns {} + Active ns {} + Requested net {} + Active net {} + Filesystem mode {} + Filesystem active {} + Allowed mounts {} + Markers {} + Fallback reason {}", + status.enabled, + status.active, + status.supported, + status.in_container, + status.requested.namespace_restrictions, + status.namespace_active, + status.requested.network_isolation, + status.network_active, + status.filesystem_mode.as_str(), + status.filesystem_active, + if status.allowed_mounts.is_empty() { + "".to_string() + } else { + status.allowed_mounts.join(", ") + }, + if status.container_markers.is_empty() { + "".to_string() + } else { + status.container_markers.join(", ") + }, + status + .fallback_reason + .clone() + .unwrap_or_else(|| "".to_string()), + ) +} + +fn format_commit_preflight_report(branch: Option<&str>, summary: GitWorkspaceSummary) -> String { + format!( + "Commit + Result ready + Branch {} + Workspace {} + Changed files {} + Action create a git commit from the current workspace changes", + branch.unwrap_or("unknown"), + summary.headline(), + summary.changed_files, + ) +} + +fn format_commit_skipped_report() -> String { + "Commit + Result skipped + Reason no workspace changes + Action create a git commit from the current workspace changes + Next /status to inspect context · /diff to inspect repo changes" + .to_string() +} + +fn print_sandbox_status_snapshot( + output_format: CliOutputFormat, +) -> Result<(), Box> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader + .load() + .unwrap_or_else(|_| runtime::RuntimeConfig::empty()); + let status = resolve_sandbox_status(runtime_config.sandbox(), &cwd); + match output_format { + CliOutputFormat::Text => println!("{}", format_sandbox_report(&status)), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&sandbox_json_value(&status))? + ), + } + Ok(()) +} + +fn sandbox_json_value(status: &runtime::SandboxStatus) -> serde_json::Value { + json!({ + "kind": "sandbox", + "enabled": status.enabled, + "active": status.active, + "supported": status.supported, + "in_container": status.in_container, + "requested_namespace": status.requested.namespace_restrictions, + "active_namespace": status.namespace_active, + "requested_network": status.requested.network_isolation, + "active_network": status.network_active, + "filesystem_mode": status.filesystem_mode.as_str(), + "filesystem_active": status.filesystem_active, + "allowed_mounts": status.allowed_mounts, + "markers": status.container_markers, + "fallback_reason": status.fallback_reason, + }) +} + +fn render_help_topic(topic: LocalHelpTopic) -> String { + match topic { + LocalHelpTopic::Status => "Status + Usage claw status + Purpose show the local workspace snapshot without entering the REPL + Output model, permissions, git state, config files, and sandbox status + Related /status · claw --resume latest /status" + .to_string(), + LocalHelpTopic::Sandbox => "Sandbox + Usage claw sandbox + Purpose inspect the resolved sandbox and isolation state for the current directory + Output namespace, network, filesystem, and fallback details + Related /sandbox · claw status" + .to_string(), + LocalHelpTopic::Doctor => "Doctor + Usage claw doctor + Purpose diagnose local auth, config, workspace, sandbox, and build metadata + Output local-only health report; no provider request or session resume required + Related /doctor · claw --resume latest /doctor" + .to_string(), + } +} + +fn print_help_topic(topic: LocalHelpTopic) { + println!("{}", render_help_topic(topic)); +} + fn render_config_report(section: Option<&str>) -> Result> { let cwd = env::current_dir()?; let loader = ConfigLoader::default_for(&cwd); @@ -1741,9 +5433,12 @@ fn render_config_report(section: Option<&str>) -> Result runtime_config.get("env"), "hooks" => runtime_config.get("hooks"), "model" => runtime_config.get("model"), + "plugins" => runtime_config + .get("plugins") + .or_else(|| runtime_config.get("enabledPlugins")), other => { lines.push(format!( - " Unsupported config section '{other}'. Use env, hooks, or model." + " Unsupported config section '{other}'. Use env, hooks, model, or plugins." )); return Ok(lines.join( " @@ -1772,6 +5467,49 @@ fn render_config_report(section: Option<&str>) -> Result, +) -> Result> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let discovered = loader.discover(); + let runtime_config = loader.load()?; + + let loaded_paths: Vec<_> = runtime_config + .loaded_entries() + .iter() + .map(|e| e.path.display().to_string()) + .collect(); + + let files: Vec<_> = discovered + .iter() + .map(|e| { + let source = match e.source { + ConfigSource::User => "user", + ConfigSource::Project => "project", + ConfigSource::Local => "local", + }; + let loaded = runtime_config + .loaded_entries() + .iter() + .any(|le| le.path == e.path); + serde_json::json!({ + "path": e.path.display().to_string(), + "source": source, + "loaded": loaded, + }) + }) + .collect(); + + Ok(serde_json::json!({ + "kind": "config", + "cwd": cwd.display().to_string(), + "loaded_files": loaded_paths.len(), + "merged_keys": runtime_config.merged().len(), + "files": files, + })) +} + fn render_memory_report() -> Result> { let cwd = env::current_dir()?; let project_context = ProjectContext::discover(&cwd, DEFAULT_DATE)?; @@ -1811,16 +5549,52 @@ fn render_memory_report() -> Result> { )) } +fn render_memory_json() -> Result> { + let cwd = env::current_dir()?; + let project_context = ProjectContext::discover(&cwd, DEFAULT_DATE)?; + let files: Vec<_> = project_context + .instruction_files + .iter() + .map(|f| { + json!({ + "path": f.path.display().to_string(), + "lines": f.content.lines().count(), + "preview": f.content.lines().next().unwrap_or("").trim(), + }) + }) + .collect(); + Ok(json!({ + "kind": "memory", + "cwd": cwd.display().to_string(), + "instruction_files": files.len(), + "files": files, + })) +} + fn init_claude_md() -> Result> { let cwd = env::current_dir()?; Ok(initialize_repo(&cwd)?.render()) } -fn run_init() -> Result<(), Box> { - println!("{}", init_claude_md()?); +fn run_init(output_format: CliOutputFormat) -> Result<(), Box> { + let message = init_claude_md()?; + match output_format { + CliOutputFormat::Text => println!("{message}"), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&init_json_value(&message))? + ), + } Ok(()) } +fn init_json_value(message: &str) -> serde_json::Value { + json!({ + "kind": "init", + "message": message, + }) +} + fn normalize_permission_mode(mode: &str) -> Option<&'static str> { match mode.trim() { "read-only" => Some("read-only"), @@ -1831,22 +5605,447 @@ fn normalize_permission_mode(mode: &str) -> Option<&'static str> { } fn render_diff_report() -> Result> { - let output = std::process::Command::new("git") - .args(["diff", "--", ":(exclude).omx"]) - .current_dir(env::current_dir()?) - .output()?; - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - return Err(format!("git diff failed: {stderr}").into()); + render_diff_report_for(&env::current_dir()?) +} + +fn render_diff_report_for(cwd: &Path) -> Result> { + // Verify we are inside a git repository before calling `git diff`. + // Running `git diff --cached` outside a git tree produces a misleading + // "unknown option `cached`" error because git falls back to --no-index mode. + let in_git_repo = std::process::Command::new("git") + .args(["rev-parse", "--is-inside-work-tree"]) + .current_dir(cwd) + .output() + .map(|o| o.status.success()) + .unwrap_or(false); + if !in_git_repo { + return Ok(format!( + "Diff\n Result no git repository\n Detail {} is not inside a git project", + cwd.display() + )); } - let diff = String::from_utf8(output.stdout)?; - if diff.trim().is_empty() { + let staged = run_git_diff_command_in(cwd, &["diff", "--cached"])?; + let unstaged = run_git_diff_command_in(cwd, &["diff"])?; + if staged.trim().is_empty() && unstaged.trim().is_empty() { return Ok( "Diff\n Result clean working tree\n Detail no current changes" .to_string(), ); } - Ok(format!("Diff\n\n{}", diff.trim_end())) + + let mut sections = Vec::new(); + if !staged.trim().is_empty() { + sections.push(format!("Staged changes:\n{}", staged.trim_end())); + } + if !unstaged.trim().is_empty() { + sections.push(format!("Unstaged changes:\n{}", unstaged.trim_end())); + } + + Ok(format!("Diff\n\n{}", sections.join("\n\n"))) +} + +fn render_diff_json_for(cwd: &Path) -> Result> { + let in_git_repo = std::process::Command::new("git") + .args(["rev-parse", "--is-inside-work-tree"]) + .current_dir(cwd) + .output() + .map(|o| o.status.success()) + .unwrap_or(false); + if !in_git_repo { + return Ok(serde_json::json!({ + "kind": "diff", + "result": "no_git_repo", + "detail": format!("{} is not inside a git project", cwd.display()), + })); + } + let staged = run_git_diff_command_in(cwd, &["diff", "--cached"])?; + let unstaged = run_git_diff_command_in(cwd, &["diff"])?; + Ok(serde_json::json!({ + "kind": "diff", + "result": if staged.trim().is_empty() && unstaged.trim().is_empty() { "clean" } else { "changes" }, + "staged": staged.trim(), + "unstaged": unstaged.trim(), + })) +} + +fn run_git_diff_command_in( + cwd: &Path, + args: &[&str], +) -> Result> { + let output = std::process::Command::new("git") + .args(args) + .current_dir(cwd) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); + } + Ok(String::from_utf8(output.stdout)?) +} + +fn render_teleport_report(target: &str) -> Result> { + let cwd = env::current_dir()?; + + let file_list = Command::new("rg") + .args(["--files"]) + .current_dir(&cwd) + .output()?; + let file_matches = if file_list.status.success() { + String::from_utf8(file_list.stdout)? + .lines() + .filter(|line| line.contains(target)) + .take(10) + .map(ToOwned::to_owned) + .collect::>() + } else { + Vec::new() + }; + + let content_output = Command::new("rg") + .args(["-n", "-S", "--color", "never", target, "."]) + .current_dir(&cwd) + .output()?; + + let mut lines = vec![ + "Teleport".to_string(), + format!(" Target {target}"), + " Action search workspace files and content for the target".to_string(), + ]; + if !file_matches.is_empty() { + lines.push(String::new()); + lines.push("File matches".to_string()); + lines.extend(file_matches.into_iter().map(|path| format!(" {path}"))); + } + + if content_output.status.success() { + let matches = String::from_utf8(content_output.stdout)?; + if !matches.trim().is_empty() { + lines.push(String::new()); + lines.push("Content matches".to_string()); + lines.push(truncate_for_prompt(&matches, 4_000)); + } + } + + if lines.len() == 1 { + lines.push(" Result no matches found".to_string()); + } + + Ok(lines.join("\n")) +} + +fn render_last_tool_debug_report(session: &Session) -> Result> { + let last_tool_use = session + .messages + .iter() + .rev() + .find_map(|message| { + message.blocks.iter().rev().find_map(|block| match block { + ContentBlock::ToolUse { id, name, input } => { + Some((id.clone(), name.clone(), input.clone())) + } + _ => None, + }) + }) + .ok_or_else(|| "no prior tool call found in session".to_string())?; + + let tool_result = session.messages.iter().rev().find_map(|message| { + message.blocks.iter().rev().find_map(|block| match block { + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } if tool_use_id == &last_tool_use.0 => { + Some((tool_name.clone(), output.clone(), *is_error)) + } + _ => None, + }) + }); + + let mut lines = vec![ + "Debug tool call".to_string(), + " Action inspect the last recorded tool call and its result".to_string(), + format!(" Tool id {}", last_tool_use.0), + format!(" Tool name {}", last_tool_use.1), + " Input".to_string(), + indent_block(&last_tool_use.2, 4), + ]; + + match tool_result { + Some((tool_name, output, is_error)) => { + lines.push(" Result".to_string()); + lines.push(format!(" name {tool_name}")); + lines.push(format!( + " status {}", + if is_error { "error" } else { "ok" } + )); + lines.push(indent_block(&output, 4)); + } + None => lines.push(" Result missing tool result".to_string()), + } + + Ok(lines.join("\n")) +} + +fn indent_block(value: &str, spaces: usize) -> String { + let indent = " ".repeat(spaces); + value + .lines() + .map(|line| format!("{indent}{line}")) + .collect::>() + .join("\n") +} + +fn validate_no_args( + command_name: &str, + args: Option<&str>, +) -> Result<(), Box> { + if let Some(args) = args.map(str::trim).filter(|value| !value.is_empty()) { + return Err(format!( + "{command_name} does not accept arguments. Received: {args}\nUsage: {command_name}" + ) + .into()); + } + Ok(()) +} + +fn format_bughunter_report(scope: Option<&str>) -> String { + format!( + "Bughunter + Scope {} + Action inspect the selected code for likely bugs and correctness issues + Output findings should include file paths, severity, and suggested fixes", + scope.unwrap_or("the current repository") + ) +} + +fn format_ultraplan_report(task: Option<&str>) -> String { + format!( + "Ultraplan + Task {} + Action break work into a multi-step execution plan + Output plan should cover goals, risks, sequencing, verification, and rollback", + task.unwrap_or("the current repo work") + ) +} + +fn format_pr_report(branch: &str, context: Option<&str>) -> String { + format!( + "PR + Branch {branch} + Context {} + Action draft or create a pull request for the current branch + Output title and markdown body suitable for GitHub", + context.unwrap_or("none") + ) +} + +fn format_issue_report(context: Option<&str>) -> String { + format!( + "Issue + Context {} + Action draft or create a GitHub issue from the current context + Output title and markdown body suitable for GitHub", + context.unwrap_or("none") + ) +} + +fn git_output(args: &[&str]) -> Result> { + let output = Command::new("git") + .args(args) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); + } + Ok(String::from_utf8(output.stdout)?) +} + +fn git_status_ok(args: &[&str]) -> Result<(), Box> { + let output = Command::new("git") + .args(args) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); + } + Ok(()) +} + +fn command_exists(name: &str) -> bool { + Command::new("which") + .arg(name) + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn write_temp_text_file( + filename: &str, + contents: &str, +) -> Result> { + let path = env::temp_dir().join(filename); + fs::write(&path, contents)?; + Ok(path) +} + +const DEFAULT_HISTORY_LIMIT: usize = 20; + +fn parse_history_count(raw: Option<&str>) -> Result { + let Some(raw) = raw else { + return Ok(DEFAULT_HISTORY_LIMIT); + }; + let parsed: usize = raw + .parse() + .map_err(|_| format!("history: invalid count '{raw}'. Expected a positive integer."))?; + if parsed == 0 { + return Err("history: count must be greater than 0.".to_string()); + } + Ok(parsed) +} + +fn format_history_timestamp(timestamp_ms: u64) -> String { + let secs = timestamp_ms / 1_000; + let subsec_ms = timestamp_ms % 1_000; + let days_since_epoch = secs / 86_400; + let seconds_of_day = secs % 86_400; + let hours = seconds_of_day / 3_600; + let minutes = (seconds_of_day % 3_600) / 60; + let seconds = seconds_of_day % 60; + + let (year, month, day) = civil_from_days(i64::try_from(days_since_epoch).unwrap_or(0)); + format!("{year:04}-{month:02}-{day:02}T{hours:02}:{minutes:02}:{seconds:02}.{subsec_ms:03}Z") +} + +// Computes civil (Gregorian) year/month/day from days since the Unix epoch +// (1970-01-01) using Howard Hinnant's `civil_from_days` algorithm. +fn civil_from_days(days: i64) -> (i32, u32, u32) { + let z = days + 719_468; + let era = if z >= 0 { + z / 146_097 + } else { + (z - 146_096) / 146_097 + }; + let doe = (z - era * 146_097) as u64; // [0, 146_096] + let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365; // [0, 399] + let y = yoe as i64 + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); // [0, 365] + let mp = (5 * doy + 2) / 153; // [0, 11] + let d = doy - (153 * mp + 2) / 5 + 1; // [1, 31] + let m = if mp < 10 { mp + 3 } else { mp - 9 }; // [1, 12] + let y = y + i64::from(m <= 2); + (y as i32, m as u32, d as u32) +} + +fn render_prompt_history_report(entries: &[PromptHistoryEntry], limit: usize) -> String { + if entries.is_empty() { + return "Prompt history\n Result no prompts recorded yet".to_string(); + } + + let total = entries.len(); + let start = total.saturating_sub(limit); + let shown = &entries[start..]; + let mut lines = vec![ + "Prompt history".to_string(), + format!(" Total {total}"), + format!(" Showing {} most recent", shown.len()), + format!(" Reverse search Ctrl-R in the REPL"), + String::new(), + ]; + for (offset, entry) in shown.iter().enumerate() { + let absolute_index = start + offset + 1; + let timestamp = format_history_timestamp(entry.timestamp_ms); + let first_line = entry.text.lines().next().unwrap_or("").trim(); + let display = if first_line.chars().count() > 80 { + let truncated: String = first_line.chars().take(77).collect(); + format!("{truncated}...") + } else { + first_line.to_string() + }; + lines.push(format!(" {absolute_index:>3}. [{timestamp}] {display}")); + } + lines.join("\n") +} + +fn collect_session_prompt_history(session: &Session) -> Vec { + if !session.prompt_history.is_empty() { + return session + .prompt_history + .iter() + .map(|entry| PromptHistoryEntry { + timestamp_ms: entry.timestamp_ms, + text: entry.text.clone(), + }) + .collect(); + } + let timestamp_ms = session.updated_at_ms; + session + .messages + .iter() + .filter(|message| message.role == MessageRole::User) + .filter_map(|message| { + message.blocks.iter().find_map(|block| match block { + ContentBlock::Text { text } => Some(PromptHistoryEntry { + timestamp_ms, + text: text.clone(), + }), + _ => None, + }) + }) + .collect() +} + +fn recent_user_context(session: &Session, limit: usize) -> String { + let requests = session + .messages + .iter() + .filter(|message| message.role == MessageRole::User) + .filter_map(|message| { + message.blocks.iter().find_map(|block| match block { + ContentBlock::Text { text } => Some(text.trim().to_string()), + _ => None, + }) + }) + .rev() + .take(limit) + .collect::>(); + + if requests.is_empty() { + "".to_string() + } else { + requests + .into_iter() + .rev() + .enumerate() + .map(|(index, text)| format!("{}. {}", index + 1, text)) + .collect::>() + .join("\n") + } +} + +fn truncate_for_prompt(value: &str, limit: usize) -> String { + if value.chars().count() <= limit { + value.trim().to_string() + } else { + let truncated = value.chars().take(limit).collect::(); + format!("{}\n…[truncated]", truncated.trim_end()) + } +} + +fn sanitize_generated_message(value: &str) -> String { + value.trim().trim_matches('`').trim().replace("\r\n", "\n") +} + +fn parse_titled_body(value: &str) -> Option<(String, String)> { + let normalized = sanitize_generated_message(value); + let title = normalized + .lines() + .find_map(|line| line.strip_prefix("TITLE:").map(str::trim))?; + let body_start = normalized.find("BODY:")?; + let body = normalized[body_start + "BODY:".len()..].trim(); + Some((title.to_string(), body.to_string())) } fn render_version_report() -> String { @@ -1870,12 +6069,6 @@ fn render_export_text(session: &Session) -> String { for block in &message.blocks { match block { ContentBlock::Text { text } => lines.push(text.clone()), - ContentBlock::Thinking { thinking, .. } => { - lines.push(format!("[thinking] {thinking}")); - } - ContentBlock::RedactedThinking { .. } => { - lines.push("[thinking] ".to_string()); - } ContentBlock::ToolUse { id, name, input } => { lines.push(format!("[tool_use id={id} name={name}] {input}")); } @@ -1950,6 +6143,172 @@ fn resolve_export_path( Ok(cwd.join(final_name)) } +const SESSION_MARKDOWN_TOOL_SUMMARY_LIMIT: usize = 280; + +fn summarize_tool_payload_for_markdown(payload: &str) -> String { + let compact = match serde_json::from_str::(payload) { + Ok(value) => value.to_string(), + Err(_) => payload.split_whitespace().collect::>().join(" "), + }; + if compact.is_empty() { + return String::new(); + } + truncate_for_summary(&compact, SESSION_MARKDOWN_TOOL_SUMMARY_LIMIT) +} + +fn run_export( + session_reference: &str, + output_path: Option<&Path>, + output_format: CliOutputFormat, +) -> Result<(), Box> { + let handle = resolve_session_reference(session_reference)?; + let session = Session::load_from_path(&handle.path)?; + let markdown = render_session_markdown(&session, &handle.id, &handle.path); + + if let Some(path) = output_path { + fs::write(path, &markdown)?; + let report = format!( + "Export\n Result wrote markdown transcript\n File {}\n Session {}\n Messages {}", + path.display(), + handle.id, + session.messages.len(), + ); + match output_format { + CliOutputFormat::Text => println!("{report}"), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "export", + "message": report, + "session_id": handle.id, + "file": path.display().to_string(), + "messages": session.messages.len(), + }))? + ), + } + return Ok(()); + } + + match output_format { + CliOutputFormat::Text => { + print!("{markdown}"); + if !markdown.ends_with('\n') { + println!(); + } + } + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "export", + "session_id": handle.id, + "file": handle.path.display().to_string(), + "messages": session.messages.len(), + "markdown": markdown, + }))? + ), + } + Ok(()) +} + +fn render_session_markdown(session: &Session, session_id: &str, session_path: &Path) -> String { + let mut lines = vec![ + "# Conversation Export".to_string(), + String::new(), + format!("- **Session**: `{session_id}`"), + format!("- **File**: `{}`", session_path.display()), + format!("- **Messages**: {}", session.messages.len()), + ]; + if let Some(workspace_root) = session.workspace_root() { + lines.push(format!("- **Workspace**: `{}`", workspace_root.display())); + } + if let Some(fork) = &session.fork { + let branch = fork.branch_name.as_deref().unwrap_or("(unnamed)"); + lines.push(format!( + "- **Forked from**: `{}` (branch `{branch}`)", + fork.parent_session_id + )); + } + if let Some(compaction) = &session.compaction { + lines.push(format!( + "- **Compactions**: {} (last removed {} messages)", + compaction.count, compaction.removed_message_count + )); + } + lines.push(String::new()); + lines.push("---".to_string()); + lines.push(String::new()); + + for (index, message) in session.messages.iter().enumerate() { + let role = match message.role { + MessageRole::System => "System", + MessageRole::User => "User", + MessageRole::Assistant => "Assistant", + MessageRole::Tool => "Tool", + }; + lines.push(format!("## {}. {role}", index + 1)); + lines.push(String::new()); + for block in &message.blocks { + match block { + ContentBlock::Text { text } => { + let trimmed = text.trim_end(); + if !trimmed.is_empty() { + lines.push(trimmed.to_string()); + lines.push(String::new()); + } + } + ContentBlock::ToolUse { id, name, input } => { + lines.push(format!( + "**Tool call** `{name}` _(id `{}`)_", + short_tool_id(id) + )); + let summary = summarize_tool_payload_for_markdown(input); + if !summary.is_empty() { + lines.push(format!("> {summary}")); + } + lines.push(String::new()); + } + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } => { + let status = if *is_error { "error" } else { "ok" }; + lines.push(format!( + "**Tool result** `{tool_name}` _(id `{}`, {status})_", + short_tool_id(tool_use_id) + )); + let summary = summarize_tool_payload_for_markdown(output); + if !summary.is_empty() { + lines.push(format!("> {summary}")); + } + lines.push(String::new()); + } + } + } + if let Some(usage) = message.usage { + lines.push(format!( + "_tokens: in={} out={} cache_create={} cache_read={}_", + usage.input_tokens, + usage.output_tokens, + usage.cache_creation_input_tokens, + usage.cache_read_input_tokens, + )); + lines.push(String::new()); + } + } + lines.join("\n") +} + +fn short_tool_id(id: &str) -> String { + let char_count = id.chars().count(); + if char_count <= 12 { + return id.to_string(); + } + let prefix: String = id.chars().take(12).collect(); + format!("{prefix}…") +} + fn build_system_prompt() -> Result, Box> { Ok(load_system_prompt( env::current_dir()?, @@ -1959,22 +6318,522 @@ fn build_system_prompt() -> Result, Box> { )?) } +fn build_runtime_plugin_state() -> Result> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + build_runtime_plugin_state_with_loader(&cwd, &loader, &runtime_config) +} + +fn build_runtime_plugin_state_with_loader( + cwd: &Path, + loader: &ConfigLoader, + runtime_config: &runtime::RuntimeConfig, +) -> Result> { + let plugin_manager = build_plugin_manager(cwd, loader, runtime_config); + let plugin_registry = plugin_manager.plugin_registry()?; + let plugin_hook_config = + runtime_hook_config_from_plugin_hooks(plugin_registry.aggregated_hooks()?); + let feature_config = runtime_config + .feature_config() + .clone() + .with_hooks(runtime_config.hooks().merged(&plugin_hook_config)); + let (mcp_state, runtime_tools) = build_runtime_mcp_state(runtime_config)?; + let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_registry.aggregated_tools()?)? + .with_runtime_tools(runtime_tools)?; + Ok(RuntimePluginState { + feature_config, + tool_registry, + plugin_registry, + mcp_state, + }) +} + +fn build_plugin_manager( + cwd: &Path, + loader: &ConfigLoader, + runtime_config: &runtime::RuntimeConfig, +) -> PluginManager { + let plugin_settings = runtime_config.plugins(); + let mut plugin_config = PluginManagerConfig::new(loader.config_home().to_path_buf()); + plugin_config.enabled_plugins = plugin_settings.enabled_plugins().clone(); + plugin_config.external_dirs = plugin_settings + .external_directories() + .iter() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)) + .collect(); + plugin_config.install_root = plugin_settings + .install_root() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + plugin_config.registry_path = plugin_settings + .registry_path() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + plugin_config.bundled_root = plugin_settings + .bundled_root() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + PluginManager::new(plugin_config) +} + +fn resolve_plugin_path(cwd: &Path, config_home: &Path, value: &str) -> PathBuf { + let path = PathBuf::from(value); + if path.is_absolute() { + path + } else if value.starts_with('.') { + cwd.join(path) + } else { + config_home.join(path) + } +} + +fn runtime_hook_config_from_plugin_hooks(hooks: PluginHooks) -> runtime::RuntimeHookConfig { + runtime::RuntimeHookConfig::new( + hooks.pre_tool_use, + hooks.post_tool_use, + hooks.post_tool_use_failure, + ) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct InternalPromptProgressState { + command_label: &'static str, + task_label: String, + step: usize, + phase: String, + detail: Option, + saw_final_text: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InternalPromptProgressEvent { + Started, + Update, + Heartbeat, + Complete, + Failed, +} + +#[derive(Debug)] +struct InternalPromptProgressShared { + state: Mutex, + output_lock: Mutex<()>, + started_at: Instant, +} + +#[derive(Debug, Clone)] +struct InternalPromptProgressReporter { + shared: Arc, +} + +#[derive(Debug)] +struct InternalPromptProgressRun { + reporter: InternalPromptProgressReporter, + heartbeat_stop: Option>, + heartbeat_handle: Option>, +} + +impl InternalPromptProgressReporter { + fn ultraplan(task: &str) -> Self { + Self { + shared: Arc::new(InternalPromptProgressShared { + state: Mutex::new(InternalPromptProgressState { + command_label: "Ultraplan", + task_label: task.to_string(), + step: 0, + phase: "planning started".to_string(), + detail: Some(format!("task: {task}")), + saw_final_text: false, + }), + output_lock: Mutex::new(()), + started_at: Instant::now(), + }), + } + } + + fn emit(&self, event: InternalPromptProgressEvent, error: Option<&str>) { + let snapshot = self.snapshot(); + let line = format_internal_prompt_progress_line(event, &snapshot, self.elapsed(), error); + self.write_line(&line); + } + + fn mark_model_phase(&self) { + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + state.step += 1; + state.phase = if state.step == 1 { + "analyzing request".to_string() + } else { + "reviewing findings".to_string() + }; + state.detail = Some(format!("task: {}", state.task_label)); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn mark_tool_phase(&self, name: &str, input: &str) { + let detail = describe_tool_progress(name, input); + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + state.step += 1; + state.phase = format!("running {name}"); + state.detail = Some(detail); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn mark_text_phase(&self, text: &str) { + let trimmed = text.trim(); + if trimmed.is_empty() { + return; + } + let detail = truncate_for_summary(first_visible_line(trimmed), 120); + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + if state.saw_final_text { + return; + } + state.saw_final_text = true; + state.step += 1; + state.phase = "drafting final plan".to_string(); + state.detail = (!detail.is_empty()).then_some(detail); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn emit_heartbeat(&self) { + let snapshot = self.snapshot(); + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Heartbeat, + &snapshot, + self.elapsed(), + None, + )); + } + + fn snapshot(&self) -> InternalPromptProgressState { + self.shared + .state + .lock() + .expect("internal prompt progress state poisoned") + .clone() + } + + fn elapsed(&self) -> Duration { + self.shared.started_at.elapsed() + } + + fn write_line(&self, line: &str) { + let _guard = self + .shared + .output_lock + .lock() + .expect("internal prompt progress output lock poisoned"); + let mut stdout = io::stdout(); + let _ = writeln!(stdout, "{line}"); + let _ = stdout.flush(); + } +} + +impl InternalPromptProgressRun { + fn start_ultraplan(task: &str) -> Self { + let reporter = InternalPromptProgressReporter::ultraplan(task); + reporter.emit(InternalPromptProgressEvent::Started, None); + + let (heartbeat_stop, heartbeat_rx) = mpsc::channel(); + let heartbeat_reporter = reporter.clone(); + let heartbeat_handle = thread::spawn(move || loop { + match heartbeat_rx.recv_timeout(INTERNAL_PROGRESS_HEARTBEAT_INTERVAL) { + Ok(()) | Err(RecvTimeoutError::Disconnected) => break, + Err(RecvTimeoutError::Timeout) => heartbeat_reporter.emit_heartbeat(), + } + }); + + Self { + reporter, + heartbeat_stop: Some(heartbeat_stop), + heartbeat_handle: Some(heartbeat_handle), + } + } + + fn reporter(&self) -> InternalPromptProgressReporter { + self.reporter.clone() + } + + fn finish_success(&mut self) { + self.stop_heartbeat(); + self.reporter + .emit(InternalPromptProgressEvent::Complete, None); + } + + fn finish_failure(&mut self, error: &str) { + self.stop_heartbeat(); + self.reporter + .emit(InternalPromptProgressEvent::Failed, Some(error)); + } + + fn stop_heartbeat(&mut self) { + if let Some(sender) = self.heartbeat_stop.take() { + let _ = sender.send(()); + } + if let Some(handle) = self.heartbeat_handle.take() { + let _ = handle.join(); + } + } +} + +impl Drop for InternalPromptProgressRun { + fn drop(&mut self) { + self.stop_heartbeat(); + } +} + +fn format_internal_prompt_progress_line( + event: InternalPromptProgressEvent, + snapshot: &InternalPromptProgressState, + elapsed: Duration, + error: Option<&str>, +) -> String { + let elapsed_seconds = elapsed.as_secs(); + let step_label = if snapshot.step == 0 { + "current step pending".to_string() + } else { + format!("current step {}", snapshot.step) + }; + let mut status_bits = vec![step_label, format!("phase {}", snapshot.phase)]; + if let Some(detail) = snapshot + .detail + .as_deref() + .filter(|detail| !detail.is_empty()) + { + status_bits.push(detail.to_string()); + } + let status = status_bits.join(" · "); + match event { + InternalPromptProgressEvent::Started => { + format!( + "🧭 {} status · planning started · {status}", + snapshot.command_label + ) + } + InternalPromptProgressEvent::Update => { + format!("… {} status · {status}", snapshot.command_label) + } + InternalPromptProgressEvent::Heartbeat => format!( + "… {} heartbeat · {elapsed_seconds}s elapsed · {status}", + snapshot.command_label + ), + InternalPromptProgressEvent::Complete => format!( + "✔ {} status · completed · {elapsed_seconds}s elapsed · {} steps total", + snapshot.command_label, snapshot.step + ), + InternalPromptProgressEvent::Failed => format!( + "✘ {} status · failed · {elapsed_seconds}s elapsed · {}", + snapshot.command_label, + error.unwrap_or("unknown error") + ), + } +} + +fn describe_tool_progress(name: &str, input: &str) -> String { + let parsed: serde_json::Value = + serde_json::from_str(input).unwrap_or(serde_json::Value::String(input.to_string())); + match name { + "bash" | "Bash" => { + let command = parsed + .get("command") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + if command.is_empty() { + "running shell command".to_string() + } else { + format!("command {}", truncate_for_summary(command.trim(), 100)) + } + } + "read_file" | "Read" => format!("reading {}", extract_tool_path(&parsed)), + "write_file" | "Write" => format!("writing {}", extract_tool_path(&parsed)), + "edit_file" | "Edit" => format!("editing {}", extract_tool_path(&parsed)), + "glob_search" | "Glob" => { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("glob `{pattern}` in {scope}") + } + "grep_search" | "Grep" => { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("grep `{pattern}` in {scope}") + } + "web_search" | "WebSearch" => parsed + .get("query") + .and_then(|value| value.as_str()) + .map_or_else( + || "running web search".to_string(), + |query| format!("query {}", truncate_for_summary(query, 100)), + ), + _ => { + let summary = summarize_tool_payload(input); + if summary.is_empty() { + format!("running {name}") + } else { + format!("{name}: {summary}") + } + } + } +} + +#[allow(clippy::needless_pass_by_value)] +#[allow(clippy::too_many_arguments)] fn build_runtime( session: Session, + session_id: &str, model: String, system_prompt: Vec, enable_tools: bool, + emit_output: bool, allowed_tools: Option, permission_mode: PermissionMode, -) -> Result, Box> -{ - Ok(ConversationRuntime::new( + progress_reporter: Option, +) -> Result> { + let runtime_plugin_state = build_runtime_plugin_state()?; + build_runtime_with_plugin_state( session, - AnthropicRuntimeClient::new(model, enable_tools, allowed_tools.clone())?, - CliToolExecutor::new(allowed_tools), - permission_policy(permission_mode), + session_id, + model, system_prompt, - )) + enable_tools, + emit_output, + allowed_tools, + permission_mode, + progress_reporter, + runtime_plugin_state, + ) +} + +#[allow(clippy::needless_pass_by_value)] +#[allow(clippy::too_many_arguments)] +fn build_runtime_with_plugin_state( + mut session: Session, + session_id: &str, + model: String, + system_prompt: Vec, + enable_tools: bool, + emit_output: bool, + allowed_tools: Option, + permission_mode: PermissionMode, + progress_reporter: Option, + runtime_plugin_state: RuntimePluginState, +) -> Result> { + // Persist the model in session metadata so resumed sessions can report it. + if session.model.is_none() { + session.model = Some(model.clone()); + } + let RuntimePluginState { + feature_config, + tool_registry, + plugin_registry, + mcp_state, + } = runtime_plugin_state; + plugin_registry.initialize()?; + let policy = permission_policy(permission_mode, &feature_config, &tool_registry) + .map_err(std::io::Error::other)?; + let mut runtime = ConversationRuntime::new_with_features( + session, + AnthropicRuntimeClient::new( + session_id, + model, + enable_tools, + emit_output, + allowed_tools.clone(), + tool_registry.clone(), + progress_reporter, + )?, + CliToolExecutor::new( + allowed_tools.clone(), + emit_output, + tool_registry.clone(), + mcp_state.clone(), + ), + policy, + system_prompt, + &feature_config, + ); + if emit_output { + runtime = runtime.with_hook_progress_reporter(Box::new(CliHookProgressReporter)); + } + Ok(BuiltRuntime::new(runtime, plugin_registry, mcp_state)) +} + +struct CliHookProgressReporter; + +impl runtime::HookProgressReporter for CliHookProgressReporter { + fn on_event(&mut self, event: &runtime::HookProgressEvent) { + match event { + runtime::HookProgressEvent::Started { + event, + tool_name, + command, + } => eprintln!( + "[hook {event_name}] {tool_name}: {command}", + event_name = event.as_str() + ), + runtime::HookProgressEvent::Completed { + event, + tool_name, + command, + } => eprintln!( + "[hook done {event_name}] {tool_name}: {command}", + event_name = event.as_str() + ), + runtime::HookProgressEvent::Cancelled { + event, + tool_name, + command, + } => eprintln!( + "[hook cancelled {event_name}] {tool_name}: {command}", + event_name = event.as_str() + ), + } + } } struct CliPermissionPrompter { @@ -1997,6 +6856,9 @@ impl runtime::PermissionPrompter for CliPermissionPrompter { println!(" Tool {}", request.tool_name); println!(" Current mode {}", self.current_mode.as_str()); println!(" Required mode {}", request.required_mode.as_str()); + if let Some(reason) = &request.reason { + println!(" Reason {reason}"); + } println!(" Input {}", request.input); print!("Approve this tool call? [y/N]: "); let _ = io::stdout().flush(); @@ -2023,190 +6885,1086 @@ impl runtime::PermissionPrompter for CliPermissionPrompter { } } +// NOTE: Despite the historical name `AnthropicRuntimeClient`, this struct +// now holds an `ApiProviderClient` which dispatches to Anthropic, xAI, +// OpenAI, or DashScope at construction time based on +// `detect_provider_kind(&model)`. The struct name is kept to avoid +// churning `BuiltRuntime` and every Deref/DerefMut site that references +// it. See ROADMAP #29 for the provider-dispatch routing fix. struct AnthropicRuntimeClient { runtime: tokio::runtime::Runtime, - client: ClawApiClient, + client: ApiProviderClient, + session_id: String, model: String, enable_tools: bool, + emit_output: bool, allowed_tools: Option, + tool_registry: GlobalToolRegistry, + progress_reporter: Option, + reasoning_effort: Option, } impl AnthropicRuntimeClient { fn new( + session_id: &str, model: String, enable_tools: bool, + emit_output: bool, allowed_tools: Option, + tool_registry: GlobalToolRegistry, + progress_reporter: Option, ) -> Result> { + // Dispatch to the correct provider at construction time. + // `ApiProviderClient` (exposed by the api crate as + // `ProviderClient`) is an enum over Anthropic / xAI / OpenAI + // variants, where xAI and OpenAI both use the OpenAI-compat + // wire format under the hood. We consult + // `detect_provider_kind(&resolved_model)` so model-name prefix + // routing (`openai/`, `gpt-`, `grok`, `qwen/`) wins over + // env-var presence. + // + // For Anthropic we build the client directly instead of going + // through `ApiProviderClient::from_model_with_anthropic_auth` + // so we can explicitly apply `api::read_base_url()` — that + // reads `ANTHROPIC_BASE_URL` and is required for the local + // mock-server test harness + // (`crates/rusty-claude-cli/tests/compact_output.rs`) to point + // claw at its fake Anthropic endpoint. We also attach a + // session-scoped prompt cache on the Anthropic path; the + // prompt cache is Anthropic-only so non-Anthropic variants + // skip it. + let resolved_model = api::resolve_model_alias(&model); + let client = match detect_provider_kind(&resolved_model) { + ProviderKind::Anthropic => { + let auth = resolve_cli_auth_source()?; + let inner = AnthropicClient::from_auth(auth) + .with_base_url(api::read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)); + ApiProviderClient::Anthropic(inner) + } + ProviderKind::Xai | ProviderKind::OpenAi => { + // The api crate's `ProviderClient::from_model_with_anthropic_auth` + // with `None` for the anthropic auth routes via + // `detect_provider_kind` and builds an + // `OpenAiCompatClient::from_env` with the matching + // `OpenAiCompatConfig` (openai / xai / dashscope). + // That reads the correct API-key env var and BASE_URL + // override internally, so this one call covers OpenAI, + // OpenRouter, xAI, DashScope, Ollama, and any other + // OpenAI-compat endpoint users configure via + // `OPENAI_BASE_URL` / `XAI_BASE_URL` / `DASHSCOPE_BASE_URL`. + ApiProviderClient::from_model_with_anthropic_auth(&resolved_model, None)? + } + }; Ok(Self { runtime: tokio::runtime::Runtime::new()?, - client: ClawApiClient::from_auth(resolve_cli_auth_source()?).with_base_url(api::read_base_url()), + client, + session_id: session_id.to_string(), model, enable_tools, + emit_output, allowed_tools, + tool_registry, + progress_reporter, + reasoning_effort: None, }) } + + fn set_reasoning_effort(&mut self, effort: Option) { + self.reasoning_effort = effort; + } } fn resolve_cli_auth_source() -> Result> { - Ok(resolve_startup_auth_source(|| { - let cwd = env::current_dir().map_err(api::ApiError::from)?; - let config = ConfigLoader::default_for(&cwd).load().map_err(|error| { - api::ApiError::Auth(format!("failed to load runtime OAuth config: {error}")) - })?; - Ok(config.oauth().cloned()) - })?) + let cwd = env::current_dir()?; + Ok(resolve_cli_auth_source_for_cwd(&cwd, default_oauth_config)?) +} + +fn resolve_cli_auth_source_for_cwd( + cwd: &Path, + default_oauth: F, +) -> Result +where + F: FnOnce() -> OAuthConfig, +{ + resolve_startup_auth_source(|| { + Ok(Some( + load_runtime_oauth_config_for(cwd)?.unwrap_or_else(default_oauth), + )) + }) +} + +fn load_runtime_oauth_config_for(cwd: &Path) -> Result, api::ApiError> { + let config = ConfigLoader::default_for(cwd).load().map_err(|error| { + api::ApiError::Auth(format!("failed to load runtime OAuth config: {error}")) + })?; + Ok(config.oauth().cloned()) } impl ApiClient for AnthropicRuntimeClient { #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_model_phase(); + } + let is_post_tool = request_ends_with_tool_result(&request); let message_request = MessageRequest { model: self.model.clone(), - max_tokens: DEFAULT_MAX_TOKENS, + max_tokens: max_tokens_for_model(&self.model), messages: convert_messages(&request.messages), system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), - tools: self.enable_tools.then(|| { - filter_tool_specs(self.allowed_tools.as_ref()) - .into_iter() - .map(|spec| ToolDefinition { - name: spec.name.to_string(), - description: Some(spec.description.to_string()), - input_schema: spec.input_schema, - }) - .collect() - }), + tools: self + .enable_tools + .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), tool_choice: self.enable_tools.then_some(ToolChoice::Auto), stream: true, + reasoning_effort: self.reasoning_effort.clone(), + ..Default::default() }; self.runtime.block_on(async { - let mut stream = self - .client - .stream_message(&message_request) - .await - .map_err(|error| RuntimeError::new(error.to_string()))?; - let mut stdout = io::stdout(); - let mut events = Vec::new(); - let mut pending_tool: Option<(String, String, String)> = None; - let mut saw_stop = false; + // When resuming after tool execution, apply a stall timeout on the + // first stream event. If the model does not respond within the + // deadline we drop the stalled connection and re-send the request as + // a continuation nudge (one retry only). + let max_attempts: usize = if is_post_tool { 2 } else { 1 }; - while let Some(event) = stream - .next_event() - .await - .map_err(|error| RuntimeError::new(error.to_string()))? - { - match event { - ApiStreamEvent::MessageStart(start) => { - for block in start.message.content { - push_output_block(block, &mut stdout, &mut events, &mut pending_tool)?; - } - } - ApiStreamEvent::ContentBlockStart(start) => { - push_output_block( - start.content_block, - &mut stdout, - &mut events, - &mut pending_tool, - )?; - } - ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { - ContentBlockDelta::TextDelta { text } => { - if !text.is_empty() { - write!(stdout, "{text}") - .and_then(|()| stdout.flush()) - .map_err(|error| RuntimeError::new(error.to_string()))?; - events.push(AssistantEvent::TextDelta(text)); - } - } - ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = &mut pending_tool { - input.push_str(&partial_json); - } - } - ContentBlockDelta::ThinkingDelta { .. } - | ContentBlockDelta::SignatureDelta { .. } => {} - }, - ApiStreamEvent::ContentBlockStop(_) => { - if let Some((id, name, input)) = pending_tool.take() { - events.push(AssistantEvent::ToolUse { id, name, input }); - } - } - ApiStreamEvent::MessageDelta(delta) => { - events.push(AssistantEvent::Usage(TokenUsage { - input_tokens: delta.usage.input_tokens, - output_tokens: delta.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - })); - } - ApiStreamEvent::MessageStop(_) => { - saw_stop = true; - events.push(AssistantEvent::MessageStop); + for attempt in 1..=max_attempts { + let result = self + .consume_stream(&message_request, is_post_tool && attempt == 1) + .await; + match result { + Ok(events) => return Ok(events), + Err(error) + if error.to_string().contains("post-tool stall") + && attempt < max_attempts => + { + // Stalled after tool completion — nudge the model by + // re-sending the same request. + continue; } + Err(error) => return Err(error), } } - if !saw_stop - && events.iter().any(|event| { - matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) - || matches!(event, AssistantEvent::ToolUse { .. }) - }) - { - events.push(AssistantEvent::MessageStop); - } - - if events - .iter() - .any(|event| matches!(event, AssistantEvent::MessageStop)) - { - return Ok(events); - } - - let response = self - .client - .send_message(&MessageRequest { - stream: false, - ..message_request.clone() - }) - .await - .map_err(|error| RuntimeError::new(error.to_string()))?; - response_to_events(response, &mut stdout) + Err(RuntimeError::new("post-tool continuation nudge exhausted")) }) } } -fn slash_command_completion_candidates() -> Vec { - slash_command_specs() +impl AnthropicRuntimeClient { + /// Consume a single streaming response, optionally applying a stall + /// timeout on the first event for post-tool continuations. + #[allow(clippy::too_many_lines)] + async fn consume_stream( + &self, + message_request: &MessageRequest, + apply_stall_timeout: bool, + ) -> Result, RuntimeError> { + let mut stream = self + .client + .stream_message(message_request) + .await + .map_err(|error| { + RuntimeError::new(format_user_visible_api_error(&self.session_id, &error)) + })?; + let mut stdout = io::stdout(); + let mut sink = io::sink(); + let out: &mut dyn Write = if self.emit_output { + &mut stdout + } else { + &mut sink + }; + let renderer = TerminalRenderer::new(); + let mut markdown_stream = MarkdownStreamState::default(); + let mut events = Vec::new(); + let mut pending_tool: Option<(String, String, String)> = None; + let mut block_has_thinking_summary = false; + let mut saw_stop = false; + let mut received_any_event = false; + + loop { + let next = if apply_stall_timeout && !received_any_event { + match tokio::time::timeout(POST_TOOL_STALL_TIMEOUT, stream.next_event()).await { + Ok(inner) => inner.map_err(|error| { + RuntimeError::new(format_user_visible_api_error(&self.session_id, &error)) + })?, + Err(_elapsed) => { + return Err(RuntimeError::new( + "post-tool stall: model did not respond within timeout", + )); + } + } + } else { + stream.next_event().await.map_err(|error| { + RuntimeError::new(format_user_visible_api_error(&self.session_id, &error)) + })? + }; + + let Some(event) = next else { + break; + }; + received_any_event = true; + + match event { + ApiStreamEvent::MessageStart(start) => { + for block in start.message.content { + push_output_block( + block, + out, + &mut events, + &mut pending_tool, + true, + &mut block_has_thinking_summary, + )?; + } + } + ApiStreamEvent::ContentBlockStart(start) => { + push_output_block( + start.content_block, + out, + &mut events, + &mut pending_tool, + true, + &mut block_has_thinking_summary, + )?; + } + ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { + ContentBlockDelta::TextDelta { text } => { + if !text.is_empty() { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_text_phase(&text); + } + if let Some(rendered) = markdown_stream.push(&renderer, &text) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + events.push(AssistantEvent::TextDelta(text)); + } + } + ContentBlockDelta::InputJsonDelta { partial_json } => { + if let Some((_, _, input)) = &mut pending_tool { + input.push_str(&partial_json); + } + } + ContentBlockDelta::ThinkingDelta { .. } => { + if !block_has_thinking_summary { + render_thinking_block_summary(out, None, false)?; + block_has_thinking_summary = true; + } + } + ContentBlockDelta::SignatureDelta { .. } => {} + }, + ApiStreamEvent::ContentBlockStop(_) => { + block_has_thinking_summary = false; + if let Some(rendered) = markdown_stream.flush(&renderer) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + if let Some((id, name, input)) = pending_tool.take() { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_tool_phase(&name, &input); + } + // Display tool call now that input is fully accumulated + writeln!(out, "\n{}", format_tool_call_start(&name, &input)) + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + ApiStreamEvent::MessageDelta(delta) => { + events.push(AssistantEvent::Usage(delta.usage.token_usage())); + } + ApiStreamEvent::MessageStop(_) => { + saw_stop = true; + if let Some(rendered) = markdown_stream.flush(&renderer) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + events.push(AssistantEvent::MessageStop); + } + } + } + + push_prompt_cache_record(&self.client, &mut events); + + if !saw_stop + && events.iter().any(|event| { + matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) + || matches!(event, AssistantEvent::ToolUse { .. }) + }) + { + events.push(AssistantEvent::MessageStop); + } + + if events + .iter() + .any(|event| matches!(event, AssistantEvent::MessageStop)) + { + return Ok(events); + } + + let response = self + .client + .send_message(&MessageRequest { + stream: false, + ..message_request.clone() + }) + .await + .map_err(|error| { + RuntimeError::new(format_user_visible_api_error(&self.session_id, &error)) + })?; + let mut events = response_to_events(response, out)?; + push_prompt_cache_record(&self.client, &mut events); + Ok(events) + } +} + +/// Returns `true` when the conversation ends with a tool-result message, +/// meaning the model is expected to continue after tool execution. +fn request_ends_with_tool_result(request: &ApiRequest) -> bool { + request + .messages + .last() + .is_some_and(|message| message.role == MessageRole::Tool) +} + +fn format_user_visible_api_error(session_id: &str, error: &api::ApiError) -> String { + if error.is_context_window_failure() { + format_context_window_blocked_error(session_id, error) + } else if error.is_generic_fatal_wrapper() { + let mut qualifiers = vec![format!("session {session_id}")]; + if let Some(request_id) = error.request_id() { + qualifiers.push(format!("trace {request_id}")); + } + format!( + "{} ({}): {}", + error.safe_failure_class(), + qualifiers.join(", "), + error + ) + } else { + error.to_string() + } +} + +fn format_context_window_blocked_error(session_id: &str, error: &api::ApiError) -> String { + let mut lines = vec![ + "Context window blocked".to_string(), + " Failure class context_window_blocked".to_string(), + format!(" Session {session_id}"), + ]; + + if let Some(request_id) = error.request_id() { + lines.push(format!(" Trace {request_id}")); + } + + match error { + api::ApiError::ContextWindowExceeded { + model, + estimated_input_tokens, + requested_output_tokens, + estimated_total_tokens, + context_window_tokens, + } => { + lines.push(format!(" Model {model}")); + lines.push(format!( + " Input estimate ~{estimated_input_tokens} tokens (heuristic)" + )); + lines.push(format!( + " Requested output {requested_output_tokens} tokens" + )); + lines.push(format!( + " Total estimate ~{estimated_total_tokens} tokens (heuristic)" + )); + lines.push(format!(" Context window {context_window_tokens} tokens")); + } + api::ApiError::Api { message, body, .. } => { + let detail = message.as_deref().unwrap_or(body).trim(); + if !detail.is_empty() { + lines.push(format!( + " Detail {}", + truncate_for_summary(detail, 120) + )); + } + } + api::ApiError::RetriesExhausted { last_error, .. } => { + let detail = match last_error.as_ref() { + api::ApiError::Api { message, body, .. } => message.as_deref().unwrap_or(body), + other => return format_context_window_blocked_error(session_id, other), + } + .trim(); + if !detail.is_empty() { + lines.push(format!( + " Detail {}", + truncate_for_summary(detail, 120) + )); + } + } + _ => {} + } + + lines.push(String::new()); + lines.push("Recovery".to_string()); + lines.push(" Compact /compact".to_string()); + lines.push(format!( + " Resume compact claw --resume {session_id} /compact" + )); + lines.push(" Fresh session /clear --confirm".to_string()); + lines.push( + " Reduce scope remove large pasted context/files or ask for a smaller slice" + .to_string(), + ); + lines.push(" Retry rerun after compacting or reducing the request".to_string()); + + lines.join("\n") +} + +fn final_assistant_text(summary: &runtime::TurnSummary) -> String { + summary + .assistant_messages + .last() + .map(|message| { + message + .blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("") + }) + .unwrap_or_default() +} + +fn collect_tool_uses(summary: &runtime::TurnSummary) -> Vec { + summary + .assistant_messages .iter() - .map(|spec| format!("/{}", spec.name)) + .flat_map(|message| message.blocks.iter()) + .filter_map(|block| match block { + ContentBlock::ToolUse { id, name, input } => Some(json!({ + "id": id, + "name": name, + "input": input, + })), + _ => None, + }) .collect() } +fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec { + summary + .tool_results + .iter() + .flat_map(|message| message.blocks.iter()) + .filter_map(|block| match block { + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } => Some(json!({ + "tool_use_id": tool_use_id, + "tool_name": tool_name, + "output": output, + "is_error": is_error, + })), + _ => None, + }) + .collect() +} + +fn collect_prompt_cache_events(summary: &runtime::TurnSummary) -> Vec { + summary + .prompt_cache_events + .iter() + .map(|event| { + json!({ + "unexpected": event.unexpected, + "reason": event.reason, + "previous_cache_read_input_tokens": event.previous_cache_read_input_tokens, + "current_cache_read_input_tokens": event.current_cache_read_input_tokens, + "token_drop": event.token_drop, + }) + }) + .collect() +} + +/// Slash commands that are registered in the spec list but not yet implemented +/// in this build. Used to filter both REPL completions and help output so the +/// discovery surface only shows commands that actually work (ROADMAP #39). +const STUB_COMMANDS: &[&str] = &[ + "login", + "logout", + "vim", + "upgrade", + "share", + "feedback", + "files", + "fast", + "exit", + "summary", + "desktop", + "brief", + "advisor", + "stickers", + "insights", + "thinkback", + "release-notes", + "security-review", + "keybindings", + "privacy-settings", + "plan", + "review", + "tasks", + "theme", + "voice", + "usage", + "rename", + "copy", + "hooks", + "context", + "color", + "effort", + "branch", + "rewind", + "ide", + "tag", + "output-style", + "add-dir", + // Spec entries with no parse arm — produce circular "Did you mean" error + // without this guard. Adding here routes them to the proper unsupported + // message and excludes them from REPL completions / help. + // NOTE: do NOT add "stats", "tokens", "cache" — they are implemented. + "allowed-tools", + "bookmarks", + "workspace", + "reasoning", + "budget", + "rate-limit", + "changelog", + "diagnostics", + "metrics", + "tool-details", + "focus", + "unfocus", + "pin", + "unpin", + "language", + "profile", + "max-tokens", + "temperature", + "system-prompt", + "notifications", + "telemetry", + "env", + "project", + "terminal-setup", + "api-key", + "reset", + "undo", + "stop", + "retry", + "paste", + "screenshot", + "image", + "search", + "listen", + "speak", + "format", + "test", + "lint", + "build", + "run", + "git", + "stash", + "blame", + "log", + "cron", + "team", + "benchmark", + "migrate", + "templates", + "explain", + "refactor", + "docs", + "fix", + "perf", + "chat", + "web", + "map", + "symbols", + "references", + "definition", + "hover", + "autofix", + "multi", + "macro", + "alias", + "parallel", + "subagent", + "agent", +]; + +fn slash_command_completion_candidates_with_sessions( + model: &str, + active_session_id: Option<&str>, + recent_session_ids: Vec, +) -> Vec { + let mut completions = BTreeSet::new(); + + for spec in slash_command_specs() { + if STUB_COMMANDS.contains(&spec.name) { + continue; + } + completions.insert(format!("/{}", spec.name)); + for alias in spec.aliases { + if !STUB_COMMANDS.contains(alias) { + completions.insert(format!("/{alias}")); + } + } + } + + for candidate in [ + "/bughunter ", + "/clear --confirm", + "/config ", + "/config env", + "/config hooks", + "/config model", + "/config plugins", + "/mcp ", + "/mcp list", + "/mcp show ", + "/export ", + "/issue ", + "/model ", + "/model opus", + "/model sonnet", + "/model haiku", + "/permissions ", + "/permissions read-only", + "/permissions workspace-write", + "/permissions danger-full-access", + "/plugin list", + "/plugin install ", + "/plugin enable ", + "/plugin disable ", + "/plugin uninstall ", + "/plugin update ", + "/plugins list", + "/pr ", + "/resume ", + "/session list", + "/session switch ", + "/session fork ", + "/teleport ", + "/ultraplan ", + "/agents help", + "/mcp help", + "/skills help", + ] { + completions.insert(candidate.to_string()); + } + + if !model.trim().is_empty() { + completions.insert(format!("/model {}", resolve_model_alias(model))); + completions.insert(format!("/model {model}")); + } + + if let Some(active_session_id) = active_session_id.filter(|value| !value.trim().is_empty()) { + completions.insert(format!("/resume {active_session_id}")); + completions.insert(format!("/session switch {active_session_id}")); + } + + for session_id in recent_session_ids + .into_iter() + .filter(|value| !value.trim().is_empty()) + .take(10) + { + completions.insert(format!("/resume {session_id}")); + completions.insert(format!("/session switch {session_id}")); + } + + completions.into_iter().collect() +} + fn format_tool_call_start(name: &str, input: &str) -> String { + let parsed: serde_json::Value = + serde_json::from_str(input).unwrap_or(serde_json::Value::String(input.to_string())); + + let detail = match name { + "bash" | "Bash" => format_bash_call(&parsed), + "read_file" | "Read" => { + let path = extract_tool_path(&parsed); + format!("\x1b[2m📄 Reading {path}…\x1b[0m") + } + "write_file" | "Write" => { + let path = extract_tool_path(&parsed); + let lines = parsed + .get("content") + .and_then(|value| value.as_str()) + .map_or(0, |content| content.lines().count()); + format!("\x1b[1;32m✏️ Writing {path}\x1b[0m \x1b[2m({lines} lines)\x1b[0m") + } + "edit_file" | "Edit" => { + let path = extract_tool_path(&parsed); + let old_value = parsed + .get("old_string") + .or_else(|| parsed.get("oldString")) + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let new_value = parsed + .get("new_string") + .or_else(|| parsed.get("newString")) + .and_then(|value| value.as_str()) + .unwrap_or_default(); + format!( + "\x1b[1;33m📝 Editing {path}\x1b[0m{}", + format_patch_preview(old_value, new_value) + .map(|preview| format!("\n{preview}")) + .unwrap_or_default() + ) + } + "glob_search" | "Glob" => format_search_start("🔎 Glob", &parsed), + "grep_search" | "Grep" => format_search_start("🔎 Grep", &parsed), + "web_search" | "WebSearch" => parsed + .get("query") + .and_then(|value| value.as_str()) + .unwrap_or("?") + .to_string(), + _ => summarize_tool_payload(input), + }; + + let border = "─".repeat(name.len() + 8); format!( - "Tool call - Name {name} - Input {}", - summarize_tool_payload(input) + "\x1b[38;5;245m╭─ \x1b[1;36m{name}\x1b[0;38;5;245m ─╮\x1b[0m\n\x1b[38;5;245m│\x1b[0m {detail}\n\x1b[38;5;245m╰{border}╯\x1b[0m" ) } fn format_tool_result(name: &str, output: &str, is_error: bool) -> String { - let status = if is_error { "error" } else { "ok" }; + let icon = if is_error { + "\x1b[1;31m✗\x1b[0m" + } else { + "\x1b[1;32m✓\x1b[0m" + }; + if is_error { + let summary = truncate_for_summary(output.trim(), 160); + return if summary.is_empty() { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m") + } else { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m\n\x1b[38;5;203m{summary}\x1b[0m") + }; + } + + let parsed: serde_json::Value = + serde_json::from_str(output).unwrap_or(serde_json::Value::String(output.to_string())); + match name { + "bash" | "Bash" => format_bash_result(icon, &parsed), + "read_file" | "Read" => format_read_result(icon, &parsed), + "write_file" | "Write" => format_write_result(icon, &parsed), + "edit_file" | "Edit" => format_edit_result(icon, &parsed), + "glob_search" | "Glob" => format_glob_result(icon, &parsed), + "grep_search" | "Grep" => format_grep_result(icon, &parsed), + _ => format_generic_tool_result(icon, name, &parsed), + } +} + +const DISPLAY_TRUNCATION_NOTICE: &str = + "\x1b[2m… output truncated for display; full result preserved in session.\x1b[0m"; +const READ_DISPLAY_MAX_LINES: usize = 80; +const READ_DISPLAY_MAX_CHARS: usize = 6_000; +const TOOL_OUTPUT_DISPLAY_MAX_LINES: usize = 60; +const TOOL_OUTPUT_DISPLAY_MAX_CHARS: usize = 4_000; + +fn extract_tool_path(parsed: &serde_json::Value) -> String { + parsed + .get("file_path") + .or_else(|| parsed.get("filePath")) + .or_else(|| parsed.get("path")) + .and_then(|value| value.as_str()) + .unwrap_or("?") + .to_string() +} + +fn format_search_start(label: &str, parsed: &serde_json::Value) -> String { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("{label} {pattern}\n\x1b[2min {scope}\x1b[0m") +} + +fn format_patch_preview(old_value: &str, new_value: &str) -> Option { + if old_value.is_empty() && new_value.is_empty() { + return None; + } + Some(format!( + "\x1b[38;5;203m- {}\x1b[0m\n\x1b[38;5;70m+ {}\x1b[0m", + truncate_for_summary(first_visible_line(old_value), 72), + truncate_for_summary(first_visible_line(new_value), 72) + )) +} + +fn format_bash_call(parsed: &serde_json::Value) -> String { + let command = parsed + .get("command") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + if command.is_empty() { + String::new() + } else { + format!( + "\x1b[48;5;236;38;5;255m $ {} \x1b[0m", + truncate_for_summary(command, 160) + ) + } +} + +fn first_visible_line(text: &str) -> &str { + text.lines() + .find(|line| !line.trim().is_empty()) + .unwrap_or(text) +} + +fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { + use std::fmt::Write as _; + + let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; + if let Some(task_id) = parsed + .get("backgroundTaskId") + .and_then(|value| value.as_str()) + { + write!(&mut lines[0], " backgrounded ({task_id})").expect("write to string"); + } else if let Some(status) = parsed + .get("returnCodeInterpretation") + .and_then(|value| value.as_str()) + .filter(|status| !status.is_empty()) + { + write!(&mut lines[0], " {status}").expect("write to string"); + } + + if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { + if !stdout.trim().is_empty() { + lines.push(truncate_output_for_display( + stdout, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + )); + } + } + if let Some(stderr) = parsed.get("stderr").and_then(|value| value.as_str()) { + if !stderr.trim().is_empty() { + lines.push(format!( + "\x1b[38;5;203m{}\x1b[0m", + truncate_output_for_display( + stderr, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ) + )); + } + } + + lines.join("\n\n") +} + +fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { + let file = parsed.get("file").unwrap_or(parsed); + let path = extract_tool_path(file); + let start_line = file + .get("startLine") + .and_then(serde_json::Value::as_u64) + .unwrap_or(1); + let num_lines = file + .get("numLines") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let total_lines = file + .get("totalLines") + .and_then(serde_json::Value::as_u64) + .unwrap_or(num_lines); + let content = file + .get("content") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let end_line = start_line.saturating_add(num_lines.saturating_sub(1)); + format!( - "### Tool `{name}` - -- Status: {status} -- Output: - -```json -{} -``` -", - prettify_tool_payload(output) + "{icon} \x1b[2m📄 Read {path} (lines {}-{} of {})\x1b[0m\n{}", + start_line, + end_line.max(start_line), + total_lines, + truncate_output_for_display(content, READ_DISPLAY_MAX_LINES, READ_DISPLAY_MAX_CHARS) ) } +fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String { + let path = extract_tool_path(parsed); + let kind = parsed + .get("type") + .and_then(|value| value.as_str()) + .unwrap_or("write"); + let line_count = parsed + .get("content") + .and_then(|value| value.as_str()) + .map_or(0, |content| content.lines().count()); + format!( + "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", + if kind == "create" { "Wrote" } else { "Updated" }, + ) +} + +fn format_structured_patch_preview(parsed: &serde_json::Value) -> Option { + let hunks = parsed.get("structuredPatch")?.as_array()?; + let mut preview = Vec::new(); + for hunk in hunks.iter().take(2) { + let lines = hunk.get("lines")?.as_array()?; + for line in lines.iter().filter_map(|value| value.as_str()).take(6) { + match line.chars().next() { + Some('+') => preview.push(format!("\x1b[38;5;70m{line}\x1b[0m")), + Some('-') => preview.push(format!("\x1b[38;5;203m{line}\x1b[0m")), + _ => preview.push(line.to_string()), + } + } + } + if preview.is_empty() { + None + } else { + Some(preview.join("\n")) + } +} + +fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { + let path = extract_tool_path(parsed); + let suffix = if parsed + .get("replaceAll") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + { + " (replace all)" + } else { + "" + }; + let preview = format_structured_patch_preview(parsed).or_else(|| { + let old_value = parsed + .get("oldString") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let new_value = parsed + .get("newString") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + format_patch_preview(old_value, new_value) + }); + + match preview { + Some(preview) => format!("{icon} \x1b[1;33m📝 Edited {path}{suffix}\x1b[0m\n{preview}"), + None => format!("{icon} \x1b[1;33m📝 Edited {path}{suffix}\x1b[0m"), + } +} + +fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { + let num_files = parsed + .get("numFiles") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let filenames = parsed + .get("filenames") + .and_then(|value| value.as_array()) + .map(|files| { + files + .iter() + .filter_map(|value| value.as_str()) + .take(8) + .collect::>() + .join("\n") + }) + .unwrap_or_default(); + if filenames.is_empty() { + format!("{icon} \x1b[38;5;245mglob_search\x1b[0m matched {num_files} files") + } else { + format!("{icon} \x1b[38;5;245mglob_search\x1b[0m matched {num_files} files\n{filenames}") + } +} + +fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { + let num_matches = parsed + .get("numMatches") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let num_files = parsed + .get("numFiles") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let content = parsed + .get("content") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let filenames = parsed + .get("filenames") + .and_then(|value| value.as_array()) + .map(|files| { + files + .iter() + .filter_map(|value| value.as_str()) + .take(8) + .collect::>() + .join("\n") + }) + .unwrap_or_default(); + let summary = format!( + "{icon} \x1b[38;5;245mgrep_search\x1b[0m {num_matches} matches across {num_files} files" + ); + if !content.trim().is_empty() { + format!( + "{summary}\n{}", + truncate_output_for_display( + content, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ) + ) + } else if !filenames.is_empty() { + format!("{summary}\n{filenames}") + } else { + summary + } +} + +fn format_generic_tool_result(icon: &str, name: &str, parsed: &serde_json::Value) -> String { + let rendered_output = match parsed { + serde_json::Value::String(text) => text.clone(), + serde_json::Value::Null => String::new(), + serde_json::Value::Object(_) | serde_json::Value::Array(_) => { + serde_json::to_string_pretty(parsed).unwrap_or_else(|_| parsed.to_string()) + } + _ => parsed.to_string(), + }; + let preview = truncate_output_for_display( + &rendered_output, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ); + + if preview.is_empty() { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m") + } else if preview.contains('\n') { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m\n{preview}") + } else { + format!("{icon} \x1b[38;5;245m{name}:\x1b[0m {preview}") + } +} + fn summarize_tool_payload(payload: &str) -> String { let compact = match serde_json::from_str::(payload) { Ok(value) => value.to_string(), @@ -2215,13 +7973,6 @@ fn summarize_tool_payload(payload: &str) -> String { truncate_for_summary(&compact, 96) } -fn prettify_tool_payload(payload: &str) -> String { - match serde_json::from_str::(payload) { - Ok(value) => serde_json::to_string_pretty(&value).unwrap_or_else(|_| payload.to_string()), - Err(_) => payload.to_string(), - } -} - fn truncate_for_summary(value: &str, limit: usize) -> String { let mut chars = value.chars(); let truncated = chars.by_ref().take(limit).collect::(); @@ -2232,71 +7983,245 @@ fn truncate_for_summary(value: &str, limit: usize) -> String { } } +fn truncate_output_for_display(content: &str, max_lines: usize, max_chars: usize) -> String { + let original = content.trim_end_matches('\n'); + if original.is_empty() { + return String::new(); + } + + let mut preview_lines = Vec::new(); + let mut used_chars = 0usize; + let mut truncated = false; + + for (index, line) in original.lines().enumerate() { + if index >= max_lines { + truncated = true; + break; + } + + let newline_cost = usize::from(!preview_lines.is_empty()); + let available = max_chars.saturating_sub(used_chars + newline_cost); + if available == 0 { + truncated = true; + break; + } + + let line_chars = line.chars().count(); + if line_chars > available { + preview_lines.push(line.chars().take(available).collect::()); + truncated = true; + break; + } + + preview_lines.push(line.to_string()); + used_chars += newline_cost + line_chars; + } + + let mut preview = preview_lines.join("\n"); + if truncated { + if !preview.is_empty() { + preview.push('\n'); + } + preview.push_str(DISPLAY_TRUNCATION_NOTICE); + } + preview +} + +fn render_thinking_block_summary( + out: &mut (impl Write + ?Sized), + char_count: Option, + redacted: bool, +) -> Result<(), RuntimeError> { + let summary = if redacted { + "\n▶ Thinking block hidden by provider\n".to_string() + } else if let Some(char_count) = char_count { + format!("\n▶ Thinking ({char_count} chars hidden)\n") + } else { + "\n▶ Thinking hidden\n".to_string() + }; + write!(out, "{summary}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string())) +} + fn push_output_block( block: OutputContentBlock, - out: &mut impl Write, + out: &mut (impl Write + ?Sized), events: &mut Vec, pending_tool: &mut Option<(String, String, String)>, + streaming_tool_input: bool, + block_has_thinking_summary: &mut bool, ) -> Result<(), RuntimeError> { match block { OutputContentBlock::Text { text } => { if !text.is_empty() { - write!(out, "{text}") + let rendered = TerminalRenderer::new().markdown_to_ansi(&text); + write!(out, "{rendered}") .and_then(|()| out.flush()) .map_err(|error| RuntimeError::new(error.to_string()))?; events.push(AssistantEvent::TextDelta(text)); } } OutputContentBlock::ToolUse { id, name, input } => { - writeln!( - out, - " -{}", - format_tool_call_start(&name, &input.to_string()) - ) - .and_then(|()| out.flush()) - .map_err(|error| RuntimeError::new(error.to_string()))?; - *pending_tool = Some((id, name, input.to_string())); + // During streaming, the initial content_block_start has an empty input ({}). + // The real input arrives via input_json_delta events. In + // non-streaming responses, preserve a legitimate empty object. + let initial_input = if streaming_tool_input + && input.is_object() + && input.as_object().is_some_and(serde_json::Map::is_empty) + { + String::new() + } else { + input.to_string() + }; + *pending_tool = Some((id, name, initial_input)); + } + OutputContentBlock::Thinking { thinking, .. } => { + render_thinking_block_summary(out, Some(thinking.chars().count()), false)?; + *block_has_thinking_summary = true; + } + OutputContentBlock::RedactedThinking { .. } => { + render_thinking_block_summary(out, None, true)?; + *block_has_thinking_summary = true; } - OutputContentBlock::Thinking { .. } | OutputContentBlock::RedactedThinking { .. } => {} } Ok(()) } fn response_to_events( response: MessageResponse, - out: &mut impl Write, + out: &mut (impl Write + ?Sized), ) -> Result, RuntimeError> { let mut events = Vec::new(); let mut pending_tool = None; for block in response.content { - push_output_block(block, out, &mut events, &mut pending_tool)?; + let mut block_has_thinking_summary = false; + push_output_block( + block, + out, + &mut events, + &mut pending_tool, + false, + &mut block_has_thinking_summary, + )?; if let Some((id, name, input)) = pending_tool.take() { events.push(AssistantEvent::ToolUse { id, name, input }); } } - events.push(AssistantEvent::Usage(TokenUsage { - input_tokens: response.usage.input_tokens, - output_tokens: response.usage.output_tokens, - cache_creation_input_tokens: response.usage.cache_creation_input_tokens, - cache_read_input_tokens: response.usage.cache_read_input_tokens, - })); + events.push(AssistantEvent::Usage(response.usage.token_usage())); events.push(AssistantEvent::MessageStop); Ok(events) } +fn push_prompt_cache_record(client: &ApiProviderClient, events: &mut Vec) { + // `ApiProviderClient::take_last_prompt_cache_record` is a pass-through + // to the Anthropic variant and returns `None` for OpenAI-compat / + // xAI variants, which do not have a prompt cache. So this helper + // remains a no-op on non-Anthropic providers without any extra + // branching here. + if let Some(record) = client.take_last_prompt_cache_record() { + if let Some(event) = prompt_cache_record_to_runtime_event(record) { + events.push(AssistantEvent::PromptCache(event)); + } + } +} + +fn prompt_cache_record_to_runtime_event( + record: api::PromptCacheRecord, +) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + struct CliToolExecutor { renderer: TerminalRenderer, + emit_output: bool, allowed_tools: Option, + tool_registry: GlobalToolRegistry, + mcp_state: Option>>, } impl CliToolExecutor { - fn new(allowed_tools: Option) -> Self { + fn new( + allowed_tools: Option, + emit_output: bool, + tool_registry: GlobalToolRegistry, + mcp_state: Option>>, + ) -> Self { Self { renderer: TerminalRenderer::new(), + emit_output, allowed_tools, + tool_registry, + mcp_state, + } + } + + fn execute_search_tool(&self, value: serde_json::Value) -> Result { + let input: ToolSearchRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + let (pending_mcp_servers, mcp_degraded) = + self.mcp_state.as_ref().map_or((None, None), |state| { + let state = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + (state.pending_servers(), state.degraded_report()) + }); + serde_json::to_string_pretty(&self.tool_registry.search( + &input.query, + input.max_results.unwrap_or(5), + pending_mcp_servers, + mcp_degraded, + )) + .map_err(|error| ToolError::new(error.to_string())) + } + + fn execute_runtime_tool( + &self, + tool_name: &str, + value: serde_json::Value, + ) -> Result { + let Some(mcp_state) = &self.mcp_state else { + return Err(ToolError::new(format!( + "runtime tool `{tool_name}` is unavailable without configured MCP servers" + ))); + }; + let mut mcp_state = mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + match tool_name { + "MCPTool" => { + let input: McpToolRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + let qualified_name = input + .qualified_name + .or(input.tool) + .ok_or_else(|| ToolError::new("missing required field `qualifiedName`"))?; + mcp_state.call_tool(&qualified_name, input.arguments) + } + "ListMcpResourcesTool" => { + let input: ListMcpResourcesRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + match input.server { + Some(server_name) => mcp_state.list_resources_for_server(&server_name), + None => mcp_state.list_resources_for_all_servers(), + } + } + "ReadMcpResourceTool" => { + let input: ReadMcpResourceRequest = serde_json::from_value(value) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + mcp_state.read_resource(&input.server, &input.uri) + } + _ => mcp_state.call_tool(tool_name, Some(value)), } } } @@ -2314,35 +8239,49 @@ impl ToolExecutor for CliToolExecutor { } let value = serde_json::from_str(input) .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; - match execute_tool(tool_name, &value) { + let result = if tool_name == "ToolSearch" { + self.execute_search_tool(value) + } else if self.tool_registry.has_runtime_tool(tool_name) { + self.execute_runtime_tool(tool_name, value) + } else { + self.tool_registry + .execute(tool_name, &value) + .map_err(ToolError::new) + }; + match result { Ok(output) => { - let markdown = format_tool_result(tool_name, &output, false); - self.renderer - .stream_markdown(&markdown, &mut io::stdout()) - .map_err(|error| ToolError::new(error.to_string()))?; + if self.emit_output { + let markdown = format_tool_result(tool_name, &output, false); + self.renderer + .stream_markdown(&markdown, &mut io::stdout()) + .map_err(|error| ToolError::new(error.to_string()))?; + } Ok(output) } Err(error) => { - let markdown = format_tool_result(tool_name, &error, true); - self.renderer - .stream_markdown(&markdown, &mut io::stdout()) - .map_err(|stream_error| ToolError::new(stream_error.to_string()))?; - Err(ToolError::new(error)) + if self.emit_output { + let markdown = format_tool_result(tool_name, &error.to_string(), true); + self.renderer + .stream_markdown(&markdown, &mut io::stdout()) + .map_err(|stream_error| ToolError::new(stream_error.to_string()))?; + } + Err(error) } } } } -fn permission_policy(mode: PermissionMode) -> PermissionPolicy { - tool_permission_specs() - .into_iter() - .fold(PermissionPolicy::new(mode), |policy, spec| { - policy.with_tool_requirement(spec.name, spec.required_permission) - }) -} - -fn tool_permission_specs() -> Vec { - mvp_tool_specs() +fn permission_policy( + mode: PermissionMode, + feature_config: &runtime::RuntimeFeatureConfig, + tool_registry: &GlobalToolRegistry, +) -> Result { + Ok(tool_registry.permission_specs(None)?.into_iter().fold( + PermissionPolicy::new(mode).with_permission_rules(feature_config.permission_rules()), + |policy, (name, required_permission)| { + policy.with_tool_requirement(name, required_permission) + }, + )) } fn convert_messages(messages: &[ConversationMessage]) -> Vec { @@ -2358,16 +8297,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { .iter() .map(|block| match block { ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() }, - ContentBlock::Thinking { - thinking, - signature, - } => InputContentBlock::Thinking { - thinking: thinking.clone(), - signature: signature.clone(), - }, - ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking { - data: serde_json::from_str(&data.render()).unwrap_or(serde_json::Value::Null), - }, ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse { id: id.clone(), name: name.clone(), @@ -2396,6 +8325,7 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { .collect() } +#[allow(clippy::too_many_lines)] fn print_help_to(out: &mut impl Write) -> io::Result<()> { writeln!(out, "claw v{VERSION}")?; writeln!(out)?; @@ -2417,21 +8347,45 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { writeln!(out, " Shorthand non-interactive prompt mode")?; writeln!( out, - " claw --resume SESSION.json [/status] [/compact] [...]" + " claw --resume [SESSION.jsonl|session-id|latest] [/status] [/compact] [...]" )?; writeln!( out, " Inspect or maintain a saved session without entering the REPL" )?; - writeln!(out, " claw dump-manifests")?; - writeln!(out, " claw bootstrap-plan")?; + writeln!(out, " claw help")?; + writeln!(out, " Alias for --help")?; + writeln!(out, " claw version")?; + writeln!(out, " Alias for --version")?; + writeln!(out, " claw status")?; writeln!( out, - " claw system-prompt [--cwd PATH] [--date YYYY-MM-DD]" + " Show the current local workspace status snapshot" )?; + writeln!(out, " claw sandbox")?; + writeln!(out, " Show the current sandbox isolation snapshot")?; + writeln!(out, " claw doctor")?; + writeln!( + out, + " Diagnose local auth, config, workspace, and sandbox health" + )?; + writeln!(out, " claw dump-manifests")?; + writeln!(out, " claw bootstrap-plan")?; + writeln!(out, " claw agents")?; + writeln!(out, " claw mcp")?; + writeln!(out, " claw skills")?; + writeln!(out, " claw system-prompt [--cwd PATH] [--date YYYY-MM-DD]")?; writeln!(out, " claw login")?; writeln!(out, " claw logout")?; writeln!(out, " claw init")?; + writeln!( + out, + " claw export [PATH] [--session SESSION] [--output PATH]" + )?; + writeln!( + out, + " Dump the latest (or named) session as markdown; writes to PATH or stdout" + )?; writeln!(out)?; writeln!(out, "Flags:")?; writeln!( @@ -2442,10 +8396,18 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { out, " --output-format FORMAT Non-interactive output format: text or json" )?; + writeln!( + out, + " --compact Strip tool call details; print only the final assistant text (text mode only; useful for piping)" + )?; writeln!( out, " --permission-mode MODE Set read-only, workspace-write, or danger-full-access" )?; + writeln!( + out, + " --dangerously-skip-permissions Skip all permission checks" + )?; writeln!(out, " --allowedTools TOOLS Restrict enabled tools (repeatable; comma-separated aliases supported)")?; writeln!( out, @@ -2453,7 +8415,7 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { )?; writeln!(out)?; writeln!(out, "Interactive slash commands:")?; - writeln!(out, "{}", render_slash_command_help())?; + writeln!(out, "{}", render_slash_command_help_filtered(STUB_COMMANDS))?; writeln!(out)?; let resume_commands = resume_supported_slash_commands() .into_iter() @@ -2464,60 +8426,563 @@ fn print_help_to(out: &mut impl Write) -> io::Result<()> { .collect::>() .join(", "); writeln!(out, "Resume-safe commands: {resume_commands}")?; - writeln!(out, "Examples:")?; + writeln!(out)?; + writeln!(out, "Session shortcuts:")?; writeln!( out, - " claw --model claude-opus \"summarize this repo\"" + " REPL turns auto-save to .claw/sessions/.{PRIMARY_SESSION_EXTENSION}" )?; + writeln!( + out, + " Use `{LATEST_SESSION_REFERENCE}` with --resume, /resume, or /session switch to target the newest saved session" + )?; + writeln!( + out, + " Use /session list in the REPL to browse managed sessions" + )?; + writeln!(out, "Examples:")?; + writeln!(out, " claw --model claude-opus \"summarize this repo\"")?; writeln!( out, " claw --output-format json prompt \"explain src/main.rs\"" )?; + writeln!(out, " claw --compact \"summarize Cargo.toml\" | wc -l")?; writeln!( out, " claw --allowedTools read,glob \"summarize Cargo.toml\"" )?; + writeln!(out, " claw --resume {LATEST_SESSION_REFERENCE}")?; writeln!( out, - " claw --resume session.json /status /diff /export notes.txt" + " claw --resume {LATEST_SESSION_REFERENCE} /status /diff /export notes.txt" )?; + writeln!(out, " claw agents")?; + writeln!(out, " claw mcp show my-server")?; + writeln!(out, " claw /skills")?; + writeln!(out, " claw doctor")?; writeln!(out, " claw login")?; writeln!(out, " claw init")?; + writeln!(out, " claw export")?; + writeln!(out, " claw export conversation.md")?; Ok(()) } -fn print_help() { - let _ = print_help_to(&mut io::stdout()); +fn print_help(output_format: CliOutputFormat) -> Result<(), Box> { + let mut buffer = Vec::new(); + print_help_to(&mut buffer)?; + let message = String::from_utf8(buffer)?; + match output_format { + CliOutputFormat::Text => print!("{message}"), + CliOutputFormat::Json => println!( + "{}", + serde_json::to_string_pretty(&json!({ + "kind": "help", + "message": message, + }))? + ), + } + Ok(()) } #[cfg(test)] mod tests { use super::{ - filter_tool_specs, format_compact_report, format_cost_report, format_model_report, - format_model_switch_report, format_permissions_report, format_permissions_switch_report, + build_runtime_plugin_state_with_loader, build_runtime_with_plugin_state, + collect_session_prompt_history, create_managed_session_handle, describe_tool_progress, + filter_tool_specs, format_bughunter_report, format_commit_preflight_report, + format_commit_skipped_report, format_compact_report, format_connected_line, + format_cost_report, format_history_timestamp, format_internal_prompt_progress_line, + format_issue_report, format_model_report, format_model_switch_report, + format_permissions_report, format_permissions_switch_report, format_pr_report, format_resume_report, format_status_report, format_tool_call_start, format_tool_result, - normalize_permission_mode, parse_args, parse_git_status_metadata, print_help_to, - render_config_report, render_memory_report, render_repl_help, - resume_supported_slash_commands, status_context, CliAction, CliOutputFormat, SlashCommand, - StatusUsage, DEFAULT_MODEL, + format_ultraplan_report, format_unknown_slash_command, + format_unknown_slash_command_message, format_user_visible_api_error, + merge_prompt_with_stdin, normalize_permission_mode, parse_args, parse_export_args, + parse_git_status_branch, parse_git_status_metadata_for, parse_git_workspace_summary, + parse_history_count, permission_policy, print_help_to, push_output_block, + render_config_report, render_diff_report, render_diff_report_for, render_memory_report, + render_prompt_history_report, render_repl_help, render_resume_usage, + render_session_markdown, resolve_model_alias, resolve_model_alias_with_config, + resolve_repl_model, resolve_session_reference, response_to_events, + resume_supported_slash_commands, run_resume_command, short_tool_id, + slash_command_completion_candidates_with_sessions, status_context, + summarize_tool_payload_for_markdown, validate_no_args, write_mcp_server_fixture, CliAction, + CliOutputFormat, CliToolExecutor, GitWorkspaceSummary, InternalPromptProgressEvent, + InternalPromptProgressState, LiveCli, LocalHelpTopic, PromptHistoryEntry, SlashCommand, + StatusUsage, DEFAULT_MODEL, LATEST_SESSION_REFERENCE, STUB_COMMANDS, }; - use runtime::{ContentBlock, ConversationMessage, MessageRole, PermissionMode}; - use std::path::PathBuf; + use api::{ApiError, MessageResponse, OutputContentBlock, Usage}; + use plugins::{ + PluginManager, PluginManagerConfig, PluginTool, PluginToolDefinition, PluginToolPermission, + }; + use runtime::{ + load_oauth_credentials, save_oauth_credentials, AssistantEvent, ConfigLoader, ContentBlock, + ConversationMessage, MessageRole, OAuthConfig, PermissionMode, Session, ToolExecutor, + }; + use serde_json::json; + use std::fs; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::path::{Path, PathBuf}; + use std::process::Command; + use std::sync::{Mutex, MutexGuard, OnceLock}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use tools::GlobalToolRegistry; + + fn registry_with_plugin_tool() -> GlobalToolRegistry { + GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( + "plugin-demo@external", + "plugin-demo", + PluginToolDefinition { + name: "plugin_echo".to_string(), + description: Some("Echo plugin payload".to_string()), + input_schema: json!({ + "type": "object", + "properties": { + "message": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + }), + }, + "echo".to_string(), + Vec::new(), + PluginToolPermission::WorkspaceWrite, + None, + )]) + .expect("plugin tool registry should build") + } + #[test] + fn opaque_provider_wrapper_surfaces_failure_class_session_and_trace() { + let error = ApiError::Api { + status: "500".parse().expect("status"), + error_type: Some("api_error".to_string()), + message: Some( + "Something went wrong while processing your request. Please try again, or use /new to start a fresh session." + .to_string(), + ), + request_id: Some("req_jobdori_789".to_string()), + body: String::new(), + retryable: true, + }; + + let rendered = format_user_visible_api_error("session-issue-22", &error); + assert!(rendered.contains("provider_internal")); + assert!(rendered.contains("session session-issue-22")); + assert!(rendered.contains("trace req_jobdori_789")); + } + + #[test] + fn retry_exhaustion_uses_retry_failure_class_for_generic_provider_wrapper() { + let error = ApiError::RetriesExhausted { + attempts: 3, + last_error: Box::new(ApiError::Api { + status: "502".parse().expect("status"), + error_type: Some("api_error".to_string()), + message: Some( + "Something went wrong while processing your request. Please try again, or use /new to start a fresh session." + .to_string(), + ), + request_id: Some("req_jobdori_790".to_string()), + body: String::new(), + retryable: true, + }), + }; + + let rendered = format_user_visible_api_error("session-issue-22", &error); + assert!(rendered.contains("provider_retry_exhausted"), "{rendered}"); + assert!(rendered.contains("session session-issue-22")); + assert!(rendered.contains("trace req_jobdori_790")); + } + + #[test] + fn context_window_preflight_errors_render_recovery_steps() { + let error = ApiError::ContextWindowExceeded { + model: "claude-sonnet-4-6".to_string(), + estimated_input_tokens: 182_000, + requested_output_tokens: 64_000, + estimated_total_tokens: 246_000, + context_window_tokens: 200_000, + }; + + let rendered = format_user_visible_api_error("session-issue-32", &error); + assert!(rendered.contains("Context window blocked"), "{rendered}"); + assert!(rendered.contains("context_window_blocked"), "{rendered}"); + assert!( + rendered.contains("Session session-issue-32"), + "{rendered}" + ); + assert!( + rendered.contains("Model claude-sonnet-4-6"), + "{rendered}" + ); + assert!( + rendered.contains("Input estimate ~182000 tokens (heuristic)"), + "{rendered}" + ); + assert!( + rendered.contains("Total estimate ~246000 tokens (heuristic)"), + "{rendered}" + ); + assert!(rendered.contains("Compact /compact"), "{rendered}"); + assert!( + rendered.contains("Resume compact claw --resume session-issue-32 /compact"), + "{rendered}" + ); + assert!( + rendered.contains("Fresh session /clear --confirm"), + "{rendered}" + ); + assert!(rendered.contains("Reduce scope"), "{rendered}"); + assert!(rendered.contains("Retry rerun"), "{rendered}"); + } + + #[test] + fn provider_context_window_errors_are_reframed_with_same_guidance() { + let error = ApiError::Api { + status: "400".parse().expect("status"), + error_type: Some("invalid_request_error".to_string()), + message: Some( + "This model's maximum context length is 200000 tokens, but your request used 230000 tokens." + .to_string(), + ), + request_id: Some("req_ctx_456".to_string()), + body: String::new(), + retryable: false, + }; + + let rendered = format_user_visible_api_error("session-issue-32", &error); + assert!(rendered.contains("context_window_blocked"), "{rendered}"); + assert!( + rendered.contains("Trace req_ctx_456"), + "{rendered}" + ); + assert!( + rendered + .contains("Detail This model's maximum context length is 200000 tokens"), + "{rendered}" + ); + assert!(rendered.contains("Compact /compact"), "{rendered}"); + assert!( + rendered.contains("Fresh session /clear --confirm"), + "{rendered}" + ); + } + + #[test] + fn retry_wrapped_context_window_errors_keep_recovery_guidance() { + let error = ApiError::RetriesExhausted { + attempts: 2, + last_error: Box::new(ApiError::Api { + status: "413".parse().expect("status"), + error_type: Some("invalid_request_error".to_string()), + message: Some("Request is too large for this model's context window.".to_string()), + request_id: Some("req_ctx_retry_789".to_string()), + body: String::new(), + retryable: false, + }), + }; + + let rendered = format_user_visible_api_error("session-issue-32", &error); + assert!(rendered.contains("Context window blocked"), "{rendered}"); + assert!(rendered.contains("context_window_blocked"), "{rendered}"); + assert!( + rendered.contains("Trace req_ctx_retry_789"), + "{rendered}" + ); + assert!( + rendered + .contains("Detail Request is too large for this model's context window."), + "{rendered}" + ); + assert!(rendered.contains("Compact /compact"), "{rendered}"); + assert!( + rendered.contains("Resume compact claw --resume session-issue-32 /compact"), + "{rendered}" + ); + } + + fn temp_dir() -> PathBuf { + use std::sync::atomic::{AtomicU64, Ordering}; + + static COUNTER: AtomicU64 = AtomicU64::new(0); + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + let unique = COUNTER.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!("rusty-claude-cli-{nanos}-{unique}")) + } + + fn git(args: &[&str], cwd: &Path) { + let status = Command::new("git") + .args(args) + .current_dir(cwd) + .status() + .expect("git command should run"); + assert!( + status.success(), + "git command failed: git {}", + args.join(" ") + ); + } + + fn env_lock() -> MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:create_api_key".to_string(), "user:profile".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + + fn with_current_dir(cwd: &Path, f: impl FnOnce() -> T) -> T { + let _guard = cwd_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let previous = std::env::current_dir().expect("cwd should load"); + std::env::set_current_dir(cwd).expect("cwd should change"); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + std::env::set_current_dir(previous).expect("cwd should restore"); + match result { + Ok(value) => value, + Err(payload) => std::panic::resume_unwind(payload), + } + } + + fn write_plugin_fixture(root: &Path, name: &str, include_hooks: bool, include_lifecycle: bool) { + fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir"); + if include_hooks { + fs::create_dir_all(root.join("hooks")).expect("hooks dir"); + fs::write( + root.join("hooks").join("pre.sh"), + "#!/bin/sh\nprintf 'plugin pre hook'\n", + ) + .expect("write hook"); + } + if include_lifecycle { + fs::create_dir_all(root.join("lifecycle")).expect("lifecycle dir"); + fs::write( + root.join("lifecycle").join("init.sh"), + "#!/bin/sh\nprintf 'init\\n' >> lifecycle.log\n", + ) + .expect("write init lifecycle"); + fs::write( + root.join("lifecycle").join("shutdown.sh"), + "#!/bin/sh\nprintf 'shutdown\\n' >> lifecycle.log\n", + ) + .expect("write shutdown lifecycle"); + } + + let hooks = if include_hooks { + ",\n \"hooks\": {\n \"PreToolUse\": [\"./hooks/pre.sh\"]\n }" + } else { + "" + }; + let lifecycle = if include_lifecycle { + ",\n \"lifecycle\": {\n \"Init\": [\"./lifecycle/init.sh\"],\n \"Shutdown\": [\"./lifecycle/shutdown.sh\"]\n }" + } else { + "" + }; + fs::write( + root.join(".claude-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"runtime plugin fixture\"{hooks}{lifecycle}\n}}" + ), + ) + .expect("write plugin manifest"); + } #[test] fn defaults_to_repl_when_no_args() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); assert_eq!( parse_args(&[]).expect("args should parse"), CliAction::Repl { model: DEFAULT_MODEL.to_string(), allowed_tools: None, - permission_mode: PermissionMode::WorkspaceWrite, + permission_mode: PermissionMode::DangerFullAccess, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, } ); } + #[test] + fn default_permission_mode_uses_project_config_when_env_is_unset() { + let _guard = env_lock(); + let root = temp_dir(); + let cwd = root.join("project"); + let config_home = root.join("config-home"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + std::fs::write( + cwd.join(".claw").join("settings.json"), + r#"{"permissionMode":"acceptEdits"}"#, + ) + .expect("project config should write"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_permission_mode = std::env::var("RUSTY_CLAUDE_PERMISSION_MODE").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + + let resolved = with_current_dir(&cwd, super::default_permission_mode); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_permission_mode { + Some(value) => std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", value), + None => std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"), + } + std::fs::remove_dir_all(root).expect("temp config root should clean up"); + + assert_eq!(resolved, PermissionMode::WorkspaceWrite); + } + + #[test] + fn env_permission_mode_overrides_project_config_default() { + let _guard = env_lock(); + let root = temp_dir(); + let cwd = root.join("project"); + let config_home = root.join("config-home"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + std::fs::write( + cwd.join(".claw").join("settings.json"), + r#"{"permissionMode":"acceptEdits"}"#, + ) + .expect("project config should write"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_permission_mode = std::env::var("RUSTY_CLAUDE_PERMISSION_MODE").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "read-only"); + + let resolved = with_current_dir(&cwd, super::default_permission_mode); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_permission_mode { + Some(value) => std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", value), + None => std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"), + } + std::fs::remove_dir_all(root).expect("temp config root should clean up"); + + assert_eq!(resolved, PermissionMode::ReadOnly); + } + + #[test] + fn load_runtime_oauth_config_for_returns_none_without_project_config() { + let _guard = env_lock(); + let root = temp_dir(); + std::fs::create_dir_all(&root).expect("workspace should exist"); + + let oauth = super::load_runtime_oauth_config_for(&root) + .expect("loading config should succeed when files are absent"); + + std::fs::remove_dir_all(root).expect("temp workspace should clean up"); + + assert_eq!(oauth, None); + } + + #[test] + fn resolve_cli_auth_source_uses_default_oauth_when_runtime_config_is_missing() { + let _guard = env_lock(); + let workspace = temp_dir(); + let config_home = temp_dir(); + std::fs::create_dir_all(&workspace).expect("workspace should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_api_key = std::env::var("ANTHROPIC_API_KEY").ok(); + let original_auth_token = std::env::var("ANTHROPIC_AUTH_TOKEN").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(0), + scopes: vec!["org:create_api_key".to_string(), "user:profile".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + r#"{"access_token":"refreshed-access-token","refresh_token":"refreshed-refresh-token","expires_at":4102444800,"scopes":["org:create_api_key","user:profile"]}"#, + ); + + let auth = + super::resolve_cli_auth_source_for_cwd(&workspace, || sample_oauth_config(token_url)) + .expect("expired saved oauth should refresh via default config"); + + let stored = load_oauth_credentials() + .expect("load stored credentials") + .expect("stored credentials should exist"); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_api_key { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + match original_auth_token { + Some(value) => std::env::set_var("ANTHROPIC_AUTH_TOKEN", value), + None => std::env::remove_var("ANTHROPIC_AUTH_TOKEN"), + } + std::fs::remove_dir_all(workspace).expect("temp workspace should clean up"); + std::fs::remove_dir_all(config_home).expect("temp config home should clean up"); + + assert_eq!(auth.bearer_token(), Some("refreshed-access-token")); + assert_eq!(stored.access_token, "refreshed-access-token"); + assert_eq!( + stored.refresh_token.as_deref(), + Some("refreshed-refresh-token") + ); + } + #[test] fn parses_prompt_subcommand() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); let args = vec![ "prompt".to_string(), "hello".to_string(), @@ -2530,13 +8995,83 @@ mod tests { model: DEFAULT_MODEL.to_string(), output_format: CliOutputFormat::Text, allowed_tools: None, - permission_mode: PermissionMode::WorkspaceWrite, + permission_mode: PermissionMode::DangerFullAccess, + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, } ); } + #[test] + fn merge_prompt_with_stdin_returns_prompt_unchanged_when_no_pipe() { + // given + let prompt = "Review this"; + + // when + let merged = merge_prompt_with_stdin(prompt, None); + + // then + assert_eq!(merged, "Review this"); + } + + #[test] + fn merge_prompt_with_stdin_ignores_whitespace_only_pipe() { + // given + let prompt = "Review this"; + let piped = " \n\t\n "; + + // when + let merged = merge_prompt_with_stdin(prompt, Some(piped)); + + // then + assert_eq!(merged, "Review this"); + } + + #[test] + fn merge_prompt_with_stdin_appends_piped_content_as_context() { + // given + let prompt = "Review this"; + let piped = "fn main() { println!(\"hi\"); }\n"; + + // when + let merged = merge_prompt_with_stdin(prompt, Some(piped)); + + // then + assert_eq!(merged, "Review this\n\nfn main() { println!(\"hi\"); }"); + } + + #[test] + fn merge_prompt_with_stdin_trims_surrounding_whitespace_on_pipe() { + // given + let prompt = "Summarize"; + let piped = "\n\n some notes \n\n"; + + // when + let merged = merge_prompt_with_stdin(prompt, Some(piped)); + + // then + assert_eq!(merged, "Summarize\n\nsome notes"); + } + + #[test] + fn merge_prompt_with_stdin_returns_pipe_when_prompt_is_empty() { + // given + let prompt = ""; + let piped = "standalone body"; + + // when + let merged = merge_prompt_with_stdin(prompt, Some(piped)); + + // then + assert_eq!(merged, "standalone body"); + } + #[test] fn parses_bare_prompt_and_json_output_flag() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); let args = vec![ "--output-format=json".to_string(), "--model".to_string(), @@ -2551,20 +9086,149 @@ mod tests { model: "claude-opus".to_string(), output_format: CliOutputFormat::Json, allowed_tools: None, - permission_mode: PermissionMode::WorkspaceWrite, + permission_mode: PermissionMode::DangerFullAccess, + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, } ); } + #[test] + fn parses_compact_flag_for_prompt_mode() { + // given a bare prompt invocation that includes the --compact flag + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + let args = vec![ + "--compact".to_string(), + "summarize".to_string(), + "this".to_string(), + ]; + + // when parse_args interprets the flag + let parsed = parse_args(&args).expect("args should parse"); + + // then compact mode is propagated and other defaults stay unchanged + assert_eq!( + parsed, + CliAction::Prompt { + prompt: "summarize this".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + compact: true, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + } + + #[test] + fn prompt_subcommand_defaults_compact_to_false() { + // given a `prompt` subcommand invocation without --compact + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + let args = vec!["prompt".to_string(), "hello".to_string()]; + + // when parse_args runs + let parsed = parse_args(&args).expect("args should parse"); + + // then compact stays false (opt-in flag) + match parsed { + CliAction::Prompt { compact, .. } => assert!(!compact), + other => panic!("expected Prompt action, got {other:?}"), + } + } + + #[test] + fn resolves_model_aliases_in_args() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + let args = vec![ + "--model".to_string(), + "opus".to_string(), + "explain".to_string(), + "this".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Prompt { + prompt: "explain this".to_string(), + model: "claude-opus-4-6".to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + } + + #[test] + fn resolves_known_model_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6"); + assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213"); + assert_eq!(resolve_model_alias("claude-opus"), "claude-opus"); + } + + #[test] + fn user_defined_aliases_resolve_before_provider_dispatch() { + // given + let _guard = env_lock(); + let root = temp_dir(); + let cwd = root.join("project"); + let config_home = root.join("config-home"); + std::fs::create_dir_all(cwd.join(".claw")).expect("project config dir should exist"); + std::fs::create_dir_all(&config_home).expect("config home should exist"); + std::fs::write( + cwd.join(".claw").join("settings.json"), + r#"{"aliases":{"fast":"claude-haiku-4-5-20251213","smart":"opus","cheap":"grok-3-mini"}}"#, + ) + .expect("project config should write"); + + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + + // when + let direct = with_current_dir(&cwd, || resolve_model_alias_with_config("fast")); + let chained = with_current_dir(&cwd, || resolve_model_alias_with_config("smart")); + let cross_provider = with_current_dir(&cwd, || resolve_model_alias_with_config("cheap")); + let unknown = with_current_dir(&cwd, || resolve_model_alias_with_config("unknown-model")); + let builtin = with_current_dir(&cwd, || resolve_model_alias_with_config("haiku")); + + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + std::fs::remove_dir_all(root).expect("temp config root should clean up"); + + // then + assert_eq!(direct, "claude-haiku-4-5-20251213"); + assert_eq!(chained, "claude-opus-4-6"); + assert_eq!(cross_provider, "grok-3-mini"); + assert_eq!(unknown, "unknown-model"); + assert_eq!(builtin, "claude-haiku-4-5-20251213"); + } + #[test] fn parses_version_flags_without_initializing_prompt_mode() { assert_eq!( parse_args(&["--version".to_string()]).expect("args should parse"), - CliAction::Version + CliAction::Version { + output_format: CliOutputFormat::Text, + } ); assert_eq!( parse_args(&["-V".to_string()]).expect("args should parse"), - CliAction::Version + CliAction::Version { + output_format: CliOutputFormat::Text, + } ); } @@ -2577,12 +9241,68 @@ mod tests { model: DEFAULT_MODEL.to_string(), allowed_tools: None, permission_mode: PermissionMode::ReadOnly, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + } + + #[test] + fn dangerously_skip_permissions_flag_forces_danger_full_access_in_repl() { + let _guard = env_lock(); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "read-only"); + let args = vec!["--dangerously-skip-permissions".to_string()]; + let parsed = parse_args(&args).expect("args should parse"); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + + assert_eq!( + parsed, + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + } + + #[test] + fn dangerously_skip_permissions_flag_applies_to_prompt_subcommand() { + let _guard = env_lock(); + std::env::set_var("RUSTY_CLAUDE_PERMISSION_MODE", "read-only"); + let args = vec![ + "--dangerously-skip-permissions".to_string(), + "prompt".to_string(), + "do".to_string(), + "the".to_string(), + "thing".to_string(), + ]; + let parsed = parse_args(&args).expect("args should parse"); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + + assert_eq!( + parsed, + CliAction::Prompt { + prompt: "do the thing".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, } ); } #[test] fn parses_allowed_tools_flags_with_aliases_and_lists() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); let args = vec![ "--allowedTools".to_string(), "read,glob".to_string(), @@ -2598,7 +9318,10 @@ mod tests { .map(str::to_string) .collect() ), - permission_mode: PermissionMode::WorkspaceWrite, + permission_mode: PermissionMode::DangerFullAccess, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, } ); } @@ -2624,6 +9347,7 @@ mod tests { CliAction::PrintSystemPrompt { cwd: PathBuf::from("/tmp/project"), date: "2026-04-01".to_string(), + output_format: CliOutputFormat::Text, } ); } @@ -2632,30 +9356,640 @@ mod tests { fn parses_login_and_logout_subcommands() { assert_eq!( parse_args(&["login".to_string()]).expect("login should parse"), - CliAction::Login + CliAction::Login { + output_format: CliOutputFormat::Text, + } ); assert_eq!( parse_args(&["logout".to_string()]).expect("logout should parse"), - CliAction::Logout + CliAction::Logout { + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["doctor".to_string()]).expect("doctor should parse"), + CliAction::Doctor { + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["state".to_string()]).expect("state should parse"), + CliAction::State { + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&[ + "state".to_string(), + "--output-format".to_string(), + "json".to_string() + ]) + .expect("state --output-format json should parse"), + CliAction::State { + output_format: CliOutputFormat::Json, + } ); assert_eq!( parse_args(&["init".to_string()]).expect("init should parse"), - CliAction::Init + CliAction::Init { + output_format: CliOutputFormat::Text, + } ); + assert_eq!( + parse_args(&["agents".to_string()]).expect("agents should parse"), + CliAction::Agents { + args: None, + output_format: CliOutputFormat::Text + } + ); + assert_eq!( + parse_args(&["mcp".to_string()]).expect("mcp should parse"), + CliAction::Mcp { + args: None, + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["skills".to_string()]).expect("skills should parse"), + CliAction::Skills { + args: None, + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&[ + "skills".to_string(), + "help".to_string(), + "overview".to_string() + ]) + .expect("skills help overview should invoke"), + CliAction::Prompt { + prompt: "$help overview".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: crate::default_permission_mode(), + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + assert_eq!( + parse_args(&["agents".to_string(), "--help".to_string()]) + .expect("agents help should parse"), + CliAction::Agents { + args: Some("--help".to_string()), + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn local_command_help_flags_stay_on_the_local_parser_path() { + assert_eq!( + parse_args(&["status".to_string(), "--help".to_string()]) + .expect("status help should parse"), + CliAction::HelpTopic(LocalHelpTopic::Status) + ); + assert_eq!( + parse_args(&["sandbox".to_string(), "-h".to_string()]) + .expect("sandbox help should parse"), + CliAction::HelpTopic(LocalHelpTopic::Sandbox) + ); + assert_eq!( + parse_args(&["doctor".to_string(), "--help".to_string()]) + .expect("doctor help should parse"), + CliAction::HelpTopic(LocalHelpTopic::Doctor) + ); + } + + #[test] + fn parses_single_word_command_aliases_without_falling_back_to_prompt_mode() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + assert_eq!( + parse_args(&["help".to_string()]).expect("help should parse"), + CliAction::Help { + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["version".to_string()]).expect("version should parse"), + CliAction::Version { + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["status".to_string()]).expect("status should parse"), + CliAction::Status { + model: DEFAULT_MODEL.to_string(), + permission_mode: PermissionMode::DangerFullAccess, + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["sandbox".to_string()]).expect("sandbox should parse"), + CliAction::Sandbox { + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn parses_bare_export_subcommand_targeting_latest_session() { + // given + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + let args = vec!["export".to_string()]; + + // when + let parsed = parse_args(&args).expect("bare export should parse"); + + // then + assert_eq!( + parsed, + CliAction::Export { + session_reference: LATEST_SESSION_REFERENCE.to_string(), + output_path: None, + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn parses_export_subcommand_with_positional_output_path() { + // given + let args = vec!["export".to_string(), "conversation.md".to_string()]; + + // when + let parsed = parse_args(&args).expect("export with path should parse"); + + // then + assert_eq!( + parsed, + CliAction::Export { + session_reference: LATEST_SESSION_REFERENCE.to_string(), + output_path: Some(PathBuf::from("conversation.md")), + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn parses_export_subcommand_with_session_and_output_flags() { + // given + let args = vec![ + "export".to_string(), + "--session".to_string(), + "session-alpha".to_string(), + "--output".to_string(), + "/tmp/share.md".to_string(), + ]; + + // when + let parsed = parse_args(&args).expect("export flags should parse"); + + // then + assert_eq!( + parsed, + CliAction::Export { + session_reference: "session-alpha".to_string(), + output_path: Some(PathBuf::from("/tmp/share.md")), + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn parses_export_subcommand_with_inline_flag_values() { + // given + let args = vec![ + "export".to_string(), + "--session=session-beta".to_string(), + "--output=/tmp/beta.md".to_string(), + ]; + + // when + let parsed = parse_args(&args).expect("export inline flags should parse"); + + // then + assert_eq!( + parsed, + CliAction::Export { + session_reference: "session-beta".to_string(), + output_path: Some(PathBuf::from("/tmp/beta.md")), + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn parses_export_subcommand_with_json_output_format() { + // given + let args = vec![ + "--output-format=json".to_string(), + "export".to_string(), + "/tmp/notes.md".to_string(), + ]; + + // when + let parsed = parse_args(&args).expect("json export should parse"); + + // then + assert_eq!( + parsed, + CliAction::Export { + session_reference: LATEST_SESSION_REFERENCE.to_string(), + output_path: Some(PathBuf::from("/tmp/notes.md")), + output_format: CliOutputFormat::Json, + } + ); + } + + #[test] + fn rejects_unknown_export_options_with_helpful_message() { + // given + let args = vec!["export".to_string(), "--bogus".to_string()]; + + // when + let error = parse_args(&args).expect_err("unknown export option should fail"); + + // then + assert!(error.contains("unknown export option: --bogus")); + } + + #[test] + fn rejects_export_with_extra_positional_after_path() { + // given + let args = vec![ + "export".to_string(), + "first.md".to_string(), + "second.md".to_string(), + ]; + + // when + let error = parse_args(&args).expect_err("multiple positionals should fail"); + + // then + assert!(error.contains("unexpected export argument: second.md")); + } + + #[test] + fn parse_export_args_helper_defaults_to_latest_reference_and_no_output() { + // given + let args: Vec = vec![]; + + // when + let parsed = parse_export_args(&args, CliOutputFormat::Text) + .expect("empty export args should parse"); + + // then + assert_eq!( + parsed, + CliAction::Export { + session_reference: LATEST_SESSION_REFERENCE.to_string(), + output_path: None, + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn render_session_markdown_includes_header_and_summarized_tool_calls() { + // given + let mut session = Session::new(); + session.session_id = "session-export-test".to_string(); + session.messages = vec![ + ConversationMessage::user_text("How do I list files?"), + ConversationMessage::assistant(vec![ + ContentBlock::Text { + text: "I'll run a tool.".to_string(), + }, + ContentBlock::ToolUse { + id: "toolu_abcdefghijklmnop".to_string(), + name: "bash".to_string(), + input: r#"{"command":"ls -la"}"#.to_string(), + }, + ]), + ConversationMessage { + role: MessageRole::Tool, + blocks: vec![ContentBlock::ToolResult { + tool_use_id: "toolu_abcdefghijklmnop".to_string(), + tool_name: "bash".to_string(), + output: "total 8\ndrwxr-xr-x 2 user staff 64 Apr 7 12:00 .".to_string(), + is_error: false, + }], + usage: None, + }, + ]; + + // when + let markdown = render_session_markdown( + &session, + "session-export-test", + std::path::Path::new("/tmp/sessions/session-export-test.jsonl"), + ); + + // then + assert!(markdown.starts_with("# Conversation Export")); + assert!(markdown.contains("- **Session**: `session-export-test`")); + assert!(markdown.contains("- **Messages**: 3")); + assert!(markdown.contains("## 1. User")); + assert!(markdown.contains("How do I list files?")); + assert!(markdown.contains("## 2. Assistant")); + assert!(markdown.contains("**Tool call** `bash`")); + assert!(markdown.contains("toolu_abcdef…")); + assert!(markdown.contains("ls -la")); + assert!(markdown.contains("## 3. Tool")); + assert!(markdown.contains("**Tool result** `bash`")); + assert!(markdown.contains("ok")); + assert!(markdown.contains("total 8")); + } + + #[test] + fn render_session_markdown_marks_tool_errors_and_skips_empty_summaries() { + // given + let mut session = Session::new(); + session.session_id = "errs".to_string(); + session.messages = vec![ConversationMessage { + role: MessageRole::Tool, + blocks: vec![ContentBlock::ToolResult { + tool_use_id: "short".to_string(), + tool_name: "read_file".to_string(), + output: " ".to_string(), + is_error: true, + }], + usage: None, + }]; + + // when + let markdown = + render_session_markdown(&session, "errs", std::path::Path::new("errs.jsonl")); + + // then + assert!(markdown.contains("**Tool result** `read_file` _(id `short`, error)_")); + // an empty summary should not produce a stray blockquote line + assert!(!markdown.contains("> \n")); + } + + #[test] + fn summarize_tool_payload_for_markdown_compacts_json_and_truncates_overflow() { + // given + let json_payload = r#"{ + "command": "ls -la", + "cwd": "/tmp" + }"#; + let long_payload = "a".repeat(600); + + // when + let compacted = summarize_tool_payload_for_markdown(json_payload); + let truncated = summarize_tool_payload_for_markdown(&long_payload); + + // then + assert_eq!(compacted, r#"{"command":"ls -la","cwd":"/tmp"}"#); + assert!(truncated.ends_with('…')); + assert!(truncated.chars().count() <= 281); + } + + #[test] + fn short_tool_id_truncates_long_identifiers_with_ellipsis() { + // given + let long = "toolu_01ABCDEFGHIJKLMN"; + let short = "tool_1"; + + // when + let trimmed_long = short_tool_id(long); + let trimmed_short = short_tool_id(short); + + // then + assert_eq!(trimmed_long, "toolu_01ABCD…"); + assert_eq!(trimmed_short, "tool_1"); + } + + #[test] + fn parses_json_output_for_mcp_and_skills_commands() { + assert_eq!( + parse_args(&["--output-format=json".to_string(), "mcp".to_string()]) + .expect("json mcp should parse"), + CliAction::Mcp { + args: None, + output_format: CliOutputFormat::Json, + } + ); + assert_eq!( + parse_args(&[ + "--output-format=json".to_string(), + "/skills".to_string(), + "help".to_string(), + ]) + .expect("json /skills help should parse"), + CliAction::Skills { + args: Some("help".to_string()), + output_format: CliOutputFormat::Json, + } + ); + } + + #[test] + fn single_word_slash_command_names_return_guidance_instead_of_hitting_prompt_mode() { + let error = parse_args(&["cost".to_string()]).expect_err("cost should return guidance"); + assert!(error.contains("slash command")); + assert!(error.contains("/cost")); + } + + #[test] + fn multi_word_prompt_still_uses_shorthand_prompt_mode() { + let _guard = env_lock(); + std::env::remove_var("RUSTY_CLAUDE_PERMISSION_MODE"); + // Input is ["help", "me", "debug"] so the joined prompt shorthand + // must be "help me debug". A previous batch accidentally rewrote + // the expected string to "$help overview" (copy-paste slip). + assert_eq!( + parse_args(&["help".to_string(), "me".to_string(), "debug".to_string()]) + .expect("prompt shorthand should still work"), + CliAction::Prompt { + prompt: "help me debug".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: crate::default_permission_mode(), + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + } + + #[test] + fn parses_direct_agents_mcp_and_skills_slash_commands() { + assert_eq!( + parse_args(&["/agents".to_string()]).expect("/agents should parse"), + CliAction::Agents { + args: None, + output_format: CliOutputFormat::Text + } + ); + assert_eq!( + parse_args(&["/mcp".to_string(), "show".to_string(), "demo".to_string()]) + .expect("/mcp show demo should parse"), + CliAction::Mcp { + args: Some("show demo".to_string()), + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["/skills".to_string()]).expect("/skills should parse"), + CliAction::Skills { + args: None, + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["/skill".to_string()]).expect("/skill should parse"), + CliAction::Skills { + args: None, + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["/skills".to_string(), "help".to_string()]) + .expect("/skills help should parse"), + CliAction::Skills { + args: Some("help".to_string()), + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["/skill".to_string(), "list".to_string()]) + .expect("/skill list should parse"), + CliAction::Skills { + args: Some("list".to_string()), + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&[ + "/skills".to_string(), + "help".to_string(), + "overview".to_string() + ]) + .expect("/skills help overview should invoke"), + CliAction::Prompt { + prompt: "$help overview".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: crate::default_permission_mode(), + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + assert_eq!( + parse_args(&[ + "/skills".to_string(), + "install".to_string(), + "./fixtures/help-skill".to_string(), + ]) + .expect("/skills install should parse"), + CliAction::Skills { + args: Some("install ./fixtures/help-skill".to_string()), + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["/skills".to_string(), "/test".to_string()]) + .expect("/skills /test should normalize to a single skill prompt prefix"), + CliAction::Prompt { + prompt: "$test".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: crate::default_permission_mode(), + compact: false, + base_commit: None, + reasoning_effort: None, + allow_broad_cwd: false, + } + ); + let error = parse_args(&["/status".to_string()]) + .expect_err("/status should remain REPL-only when invoked directly"); + assert!(error.contains("interactive-only")); + assert!(error.contains("claw --resume SESSION.jsonl /status")); + } + + #[test] + fn direct_slash_commands_surface_shared_validation_errors() { + let compact_error = parse_args(&["/compact".to_string(), "now".to_string()]) + .expect_err("invalid /compact shape should be rejected"); + assert!(compact_error.contains("Unexpected arguments for /compact.")); + assert!(compact_error.contains("Usage /compact")); + + let plugins_error = parse_args(&[ + "/plugins".to_string(), + "list".to_string(), + "extra".to_string(), + ]) + .expect_err("invalid /plugins list shape should be rejected"); + assert!(plugins_error.contains("Usage: /plugin list")); + assert!(plugins_error.contains("Aliases /plugins, /marketplace")); + } + + #[test] + fn formats_unknown_slash_command_with_suggestions() { + let report = format_unknown_slash_command_message("statsu"); + assert!(report.contains("unknown slash command: /statsu")); + assert!(report.contains("Did you mean")); + assert!(report.contains("Use /help")); + } + + #[test] + fn formats_namespaced_omc_slash_command_with_contract_guidance() { + let report = format_unknown_slash_command_message("oh-my-claudecode:hud"); + assert!(report.contains("unknown slash command: /oh-my-claudecode:hud")); + assert!(report.contains("Claude Code/OMC plugin command")); + assert!(report.contains("plugin slash commands")); + assert!(report.contains("statusline")); + assert!(report.contains("session hooks")); } #[test] fn parses_resume_flag_with_slash_command() { let args = vec![ "--resume".to_string(), - "session.json".to_string(), + "session.jsonl".to_string(), "/compact".to_string(), ]; assert_eq!( parse_args(&args).expect("args should parse"), CliAction::ResumeSession { - session_path: PathBuf::from("session.json"), + session_path: PathBuf::from("session.jsonl"), commands: vec!["/compact".to_string()], + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn parses_resume_flag_without_path_as_latest_session() { + assert_eq!( + parse_args(&["--resume".to_string()]).expect("args should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("latest"), + commands: vec![], + output_format: CliOutputFormat::Text, + } + ); + assert_eq!( + parse_args(&["--resume".to_string(), "/status".to_string()]) + .expect("resume shortcut should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("latest"), + commands: vec!["/status".to_string()], + output_format: CliOutputFormat::Text, } ); } @@ -2664,7 +9998,7 @@ mod tests { fn parses_resume_flag_with_multiple_slash_commands() { let args = vec![ "--resume".to_string(), - "session.json".to_string(), + "session.jsonl".to_string(), "/status".to_string(), "/compact".to_string(), "/cost".to_string(), @@ -2672,12 +10006,63 @@ mod tests { assert_eq!( parse_args(&args).expect("args should parse"), CliAction::ResumeSession { - session_path: PathBuf::from("session.json"), + session_path: PathBuf::from("session.jsonl"), commands: vec![ "/status".to_string(), "/compact".to_string(), "/cost".to_string(), ], + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn rejects_unknown_options_with_helpful_guidance() { + let error = parse_args(&["--resum".to_string()]).expect_err("unknown option should fail"); + assert!(error.contains("unknown option: --resum")); + assert!(error.contains("Did you mean --resume?")); + assert!(error.contains("claw --help")); + } + + #[test] + fn parses_resume_flag_with_slash_command_arguments() { + let args = vec![ + "--resume".to_string(), + "session.jsonl".to_string(), + "/export".to_string(), + "notes.txt".to_string(), + "/clear".to_string(), + "--confirm".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("session.jsonl"), + commands: vec![ + "/export notes.txt".to_string(), + "/clear --confirm".to_string(), + ], + output_format: CliOutputFormat::Text, + } + ); + } + + #[test] + fn parses_resume_flag_with_absolute_export_path() { + let args = vec![ + "--resume".to_string(), + "session.jsonl".to_string(), + "/export".to_string(), + "/tmp/notes.txt".to_string(), + "/status".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("session.jsonl"), + commands: vec!["/export /tmp/notes.txt".to_string(), "/status".to_string()], + output_format: CliOutputFormat::Text, } ); } @@ -2688,7 +10073,7 @@ mod tests { .into_iter() .map(str::to_string) .collect(); - let filtered = filter_tool_specs(Some(&allowed)); + let filtered = filter_tool_specs(&GlobalToolRegistry::builtin(), Some(&allowed)); let names = filtered .into_iter() .map(|spec| spec.name) @@ -2696,11 +10081,35 @@ mod tests { assert_eq!(names, vec!["read_file", "grep_search"]); } + #[test] + fn filtered_tool_specs_include_plugin_tools() { + let filtered = filter_tool_specs(®istry_with_plugin_tool(), None); + let names = filtered + .into_iter() + .map(|definition| definition.name) + .collect::>(); + assert!(names.contains(&"bash".to_string())); + assert!(names.contains(&"plugin_echo".to_string())); + } + + #[test] + fn permission_policy_uses_plugin_tool_permissions() { + let feature_config = runtime::RuntimeFeatureConfig::default(); + let policy = permission_policy( + PermissionMode::ReadOnly, + &feature_config, + ®istry_with_plugin_tool(), + ) + .expect("permission policy should build"); + let required = policy.required_mode_for("plugin_echo"); + assert_eq!(required, PermissionMode::WorkspaceWrite); + } + #[test] fn shared_help_uses_resume_annotation_copy() { let help = commands::render_slash_command_help(); assert!(help.contains("Slash commands")); - assert!(help.contains("works with --resume SESSION.json")); + assert!(help.contains("works with --resume SESSION.jsonl")); } #[test] @@ -2708,20 +10117,141 @@ mod tests { let help = render_repl_help(); assert!(help.contains("REPL")); assert!(help.contains("/help")); + assert!(help.contains("Complete commands, modes, and recent sessions")); assert!(help.contains("/status")); + assert!(help.contains("/sandbox")); assert!(help.contains("/model [model]")); assert!(help.contains("/permissions [read-only|workspace-write|danger-full-access]")); assert!(help.contains("/clear [--confirm]")); assert!(help.contains("/cost")); assert!(help.contains("/resume ")); - assert!(help.contains("/config [env|hooks|model]")); + assert!(help.contains("/config [env|hooks|model|plugins]")); + assert!(help.contains("/mcp [list|show |help]")); assert!(help.contains("/memory")); assert!(help.contains("/init")); assert!(help.contains("/diff")); assert!(help.contains("/version")); assert!(help.contains("/export [file]")); - assert!(help.contains("/session [list|switch ]")); + // Batch 5 added `/session delete`; match on the stable core rather than + // the trailing bracket so future additions don't re-break this. + assert!(help.contains("/session [list|switch |fork [branch-name]")); + assert!(help.contains( + "/plugin [list|install |enable |disable |uninstall |update ]" + )); + assert!(help.contains("aliases: /plugins, /marketplace")); + assert!(help.contains("/agents")); + assert!(help.contains("/skills")); assert!(help.contains("/exit")); + assert!(help.contains("Auto-save .claw/sessions/.jsonl")); + assert!(help.contains("Resume latest /resume latest")); + } + + #[test] + fn completion_candidates_include_workflow_shortcuts_and_dynamic_sessions() { + let completions = slash_command_completion_candidates_with_sessions( + "sonnet", + Some("session-current"), + vec!["session-old".to_string()], + ); + + assert!(completions.contains(&"/model claude-sonnet-4-6".to_string())); + assert!(completions.contains(&"/permissions workspace-write".to_string())); + assert!(completions.contains(&"/session list".to_string())); + assert!(completions.contains(&"/session switch session-current".to_string())); + assert!(completions.contains(&"/resume session-old".to_string())); + assert!(completions.contains(&"/mcp list".to_string())); + assert!(completions.contains(&"/ultraplan ".to_string())); + } + + #[test] + fn startup_banner_mentions_workflow_completions() { + let _guard = env_lock(); + // Inject dummy credentials so LiveCli can construct without real Anthropic key + std::env::set_var("ANTHROPIC_API_KEY", "test-dummy-key-for-banner-test"); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + + let banner = with_current_dir(&root, || { + LiveCli::new( + "claude-sonnet-4-6".to_string(), + true, + None, + PermissionMode::DangerFullAccess, + ) + .expect("cli should initialize") + .startup_banner() + }); + + assert!(banner.contains("Tab")); + assert!(banner.contains("workflow completions")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn format_connected_line_renders_anthropic_provider_for_claude_model() { + let model = "claude-sonnet-4-6"; + + let line = format_connected_line(model); + + assert_eq!(line, "Connected: claude-sonnet-4-6 via anthropic"); + } + + #[test] + fn format_connected_line_renders_xai_provider_for_grok_model() { + let model = "grok-3"; + + let line = format_connected_line(model); + + assert_eq!(line, "Connected: grok-3 via xai"); + } + + #[test] + fn resolve_repl_model_returns_user_supplied_model_unchanged_when_explicit() { + let user_model = "claude-sonnet-4-6".to_string(); + + let resolved = resolve_repl_model(user_model); + + assert_eq!(resolved, "claude-sonnet-4-6"); + } + + #[test] + fn resolve_repl_model_falls_back_to_anthropic_model_env_when_default() { + let _guard = env_lock(); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + let config_home = root.join("config"); + fs::create_dir_all(&config_home).expect("config home dir"); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_MODEL"); + std::env::set_var("ANTHROPIC_MODEL", "sonnet"); + + let resolved = with_current_dir(&root, || resolve_repl_model(DEFAULT_MODEL.to_string())); + + assert_eq!(resolved, "claude-sonnet-4-6"); + + std::env::remove_var("ANTHROPIC_MODEL"); + std::env::remove_var("CLAW_CONFIG_HOME"); + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn resolve_repl_model_returns_default_when_env_unset_and_no_config() { + let _guard = env_lock(); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + let config_home = root.join("config"); + fs::create_dir_all(&config_home).expect("config home dir"); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_MODEL"); + + let resolved = with_current_dir(&root, || resolve_repl_model(DEFAULT_MODEL.to_string())); + + assert_eq!(resolved, DEFAULT_MODEL); + + std::env::remove_var("CLAW_CONFIG_HOME"); + fs::remove_dir_all(root).expect("cleanup temp dir"); } #[test] @@ -2730,20 +10260,23 @@ mod tests { .into_iter() .map(|spec| spec.name) .collect::>(); - assert_eq!( - names, - vec![ - "help", "status", "compact", "clear", "cost", "config", "memory", "init", "diff", - "version", "export", - ] + // Now with 135+ slash commands, verify minimum resume support + assert!( + names.len() >= 39, + "expected at least 39 resume-supported commands, got {}", + names.len() ); + // Verify key resume commands still exist + assert!(names.contains(&"help")); + assert!(names.contains(&"status")); + assert!(names.contains(&"compact")); } #[test] fn resume_report_uses_sectioned_layout() { - let report = format_resume_report("session.json", 14, 6); + let report = format_resume_report("session.jsonl", 14, 6); assert!(report.contains("Session resumed")); - assert!(report.contains("Session file session.json")); + assert!(report.contains("Session file session.jsonl")); assert!(report.contains("Messages 14")); assert!(report.contains("Turns 6")); } @@ -2800,7 +10333,15 @@ mod tests { let mut help = Vec::new(); print_help_to(&mut help).expect("help should render"); let help = String::from_utf8(help).expect("help should be utf8"); + assert!(help.contains("claw help")); + assert!(help.contains("claw version")); + assert!(help.contains("claw status")); + assert!(help.contains("claw sandbox")); assert!(help.contains("claw init")); + assert!(help.contains("claw agents")); + assert!(help.contains("claw mcp")); + assert!(help.contains("claw skills")); + assert!(help.contains("claw /skills")); } #[test] @@ -2845,12 +10386,20 @@ mod tests { "workspace-write", &super::StatusContext { cwd: PathBuf::from("/tmp/project"), - session_path: Some(PathBuf::from("session.json")), + session_path: Some(PathBuf::from("session.jsonl")), loaded_config_files: 2, discovered_config_files: 3, memory_file_count: 4, project_root: Some(PathBuf::from("/tmp")), git_branch: Some("main".to_string()), + git_summary: GitWorkspaceSummary { + changed_files: 3, + staged_files: 1, + unstaged_files: 1, + untracked_files: 1, + conflicted_files: 0, + }, + sandbox_status: runtime::SandboxStatus::default(), }, ); assert!(status.contains("Status")); @@ -2862,15 +10411,84 @@ mod tests { assert!(status.contains("Cwd /tmp/project")); assert!(status.contains("Project root /tmp")); assert!(status.contains("Git branch main")); - assert!(status.contains("Session session.json")); + assert!( + status.contains("Git state dirty · 3 files · 1 staged, 1 unstaged, 1 untracked") + ); + assert!(status.contains("Changed files 3")); + assert!(status.contains("Staged 1")); + assert!(status.contains("Unstaged 1")); + assert!(status.contains("Untracked 1")); + assert!(status.contains("Session session.jsonl")); assert!(status.contains("Config files loaded 2/3")); assert!(status.contains("Memory files 4")); + assert!(status.contains("Suggested flow /status → /diff → /commit")); + } + + #[test] + fn commit_reports_surface_workspace_context() { + let summary = GitWorkspaceSummary { + changed_files: 2, + staged_files: 1, + unstaged_files: 1, + untracked_files: 0, + conflicted_files: 0, + }; + + let preflight = format_commit_preflight_report(Some("feature/ux"), summary); + assert!(preflight.contains("Result ready")); + assert!(preflight.contains("Branch feature/ux")); + assert!(preflight.contains("Workspace dirty · 2 files · 1 staged, 1 unstaged")); + assert!(preflight + .contains("Action create a git commit from the current workspace changes")); + } + + #[test] + fn commit_skipped_report_points_to_next_steps() { + let report = format_commit_skipped_report(); + assert!(report.contains("Reason no workspace changes")); + assert!(report + .contains("Action create a git commit from the current workspace changes")); + assert!(report.contains("/status to inspect context")); + assert!(report.contains("/diff to inspect repo changes")); + } + + #[test] + fn runtime_slash_reports_describe_command_behavior() { + let bughunter = format_bughunter_report(Some("runtime")); + assert!(bughunter.contains("Scope runtime")); + assert!(bughunter.contains("inspect the selected code for likely bugs")); + + let ultraplan = format_ultraplan_report(Some("ship the release")); + assert!(ultraplan.contains("Task ship the release")); + assert!(ultraplan.contains("break work into a multi-step execution plan")); + + let pr = format_pr_report("feature/ux", Some("ready for review")); + assert!(pr.contains("Branch feature/ux")); + assert!(pr.contains("draft or create a pull request")); + + let issue = format_issue_report(Some("flaky test")); + assert!(issue.contains("Context flaky test")); + assert!(issue.contains("draft or create a GitHub issue")); + } + + #[test] + fn no_arg_commands_reject_unexpected_arguments() { + assert!(validate_no_args("/commit", None).is_ok()); + + let error = validate_no_args("/commit", Some("now")) + .expect_err("unexpected arguments should fail") + .to_string(); + assert!(error.contains("/commit does not accept arguments")); + assert!(error.contains("Received: now")); } #[test] fn config_report_supports_section_views() { let report = render_config_report(Some("env")).expect("config report should render"); assert!(report.contains("Merged section: env")); + let plugins_report = + render_config_report(Some("plugins")).expect("plugins config report should render"); + assert!(plugins_report.contains("Merged section: plugins")); } #[test] @@ -2892,19 +10510,161 @@ mod tests { #[test] fn parses_git_status_metadata() { - let (root, branch) = parse_git_status_metadata(Some( - "## rcc/cli...origin/rcc/cli + let _guard = env_lock(); + let temp_root = temp_dir(); + fs::create_dir_all(&temp_root).expect("root dir"); + let (project_root, branch) = parse_git_status_metadata_for( + &temp_root, + Some( + "## rcc/cli...origin/rcc/cli M src/main.rs", - )); + ), + ); assert_eq!(branch.as_deref(), Some("rcc/cli")); - let _ = root; + assert!(project_root.is_none()); + fs::remove_dir_all(temp_root).expect("cleanup temp dir"); + } + + #[test] + fn parses_detached_head_from_status_snapshot() { + let _guard = env_lock(); + assert_eq!( + parse_git_status_branch(Some( + "## HEAD (no branch) + M src/main.rs" + )), + Some("detached HEAD".to_string()) + ); + } + + #[test] + fn parses_git_workspace_summary_counts() { + let summary = parse_git_workspace_summary(Some( + "## feature/ux +M src/main.rs + M README.md +?? notes.md +UU conflicted.rs", + )); + + assert_eq!( + summary, + GitWorkspaceSummary { + changed_files: 4, + staged_files: 2, + unstaged_files: 2, + untracked_files: 1, + conflicted_files: 1, + } + ); + assert_eq!( + summary.headline(), + "dirty · 4 files · 2 staged, 2 unstaged, 1 untracked, 1 conflicted" + ); + } + + #[test] + fn render_diff_report_shows_clean_tree_for_committed_repo() { + let _guard = env_lock(); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + git(&["init", "--quiet"], &root); + git(&["config", "user.email", "tests@example.com"], &root); + git(&["config", "user.name", "Rusty Claude Tests"], &root); + fs::write(root.join("tracked.txt"), "hello\n").expect("write file"); + git(&["add", "tracked.txt"], &root); + git(&["commit", "-m", "init", "--quiet"], &root); + + let report = render_diff_report_for(&root).expect("diff report should render"); + assert!(report.contains("clean working tree")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn render_diff_report_includes_staged_and_unstaged_sections() { + let _guard = env_lock(); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + git(&["init", "--quiet"], &root); + git(&["config", "user.email", "tests@example.com"], &root); + git(&["config", "user.name", "Rusty Claude Tests"], &root); + fs::write(root.join("tracked.txt"), "hello\n").expect("write file"); + git(&["add", "tracked.txt"], &root); + git(&["commit", "-m", "init", "--quiet"], &root); + + fs::write(root.join("tracked.txt"), "hello\nstaged\n").expect("update file"); + git(&["add", "tracked.txt"], &root); + fs::write(root.join("tracked.txt"), "hello\nstaged\nunstaged\n") + .expect("update file twice"); + + let report = render_diff_report_for(&root).expect("diff report should render"); + assert!(report.contains("Staged changes:")); + assert!(report.contains("Unstaged changes:")); + assert!(report.contains("tracked.txt")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn render_diff_report_omits_ignored_files() { + let _guard = env_lock(); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + git(&["init", "--quiet"], &root); + git(&["config", "user.email", "tests@example.com"], &root); + git(&["config", "user.name", "Rusty Claude Tests"], &root); + fs::write(root.join(".gitignore"), ".omx/\nignored.txt\n").expect("write gitignore"); + fs::write(root.join("tracked.txt"), "hello\n").expect("write tracked"); + git(&["add", ".gitignore", "tracked.txt"], &root); + git(&["commit", "-m", "init", "--quiet"], &root); + fs::create_dir_all(root.join(".omx")).expect("write omx dir"); + fs::write(root.join(".omx").join("state.json"), "{}").expect("write ignored omx"); + fs::write(root.join("ignored.txt"), "secret\n").expect("write ignored file"); + fs::write(root.join("tracked.txt"), "hello\nworld\n").expect("write tracked change"); + + let report = render_diff_report_for(&root).expect("diff report should render"); + assert!(report.contains("tracked.txt")); + assert!(!report.contains("+++ b/ignored.txt")); + assert!(!report.contains("+++ b/.omx/state.json")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn resume_diff_command_renders_report_for_saved_session() { + let _guard = env_lock(); + let root = temp_dir(); + fs::create_dir_all(&root).expect("root dir"); + git(&["init", "--quiet"], &root); + git(&["config", "user.email", "tests@example.com"], &root); + git(&["config", "user.name", "Rusty Claude Tests"], &root); + fs::write(root.join("tracked.txt"), "hello\n").expect("write tracked"); + git(&["add", "tracked.txt"], &root); + git(&["commit", "-m", "init", "--quiet"], &root); + fs::write(root.join("tracked.txt"), "hello\nworld\n").expect("modify tracked"); + let session_path = root.join("session.json"); + Session::new() + .save_to_path(&session_path) + .expect("session should save"); + + let session = Session::load_from_path(&session_path).expect("session should load"); + let outcome = with_current_dir(&root, || { + run_resume_command(&session_path, &session, &SlashCommand::Diff) + .expect("resume diff should work") + }); + let message = outcome.message.expect("diff message should exist"); + assert!(message.contains("Unstaged changes:")); + assert!(message.contains("tracked.txt")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); } #[test] fn status_context_reads_real_workspace_metadata() { let context = status_context(None).expect("status context should load"); assert!(context.cwd.is_absolute()); - assert_eq!(context.discovered_config_files, 5); + assert!(context.discovered_config_files >= context.loaded_config_files); assert!(context.loaded_config_files <= context.discovered_config_files); } @@ -2926,43 +10686,176 @@ mod tests { fn clear_command_requires_explicit_confirmation_flag() { assert_eq!( SlashCommand::parse("/clear"), - Some(SlashCommand::Clear { confirm: false }) + Ok(Some(SlashCommand::Clear { confirm: false })) ); assert_eq!( SlashCommand::parse("/clear --confirm"), - Some(SlashCommand::Clear { confirm: true }) + Ok(Some(SlashCommand::Clear { confirm: true })) ); } #[test] fn parses_resume_and_config_slash_commands() { assert_eq!( - SlashCommand::parse("/resume saved-session.json"), - Some(SlashCommand::Resume { - session_path: Some("saved-session.json".to_string()) - }) + SlashCommand::parse("/resume saved-session.jsonl"), + Ok(Some(SlashCommand::Resume { + session_path: Some("saved-session.jsonl".to_string()) + })) ); assert_eq!( SlashCommand::parse("/clear --confirm"), - Some(SlashCommand::Clear { confirm: true }) + Ok(Some(SlashCommand::Clear { confirm: true })) ); assert_eq!( SlashCommand::parse("/config"), - Some(SlashCommand::Config { section: None }) + Ok(Some(SlashCommand::Config { section: None })) ); assert_eq!( SlashCommand::parse("/config env"), - Some(SlashCommand::Config { + Ok(Some(SlashCommand::Config { section: Some("env".to_string()) - }) + })) ); - assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory)); - assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init)); + assert_eq!( + SlashCommand::parse("/memory"), + Ok(Some(SlashCommand::Memory)) + ); + assert_eq!(SlashCommand::parse("/init"), Ok(Some(SlashCommand::Init))); + assert_eq!( + SlashCommand::parse("/session fork incident-review"), + Ok(Some(SlashCommand::Session { + action: Some("fork".to_string()), + target: Some("incident-review".to_string()) + })) + ); + } + + #[test] + fn help_mentions_jsonl_resume_examples() { + let mut help = Vec::new(); + print_help_to(&mut help).expect("help should render"); + let help = String::from_utf8(help).expect("help should be utf8"); + assert!(help.contains("claw --resume [SESSION.jsonl|session-id|latest]")); + assert!(help.contains("Use `latest` with --resume, /resume, or /session switch")); + assert!(help.contains("claw --resume latest")); + assert!(help.contains("claw --resume latest /status /diff /export notes.txt")); + } + + #[test] + fn managed_sessions_default_to_jsonl_and_resolve_legacy_json() { + let _guard = cwd_lock().lock().expect("cwd lock"); + let workspace = temp_workspace("session-resolution"); + std::fs::create_dir_all(&workspace).expect("workspace should create"); + let previous = std::env::current_dir().expect("cwd"); + std::env::set_current_dir(&workspace).expect("switch cwd"); + + let handle = create_managed_session_handle("session-alpha").expect("jsonl handle"); + assert!(handle.path.ends_with("session-alpha.jsonl")); + + let legacy_path = workspace.join(".claw/sessions/legacy.json"); + std::fs::create_dir_all( + legacy_path + .parent() + .expect("legacy path should have parent directory"), + ) + .expect("session dir should exist"); + Session::new() + .with_persistence_path(legacy_path.clone()) + .save_to_path(&legacy_path) + .expect("legacy session should save"); + + let resolved = resolve_session_reference("legacy").expect("legacy session should resolve"); + assert_eq!( + resolved + .path + .canonicalize() + .expect("resolved path should exist"), + legacy_path + .canonicalize() + .expect("legacy path should exist") + ); + + std::env::set_current_dir(previous).expect("restore cwd"); + std::fs::remove_dir_all(workspace).expect("workspace should clean up"); + } + + #[test] + fn latest_session_alias_resolves_most_recent_managed_session() { + let _guard = cwd_lock().lock().expect("cwd lock"); + let workspace = temp_workspace("latest-session-alias"); + std::fs::create_dir_all(&workspace).expect("workspace should create"); + let previous = std::env::current_dir().expect("cwd"); + std::env::set_current_dir(&workspace).expect("switch cwd"); + + let older = create_managed_session_handle("session-older").expect("older handle"); + Session::new() + .with_persistence_path(older.path.clone()) + .save_to_path(&older.path) + .expect("older session should save"); + std::thread::sleep(Duration::from_millis(20)); + let newer = create_managed_session_handle("session-newer").expect("newer handle"); + Session::new() + .with_persistence_path(newer.path.clone()) + .save_to_path(&newer.path) + .expect("newer session should save"); + + let resolved = resolve_session_reference("latest").expect("latest session should resolve"); + assert_eq!( + resolved + .path + .canonicalize() + .expect("resolved path should exist"), + newer.path.canonicalize().expect("newer path should exist") + ); + + std::env::set_current_dir(previous).expect("restore cwd"); + std::fs::remove_dir_all(workspace).expect("workspace should clean up"); + } + + #[test] + fn unknown_slash_command_guidance_suggests_nearby_commands() { + let message = format_unknown_slash_command("stats"); + assert!(message.contains("Unknown slash command: /stats")); + assert!(message.contains("/status")); + assert!(message.contains("/help")); + } + + #[test] + fn unknown_omc_slash_command_guidance_explains_runtime_gap() { + let message = format_unknown_slash_command("oh-my-claudecode:hud"); + assert!(message.contains("Unknown slash command: /oh-my-claudecode:hud")); + assert!(message.contains("Claude Code/OMC plugin command")); + assert!(message.contains("does not yet load plugin slash commands")); + } + + #[test] + fn resume_usage_mentions_latest_shortcut() { + let usage = render_resume_usage(); + assert!(usage.contains("/resume ")); + assert!(usage.contains(".claw/sessions/.jsonl")); + assert!(usage.contains("/session list")); + } + + fn cwd_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + } + + fn temp_workspace(label: &str) -> PathBuf { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("claw-cli-{label}-{nanos}")) } #[test] fn init_template_mentions_detected_rust_workspace() { - let rendered = crate::init::render_init_claude_md(std::path::Path::new(".")); + let _guard = cwd_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../.."); + let rendered = crate::init::render_init_claude_md(&workspace_root); assert!(rendered.contains("# CLAUDE.md")); assert!(rendered.contains("cargo clippy --workspace --all-targets -- -D warnings")); } @@ -2999,16 +10892,987 @@ mod tests { assert!(help.contains("Up/Down")); assert!(help.contains("Tab")); assert!(help.contains("Shift+Enter/Ctrl+J")); + assert!(help.contains("Ctrl-R")); + assert!(help.contains("Reverse-search prompt history")); + assert!(help.contains("/history [count]")); + } + + #[test] + fn parse_history_count_defaults_to_twenty_when_missing() { + // given + let raw: Option<&str> = None; + + // when + let parsed = parse_history_count(raw); + + // then + assert_eq!(parsed, Ok(20)); + } + + #[test] + fn parse_history_count_accepts_positive_integers() { + // given + let raw = Some("25"); + + // when + let parsed = parse_history_count(raw); + + // then + assert_eq!(parsed, Ok(25)); + } + + #[test] + fn parse_history_count_rejects_zero() { + // given + let raw = Some("0"); + + // when + let parsed = parse_history_count(raw); + + // then + assert!(parsed.is_err()); + assert!(parsed.unwrap_err().contains("greater than 0")); + } + + #[test] + fn parse_history_count_rejects_non_numeric() { + // given + let raw = Some("abc"); + + // when + let parsed = parse_history_count(raw); + + // then + assert!(parsed.is_err()); + assert!(parsed.unwrap_err().contains("invalid count 'abc'")); + } + + #[test] + fn format_history_timestamp_renders_iso8601_utc() { + // given + // 2023-01-15T12:34:56.789Z -> 1673786096789 ms + let timestamp_ms: u64 = 1_673_786_096_789; + + // when + let formatted = format_history_timestamp(timestamp_ms); + + // then + assert_eq!(formatted, "2023-01-15T12:34:56.789Z"); + } + + #[test] + fn format_history_timestamp_renders_unix_epoch_origin() { + // given + let timestamp_ms: u64 = 0; + + // when + let formatted = format_history_timestamp(timestamp_ms); + + // then + assert_eq!(formatted, "1970-01-01T00:00:00.000Z"); + } + + #[test] + fn render_prompt_history_report_lists_entries_with_timestamps() { + // given + let entries = vec![ + PromptHistoryEntry { + timestamp_ms: 1_673_786_096_000, + text: "first prompt".to_string(), + }, + PromptHistoryEntry { + timestamp_ms: 1_673_786_100_000, + text: "second prompt".to_string(), + }, + ]; + + // when + let rendered = render_prompt_history_report(&entries, 10); + + // then + assert!(rendered.contains("Prompt history")); + assert!(rendered.contains("Total 2")); + assert!(rendered.contains("Showing 2 most recent")); + assert!(rendered.contains("Reverse search Ctrl-R in the REPL")); + assert!(rendered.contains("2023-01-15T12:34:56.000Z")); + assert!(rendered.contains("first prompt")); + assert!(rendered.contains("second prompt")); + } + + #[test] + fn render_prompt_history_report_truncates_to_limit_from_the_tail() { + // given + let entries = vec![ + PromptHistoryEntry { + timestamp_ms: 1_000, + text: "older".to_string(), + }, + PromptHistoryEntry { + timestamp_ms: 2_000, + text: "middle".to_string(), + }, + PromptHistoryEntry { + timestamp_ms: 3_000, + text: "latest".to_string(), + }, + ]; + + // when + let rendered = render_prompt_history_report(&entries, 2); + + // then + assert!(rendered.contains("Total 3")); + assert!(rendered.contains("Showing 2 most recent")); + assert!(!rendered.contains("older")); + assert!(rendered.contains("middle")); + assert!(rendered.contains("latest")); + } + + #[test] + fn render_prompt_history_report_handles_empty_history() { + // given + let entries: Vec = Vec::new(); + + // when + let rendered = render_prompt_history_report(&entries, 10); + + // then + assert!(rendered.contains("no prompts recorded yet")); + } + + #[test] + fn collect_session_prompt_history_extracts_user_text_blocks() { + // given + let mut session = Session::new(); + session.push_user_text("hello").unwrap(); + session.push_user_text("world").unwrap(); + + // when + let entries = collect_session_prompt_history(&session); + + // then + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].text, "hello"); + assert_eq!(entries[1].text, "world"); } #[test] fn tool_rendering_helpers_compact_output() { let start = format_tool_call_start("read_file", r#"{"path":"src/main.rs"}"#); - assert!(start.contains("Tool call")); + assert!(start.contains("read_file")); assert!(start.contains("src/main.rs")); - let done = format_tool_result("read_file", r#"{"contents":"hello"}"#, false); - assert!(done.contains("Tool `read_file`")); - assert!(done.contains("contents")); + let done = format_tool_result( + "read_file", + r#"{"file":{"filePath":"src/main.rs","content":"hello","numLines":1,"startLine":1,"totalLines":1}}"#, + false, + ); + assert!(done.contains("📄 Read src/main.rs")); + assert!(done.contains("hello")); + } + + #[test] + fn tool_rendering_truncates_large_read_output_for_display_only() { + let content = (0..200) + .map(|index| format!("line {index:03}")) + .collect::>() + .join("\n"); + let output = json!({ + "file": { + "filePath": "src/main.rs", + "content": content, + "numLines": 200, + "startLine": 1, + "totalLines": 200 + } + }) + .to_string(); + + let rendered = format_tool_result("read_file", &output, false); + + assert!(rendered.contains("line 000")); + assert!(rendered.contains("line 079")); + assert!(!rendered.contains("line 199")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("line 199")); + } + + #[test] + fn tool_rendering_truncates_large_bash_output_for_display_only() { + let stdout = (0..120) + .map(|index| format!("stdout {index:03}")) + .collect::>() + .join("\n"); + let output = json!({ + "stdout": stdout, + "stderr": "", + "returnCodeInterpretation": "completed successfully" + }) + .to_string(); + + let rendered = format_tool_result("bash", &output, false); + + assert!(rendered.contains("stdout 000")); + assert!(rendered.contains("stdout 059")); + assert!(!rendered.contains("stdout 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("stdout 119")); + } + + #[test] + fn tool_rendering_truncates_generic_long_output_for_display_only() { + let items = (0..120) + .map(|index| format!("payload {index:03}")) + .collect::>(); + let output = json!({ + "summary": "plugin payload", + "items": items, + }) + .to_string(); + + let rendered = format_tool_result("plugin_echo", &output, false); + + assert!(rendered.contains("plugin_echo")); + assert!(rendered.contains("payload 000")); + assert!(rendered.contains("payload 040")); + assert!(!rendered.contains("payload 080")); + assert!(!rendered.contains("payload 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("payload 119")); + } + + #[test] + fn tool_rendering_truncates_raw_generic_output_for_display_only() { + let output = (0..120) + .map(|index| format!("raw {index:03}")) + .collect::>() + .join("\n"); + + let rendered = format_tool_result("plugin_echo", &output, false); + + assert!(rendered.contains("plugin_echo")); + assert!(rendered.contains("raw 000")); + assert!(rendered.contains("raw 059")); + assert!(!rendered.contains("raw 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("raw 119")); + } + + #[test] + fn ultraplan_progress_lines_include_phase_step_and_elapsed_status() { + let snapshot = InternalPromptProgressState { + command_label: "Ultraplan", + task_label: "ship plugin progress".to_string(), + step: 3, + phase: "running read_file".to_string(), + detail: Some("reading rust/crates/rusty-claude-cli/src/main.rs".to_string()), + saw_final_text: false, + }; + + let started = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Started, + &snapshot, + Duration::from_secs(0), + None, + ); + let heartbeat = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Heartbeat, + &snapshot, + Duration::from_secs(9), + None, + ); + let completed = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Complete, + &snapshot, + Duration::from_secs(12), + None, + ); + let failed = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Failed, + &snapshot, + Duration::from_secs(12), + Some("network timeout"), + ); + + assert!(started.contains("planning started")); + assert!(started.contains("current step 3")); + assert!(heartbeat.contains("heartbeat")); + assert!(heartbeat.contains("9s elapsed")); + assert!(heartbeat.contains("phase running read_file")); + assert!(completed.contains("completed")); + assert!(completed.contains("3 steps total")); + assert!(failed.contains("failed")); + assert!(failed.contains("network timeout")); + } + + #[test] + fn describe_tool_progress_summarizes_known_tools() { + assert_eq!( + describe_tool_progress("read_file", r#"{"path":"src/main.rs"}"#), + "reading src/main.rs" + ); + assert!( + describe_tool_progress("bash", r#"{"command":"cargo test -p rusty-claude-cli"}"#) + .contains("cargo test -p rusty-claude-cli") + ); + assert_eq!( + describe_tool_progress("grep_search", r#"{"pattern":"ultraplan","path":"rust"}"#), + "grep `ultraplan` in rust" + ); + } + + #[test] + fn push_output_block_renders_markdown_text() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tool = None; + let mut block_has_thinking_summary = false; + + push_output_block( + OutputContentBlock::Text { + text: "# Heading".to_string(), + }, + &mut out, + &mut events, + &mut pending_tool, + false, + &mut block_has_thinking_summary, + ) + .expect("text block should render"); + + let rendered = String::from_utf8(out).expect("utf8"); + assert!(rendered.contains("Heading")); + assert!(rendered.contains('\u{1b}')); + } + + #[test] + fn push_output_block_skips_empty_object_prefix_for_tool_streams() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tool = None; + let mut block_has_thinking_summary = false; + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + &mut out, + &mut events, + &mut pending_tool, + true, + &mut block_has_thinking_summary, + ) + .expect("tool block should accumulate"); + + assert!(events.is_empty()); + assert_eq!( + pending_tool, + Some(("tool-1".to_string(), "read_file".to_string(), String::new(),)) + ); + } + + #[test] + fn response_to_events_preserves_empty_object_json_input_outside_streaming() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-1".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }], + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::ToolUse { name, input, .. } + if name == "read_file" && input == "{}" + )); + } + + #[test] + fn response_to_events_preserves_non_empty_json_input_outside_streaming() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-2".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "read_file".to_string(), + input: json!({ "path": "rust/Cargo.toml" }), + }], + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::ToolUse { name, input, .. } + if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}" + )); + } + + #[test] + fn response_to_events_renders_collapsed_thinking_summary() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-3".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![ + OutputContentBlock::Thinking { + thinking: "step 1".to_string(), + signature: Some("sig_123".to_string()), + }, + OutputContentBlock::Text { + text: "Final answer".to_string(), + }, + ], + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::TextDelta(text) if text == "Final answer" + )); + let rendered = String::from_utf8(out).expect("utf8"); + assert!(rendered.contains("▶ Thinking (6 chars hidden)")); + assert!(!rendered.contains("step 1")); + } + + #[test] + fn login_browser_failure_keeps_json_stdout_clean() { + let mut stdout = Vec::new(); + let mut stderr = Vec::new(); + let error = std::io::Error::new( + std::io::ErrorKind::NotFound, + "no supported browser opener command found", + ); + + super::emit_login_browser_open_failure( + CliOutputFormat::Json, + "https://example.test/oauth/authorize", + &error, + &mut stdout, + &mut stderr, + ) + .expect("browser warning should render"); + + assert!(stdout.is_empty()); + let stderr = String::from_utf8(stderr).expect("utf8"); + assert!(stderr.contains("failed to open browser automatically")); + assert!(stderr.contains("Open this URL manually:")); + assert!(stderr.contains("https://example.test/oauth/authorize")); + } + + #[test] + fn build_runtime_plugin_state_merges_plugin_hooks_into_runtime_features() { + let config_home = temp_dir(); + let workspace = temp_dir(); + let source_root = temp_dir(); + fs::create_dir_all(&config_home).expect("config home"); + fs::create_dir_all(&workspace).expect("workspace"); + fs::create_dir_all(&source_root).expect("source root"); + write_plugin_fixture(&source_root, "hook-runtime-demo", true, false); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(source_root.to_str().expect("utf8 source path")) + .expect("plugin install should succeed"); + let loader = ConfigLoader::new(&workspace, &config_home); + let runtime_config = loader.load().expect("runtime config should load"); + let state = build_runtime_plugin_state_with_loader(&workspace, &loader, &runtime_config) + .expect("plugin state should load"); + let pre_hooks = state.feature_config.hooks().pre_tool_use(); + assert_eq!(pre_hooks.len(), 1); + assert!( + pre_hooks[0].ends_with("hooks/pre.sh"), + "expected installed plugin hook path, got {pre_hooks:?}" + ); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + #[allow(clippy::too_many_lines)] + fn build_runtime_plugin_state_discovers_mcp_tools_and_surfaces_pending_servers() { + let config_home = temp_dir(); + let workspace = temp_dir(); + fs::create_dir_all(&config_home).expect("config home"); + fs::create_dir_all(&workspace).expect("workspace"); + let script_path = workspace.join("fixture-mcp.py"); + write_mcp_server_fixture(&script_path); + fs::write( + config_home.join("settings.json"), + format!( + r#"{{ + "mcpServers": {{ + "alpha": {{ + "command": "python3", + "args": ["{}"] + }}, + "broken": {{ + "command": "python3", + "args": ["-c", "import sys; sys.exit(0)"] + }} + }} + }}"#, + script_path.to_string_lossy() + ), + ) + .expect("write mcp settings"); + + let loader = ConfigLoader::new(&workspace, &config_home); + let runtime_config = loader.load().expect("runtime config should load"); + let state = build_runtime_plugin_state_with_loader(&workspace, &loader, &runtime_config) + .expect("runtime plugin state should load"); + + let allowed = state + .tool_registry + .normalize_allowed_tools(&["mcp__alpha__echo".to_string(), "MCPTool".to_string()]) + .expect("mcp tools should be allow-listable") + .expect("allow-list should exist"); + assert!(allowed.contains("mcp__alpha__echo")); + assert!(allowed.contains("MCPTool")); + + let mut executor = CliToolExecutor::new( + None, + false, + state.tool_registry.clone(), + state.mcp_state.clone(), + ); + + let tool_output = executor + .execute("mcp__alpha__echo", r#"{"text":"hello"}"#) + .expect("discovered mcp tool should execute"); + let tool_json: serde_json::Value = + serde_json::from_str(&tool_output).expect("tool output should be json"); + assert_eq!(tool_json["structuredContent"]["echoed"], "hello"); + + let wrapped_output = executor + .execute( + "MCPTool", + r#"{"qualifiedName":"mcp__alpha__echo","arguments":{"text":"wrapped"}}"#, + ) + .expect("generic mcp wrapper should execute"); + let wrapped_json: serde_json::Value = + serde_json::from_str(&wrapped_output).expect("wrapped output should be json"); + assert_eq!(wrapped_json["structuredContent"]["echoed"], "wrapped"); + + let search_output = executor + .execute("ToolSearch", r#"{"query":"alpha echo","max_results":5}"#) + .expect("tool search should execute"); + let search_json: serde_json::Value = + serde_json::from_str(&search_output).expect("search output should be json"); + assert_eq!(search_json["matches"][0], "mcp__alpha__echo"); + assert_eq!(search_json["pending_mcp_servers"][0], "broken"); + assert_eq!( + search_json["mcp_degraded"]["failed_servers"][0]["server_name"], + "broken" + ); + assert_eq!( + search_json["mcp_degraded"]["failed_servers"][0]["phase"], + "tool_discovery" + ); + assert_eq!( + search_json["mcp_degraded"]["available_tools"][0], + "mcp__alpha__echo" + ); + + let listed = executor + .execute("ListMcpResourcesTool", r#"{"server":"alpha"}"#) + .expect("resources should list"); + let listed_json: serde_json::Value = + serde_json::from_str(&listed).expect("resource output should be json"); + assert_eq!(listed_json["resources"][0]["uri"], "file://guide.txt"); + + let read = executor + .execute( + "ReadMcpResourceTool", + r#"{"server":"alpha","uri":"file://guide.txt"}"#, + ) + .expect("resource should read"); + let read_json: serde_json::Value = + serde_json::from_str(&read).expect("resource read output should be json"); + assert_eq!( + read_json["contents"][0]["text"], + "contents for file://guide.txt" + ); + + if let Some(mcp_state) = state.mcp_state { + mcp_state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .shutdown() + .expect("mcp shutdown should succeed"); + } + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(workspace); + } + + #[test] + fn build_runtime_plugin_state_surfaces_unsupported_mcp_servers_structurally() { + let config_home = temp_dir(); + let workspace = temp_dir(); + fs::create_dir_all(&config_home).expect("config home"); + fs::create_dir_all(&workspace).expect("workspace"); + fs::write( + config_home.join("settings.json"), + r#"{ + "mcpServers": { + "remote": { + "url": "https://example.test/mcp" + } + } + }"#, + ) + .expect("write mcp settings"); + + let loader = ConfigLoader::new(&workspace, &config_home); + let runtime_config = loader.load().expect("runtime config should load"); + let state = build_runtime_plugin_state_with_loader(&workspace, &loader, &runtime_config) + .expect("runtime plugin state should load"); + let mut executor = CliToolExecutor::new( + None, + false, + state.tool_registry.clone(), + state.mcp_state.clone(), + ); + + let search_output = executor + .execute("ToolSearch", r#"{"query":"remote","max_results":5}"#) + .expect("tool search should execute"); + let search_json: serde_json::Value = + serde_json::from_str(&search_output).expect("search output should be json"); + assert_eq!(search_json["pending_mcp_servers"][0], "remote"); + assert_eq!( + search_json["mcp_degraded"]["failed_servers"][0]["server_name"], + "remote" + ); + assert_eq!( + search_json["mcp_degraded"]["failed_servers"][0]["phase"], + "server_registration" + ); + assert_eq!( + search_json["mcp_degraded"]["failed_servers"][0]["error"]["context"]["transport"], + "http" + ); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(workspace); + } + + #[test] + fn build_runtime_runs_plugin_lifecycle_init_and_shutdown() { + // Serialize access to process-wide env vars so parallel tests that + // set/remove ANTHROPIC_API_KEY do not race with this test. + let _guard = env_lock(); + let config_home = temp_dir(); + // Inject a dummy API key so runtime construction succeeds without real credentials. + // This test only exercises plugin lifecycle (init/shutdown), never calls the API. + std::env::set_var("ANTHROPIC_API_KEY", "test-dummy-key-for-plugin-lifecycle"); + let workspace = temp_dir(); + let source_root = temp_dir(); + fs::create_dir_all(&config_home).expect("config home"); + fs::create_dir_all(&workspace).expect("workspace"); + fs::create_dir_all(&source_root).expect("source root"); + write_plugin_fixture(&source_root, "lifecycle-runtime-demo", false, true); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install = manager + .install(source_root.to_str().expect("utf8 source path")) + .expect("plugin install should succeed"); + let log_path = install.install_path.join("lifecycle.log"); + let loader = ConfigLoader::new(&workspace, &config_home); + let runtime_config = loader.load().expect("runtime config should load"); + let runtime_plugin_state = + build_runtime_plugin_state_with_loader(&workspace, &loader, &runtime_config) + .expect("plugin state should load"); + let mut runtime = build_runtime_with_plugin_state( + Session::new(), + "runtime-plugin-lifecycle", + DEFAULT_MODEL.to_string(), + vec!["test system prompt".to_string()], + true, + false, + None, + PermissionMode::DangerFullAccess, + None, + runtime_plugin_state, + ) + .expect("runtime should build"); + + assert_eq!( + fs::read_to_string(&log_path).expect("init log should exist"), + "init\n" + ); + + runtime + .shutdown_plugins() + .expect("plugin shutdown should succeed"); + + assert_eq!( + fs::read_to_string(&log_path).expect("shutdown log should exist"), + "init\nshutdown\n" + ); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(source_root); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn rejects_invalid_reasoning_effort_value() { + let err = parse_args(&[ + "--reasoning-effort".to_string(), + "turbo".to_string(), + "prompt".to_string(), + "hello".to_string(), + ]) + .unwrap_err(); + assert!( + err.contains("invalid value for --reasoning-effort"), + "unexpected error: {err}" + ); + assert!(err.contains("turbo"), "unexpected error: {err}"); + } + + #[test] + fn accepts_valid_reasoning_effort_values() { + for value in ["low", "medium", "high"] { + let result = parse_args(&[ + "--reasoning-effort".to_string(), + value.to_string(), + "prompt".to_string(), + "hello".to_string(), + ]); + assert!( + result.is_ok(), + "--reasoning-effort {value} should be accepted, got: {:?}", + result + ); + if let Ok(CliAction::Prompt { + reasoning_effort, .. + }) = result + { + assert_eq!(reasoning_effort.as_deref(), Some(value)); + } + } + } + + #[test] + fn stub_commands_absent_from_repl_completions() { + let candidates = + slash_command_completion_candidates_with_sessions("claude-3-5-sonnet", None, vec![]); + for stub in STUB_COMMANDS { + let with_slash = format!("/{stub}"); + assert!( + !candidates.contains(&with_slash), + "stub command {with_slash} should not appear in REPL completions" + ); + } + } +} + +fn write_mcp_server_fixture(script_path: &Path) { + let script = [ + "#!/usr/bin/env python3", + "import json, sys", + "", + "def read_message():", + " header = b''", + r" while not header.endswith(b'\r\n\r\n'):", + " chunk = sys.stdin.buffer.read(1)", + " if not chunk:", + " return None", + " header += chunk", + " length = 0", + r" for line in header.decode().split('\r\n'):", + r" if line.lower().startswith('content-length:'):", + " length = int(line.split(':', 1)[1].strip())", + " payload = sys.stdin.buffer.read(length)", + " return json.loads(payload.decode())", + "", + "def send_message(message):", + " payload = json.dumps(message).encode()", + r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)", + " sys.stdout.buffer.flush()", + "", + "while True:", + " request = read_message()", + " if request is None:", + " break", + " method = request['method']", + " if method == 'initialize':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'protocolVersion': request['params']['protocolVersion'],", + " 'capabilities': {'tools': {}, 'resources': {}},", + " 'serverInfo': {'name': 'fixture', 'version': '1.0.0'}", + " }", + " })", + " elif method == 'tools/list':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'tools': [", + " {", + " 'name': 'echo',", + " 'description': 'Echo from MCP fixture',", + " 'inputSchema': {", + " 'type': 'object',", + " 'properties': {'text': {'type': 'string'}},", + " 'required': ['text'],", + " 'additionalProperties': False", + " },", + " 'annotations': {'readOnlyHint': True}", + " }", + " ]", + " }", + " })", + " elif method == 'tools/call':", + " args = request['params'].get('arguments') or {}", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'content': [{'type': 'text', 'text': f\"echo:{args.get('text', '')}\"}],", + " 'structuredContent': {'echoed': args.get('text', '')},", + " 'isError': False", + " }", + " })", + " elif method == 'resources/list':", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'resources': [{'uri': 'file://guide.txt', 'name': 'guide', 'mimeType': 'text/plain'}]", + " }", + " })", + " elif method == 'resources/read':", + " uri = request['params']['uri']", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'result': {", + " 'contents': [{'uri': uri, 'mimeType': 'text/plain', 'text': f'contents for {uri}'}]", + " }", + " })", + " else:", + " send_message({", + " 'jsonrpc': '2.0',", + " 'id': request['id'],", + " 'error': {'code': -32601, 'message': method}", + " })", + "", + ] + .join("\n"); + fs::write(script_path, script).expect("mcp fixture script should write"); +} + +#[cfg(test)] +mod sandbox_report_tests { + use super::{format_sandbox_report, HookAbortMonitor}; + use runtime::HookAbortSignal; + use std::sync::mpsc; + use std::time::Duration; + + #[test] + fn sandbox_report_renders_expected_fields() { + let report = format_sandbox_report(&runtime::SandboxStatus::default()); + assert!(report.contains("Sandbox")); + assert!(report.contains("Enabled")); + assert!(report.contains("Filesystem mode")); + assert!(report.contains("Fallback reason")); + } + + #[test] + fn hook_abort_monitor_stops_without_aborting() { + let abort_signal = HookAbortSignal::new(); + let (ready_tx, ready_rx) = mpsc::channel(); + let monitor = HookAbortMonitor::spawn_with_waiter( + abort_signal.clone(), + move |stop_rx, abort_signal| { + ready_tx.send(()).expect("ready signal"); + let _ = stop_rx.recv(); + assert!(!abort_signal.is_aborted()); + }, + ); + + ready_rx.recv().expect("waiter should be ready"); + monitor.stop(); + + assert!(!abort_signal.is_aborted()); + } + + #[test] + fn hook_abort_monitor_propagates_interrupt() { + let abort_signal = HookAbortSignal::new(); + let (done_tx, done_rx) = mpsc::channel(); + let monitor = HookAbortMonitor::spawn_with_waiter( + abort_signal.clone(), + move |_stop_rx, abort_signal| { + abort_signal.abort(); + done_tx.send(()).expect("done signal"); + }, + ); + + done_rx + .recv_timeout(Duration::from_secs(1)) + .expect("interrupt should complete"); + monitor.stop(); + + assert!(abort_signal.is_aborted()); } } diff --git a/crates/rusty-claude-cli/src/render.rs b/crates/rusty-claude-cli/src/render.rs index 18423b3..cb7828d 100644 --- a/crates/rusty-claude-cli/src/render.rs +++ b/crates/rusty-claude-cli/src/render.rs @@ -1,7 +1,5 @@ use std::fmt::Write as FmtWrite; use std::io::{self, Write}; -use std::thread; -use std::time::Duration; use crossterm::cursor::{MoveToColumn, RestorePosition, SavePosition}; use crossterm::style::{Color, Print, ResetColor, SetForegroundColor, Stylize}; @@ -22,6 +20,7 @@ pub struct ColorTheme { link: Color, quote: Color, table_border: Color, + code_block_border: Color, spinner_active: Color, spinner_done: Color, spinner_failed: Color, @@ -37,6 +36,7 @@ impl Default for ColorTheme { link: Color::Blue, quote: Color::DarkGrey, table_border: Color::DarkCyan, + code_block_border: Color::DarkGrey, spinner_active: Color::Blue, spinner_done: Color::Green, spinner_failed: Color::Red, @@ -154,33 +154,64 @@ impl TableState { struct RenderState { emphasis: usize, strong: usize, + heading_level: Option, quote: usize, list_stack: Vec, + link_stack: Vec, table: Option, } +#[derive(Debug, Clone, PartialEq, Eq)] +struct LinkState { + destination: String, + text: String, +} + impl RenderState { fn style_text(&self, text: &str, theme: &ColorTheme) -> String { - let mut styled = text.to_string(); - if self.strong > 0 { - styled = format!("{}", styled.bold().with(theme.strong)); + let mut style = text.stylize(); + + if matches!(self.heading_level, Some(1 | 2)) || self.strong > 0 { + style = style.bold(); } if self.emphasis > 0 { - styled = format!("{}", styled.italic().with(theme.emphasis)); + style = style.italic(); } + + if let Some(level) = self.heading_level { + style = match level { + 1 => style.with(theme.heading), + 2 => style.white(), + 3 => style.with(Color::Blue), + _ => style.with(Color::Grey), + }; + } else if self.strong > 0 { + style = style.with(theme.strong); + } else if self.emphasis > 0 { + style = style.with(theme.emphasis); + } + if self.quote > 0 { - styled = format!("{}", styled.with(theme.quote)); + style = style.with(theme.quote); } - styled + + format!("{style}") } - fn capture_target_mut<'a>(&'a mut self, output: &'a mut String) -> &'a mut String { - if let Some(table) = self.table.as_mut() { - &mut table.current_cell + fn append_raw(&mut self, output: &mut String, text: &str) { + if let Some(link) = self.link_stack.last_mut() { + link.text.push_str(text); + } else if let Some(table) = self.table.as_mut() { + table.current_cell.push_str(text); } else { - output + output.push_str(text); } } + + fn append_styled(&mut self, output: &mut String, text: &str, theme: &ColorTheme) { + let styled = self.style_text(text, theme); + self.append_raw(output, &styled); + } } #[derive(Debug)] @@ -218,13 +249,14 @@ impl TerminalRenderer { #[must_use] pub fn render_markdown(&self, markdown: &str) -> String { + let normalized = normalize_nested_fences(markdown); let mut output = String::new(); let mut state = RenderState::default(); let mut code_language = String::new(); let mut code_buffer = String::new(); let mut in_code_block = false; - for event in Parser::new_ext(markdown, Options::all()) { + for event in Parser::new_ext(&normalized, Options::all()) { self.render_event( event, &mut state, @@ -238,6 +270,11 @@ impl TerminalRenderer { output.trim_end().to_string() } + #[must_use] + pub fn markdown_to_ansi(&self, markdown: &str) -> String { + self.render_markdown(markdown) + } + #[allow(clippy::too_many_lines)] fn render_event( &self, @@ -249,15 +286,21 @@ impl TerminalRenderer { in_code_block: &mut bool, ) { match event { - Event::Start(Tag::Heading { level, .. }) => self.start_heading(level as u8, output), - Event::End(TagEnd::Heading(..) | TagEnd::Paragraph) => output.push_str("\n\n"), + Event::Start(Tag::Heading { level, .. }) => { + Self::start_heading(state, level as u8, output); + } + Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), Event::End(TagEnd::BlockQuote(..)) => { state.quote = state.quote.saturating_sub(1); output.push('\n'); } + Event::End(TagEnd::Heading(..)) => { + state.heading_level = None; + output.push_str("\n\n"); + } Event::End(TagEnd::Item) | Event::SoftBreak | Event::HardBreak => { - state.capture_target_mut(output).push('\n'); + state.append_raw(output, "\n"); } Event::Start(Tag::List(first_item)) => { let kind = match first_item { @@ -293,41 +336,52 @@ impl TerminalRenderer { Event::Code(code) => { let rendered = format!("{}", format!("`{code}`").with(self.color_theme.inline_code)); - state.capture_target_mut(output).push_str(&rendered); + state.append_raw(output, &rendered); } Event::Rule => output.push_str("---\n"), Event::Text(text) => { self.push_text(text.as_ref(), state, output, code_buffer, *in_code_block); } Event::Html(html) | Event::InlineHtml(html) => { - state.capture_target_mut(output).push_str(&html); + state.append_raw(output, &html); } Event::FootnoteReference(reference) => { - let _ = write!(state.capture_target_mut(output), "[{reference}]"); + state.append_raw(output, &format!("[{reference}]")); } Event::TaskListMarker(done) => { - state - .capture_target_mut(output) - .push_str(if done { "[x] " } else { "[ ] " }); + state.append_raw(output, if done { "[x] " } else { "[ ] " }); } Event::InlineMath(math) | Event::DisplayMath(math) => { - state.capture_target_mut(output).push_str(&math); + state.append_raw(output, &math); } Event::Start(Tag::Link { dest_url, .. }) => { - let rendered = format!( - "{}", - format!("[{dest_url}]") - .underlined() - .with(self.color_theme.link) - ); - state.capture_target_mut(output).push_str(&rendered); + state.link_stack.push(LinkState { + destination: dest_url.to_string(), + text: String::new(), + }); + } + Event::End(TagEnd::Link) => { + if let Some(link) = state.link_stack.pop() { + let label = if link.text.is_empty() { + link.destination.clone() + } else { + link.text + }; + let rendered = format!( + "{}", + format!("[{label}]({})", link.destination) + .underlined() + .with(self.color_theme.link) + ); + state.append_raw(output, &rendered); + } } Event::Start(Tag::Image { dest_url, .. }) => { let rendered = format!( "{}", format!("[image:{dest_url}]").with(self.color_theme.link) ); - state.capture_target_mut(output).push_str(&rendered); + state.append_raw(output, &rendered); } Event::Start(Tag::Table(..)) => state.table = Some(TableState::default()), Event::End(TagEnd::Table) => { @@ -369,19 +423,15 @@ impl TerminalRenderer { } } Event::Start(Tag::Paragraph | Tag::MetadataBlock(..) | _) - | Event::End(TagEnd::Link | TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {} + | Event::End(TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {} } } - fn start_heading(&self, level: u8, output: &mut String) { - output.push('\n'); - let prefix = match level { - 1 => "# ", - 2 => "## ", - 3 => "### ", - _ => "#### ", - }; - let _ = write!(output, "{}", prefix.bold().with(self.color_theme.heading)); + fn start_heading(state: &mut RenderState, level: u8, output: &mut String) { + state.heading_level = Some(level); + if !output.is_empty() { + output.push('\n'); + } } fn start_quote(&self, state: &mut RenderState, output: &mut String) { @@ -405,20 +455,27 @@ impl TerminalRenderer { } fn start_code_block(&self, code_language: &str, output: &mut String) { - if !code_language.is_empty() { - let _ = writeln!( - output, - "{}", - format!("╭─ {code_language}").with(self.color_theme.heading) - ); - } + let label = if code_language.is_empty() { + "code".to_string() + } else { + code_language.to_string() + }; + let _ = writeln!( + output, + "{}", + format!("╭─ {label}") + .bold() + .with(self.color_theme.code_block_border) + ); } fn finish_code_block(&self, code_buffer: &str, code_language: &str, output: &mut String) { output.push_str(&self.highlight_code(code_buffer, code_language)); - if !code_language.is_empty() { - let _ = write!(output, "{}", "╰─".with(self.color_theme.heading)); - } + let _ = write!( + output, + "{}", + "╰─".bold().with(self.color_theme.code_block_border) + ); output.push_str("\n\n"); } @@ -433,8 +490,7 @@ impl TerminalRenderer { if in_code_block { code_buffer.push_str(text); } else { - let rendered = state.style_text(text, &self.color_theme); - state.capture_target_mut(output).push_str(&rendered); + state.append_styled(output, text, &self.color_theme); } } @@ -521,9 +577,10 @@ impl TerminalRenderer { for line in LinesWithEndings::from(code) { match syntax_highlighter.highlight_line(line, &self.syntax_set) { Ok(ranges) => { - colored_output.push_str(&as_24_bit_terminal_escaped(&ranges[..], false)); + let escaped = as_24_bit_terminal_escaped(&ranges[..], false); + colored_output.push_str(&apply_code_block_background(&escaped)); } - Err(_) => colored_output.push_str(line), + Err(_) => colored_output.push_str(&apply_code_block_background(line)), } } @@ -531,16 +588,296 @@ impl TerminalRenderer { } pub fn stream_markdown(&self, markdown: &str, out: &mut impl Write) -> io::Result<()> { - let rendered_markdown = self.render_markdown(markdown); - for chunk in rendered_markdown.split_inclusive(char::is_whitespace) { - write!(out, "{chunk}")?; - out.flush()?; - thread::sleep(Duration::from_millis(8)); + let rendered_markdown = self.markdown_to_ansi(markdown); + write!(out, "{rendered_markdown}")?; + if !rendered_markdown.ends_with('\n') { + writeln!(out)?; } - writeln!(out) + out.flush() } } +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct MarkdownStreamState { + pending: String, +} + +impl MarkdownStreamState { + #[must_use] + pub fn push(&mut self, renderer: &TerminalRenderer, delta: &str) -> Option { + self.pending.push_str(delta); + let split = find_stream_safe_boundary(&self.pending)?; + let ready = self.pending[..split].to_string(); + self.pending.drain(..split); + Some(renderer.markdown_to_ansi(&ready)) + } + + #[must_use] + pub fn flush(&mut self, renderer: &TerminalRenderer) -> Option { + if self.pending.trim().is_empty() { + self.pending.clear(); + None + } else { + let pending = std::mem::take(&mut self.pending); + Some(renderer.markdown_to_ansi(&pending)) + } + } +} + +fn apply_code_block_background(line: &str) -> String { + let trimmed = line.trim_end_matches('\n'); + let trailing_newline = if trimmed.len() == line.len() { + "" + } else { + "\n" + }; + let with_background = trimmed.replace("\u{1b}[0m", "\u{1b}[0;48;5;236m"); + format!("\u{1b}[48;5;236m{with_background}\u{1b}[0m{trailing_newline}") +} + +/// Pre-process raw markdown so that fenced code blocks whose body contains +/// fence markers of equal or greater length are wrapped with a longer fence. +/// +/// LLMs frequently emit triple-backtick code blocks that contain triple-backtick +/// examples. CommonMark (and pulldown-cmark) treats the inner marker as the +/// closing fence, breaking the render. This function detects the situation and +/// upgrades the outer fence to use enough backticks (or tildes) that the inner +/// markers become ordinary content. +fn normalize_nested_fences(markdown: &str) -> String { + // A fence line is either "labeled" (has an info string ⇒ always an opener) + // or "bare" (no info string ⇒ could be opener or closer). + #[derive(Debug, Clone)] + struct FenceLine { + char: char, + len: usize, + has_info: bool, + indent: usize, + } + + fn parse_fence_line(line: &str) -> Option { + let trimmed = line.trim_end_matches('\n').trim_end_matches('\r'); + let indent = trimmed.chars().take_while(|c| *c == ' ').count(); + if indent > 3 { + return None; + } + let rest = &trimmed[indent..]; + let ch = rest.chars().next()?; + if ch != '`' && ch != '~' { + return None; + } + let len = rest.chars().take_while(|c| *c == ch).count(); + if len < 3 { + return None; + } + let after = &rest[len..]; + if ch == '`' && after.contains('`') { + return None; + } + let has_info = !after.trim().is_empty(); + Some(FenceLine { + char: ch, + len, + has_info, + indent, + }) + } + + let lines: Vec<&str> = markdown.split_inclusive('\n').collect(); + // Handle final line that may lack trailing newline. + // split_inclusive already keeps the original chunks, including a + // final chunk without '\n' if the input doesn't end with one. + + // First pass: classify every line. + let fence_info: Vec> = lines.iter().map(|l| parse_fence_line(l)).collect(); + + // Second pass: pair openers with closers using a stack, recording + // (opener_idx, closer_idx) pairs plus the max fence length found between + // them. + struct StackEntry { + line_idx: usize, + fence: FenceLine, + } + + let mut stack: Vec = Vec::new(); + // Paired blocks: (opener_line, closer_line, max_inner_fence_len) + let mut pairs: Vec<(usize, usize, usize)> = Vec::new(); + + for (i, fi) in fence_info.iter().enumerate() { + let Some(fl) = fi else { continue }; + + if fl.has_info { + // Labeled fence ⇒ always an opener. + stack.push(StackEntry { + line_idx: i, + fence: fl.clone(), + }); + } else { + // Bare fence ⇒ try to close the top of the stack if compatible. + let closes_top = stack + .last() + .is_some_and(|top| top.fence.char == fl.char && fl.len >= top.fence.len); + if closes_top { + let opener = stack.pop().unwrap(); + // Find max fence length of any fence line strictly between + // opener and closer (these are the nested fences). + let inner_max = fence_info[opener.line_idx + 1..i] + .iter() + .filter_map(|fi| fi.as_ref().map(|f| f.len)) + .max() + .unwrap_or(0); + pairs.push((opener.line_idx, i, inner_max)); + } else { + // Treat as opener. + stack.push(StackEntry { + line_idx: i, + fence: fl.clone(), + }); + } + } + } + + // Determine which lines need rewriting. A pair needs rewriting when + // its opener length <= max inner fence length. + struct Rewrite { + char: char, + new_len: usize, + indent: usize, + } + let mut rewrites: std::collections::HashMap = std::collections::HashMap::new(); + + for (opener_idx, closer_idx, inner_max) in &pairs { + let opener_fl = fence_info[*opener_idx].as_ref().unwrap(); + if opener_fl.len <= *inner_max { + let new_len = inner_max + 1; + let info_part = { + let trimmed = lines[*opener_idx] + .trim_end_matches('\n') + .trim_end_matches('\r'); + let rest = &trimmed[opener_fl.indent..]; + rest[opener_fl.len..].to_string() + }; + rewrites.insert( + *opener_idx, + Rewrite { + char: opener_fl.char, + new_len, + indent: opener_fl.indent, + }, + ); + let closer_fl = fence_info[*closer_idx].as_ref().unwrap(); + rewrites.insert( + *closer_idx, + Rewrite { + char: closer_fl.char, + new_len, + indent: closer_fl.indent, + }, + ); + // Store info string only in the opener; closer keeps the trailing + // portion which is already handled through the original line. + // Actually, we rebuild both lines from scratch below, including + // the info string for the opener. + let _ = info_part; // consumed in rebuild + } + } + + if rewrites.is_empty() { + return markdown.to_string(); + } + + // Rebuild. + let mut out = String::with_capacity(markdown.len() + rewrites.len() * 4); + for (i, line) in lines.iter().enumerate() { + if let Some(rw) = rewrites.get(&i) { + let fence_str: String = std::iter::repeat(rw.char).take(rw.new_len).collect(); + let indent_str: String = std::iter::repeat(' ').take(rw.indent).collect(); + // Recover the original info string (if any) and trailing newline. + let trimmed = line.trim_end_matches('\n').trim_end_matches('\r'); + let fi = fence_info[i].as_ref().unwrap(); + let info = &trimmed[fi.indent + fi.len..]; + let trailing = &line[trimmed.len()..]; + out.push_str(&indent_str); + out.push_str(&fence_str); + out.push_str(info); + out.push_str(trailing); + } else { + out.push_str(line); + } + } + out +} + +fn find_stream_safe_boundary(markdown: &str) -> Option { + let mut open_fence: Option = None; + let mut last_boundary = None; + + for (offset, line) in markdown.split_inclusive('\n').scan(0usize, |cursor, line| { + let start = *cursor; + *cursor += line.len(); + Some((start, line)) + }) { + let line_without_newline = line.trim_end_matches('\n'); + if let Some(opener) = open_fence { + if line_closes_fence(line_without_newline, opener) { + open_fence = None; + last_boundary = Some(offset + line.len()); + } + continue; + } + + if let Some(opener) = parse_fence_opener(line_without_newline) { + open_fence = Some(opener); + continue; + } + + if line_without_newline.trim().is_empty() { + last_boundary = Some(offset + line.len()); + } + } + + last_boundary +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct FenceMarker { + character: char, + length: usize, +} + +fn parse_fence_opener(line: &str) -> Option { + let indent = line.chars().take_while(|c| *c == ' ').count(); + if indent > 3 { + return None; + } + let rest = &line[indent..]; + let character = rest.chars().next()?; + if character != '`' && character != '~' { + return None; + } + let length = rest.chars().take_while(|c| *c == character).count(); + if length < 3 { + return None; + } + let info_string = &rest[length..]; + if character == '`' && info_string.contains('`') { + return None; + } + Some(FenceMarker { character, length }) +} + +fn line_closes_fence(line: &str, opener: FenceMarker) -> bool { + let indent = line.chars().take_while(|c| *c == ' ').count(); + if indent > 3 { + return false; + } + let rest = &line[indent..]; + let length = rest.chars().take_while(|c| *c == opener.character).count(); + if length < opener.length { + return false; + } + rest[length..].chars().all(|c| c == ' ' || c == '\t') +} + fn visible_width(input: &str) -> usize { strip_ansi(input).chars().count() } @@ -569,7 +906,7 @@ fn strip_ansi(input: &str) -> String { #[cfg(test)] mod tests { - use super::{strip_ansi, Spinner, TerminalRenderer}; + use super::{strip_ansi, MarkdownStreamState, Spinner, TerminalRenderer}; #[test] fn renders_markdown_with_styling_and_lists() { @@ -583,16 +920,28 @@ mod tests { assert!(markdown_output.contains('\u{1b}')); } + #[test] + fn renders_links_as_colored_markdown_labels() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.render_markdown("See [Claw](https://example.com/docs) now."); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("[Claw](https://example.com/docs)")); + assert!(markdown_output.contains('\u{1b}')); + } + #[test] fn highlights_fenced_code_blocks() { let terminal_renderer = TerminalRenderer::new(); let markdown_output = - terminal_renderer.render_markdown("```rust\nfn hi() { println!(\"hi\"); }\n```"); + terminal_renderer.markdown_to_ansi("```rust\nfn hi() { println!(\"hi\"); }\n```"); let plain_text = strip_ansi(&markdown_output); assert!(plain_text.contains("╭─ rust")); assert!(plain_text.contains("fn hi")); assert!(markdown_output.contains('\u{1b}')); + assert!(markdown_output.contains("[48;5;236m")); } #[test] @@ -623,6 +972,80 @@ mod tests { assert!(markdown_output.contains('\u{1b}')); } + #[test] + fn streaming_state_waits_for_complete_blocks() { + let renderer = TerminalRenderer::new(); + let mut state = MarkdownStreamState::default(); + + assert_eq!(state.push(&renderer, "# Heading"), None); + let flushed = state + .push(&renderer, "\n\nParagraph\n\n") + .expect("completed block"); + let plain_text = strip_ansi(&flushed); + assert!(plain_text.contains("Heading")); + assert!(plain_text.contains("Paragraph")); + + assert_eq!(state.push(&renderer, "```rust\nfn main() {}\n"), None); + let code = state + .push(&renderer, "```\n") + .expect("closed code fence flushes"); + assert!(strip_ansi(&code).contains("fn main()")); + } + + #[test] + fn streaming_state_holds_outer_fence_with_nested_inner_fence() { + let renderer = TerminalRenderer::new(); + let mut state = MarkdownStreamState::default(); + + assert_eq!( + state.push(&renderer, "````markdown\n```rust\nfn inner() {}\n"), + None, + "inner triple backticks must not close the outer four-backtick fence" + ); + assert_eq!( + state.push(&renderer, "```\n"), + None, + "closing the inner fence must not flush the outer fence" + ); + let flushed = state + .push(&renderer, "````\n") + .expect("closing the outer four-backtick fence flushes the buffered block"); + let plain_text = strip_ansi(&flushed); + assert!(plain_text.contains("fn inner()")); + assert!(plain_text.contains("```rust")); + } + + #[test] + fn streaming_state_distinguishes_backtick_and_tilde_fences() { + let renderer = TerminalRenderer::new(); + let mut state = MarkdownStreamState::default(); + + assert_eq!(state.push(&renderer, "~~~text\n"), None); + assert_eq!( + state.push(&renderer, "```\nstill inside tilde fence\n"), + None, + "a backtick fence cannot close a tilde-opened fence" + ); + assert_eq!(state.push(&renderer, "```\n"), None); + let flushed = state + .push(&renderer, "~~~\n") + .expect("matching tilde marker closes the fence"); + let plain_text = strip_ansi(&flushed); + assert!(plain_text.contains("still inside tilde fence")); + } + + #[test] + fn renders_nested_fenced_code_block_preserves_inner_markers() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.markdown_to_ansi("````markdown\n```rust\nfn nested() {}\n```\n````"); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("╭─ markdown")); + assert!(plain_text.contains("```rust")); + assert!(plain_text.contains("fn nested()")); + } + #[test] fn spinner_advances_frames() { let terminal_renderer = TerminalRenderer::new(); diff --git a/crates/rusty-claude-cli/tests/cli_flags_and_config_defaults.rs b/crates/rusty-claude-cli/tests/cli_flags_and_config_defaults.rs new file mode 100644 index 0000000..21a93e2 --- /dev/null +++ b/crates/rusty-claude-cli/tests/cli_flags_and_config_defaults.rs @@ -0,0 +1,298 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use runtime::Session; + +static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0); + +#[test] +fn status_command_applies_model_and_permission_mode_flags() { + // given + let temp_dir = unique_temp_dir("status-flags"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + + // when + let output = Command::new(env!("CARGO_BIN_EXE_claw")) + .current_dir(&temp_dir) + .args([ + "--model", + "sonnet", + "--permission-mode", + "read-only", + "status", + ]) + .output() + .expect("claw should launch"); + + // then + assert_success(&output); + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Status")); + assert!(stdout.contains("Model claude-sonnet-4-6")); + assert!(stdout.contains("Permission mode read-only")); + + fs::remove_dir_all(temp_dir).expect("cleanup temp dir"); +} + +#[test] +fn resume_flag_loads_a_saved_session_and_dispatches_status() { + // given + let temp_dir = unique_temp_dir("resume-status"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = write_session(&temp_dir, "resume-status"); + + // when + let output = Command::new(env!("CARGO_BIN_EXE_claw")) + .current_dir(&temp_dir) + .args([ + "--resume", + session_path.to_str().expect("utf8 path"), + "/status", + ]) + .output() + .expect("claw should launch"); + + // then + assert_success(&output); + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Status")); + assert!(stdout.contains("Messages 1")); + assert!(stdout.contains("Session ")); + assert!(stdout.contains(session_path.to_str().expect("utf8 path"))); + + fs::remove_dir_all(temp_dir).expect("cleanup temp dir"); +} + +#[test] +fn slash_command_names_match_known_commands_and_suggest_nearby_unknown_ones() { + // given + let temp_dir = unique_temp_dir("slash-dispatch"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + + // when + let help_output = Command::new(env!("CARGO_BIN_EXE_claw")) + .current_dir(&temp_dir) + .arg("/help") + .output() + .expect("claw should launch"); + let unknown_output = Command::new(env!("CARGO_BIN_EXE_claw")) + .current_dir(&temp_dir) + .arg("/zstats") + .output() + .expect("claw should launch"); + + // then + assert_success(&help_output); + let help_stdout = String::from_utf8(help_output.stdout).expect("stdout should be utf8"); + assert!(help_stdout.contains("Interactive slash commands:")); + assert!(help_stdout.contains("/status")); + + assert!( + !unknown_output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&unknown_output.stdout), + String::from_utf8_lossy(&unknown_output.stderr) + ); + let stderr = String::from_utf8(unknown_output.stderr).expect("stderr should be utf8"); + assert!(stderr.contains("unknown slash command outside the REPL: /zstats")); + assert!(stderr.contains("Did you mean")); + assert!(stderr.contains("/status")); + + fs::remove_dir_all(temp_dir).expect("cleanup temp dir"); +} + +#[test] +fn omc_namespaced_slash_commands_surface_a_targeted_compatibility_hint() { + let temp_dir = unique_temp_dir("slash-dispatch-omc"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + + let output = Command::new(env!("CARGO_BIN_EXE_claw")) + .current_dir(&temp_dir) + .arg("/oh-my-claudecode:hud") + .output() + .expect("claw should launch"); + + assert!( + !output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + let stderr = String::from_utf8(output.stderr).expect("stderr should be utf8"); + assert!(stderr.contains("unknown slash command outside the REPL: /oh-my-claudecode:hud")); + assert!(stderr.contains("Claude Code/OMC plugin command")); + assert!(stderr.contains("does not yet load plugin slash commands")); + + fs::remove_dir_all(temp_dir).expect("cleanup temp dir"); +} + +#[test] +fn config_command_loads_defaults_from_standard_config_locations() { + // given + let temp_dir = unique_temp_dir("config-defaults"); + let config_home = temp_dir.join("home").join(".claw"); + fs::create_dir_all(temp_dir.join(".claw")).expect("project config dir should exist"); + fs::create_dir_all(&config_home).expect("home config dir should exist"); + + fs::write(config_home.join("settings.json"), r#"{"model":"haiku"}"#) + .expect("write user settings"); + fs::write(temp_dir.join(".claw.json"), r#"{"model":"sonnet"}"#) + .expect("write project settings"); + fs::write( + temp_dir.join(".claw").join("settings.local.json"), + r#"{"model":"opus"}"#, + ) + .expect("write local settings"); + let session_path = write_session(&temp_dir, "config-defaults"); + + // when + let output = command_in(&temp_dir) + .env("CLAW_CONFIG_HOME", &config_home) + .args([ + "--resume", + session_path.to_str().expect("utf8 path"), + "/config", + "model", + ]) + .output() + .expect("claw should launch"); + + // then + assert_success(&output); + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Config")); + assert!(stdout.contains("Loaded files 3")); + assert!(stdout.contains("Merged section: model")); + assert!(stdout.contains("opus")); + assert!(stdout.contains( + config_home + .join("settings.json") + .to_str() + .expect("utf8 path") + )); + assert!(stdout.contains(temp_dir.join(".claw.json").to_str().expect("utf8 path"))); + assert!(stdout.contains( + temp_dir + .join(".claw") + .join("settings.local.json") + .to_str() + .expect("utf8 path") + )); + + fs::remove_dir_all(temp_dir).expect("cleanup temp dir"); +} + +#[test] +fn doctor_command_runs_as_a_local_shell_entrypoint() { + // given + let temp_dir = unique_temp_dir("doctor-entrypoint"); + let config_home = temp_dir.join("home").join(".claw"); + fs::create_dir_all(&config_home).expect("config home should exist"); + + // when + let output = command_in(&temp_dir) + .env("CLAW_CONFIG_HOME", &config_home) + .env_remove("ANTHROPIC_API_KEY") + .env_remove("ANTHROPIC_AUTH_TOKEN") + .env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9") + .arg("doctor") + .output() + .expect("claw doctor should launch"); + + // then + assert_success(&output); + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Doctor")); + assert!(stdout.contains("Auth")); + assert!(stdout.contains("Config")); + assert!(stdout.contains("Workspace")); + assert!(stdout.contains("Sandbox")); + assert!(!stdout.contains("Thinking")); + + fs::remove_dir_all(temp_dir).expect("cleanup temp dir"); +} + +#[test] +fn local_subcommand_help_does_not_fall_through_to_runtime_or_provider_calls() { + let temp_dir = unique_temp_dir("subcommand-help"); + let config_home = temp_dir.join("home").join(".claw"); + fs::create_dir_all(&config_home).expect("config home should exist"); + + let doctor_help = command_in(&temp_dir) + .env("CLAW_CONFIG_HOME", &config_home) + .env_remove("ANTHROPIC_API_KEY") + .env_remove("ANTHROPIC_AUTH_TOKEN") + .env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9") + .args(["doctor", "--help"]) + .output() + .expect("doctor help should launch"); + let status_help = command_in(&temp_dir) + .env("CLAW_CONFIG_HOME", &config_home) + .env_remove("ANTHROPIC_API_KEY") + .env_remove("ANTHROPIC_AUTH_TOKEN") + .env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9") + .args(["status", "--help"]) + .output() + .expect("status help should launch"); + + assert_success(&doctor_help); + let doctor_stdout = String::from_utf8(doctor_help.stdout).expect("stdout should be utf8"); + assert!(doctor_stdout.contains("Usage claw doctor")); + assert!(doctor_stdout.contains("local-only health report")); + assert!(!doctor_stdout.contains("Thinking")); + + assert_success(&status_help); + let status_stdout = String::from_utf8(status_help.stdout).expect("stdout should be utf8"); + assert!(status_stdout.contains("Usage claw status")); + assert!(status_stdout.contains("local workspace snapshot")); + assert!(!status_stdout.contains("Thinking")); + + let doctor_stderr = String::from_utf8(doctor_help.stderr).expect("stderr should be utf8"); + let status_stderr = String::from_utf8(status_help.stderr).expect("stderr should be utf8"); + assert!(!doctor_stderr.contains("auth_unavailable")); + assert!(!status_stderr.contains("auth_unavailable")); + + fs::remove_dir_all(temp_dir).expect("cleanup temp dir"); +} + +fn command_in(cwd: &Path) -> Command { + let mut command = Command::new(env!("CARGO_BIN_EXE_claw")); + command.current_dir(cwd); + command +} + +fn write_session(root: &Path, label: &str) -> PathBuf { + let session_path = root.join(format!("{label}.jsonl")); + let mut session = Session::new(); + session + .push_user_text(format!("session fixture for {label}")) + .expect("session write should succeed"); + session + .save_to_path(&session_path) + .expect("session should persist"); + session_path +} + +fn assert_success(output: &Output) { + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); +} + +fn unique_temp_dir(label: &str) -> PathBuf { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_millis(); + let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!( + "claw-{label}-{}-{millis}-{counter}", + std::process::id() + )) +} diff --git a/crates/rusty-claude-cli/tests/compact_output.rs b/crates/rusty-claude-cli/tests/compact_output.rs new file mode 100644 index 0000000..456862f --- /dev/null +++ b/crates/rusty-claude-cli/tests/compact_output.rs @@ -0,0 +1,159 @@ +use std::fs; +use std::path::PathBuf; +use std::process::{Command, Output}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use mock_anthropic_service::{MockAnthropicService, SCENARIO_PREFIX}; + +static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0); + +#[test] +fn compact_flag_prints_only_final_assistant_text_without_tool_call_details() { + // given a workspace pointed at the mock Anthropic service and a fixture file + // that the read_file_roundtrip scenario will fetch through a tool call + let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build"); + let server = runtime + .block_on(MockAnthropicService::spawn()) + .expect("mock service should start"); + let base_url = server.base_url(); + + let workspace = unique_temp_dir("compact-read-file"); + let config_home = workspace.join("config-home"); + let home = workspace.join("home"); + fs::create_dir_all(&workspace).expect("workspace should exist"); + fs::create_dir_all(&config_home).expect("config home should exist"); + fs::create_dir_all(&home).expect("home should exist"); + fs::write(workspace.join("fixture.txt"), "alpha parity line\n").expect("fixture should write"); + + // when we run claw in compact text mode against a tool-using scenario + let prompt = format!("{SCENARIO_PREFIX}read_file_roundtrip"); + let output = run_claw( + &workspace, + &config_home, + &home, + &base_url, + &[ + "--model", + "sonnet", + "--permission-mode", + "read-only", + "--allowedTools", + "read_file", + "--compact", + &prompt, + ], + ); + + // then the command exits successfully and stdout contains exactly the final + // assistant text with no tool call IDs, JSON envelopes, or spinner output + assert!( + output.status.success(), + "compact run should succeed\nstdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr), + ); + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + let trimmed = stdout.trim_end_matches('\n'); + assert_eq!( + trimmed, "read_file roundtrip complete: alpha parity line", + "compact stdout should contain only the final assistant text" + ); + assert!( + !stdout.contains("toolu_"), + "compact stdout must not leak tool_use_id ({stdout:?})" + ); + assert!( + !stdout.contains("\"tool_uses\""), + "compact stdout must not leak json envelopes ({stdout:?})" + ); + assert!( + !stdout.contains("Thinking"), + "compact stdout must not include the spinner banner ({stdout:?})" + ); + + fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed"); +} + +#[test] +fn compact_flag_streaming_text_only_emits_final_message_text() { + // given a workspace pointed at the mock Anthropic service running the + // streaming_text scenario which only emits a single assistant text block + let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build"); + let server = runtime + .block_on(MockAnthropicService::spawn()) + .expect("mock service should start"); + let base_url = server.base_url(); + + let workspace = unique_temp_dir("compact-streaming-text"); + let config_home = workspace.join("config-home"); + let home = workspace.join("home"); + fs::create_dir_all(&workspace).expect("workspace should exist"); + fs::create_dir_all(&config_home).expect("config home should exist"); + fs::create_dir_all(&home).expect("home should exist"); + + // when we invoke claw with --compact for the streaming text scenario + let prompt = format!("{SCENARIO_PREFIX}streaming_text"); + let output = run_claw( + &workspace, + &config_home, + &home, + &base_url, + &[ + "--model", + "sonnet", + "--permission-mode", + "read-only", + "--compact", + &prompt, + ], + ); + + // then stdout should be exactly the assistant text followed by a newline + assert!( + output.status.success(), + "compact streaming run should succeed\nstdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr), + ); + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert_eq!( + stdout, "Mock streaming says hello from the parity harness.\n", + "compact streaming stdout should contain only the final assistant text" + ); + + fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed"); +} + +fn run_claw( + cwd: &std::path::Path, + config_home: &std::path::Path, + home: &std::path::Path, + base_url: &str, + args: &[&str], +) -> Output { + let mut command = Command::new(env!("CARGO_BIN_EXE_claw")); + command + .current_dir(cwd) + .env_clear() + .env("ANTHROPIC_API_KEY", "test-compact-key") + .env("ANTHROPIC_BASE_URL", base_url) + .env("CLAW_CONFIG_HOME", config_home) + .env("HOME", home) + .env("NO_COLOR", "1") + .env("PATH", "/usr/bin:/bin") + .args(args); + command.output().expect("claw should launch") +} + +fn unique_temp_dir(label: &str) -> PathBuf { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_millis(); + let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!( + "claw-compact-{label}-{}-{millis}-{counter}", + std::process::id() + )) +} diff --git a/crates/rusty-claude-cli/tests/mock_parity_harness.rs b/crates/rusty-claude-cli/tests/mock_parity_harness.rs new file mode 100644 index 0000000..8e8cbf9 --- /dev/null +++ b/crates/rusty-claude-cli/tests/mock_parity_harness.rs @@ -0,0 +1,884 @@ +#![cfg(unix)] +use std::collections::BTreeMap; +use std::fs; +use std::io::Write; +use std::os::unix::fs::PermissionsExt; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output, Stdio}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use mock_anthropic_service::{MockAnthropicService, SCENARIO_PREFIX}; +use serde_json::{json, Value}; + +static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0); + +#[test] +#[allow(clippy::too_many_lines)] +fn clean_env_cli_reaches_mock_anthropic_service_across_scripted_parity_scenarios() { + let manifest_entries = load_scenario_manifest(); + let manifest = manifest_entries + .iter() + .cloned() + .map(|entry| (entry.name.clone(), entry)) + .collect::>(); + let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build"); + let server = runtime + .block_on(MockAnthropicService::spawn()) + .expect("mock service should start"); + let base_url = server.base_url(); + + let cases = [ + ScenarioCase { + name: "streaming_text", + permission_mode: "read-only", + allowed_tools: None, + stdin: None, + prepare: prepare_noop, + assert: assert_streaming_text, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "read_file_roundtrip", + permission_mode: "read-only", + allowed_tools: Some("read_file"), + stdin: None, + prepare: prepare_read_fixture, + assert: assert_read_file_roundtrip, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "grep_chunk_assembly", + permission_mode: "read-only", + allowed_tools: Some("grep_search"), + stdin: None, + prepare: prepare_grep_fixture, + assert: assert_grep_chunk_assembly, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "write_file_allowed", + permission_mode: "workspace-write", + allowed_tools: Some("write_file"), + stdin: None, + prepare: prepare_noop, + assert: assert_write_file_allowed, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "write_file_denied", + permission_mode: "read-only", + allowed_tools: Some("write_file"), + stdin: None, + prepare: prepare_noop, + assert: assert_write_file_denied, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "multi_tool_turn_roundtrip", + permission_mode: "read-only", + allowed_tools: Some("read_file,grep_search"), + stdin: None, + prepare: prepare_multi_tool_fixture, + assert: assert_multi_tool_turn_roundtrip, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "bash_stdout_roundtrip", + permission_mode: "danger-full-access", + allowed_tools: Some("bash"), + stdin: None, + prepare: prepare_noop, + assert: assert_bash_stdout_roundtrip, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "bash_permission_prompt_approved", + permission_mode: "workspace-write", + allowed_tools: Some("bash"), + stdin: Some("y\n"), + prepare: prepare_noop, + assert: assert_bash_permission_prompt_approved, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "bash_permission_prompt_denied", + permission_mode: "workspace-write", + allowed_tools: Some("bash"), + stdin: Some("n\n"), + prepare: prepare_noop, + assert: assert_bash_permission_prompt_denied, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "plugin_tool_roundtrip", + permission_mode: "workspace-write", + allowed_tools: None, + stdin: None, + prepare: prepare_plugin_fixture, + assert: assert_plugin_tool_roundtrip, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "auto_compact_triggered", + permission_mode: "read-only", + allowed_tools: None, + stdin: None, + prepare: prepare_noop, + assert: assert_auto_compact_triggered, + extra_env: None, + resume_session: None, + }, + ScenarioCase { + name: "token_cost_reporting", + permission_mode: "read-only", + allowed_tools: None, + stdin: None, + prepare: prepare_noop, + assert: assert_token_cost_reporting, + extra_env: None, + resume_session: None, + }, + ]; + + let case_names = cases.iter().map(|case| case.name).collect::>(); + let manifest_names = manifest_entries + .iter() + .map(|entry| entry.name.as_str()) + .collect::>(); + assert_eq!( + case_names, manifest_names, + "manifest and harness cases must stay aligned" + ); + + let mut scenario_reports = Vec::new(); + + for case in cases { + let workspace = HarnessWorkspace::new(unique_temp_dir(case.name)); + workspace.create().expect("workspace should exist"); + (case.prepare)(&workspace); + + let run = run_case(case, &workspace, &base_url); + (case.assert)(&workspace, &run); + + let manifest_entry = manifest + .get(case.name) + .unwrap_or_else(|| panic!("missing manifest entry for {}", case.name)); + scenario_reports.push(build_scenario_report( + case.name, + manifest_entry, + &run.response, + )); + + fs::remove_dir_all(&workspace.root).expect("workspace cleanup should succeed"); + } + + let captured = runtime.block_on(server.captured_requests()); + // After `be561bf` added count_tokens preflight, each turn sends an + // extra POST to `/v1/messages/count_tokens` before the messages POST. + // The original count (21) assumed messages-only requests. We now + // filter to `/v1/messages` and verify that subset matches the original + // scenario expectation. + let messages_only: Vec<_> = captured + .iter() + .filter(|r| r.path == "/v1/messages") + .collect(); + assert_eq!( + messages_only.len(), + 21, + "twelve scenarios should produce twenty-one /v1/messages requests (total captured: {}, includes count_tokens)", + captured.len() + ); + assert!(messages_only.iter().all(|request| request.stream)); + + let scenarios = messages_only + .iter() + .map(|request| request.scenario.as_str()) + .collect::>(); + assert_eq!( + scenarios, + vec![ + "streaming_text", + "read_file_roundtrip", + "read_file_roundtrip", + "grep_chunk_assembly", + "grep_chunk_assembly", + "write_file_allowed", + "write_file_allowed", + "write_file_denied", + "write_file_denied", + "multi_tool_turn_roundtrip", + "multi_tool_turn_roundtrip", + "bash_stdout_roundtrip", + "bash_stdout_roundtrip", + "bash_permission_prompt_approved", + "bash_permission_prompt_approved", + "bash_permission_prompt_denied", + "bash_permission_prompt_denied", + "plugin_tool_roundtrip", + "plugin_tool_roundtrip", + "auto_compact_triggered", + "token_cost_reporting", + ] + ); + + let mut request_counts = BTreeMap::new(); + for request in &captured { + *request_counts + .entry(request.scenario.as_str()) + .or_insert(0_usize) += 1; + } + for report in &mut scenario_reports { + report.request_count = *request_counts + .get(report.name.as_str()) + .unwrap_or_else(|| panic!("missing request count for {}", report.name)); + } + + maybe_write_report(&scenario_reports); +} + +#[derive(Clone, Copy)] +struct ScenarioCase { + name: &'static str, + permission_mode: &'static str, + allowed_tools: Option<&'static str>, + stdin: Option<&'static str>, + prepare: fn(&HarnessWorkspace), + assert: fn(&HarnessWorkspace, &ScenarioRun), + extra_env: Option<(&'static str, &'static str)>, + resume_session: Option<&'static str>, +} + +struct HarnessWorkspace { + root: PathBuf, + config_home: PathBuf, + home: PathBuf, +} + +impl HarnessWorkspace { + fn new(root: PathBuf) -> Self { + Self { + config_home: root.join("config-home"), + home: root.join("home"), + root, + } + } + + fn create(&self) -> std::io::Result<()> { + fs::create_dir_all(&self.root)?; + fs::create_dir_all(&self.config_home)?; + fs::create_dir_all(&self.home)?; + Ok(()) + } +} + +struct ScenarioRun { + response: Value, + stdout: String, +} + +#[derive(Debug, Clone)] +struct ScenarioManifestEntry { + name: String, + category: String, + description: String, + parity_refs: Vec, +} + +#[derive(Debug)] +struct ScenarioReport { + name: String, + category: String, + description: String, + parity_refs: Vec, + iterations: u64, + request_count: usize, + tool_uses: Vec, + tool_error_count: usize, + final_message: String, +} + +fn run_case(case: ScenarioCase, workspace: &HarnessWorkspace, base_url: &str) -> ScenarioRun { + let mut command = Command::new(env!("CARGO_BIN_EXE_claw")); + command + .current_dir(&workspace.root) + .env_clear() + .env("ANTHROPIC_API_KEY", "test-parity-key") + .env("ANTHROPIC_BASE_URL", base_url) + .env("CLAW_CONFIG_HOME", &workspace.config_home) + .env("HOME", &workspace.home) + .env("NO_COLOR", "1") + .env("PATH", "/usr/bin:/bin") + .args([ + "--model", + "sonnet", + "--permission-mode", + case.permission_mode, + "--output-format=json", + ]); + + if let Some(allowed_tools) = case.allowed_tools { + command.args(["--allowedTools", allowed_tools]); + } + if let Some((key, value)) = case.extra_env { + command.env(key, value); + } + if let Some(session_id) = case.resume_session { + command.args(["--resume", session_id]); + } + + let prompt = format!("{SCENARIO_PREFIX}{}", case.name); + command.arg(prompt); + + let output = if let Some(stdin) = case.stdin { + let mut child = command + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("claw should launch"); + child + .stdin + .as_mut() + .expect("stdin should be piped") + .write_all(stdin.as_bytes()) + .expect("stdin should write"); + child.wait_with_output().expect("claw should finish") + } else { + command.output().expect("claw should launch") + }; + + assert_success(&output); + let stdout = String::from_utf8_lossy(&output.stdout).into_owned(); + ScenarioRun { + response: parse_json_output(&stdout), + stdout, + } +} + +#[allow(dead_code)] +fn prepare_auto_compact_fixture(workspace: &HarnessWorkspace) { + let sessions_dir = workspace.root.join(".claw").join("sessions"); + fs::create_dir_all(&sessions_dir).expect("sessions dir should exist"); + + // Write a pre-seeded session with 6 messages so auto-compact can remove them + let session_id = "parity-auto-compact-seed"; + let session_jsonl = r#"{"type":"session_meta","version":3,"session_id":"parity-auto-compact-seed","created_at_ms":1743724800000,"updated_at_ms":1743724800000} +{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step one of the parity scenario"}]}} +{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step one"}]}} +{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step two of the parity scenario"}]}} +{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step two"}]}} +{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step three of the parity scenario"}]}} +{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step three"}]}} +"#; + fs::write( + sessions_dir.join(format!("{session_id}.jsonl")), + session_jsonl, + ) + .expect("pre-seeded session should write"); +} + +fn prepare_noop(_: &HarnessWorkspace) {} + +fn prepare_read_fixture(workspace: &HarnessWorkspace) { + fs::write(workspace.root.join("fixture.txt"), "alpha parity line\n") + .expect("fixture should write"); +} + +fn prepare_grep_fixture(workspace: &HarnessWorkspace) { + fs::write( + workspace.root.join("fixture.txt"), + "alpha parity line\nbeta line\ngamma parity line\n", + ) + .expect("grep fixture should write"); +} + +fn prepare_multi_tool_fixture(workspace: &HarnessWorkspace) { + fs::write( + workspace.root.join("fixture.txt"), + "alpha parity line\nbeta line\ngamma parity line\n", + ) + .expect("multi tool fixture should write"); +} + +fn prepare_plugin_fixture(workspace: &HarnessWorkspace) { + let plugin_root = workspace + .root + .join("external-plugins") + .join("parity-plugin"); + let tool_dir = plugin_root.join("tools"); + let manifest_dir = plugin_root.join(".claude-plugin"); + fs::create_dir_all(&tool_dir).expect("plugin tools dir"); + fs::create_dir_all(&manifest_dir).expect("plugin manifest dir"); + + let script_path = tool_dir.join("echo-json.sh"); + fs::write( + &script_path, + "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_PLUGIN_ID\" \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n", + ) + .expect("plugin script should write"); + let mut permissions = fs::metadata(&script_path) + .expect("plugin script metadata") + .permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("plugin script should be executable"); + + fs::write( + manifest_dir.join("plugin.json"), + r#"{ + "name": "parity-plugin", + "version": "1.0.0", + "description": "mock parity plugin", + "tools": [ + { + "name": "plugin_echo", + "description": "Echo JSON input", + "inputSchema": { + "type": "object", + "properties": { + "message": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + }, + "command": "./tools/echo-json.sh", + "requiredPermission": "workspace-write" + } + ] +}"#, + ) + .expect("plugin manifest should write"); + + fs::write( + workspace.config_home.join("settings.json"), + json!({ + "enabledPlugins": { + "parity-plugin@external": true + }, + "plugins": { + "externalDirectories": [plugin_root.parent().expect("plugin parent").display().to_string()] + } + }) + .to_string(), + ) + .expect("plugin settings should write"); +} + +fn assert_streaming_text(_: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!( + run.response["message"], + Value::String("Mock streaming says hello from the parity harness.".to_string()) + ); + assert_eq!(run.response["iterations"], Value::from(1)); + assert_eq!(run.response["tool_uses"], Value::Array(Vec::new())); + assert_eq!(run.response["tool_results"], Value::Array(Vec::new())); +} + +fn assert_read_file_roundtrip(workspace: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(2)); + assert_eq!( + run.response["tool_uses"][0]["name"], + Value::String("read_file".to_string()) + ); + assert_eq!( + run.response["tool_uses"][0]["input"], + Value::String(r#"{"path":"fixture.txt"}"#.to_string()) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("alpha parity line")); + let output = run.response["tool_results"][0]["output"] + .as_str() + .expect("tool output"); + assert!(output.contains(&workspace.root.join("fixture.txt").display().to_string())); + assert!(output.contains("alpha parity line")); +} + +fn assert_grep_chunk_assembly(_: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(2)); + assert_eq!( + run.response["tool_uses"][0]["name"], + Value::String("grep_search".to_string()) + ); + assert_eq!( + run.response["tool_uses"][0]["input"], + Value::String( + r#"{"pattern":"parity","path":"fixture.txt","output_mode":"count"}"#.to_string() + ) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("2 occurrences")); + assert_eq!( + run.response["tool_results"][0]["is_error"], + Value::Bool(false) + ); +} + +fn assert_write_file_allowed(workspace: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(2)); + assert_eq!( + run.response["tool_uses"][0]["name"], + Value::String("write_file".to_string()) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("generated/output.txt")); + let generated = workspace.root.join("generated").join("output.txt"); + let contents = fs::read_to_string(&generated).expect("generated file should exist"); + assert_eq!(contents, "created by mock service\n"); + assert_eq!( + run.response["tool_results"][0]["is_error"], + Value::Bool(false) + ); +} + +fn assert_write_file_denied(workspace: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(2)); + assert_eq!( + run.response["tool_uses"][0]["name"], + Value::String("write_file".to_string()) + ); + let tool_output = run.response["tool_results"][0]["output"] + .as_str() + .expect("tool output"); + assert!(tool_output.contains("requires workspace-write permission")); + assert_eq!( + run.response["tool_results"][0]["is_error"], + Value::Bool(true) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("denied as expected")); + assert!(!workspace.root.join("generated").join("denied.txt").exists()); +} + +fn assert_multi_tool_turn_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(2)); + let tool_uses = run.response["tool_uses"] + .as_array() + .expect("tool uses array"); + assert_eq!( + tool_uses.len(), + 2, + "expected two tool uses in a single turn" + ); + assert_eq!(tool_uses[0]["name"], Value::String("read_file".to_string())); + assert_eq!( + tool_uses[1]["name"], + Value::String("grep_search".to_string()) + ); + let tool_results = run.response["tool_results"] + .as_array() + .expect("tool results array"); + assert_eq!( + tool_results.len(), + 2, + "expected two tool results in a single turn" + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("alpha parity line")); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("2 occurrences")); +} + +fn assert_bash_stdout_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(2)); + assert_eq!( + run.response["tool_uses"][0]["name"], + Value::String("bash".to_string()) + ); + let tool_output = run.response["tool_results"][0]["output"] + .as_str() + .expect("tool output"); + let parsed: Value = serde_json::from_str(tool_output).expect("bash output json"); + assert_eq!( + parsed["stdout"], + Value::String("alpha from bash".to_string()) + ); + assert_eq!( + run.response["tool_results"][0]["is_error"], + Value::Bool(false) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("alpha from bash")); +} + +fn assert_bash_permission_prompt_approved(_: &HarnessWorkspace, run: &ScenarioRun) { + assert!(run.stdout.contains("Permission approval required")); + assert!(run.stdout.contains("Approve this tool call? [y/N]:")); + assert_eq!(run.response["iterations"], Value::from(2)); + assert_eq!( + run.response["tool_results"][0]["is_error"], + Value::Bool(false) + ); + let tool_output = run.response["tool_results"][0]["output"] + .as_str() + .expect("tool output"); + let parsed: Value = serde_json::from_str(tool_output).expect("bash output json"); + assert_eq!( + parsed["stdout"], + Value::String("approved via prompt".to_string()) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("approved and executed")); +} + +fn assert_bash_permission_prompt_denied(_: &HarnessWorkspace, run: &ScenarioRun) { + assert!(run.stdout.contains("Permission approval required")); + assert!(run.stdout.contains("Approve this tool call? [y/N]:")); + assert_eq!(run.response["iterations"], Value::from(2)); + let tool_output = run.response["tool_results"][0]["output"] + .as_str() + .expect("tool output"); + assert!(tool_output.contains("denied by user approval prompt")); + assert_eq!( + run.response["tool_results"][0]["is_error"], + Value::Bool(true) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("denied as expected")); +} + +fn assert_plugin_tool_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(2)); + assert_eq!( + run.response["tool_uses"][0]["name"], + Value::String("plugin_echo".to_string()) + ); + let tool_output = run.response["tool_results"][0]["output"] + .as_str() + .expect("tool output"); + let parsed: Value = serde_json::from_str(tool_output).expect("plugin output json"); + assert_eq!( + parsed["plugin"], + Value::String("parity-plugin@external".to_string()) + ); + assert_eq!(parsed["tool"], Value::String("plugin_echo".to_string())); + assert_eq!( + parsed["input"]["message"], + Value::String("hello from plugin parity".to_string()) + ); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("hello from plugin parity")); +} + +fn assert_auto_compact_triggered(_: &HarnessWorkspace, run: &ScenarioRun) { + // Validates that the auto_compaction field is present in JSON output (format parity). + // Trigger behavior is covered by conversation::tests::auto_compacts_when_cumulative_input_threshold_is_crossed. + assert_eq!(run.response["iterations"], Value::from(1)); + assert_eq!(run.response["tool_uses"], Value::Array(Vec::new())); + assert!( + run.response["message"] + .as_str() + .expect("message text") + .contains("auto compact parity complete."), + "expected auto compact message in response" + ); + // auto_compaction key must be present in JSON (may be null for below-threshold sessions) + assert!( + run.response + .as_object() + .expect("response object") + .contains_key("auto_compaction"), + "auto_compaction key must be present in JSON output" + ); + // Verify input_tokens field reflects the large mock token counts + let input_tokens = run.response["usage"]["input_tokens"] + .as_u64() + .expect("input_tokens should be present"); + assert!( + input_tokens >= 50_000, + "input_tokens should reflect mock service value (got {input_tokens})" + ); +} + +fn assert_token_cost_reporting(_: &HarnessWorkspace, run: &ScenarioRun) { + assert_eq!(run.response["iterations"], Value::from(1)); + assert!(run.response["message"] + .as_str() + .expect("message text") + .contains("token cost reporting parity complete."),); + let usage = &run.response["usage"]; + assert!( + usage["input_tokens"].as_u64().unwrap_or(0) > 0, + "input_tokens should be non-zero" + ); + assert!( + usage["output_tokens"].as_u64().unwrap_or(0) > 0, + "output_tokens should be non-zero" + ); + assert!( + run.response["estimated_cost"] + .as_str() + .is_some_and(|cost| cost.starts_with('$')), + "estimated_cost should be a dollar-prefixed string" + ); +} + +fn parse_json_output(stdout: &str) -> Value { + if let Some(index) = stdout.rfind("{\"auto_compaction\"") { + return serde_json::from_str(&stdout[index..]).unwrap_or_else(|error| { + panic!("failed to parse JSON response from stdout: {error}\n{stdout}") + }); + } + + stdout + .lines() + .rev() + .find_map(|line| { + let trimmed = line.trim(); + if trimmed.starts_with('{') && trimmed.ends_with('}') { + serde_json::from_str(trimmed).ok() + } else { + None + } + }) + .unwrap_or_else(|| panic!("no JSON response line found in stdout:\n{stdout}")) +} + +fn build_scenario_report( + name: &str, + manifest_entry: &ScenarioManifestEntry, + response: &Value, +) -> ScenarioReport { + ScenarioReport { + name: name.to_string(), + category: manifest_entry.category.clone(), + description: manifest_entry.description.clone(), + parity_refs: manifest_entry.parity_refs.clone(), + iterations: response["iterations"] + .as_u64() + .expect("iterations should exist"), + request_count: 0, + tool_uses: response["tool_uses"] + .as_array() + .expect("tool uses array") + .iter() + .filter_map(|value| value["name"].as_str().map(ToOwned::to_owned)) + .collect(), + tool_error_count: response["tool_results"] + .as_array() + .expect("tool results array") + .iter() + .filter(|value| value["is_error"].as_bool().unwrap_or(false)) + .count(), + final_message: response["message"] + .as_str() + .expect("message text") + .to_string(), + } +} + +fn maybe_write_report(reports: &[ScenarioReport]) { + let Some(path) = std::env::var_os("MOCK_PARITY_REPORT_PATH") else { + return; + }; + + let payload = json!({ + "scenario_count": reports.len(), + "request_count": reports.iter().map(|report| report.request_count).sum::(), + "scenarios": reports.iter().map(scenario_report_json).collect::>(), + }); + fs::write( + path, + serde_json::to_vec_pretty(&payload).expect("report json should serialize"), + ) + .expect("report should write"); +} + +fn load_scenario_manifest() -> Vec { + let manifest_path = + Path::new(env!("CARGO_MANIFEST_DIR")).join("../../mock_parity_scenarios.json"); + let manifest = fs::read_to_string(&manifest_path).expect("scenario manifest should exist"); + serde_json::from_str::>(&manifest) + .expect("scenario manifest should parse") + .into_iter() + .map(|entry| ScenarioManifestEntry { + name: entry["name"] + .as_str() + .expect("scenario name should be a string") + .to_string(), + category: entry["category"] + .as_str() + .expect("scenario category should be a string") + .to_string(), + description: entry["description"] + .as_str() + .expect("scenario description should be a string") + .to_string(), + parity_refs: entry["parity_refs"] + .as_array() + .expect("parity refs should be an array") + .iter() + .map(|value| { + value + .as_str() + .expect("parity ref should be a string") + .to_string() + }) + .collect(), + }) + .collect() +} + +fn scenario_report_json(report: &ScenarioReport) -> Value { + json!({ + "name": report.name, + "category": report.category, + "description": report.description, + "parity_refs": report.parity_refs, + "iterations": report.iterations, + "request_count": report.request_count, + "tool_uses": report.tool_uses, + "tool_error_count": report.tool_error_count, + "final_message": report.final_message, + }) +} + +fn assert_success(output: &Output) { + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); +} + +fn unique_temp_dir(label: &str) -> PathBuf { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_millis(); + let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!( + "claw-mock-parity-{label}-{}-{millis}-{counter}", + std::process::id() + )) +} diff --git a/crates/rusty-claude-cli/tests/output_format_contract.rs b/crates/rusty-claude-cli/tests/output_format_contract.rs new file mode 100644 index 0000000..7d28330 --- /dev/null +++ b/crates/rusty-claude-cli/tests/output_format_contract.rs @@ -0,0 +1,429 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde_json::Value; + +static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0); + +#[test] +fn help_emits_json_when_requested() { + let root = unique_temp_dir("help-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let parsed = assert_json_command(&root, &["--output-format", "json", "help"]); + assert_eq!(parsed["kind"], "help"); + assert!(parsed["message"] + .as_str() + .expect("help text") + .contains("Usage:")); +} + +#[test] +fn version_emits_json_when_requested() { + let root = unique_temp_dir("version-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let parsed = assert_json_command(&root, &["--output-format", "json", "version"]); + assert_eq!(parsed["kind"], "version"); + assert_eq!(parsed["version"], env!("CARGO_PKG_VERSION")); +} + +#[test] +fn status_and_sandbox_emit_json_when_requested() { + let root = unique_temp_dir("status-sandbox-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let status = assert_json_command(&root, &["--output-format", "json", "status"]); + assert_eq!(status["kind"], "status"); + assert!(status["workspace"]["cwd"].as_str().is_some()); + + let sandbox = assert_json_command(&root, &["--output-format", "json", "sandbox"]); + assert_eq!(sandbox["kind"], "sandbox"); + assert!(sandbox["filesystem_mode"].as_str().is_some()); +} + +#[test] +fn inventory_commands_emit_structured_json_when_requested() { + let root = unique_temp_dir("inventory-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let isolated_home = root.join("home"); + let isolated_config = root.join("config-home"); + let isolated_codex = root.join("codex-home"); + fs::create_dir_all(&isolated_home).expect("isolated home should exist"); + + let agents = assert_json_command_with_env( + &root, + &["--output-format", "json", "agents"], + &[ + ("HOME", isolated_home.to_str().expect("utf8 home")), + ( + "CLAW_CONFIG_HOME", + isolated_config.to_str().expect("utf8 config home"), + ), + ( + "CODEX_HOME", + isolated_codex.to_str().expect("utf8 codex home"), + ), + ], + ); + assert_eq!(agents["kind"], "agents"); + assert_eq!(agents["action"], "list"); + assert_eq!(agents["count"], 0); + assert_eq!(agents["summary"]["active"], 0); + assert!(agents["agents"] + .as_array() + .expect("agents array") + .is_empty()); + + let mcp = assert_json_command(&root, &["--output-format", "json", "mcp"]); + assert_eq!(mcp["kind"], "mcp"); + assert_eq!(mcp["action"], "list"); + + let skills = assert_json_command(&root, &["--output-format", "json", "skills"]); + assert_eq!(skills["kind"], "skills"); + assert_eq!(skills["action"], "list"); +} + +#[test] +fn agents_command_emits_structured_agent_entries_when_requested() { + let root = unique_temp_dir("agents-json-populated"); + let workspace = root.join("workspace"); + let project_agents = workspace.join(".codex").join("agents"); + let home = root.join("home"); + let user_agents = home.join(".codex").join("agents"); + let isolated_config = root.join("config-home"); + let isolated_codex = root.join("codex-home"); + fs::create_dir_all(&workspace).expect("workspace should exist"); + write_agent( + &project_agents, + "planner", + "Project planner", + "gpt-5.4", + "medium", + ); + write_agent( + &project_agents, + "verifier", + "Verification agent", + "gpt-5.4-mini", + "high", + ); + write_agent( + &user_agents, + "planner", + "User planner", + "gpt-5.4-mini", + "high", + ); + + let parsed = assert_json_command_with_env( + &workspace, + &["--output-format", "json", "agents"], + &[ + ("HOME", home.to_str().expect("utf8 home")), + ( + "CLAW_CONFIG_HOME", + isolated_config.to_str().expect("utf8 config home"), + ), + ( + "CODEX_HOME", + isolated_codex.to_str().expect("utf8 codex home"), + ), + ], + ); + + assert_eq!(parsed["kind"], "agents"); + assert_eq!(parsed["action"], "list"); + assert_eq!(parsed["count"], 3); + assert_eq!(parsed["summary"]["active"], 2); + assert_eq!(parsed["summary"]["shadowed"], 1); + assert_eq!(parsed["agents"][0]["name"], "planner"); + assert_eq!(parsed["agents"][0]["source"]["id"], "project_claw"); + assert_eq!(parsed["agents"][0]["active"], true); + assert_eq!(parsed["agents"][1]["name"], "verifier"); + assert_eq!(parsed["agents"][2]["name"], "planner"); + assert_eq!(parsed["agents"][2]["active"], false); + assert_eq!(parsed["agents"][2]["shadowed_by"]["id"], "project_claw"); +} + +#[test] +fn bootstrap_and_system_prompt_emit_json_when_requested() { + let root = unique_temp_dir("bootstrap-system-prompt-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let plan = assert_json_command(&root, &["--output-format", "json", "bootstrap-plan"]); + assert_eq!(plan["kind"], "bootstrap-plan"); + assert!(plan["phases"].as_array().expect("phases").len() > 1); + + let prompt = assert_json_command(&root, &["--output-format", "json", "system-prompt"]); + assert_eq!(prompt["kind"], "system-prompt"); + assert!(prompt["message"] + .as_str() + .expect("prompt text") + .contains("interactive agent")); +} + +#[test] +fn dump_manifests_and_init_emit_json_when_requested() { + let root = unique_temp_dir("manifest-init-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let upstream = write_upstream_fixture(&root); + let manifests = assert_json_command_with_env( + &root, + &["--output-format", "json", "dump-manifests"], + &[( + "CLAUDE_CODE_UPSTREAM", + upstream.to_str().expect("utf8 upstream"), + )], + ); + assert_eq!(manifests["kind"], "dump-manifests"); + assert_eq!(manifests["commands"], 1); + assert_eq!(manifests["tools"], 1); + + let workspace = root.join("workspace"); + fs::create_dir_all(&workspace).expect("workspace should exist"); + let init = assert_json_command(&workspace, &["--output-format", "json", "init"]); + assert_eq!(init["kind"], "init"); + assert!(workspace.join("CLAUDE.md").exists()); +} + +#[test] +fn doctor_and_resume_status_emit_json_when_requested() { + let root = unique_temp_dir("doctor-resume-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let doctor = assert_json_command(&root, &["--output-format", "json", "doctor"]); + assert_eq!(doctor["kind"], "doctor"); + assert!(doctor["message"].is_string()); + let summary = doctor["summary"].as_object().expect("doctor summary"); + assert!(summary["ok"].as_u64().is_some()); + assert!(summary["warnings"].as_u64().is_some()); + assert!(summary["failures"].as_u64().is_some()); + + let checks = doctor["checks"].as_array().expect("doctor checks"); + assert_eq!(checks.len(), 5); + let check_names = checks + .iter() + .map(|check| { + assert!(check["status"].as_str().is_some()); + assert!(check["summary"].as_str().is_some()); + assert!(check["details"].is_array()); + check["name"].as_str().expect("doctor check name") + }) + .collect::>(); + assert_eq!( + check_names, + vec!["auth", "config", "workspace", "sandbox", "system"] + ); + + let workspace = checks + .iter() + .find(|check| check["name"] == "workspace") + .expect("workspace check"); + assert!(workspace["cwd"].as_str().is_some()); + assert!(workspace["in_git_repo"].is_boolean()); + + let sandbox = checks + .iter() + .find(|check| check["name"] == "sandbox") + .expect("sandbox check"); + assert!(sandbox["filesystem_mode"].as_str().is_some()); + assert!(sandbox["enabled"].is_boolean()); + assert!(sandbox["fallback_reason"].is_null() || sandbox["fallback_reason"].is_string()); + + let session_path = root.join("session.jsonl"); + fs::write( + &session_path, + "{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n{\"type\":\"message\",\"message\":{\"role\":\"user\",\"blocks\":[{\"type\":\"text\",\"text\":\"hello\"}]}}\n", + ) + .expect("session should write"); + let resumed = assert_json_command( + &root, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 session path"), + "/status", + ], + ); + assert_eq!(resumed["kind"], "status"); + // model is null in resume mode (not known without --model flag) + assert!(resumed["model"].is_null()); + assert_eq!(resumed["usage"]["messages"], 1); + assert!(resumed["workspace"]["cwd"].as_str().is_some()); + assert!(resumed["sandbox"]["filesystem_mode"].as_str().is_some()); +} + +#[test] +fn resumed_inventory_commands_emit_structured_json_when_requested() { + let root = unique_temp_dir("resume-inventory-json"); + let config_home = root.join("config-home"); + let home = root.join("home"); + fs::create_dir_all(&config_home).expect("config home should exist"); + fs::create_dir_all(&home).expect("home should exist"); + + let session_path = root.join("session.jsonl"); + fs::write( + &session_path, + "{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-inventory-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n{\"type\":\"message\",\"message\":{\"role\":\"user\",\"blocks\":[{\"type\":\"text\",\"text\":\"inventory\"}]}}\n", + ) + .expect("session should write"); + + let mcp = assert_json_command_with_env( + &root, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 session path"), + "/mcp", + ], + &[ + ( + "CLAW_CONFIG_HOME", + config_home.to_str().expect("utf8 config home"), + ), + ("HOME", home.to_str().expect("utf8 home")), + ], + ); + assert_eq!(mcp["kind"], "mcp"); + assert_eq!(mcp["action"], "list"); + assert!(mcp["servers"].is_array()); + + let skills = assert_json_command_with_env( + &root, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 session path"), + "/skills", + ], + &[ + ( + "CLAW_CONFIG_HOME", + config_home.to_str().expect("utf8 config home"), + ), + ("HOME", home.to_str().expect("utf8 home")), + ], + ); + assert_eq!(skills["kind"], "skills"); + assert_eq!(skills["action"], "list"); + assert!(skills["summary"]["total"].is_number()); + assert!(skills["skills"].is_array()); +} + +#[test] +fn resumed_version_and_init_emit_structured_json_when_requested() { + let root = unique_temp_dir("resume-version-init-json"); + fs::create_dir_all(&root).expect("temp dir should exist"); + + let session_path = root.join("session.jsonl"); + fs::write( + &session_path, + "{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-version-init-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n", + ) + .expect("session should write"); + + let version = assert_json_command( + &root, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 session path"), + "/version", + ], + ); + assert_eq!(version["kind"], "version"); + assert_eq!(version["version"], env!("CARGO_PKG_VERSION")); + + let init = assert_json_command( + &root, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 session path"), + "/init", + ], + ); + assert_eq!(init["kind"], "init"); + assert!(root.join("CLAUDE.md").exists()); +} + +fn assert_json_command(current_dir: &Path, args: &[&str]) -> Value { + assert_json_command_with_env(current_dir, args, &[]) +} + +fn assert_json_command_with_env(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Value { + let output = run_claw(current_dir, args, envs); + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + serde_json::from_slice(&output.stdout).expect("stdout should be valid json") +} + +fn run_claw(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output { + let mut command = Command::new(env!("CARGO_BIN_EXE_claw")); + command.current_dir(current_dir).args(args); + for (key, value) in envs { + command.env(key, value); + } + command.output().expect("claw should launch") +} + +fn write_upstream_fixture(root: &Path) -> PathBuf { + let upstream = root.join("claw-code"); + let src = upstream.join("src"); + let entrypoints = src.join("entrypoints"); + fs::create_dir_all(&entrypoints).expect("upstream entrypoints dir should exist"); + fs::write( + src.join("commands.ts"), + "import FooCommand from './commands/foo'\n", + ) + .expect("commands fixture should write"); + fs::write( + src.join("tools.ts"), + "import ReadTool from './tools/read'\n", + ) + .expect("tools fixture should write"); + fs::write( + entrypoints.join("cli.tsx"), + "if (args[0] === '--version') {}\nstartupProfiler()\n", + ) + .expect("cli fixture should write"); + upstream +} + +fn write_agent(root: &Path, name: &str, description: &str, model: &str, reasoning: &str) { + fs::create_dir_all(root).expect("agent root should exist"); + fs::write( + root.join(format!("{name}.toml")), + format!( + "name = \"{name}\"\ndescription = \"{description}\"\nmodel = \"{model}\"\nmodel_reasoning_effort = \"{reasoning}\"\n" + ), + ) + .expect("agent fixture should write"); +} + +fn unique_temp_dir(label: &str) -> PathBuf { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_millis(); + let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!( + "claw-output-format-{label}-{}-{millis}-{counter}", + std::process::id() + )) +} diff --git a/crates/rusty-claude-cli/tests/resume_slash_commands.rs b/crates/rusty-claude-cli/tests/resume_slash_commands.rs new file mode 100644 index 0000000..556cbdb --- /dev/null +++ b/crates/rusty-claude-cli/tests/resume_slash_commands.rs @@ -0,0 +1,555 @@ +use std::fs; +use std::path::Path; +use std::path::PathBuf; +use std::process::{Command, Output}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use runtime::ContentBlock; +use runtime::Session; +use serde_json::Value; + +static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0); + +#[test] +fn resumed_binary_accepts_slash_commands_with_arguments() { + // given + let temp_dir = unique_temp_dir("resume-slash-commands"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + + let session_path = temp_dir.join("session.jsonl"); + let export_path = temp_dir.join("notes.txt"); + + let mut session = Session::new(); + session + .push_user_text("ship the slash command harness") + .expect("session write should succeed"); + session + .save_to_path(&session_path) + .expect("session should persist"); + + // when + let output = run_claw( + &temp_dir, + &[ + "--resume", + session_path.to_str().expect("utf8 path"), + "/export", + export_path.to_str().expect("utf8 path"), + "/clear", + "--confirm", + ], + ); + + // then + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Export")); + assert!(stdout.contains("wrote transcript")); + assert!(stdout.contains(export_path.to_str().expect("utf8 path"))); + assert!(stdout.contains("Session cleared")); + assert!(stdout.contains("Mode resumed session reset")); + assert!(stdout.contains("Previous session")); + assert!(stdout.contains("Resume previous claw --resume")); + assert!(stdout.contains("Backup ")); + assert!(stdout.contains("Session file ")); + + let export = fs::read_to_string(&export_path).expect("export file should exist"); + assert!(export.contains("# Conversation Export")); + assert!(export.contains("ship the slash command harness")); + + let restored = Session::load_from_path(&session_path).expect("cleared session should load"); + assert!(restored.messages.is_empty()); + + let backup_path = stdout + .lines() + .find_map(|line| line.strip_prefix(" Backup ")) + .map(PathBuf::from) + .expect("clear output should include backup path"); + let backup = Session::load_from_path(&backup_path).expect("backup session should load"); + assert_eq!(backup.messages.len(), 1); + assert!(matches!( + backup.messages[0].blocks.first(), + Some(ContentBlock::Text { text }) if text == "ship the slash command harness" + )); +} + +#[test] +fn status_command_applies_cli_flags_end_to_end() { + // given + let temp_dir = unique_temp_dir("status-command-flags"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + + // when + let output = run_claw( + &temp_dir, + &[ + "--model", + "sonnet", + "--permission-mode", + "read-only", + "status", + ], + ); + + // then + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Status")); + assert!(stdout.contains("Model claude-sonnet-4-6")); + assert!(stdout.contains("Permission mode read-only")); +} + +#[test] +fn resumed_config_command_loads_settings_files_end_to_end() { + // given + let temp_dir = unique_temp_dir("resume-config"); + let project_dir = temp_dir.join("project"); + let config_home = temp_dir.join("home").join(".claw"); + fs::create_dir_all(project_dir.join(".claw")).expect("project config dir should exist"); + fs::create_dir_all(&config_home).expect("config home should exist"); + + let session_path = project_dir.join("session.jsonl"); + Session::new() + .with_persistence_path(&session_path) + .save_to_path(&session_path) + .expect("session should persist"); + + fs::write(config_home.join("settings.json"), r#"{"model":"haiku"}"#) + .expect("user config should write"); + fs::write( + project_dir.join(".claw").join("settings.local.json"), + r#"{"model":"opus"}"#, + ) + .expect("local config should write"); + + // when + let output = run_claw_with_env( + &project_dir, + &[ + "--resume", + session_path.to_str().expect("utf8 path"), + "/config", + "model", + ], + &[("CLAW_CONFIG_HOME", config_home.to_str().expect("utf8 path"))], + ); + + // then + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Config")); + assert!(stdout.contains("Loaded files 2")); + assert!(stdout.contains( + config_home + .join("settings.json") + .to_str() + .expect("utf8 path") + )); + assert!(stdout.contains( + project_dir + .join(".claw") + .join("settings.local.json") + .to_str() + .expect("utf8 path") + )); + assert!(stdout.contains("Merged section: model")); + assert!(stdout.contains("opus")); +} + +#[test] +fn resume_latest_restores_the_most_recent_managed_session() { + // given + let temp_dir = unique_temp_dir("resume-latest"); + let project_dir = temp_dir.join("project"); + let sessions_dir = project_dir.join(".claw").join("sessions"); + fs::create_dir_all(&sessions_dir).expect("sessions dir should exist"); + + let older_path = sessions_dir.join("session-older.jsonl"); + let newer_path = sessions_dir.join("session-newer.jsonl"); + + let mut older = Session::new().with_persistence_path(&older_path); + older + .push_user_text("older session") + .expect("older session write should succeed"); + older + .save_to_path(&older_path) + .expect("older session should persist"); + + let mut newer = Session::new().with_persistence_path(&newer_path); + newer + .push_user_text("newer session") + .expect("newer session write should succeed"); + newer + .push_user_text("resume me") + .expect("newer session write should succeed"); + newer + .save_to_path(&newer_path) + .expect("newer session should persist"); + + // when + let output = run_claw(&project_dir, &["--resume", "latest", "/status"]); + + // then + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + assert!(stdout.contains("Status")); + assert!(stdout.contains("Messages 2")); + assert!(stdout.contains(newer_path.to_str().expect("utf8 path"))); +} + +#[test] +fn resumed_status_command_emits_structured_json_when_requested() { + // given + let temp_dir = unique_temp_dir("resume-status-json"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + + let mut session = Session::new(); + session + .push_user_text("resume status json fixture") + .expect("session write should succeed"); + session + .save_to_path(&session_path) + .expect("session should persist"); + + // when + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + "/status", + ], + ); + + // then + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + let parsed: Value = + serde_json::from_str(stdout.trim()).expect("resume status output should be json"); + assert_eq!(parsed["kind"], "status"); + // model is null in resume mode (not known without --model flag) + assert!(parsed["model"].is_null()); + assert_eq!(parsed["permission_mode"], "danger-full-access"); + assert_eq!(parsed["usage"]["messages"], 1); + assert!(parsed["usage"]["turns"].is_number()); + assert!(parsed["workspace"]["cwd"].as_str().is_some()); + assert_eq!( + parsed["workspace"]["session"], + session_path.to_str().expect("utf8 path") + ); + assert!(parsed["workspace"]["changed_files"].is_number()); + assert_eq!(parsed["workspace"]["loaded_config_files"].as_u64(), Some(0)); + assert!(parsed["sandbox"]["filesystem_mode"].as_str().is_some()); +} + +#[test] +fn resumed_status_surfaces_persisted_model() { + // given — create a session with model already set + let temp_dir = unique_temp_dir("resume-status-model"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + + let mut session = Session::new(); + session.model = Some("claude-sonnet-4-6".to_string()); + session + .push_user_text("model persistence fixture") + .expect("write ok"); + session.save_to_path(&session_path).expect("persist ok"); + + // when + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + "/status", + ], + ); + + // then + assert!( + output.status.success(), + "stderr:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8(output.stdout).expect("utf8"); + let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json"); + assert_eq!(parsed["kind"], "status"); + assert_eq!( + parsed["model"], "claude-sonnet-4-6", + "model should round-trip through session metadata" + ); +} + +#[test] +fn resumed_sandbox_command_emits_structured_json_when_requested() { + // given + let temp_dir = unique_temp_dir("resume-sandbox-json"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + + Session::new() + .save_to_path(&session_path) + .expect("session should persist"); + + // when + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + "/sandbox", + ], + ); + + // then + assert!( + output.status.success(), + "stdout:\n{}\n\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + + let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8"); + let parsed: Value = + serde_json::from_str(stdout.trim()).expect("resume sandbox output should be json"); + assert_eq!(parsed["kind"], "sandbox"); + assert!(parsed["enabled"].is_boolean()); + assert!(parsed["active"].is_boolean()); + assert!(parsed["supported"].is_boolean()); + assert!(parsed["filesystem_mode"].as_str().is_some()); + assert!(parsed["allowed_mounts"].is_array()); + assert!(parsed["markers"].is_array()); +} + +#[test] +fn resumed_version_command_emits_structured_json() { + let temp_dir = unique_temp_dir("resume-version-json"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + Session::new() + .save_to_path(&session_path) + .expect("session should persist"); + + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + "/version", + ], + ); + + assert!( + output.status.success(), + "stderr:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8(output.stdout).expect("utf8"); + let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json"); + assert_eq!(parsed["kind"], "version"); + assert!(parsed["version"].as_str().is_some()); + assert!(parsed["git_sha"].as_str().is_some()); + assert!(parsed["target"].as_str().is_some()); +} + +#[test] +fn resumed_export_command_emits_structured_json() { + let temp_dir = unique_temp_dir("resume-export-json"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + let mut session = Session::new(); + session + .push_user_text("export json fixture") + .expect("write ok"); + session.save_to_path(&session_path).expect("persist ok"); + + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + "/export", + ], + ); + + assert!( + output.status.success(), + "stderr:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8(output.stdout).expect("utf8"); + let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json"); + assert_eq!(parsed["kind"], "export"); + assert!(parsed["file"].as_str().is_some()); + assert_eq!(parsed["message_count"], 1); +} + +#[test] +fn resumed_help_command_emits_structured_json() { + let temp_dir = unique_temp_dir("resume-help-json"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + Session::new() + .save_to_path(&session_path) + .expect("persist ok"); + + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + "/help", + ], + ); + + assert!( + output.status.success(), + "stderr:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8(output.stdout).expect("utf8"); + let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json"); + assert_eq!(parsed["kind"], "help"); + assert!(parsed["text"].as_str().is_some()); + let text = parsed["text"].as_str().unwrap(); + assert!(text.contains("/status"), "help text should list /status"); +} + +#[test] +fn resumed_no_command_emits_restored_json() { + let temp_dir = unique_temp_dir("resume-no-cmd-json"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + let mut session = Session::new(); + session + .push_user_text("restored json fixture") + .expect("write ok"); + session.save_to_path(&session_path).expect("persist ok"); + + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + ], + ); + + assert!( + output.status.success(), + "stderr:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + let stdout = String::from_utf8(output.stdout).expect("utf8"); + let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json"); + assert_eq!(parsed["kind"], "restored"); + assert!(parsed["session_id"].as_str().is_some()); + assert!(parsed["path"].as_str().is_some()); + assert_eq!(parsed["message_count"], 1); +} + +#[test] +fn resumed_stub_command_emits_not_implemented_json() { + let temp_dir = unique_temp_dir("resume-stub-json"); + fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + let session_path = temp_dir.join("session.jsonl"); + Session::new() + .save_to_path(&session_path) + .expect("persist ok"); + + let output = run_claw( + &temp_dir, + &[ + "--output-format", + "json", + "--resume", + session_path.to_str().expect("utf8 path"), + "/allowed-tools", + ], + ); + + // Stub commands exit with code 2 + assert!(!output.status.success()); + let stderr = String::from_utf8(output.stderr).expect("utf8"); + let parsed: Value = serde_json::from_str(stderr.trim()).expect("should be json"); + assert_eq!(parsed["type"], "error"); + assert!( + parsed["error"] + .as_str() + .unwrap() + .contains("not yet implemented"), + "error should say not yet implemented: {:?}", + parsed["error"] + ); +} + +fn run_claw(current_dir: &Path, args: &[&str]) -> Output { + run_claw_with_env(current_dir, args, &[]) +} + +fn run_claw_with_env(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output { + let mut command = Command::new(env!("CARGO_BIN_EXE_claw")); + command.current_dir(current_dir).args(args); + for (key, value) in envs { + command.env(key, value); + } + command.output().expect("claw should launch") +} + +fn unique_temp_dir(label: &str) -> PathBuf { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("clock should be after epoch") + .as_millis(); + let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!( + "claw-{label}-{}-{millis}-{counter}", + std::process::id() + )) +} diff --git a/crates/server/src/lib.rs b/crates/server/src/lib.rs index 25a3b7c..dfc1736 100644 --- a/crates/server/src/lib.rs +++ b/crates/server/src/lib.rs @@ -20,7 +20,7 @@ use runtime::{ ToolError, ToolExecutor, }; use api::{ - max_tokens_for_model, resolve_startup_auth_source, AuthSource, ClawApiClient, + max_tokens_for_model, resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, ToolResultContentBlock, @@ -104,7 +104,7 @@ impl SessionEvent { // ── ServerApiClient:实现 runtime::ApiClient trait ──────────────────── struct ServerApiClient { - client: ClawApiClient, + client: AnthropicClient, model: String, tool_registry: GlobalToolRegistry, allowed_tools: Option>, @@ -127,6 +127,12 @@ impl ApiClient for ServerApiClient { .then(|| self.tool_registry.definitions(self.allowed_tools.as_ref())), tool_choice: self.enable_tools.then_some(ToolChoice::Auto), stream: true, + temperature: None, + top_p: None, + frequency_penalty: None, + presence_penalty: None, + stop: None, + reasoning_effort: None, }; let rt = tokio::runtime::Runtime::new() @@ -192,7 +198,6 @@ impl ApiClient for ServerApiClient { session_id: self.session_id.clone(), thinking: thinking.clone(), }); - events.push(AssistantEvent::ThinkingDelta(thinking)); } } ContentBlockDelta::SignatureDelta { .. } => {} @@ -303,7 +308,6 @@ fn push_output_block( session_id: session_id.to_string(), thinking: thinking.clone(), }); - events.push(AssistantEvent::ThinkingDelta(thinking)); } } OutputContentBlock::RedactedThinking { .. } => {} @@ -412,17 +416,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { ContentBlock::Text { text } => { InputContentBlock::Text { text: text.clone() } } - ContentBlock::Thinking { - thinking, - signature, - } => InputContentBlock::Thinking { - thinking: thinking.clone(), - signature: signature.clone(), - }, - ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking { - data: serde_json::from_str(&data.render()) - .unwrap_or(serde_json::Value::Null), - }, ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse { id: id.clone(), name: name.clone(), @@ -454,6 +447,7 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy { tool_registry .permission_specs(None) + .unwrap_or_default() .into_iter() .fold(PermissionPolicy::new(mode), |policy, (name, req)| { policy.with_tool_requirement(name, req) @@ -632,7 +626,7 @@ impl Session { let (events_tx, _) = broadcast::channel(BROADCAST_CAPACITY); let auth = resolve_server_auth_source(cwd)?; - let client = ClawApiClient::from_auth(auth).with_base_url(api::read_base_url()); + let client = AnthropicClient::from_auth(auth).with_base_url(api::read_base_url()); let api_client = ServerApiClient { client, @@ -658,7 +652,7 @@ impl Session { tool_executor, policy, system_prompt, - feature_config.clone(), + &feature_config, ); Ok(Self { diff --git a/crates/telemetry/Cargo.toml b/crates/telemetry/Cargo.toml new file mode 100644 index 0000000..d501850 --- /dev/null +++ b/crates/telemetry/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "telemetry" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +[lints] +workspace = true diff --git a/crates/telemetry/src/lib.rs b/crates/telemetry/src/lib.rs new file mode 100644 index 0000000..6e369e1 --- /dev/null +++ b/crates/telemetry/src/lib.rs @@ -0,0 +1,526 @@ +use std::fmt::{Debug, Formatter}; +use std::fs::{File, OpenOptions}; +use std::io::Write; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +pub const DEFAULT_ANTHROPIC_VERSION: &str = "2023-06-01"; +pub const DEFAULT_APP_NAME: &str = "claude-code"; +pub const DEFAULT_RUNTIME: &str = "rust"; +pub const DEFAULT_AGENTIC_BETA: &str = "claude-code-20250219"; +pub const DEFAULT_PROMPT_CACHING_SCOPE_BETA: &str = "prompt-caching-scope-2026-01-05"; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ClientIdentity { + pub app_name: String, + pub app_version: String, + pub runtime: String, +} + +impl ClientIdentity { + #[must_use] + pub fn new(app_name: impl Into, app_version: impl Into) -> Self { + Self { + app_name: app_name.into(), + app_version: app_version.into(), + runtime: DEFAULT_RUNTIME.to_string(), + } + } + + #[must_use] + pub fn with_runtime(mut self, runtime: impl Into) -> Self { + self.runtime = runtime.into(); + self + } + + #[must_use] + pub fn user_agent(&self) -> String { + format!("{}/{}", self.app_name, self.app_version) + } +} + +impl Default for ClientIdentity { + fn default() -> Self { + Self::new(DEFAULT_APP_NAME, env!("CARGO_PKG_VERSION")) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct AnthropicRequestProfile { + pub anthropic_version: String, + pub client_identity: ClientIdentity, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub betas: Vec, + #[serde(default, skip_serializing_if = "Map::is_empty")] + pub extra_body: Map, +} + +impl AnthropicRequestProfile { + #[must_use] + pub fn new(client_identity: ClientIdentity) -> Self { + Self { + anthropic_version: DEFAULT_ANTHROPIC_VERSION.to_string(), + client_identity, + betas: vec![ + DEFAULT_AGENTIC_BETA.to_string(), + DEFAULT_PROMPT_CACHING_SCOPE_BETA.to_string(), + ], + extra_body: Map::new(), + } + } + + #[must_use] + pub fn with_beta(mut self, beta: impl Into) -> Self { + let beta = beta.into(); + if !self.betas.contains(&beta) { + self.betas.push(beta); + } + self + } + + #[must_use] + pub fn with_extra_body(mut self, key: impl Into, value: Value) -> Self { + self.extra_body.insert(key.into(), value); + self + } + + #[must_use] + pub fn header_pairs(&self) -> Vec<(String, String)> { + let mut headers = vec![ + ( + "anthropic-version".to_string(), + self.anthropic_version.clone(), + ), + ("user-agent".to_string(), self.client_identity.user_agent()), + ]; + if !self.betas.is_empty() { + headers.push(("anthropic-beta".to_string(), self.betas.join(","))); + } + headers + } + + pub fn render_json_body(&self, request: &T) -> Result { + let mut body = serde_json::to_value(request)?; + let object = body.as_object_mut().ok_or_else(|| { + serde_json::Error::io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "request body must serialize to a JSON object", + )) + })?; + for (key, value) in &self.extra_body { + object.insert(key.clone(), value.clone()); + } + if !self.betas.is_empty() { + object.insert( + "betas".to_string(), + Value::Array(self.betas.iter().cloned().map(Value::String).collect()), + ); + } + Ok(body) + } +} + +impl Default for AnthropicRequestProfile { + fn default() -> Self { + Self::new(ClientIdentity::default()) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct AnalyticsEvent { + pub namespace: String, + pub action: String, + #[serde(default, skip_serializing_if = "Map::is_empty")] + pub properties: Map, +} + +impl AnalyticsEvent { + #[must_use] + pub fn new(namespace: impl Into, action: impl Into) -> Self { + Self { + namespace: namespace.into(), + action: action.into(), + properties: Map::new(), + } + } + + #[must_use] + pub fn with_property(mut self, key: impl Into, value: Value) -> Self { + self.properties.insert(key.into(), value); + self + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct SessionTraceRecord { + pub session_id: String, + pub sequence: u64, + pub name: String, + pub timestamp_ms: u64, + #[serde(default, skip_serializing_if = "Map::is_empty")] + pub attributes: Map, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum TelemetryEvent { + HttpRequestStarted { + session_id: String, + attempt: u32, + method: String, + path: String, + #[serde(default, skip_serializing_if = "Map::is_empty")] + attributes: Map, + }, + HttpRequestSucceeded { + session_id: String, + attempt: u32, + method: String, + path: String, + status: u16, + #[serde(default, skip_serializing_if = "Option::is_none")] + request_id: Option, + #[serde(default, skip_serializing_if = "Map::is_empty")] + attributes: Map, + }, + HttpRequestFailed { + session_id: String, + attempt: u32, + method: String, + path: String, + error: String, + retryable: bool, + #[serde(default, skip_serializing_if = "Map::is_empty")] + attributes: Map, + }, + Analytics(AnalyticsEvent), + SessionTrace(SessionTraceRecord), +} + +pub trait TelemetrySink: Send + Sync { + fn record(&self, event: TelemetryEvent); +} + +#[derive(Default)] +pub struct MemoryTelemetrySink { + events: Mutex>, +} + +impl MemoryTelemetrySink { + #[must_use] + pub fn events(&self) -> Vec { + self.events + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + } +} + +impl TelemetrySink for MemoryTelemetrySink { + fn record(&self, event: TelemetryEvent) { + self.events + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(event); + } +} + +pub struct JsonlTelemetrySink { + path: PathBuf, + file: Mutex, +} + +impl Debug for JsonlTelemetrySink { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("JsonlTelemetrySink") + .field("path", &self.path) + .finish_non_exhaustive() + } +} + +impl JsonlTelemetrySink { + pub fn new(path: impl AsRef) -> Result { + let path = path.as_ref().to_path_buf(); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let file = OpenOptions::new().create(true).append(true).open(&path)?; + Ok(Self { + path, + file: Mutex::new(file), + }) + } + + #[must_use] + pub fn path(&self) -> &Path { + &self.path + } +} + +impl TelemetrySink for JsonlTelemetrySink { + fn record(&self, event: TelemetryEvent) { + let Ok(line) = serde_json::to_string(&event) else { + return; + }; + let mut file = self + .file + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let _ = writeln!(file, "{line}"); + let _ = file.flush(); + } +} + +#[derive(Clone)] +pub struct SessionTracer { + session_id: String, + sequence: Arc, + sink: Arc, +} + +impl Debug for SessionTracer { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionTracer") + .field("session_id", &self.session_id) + .finish_non_exhaustive() + } +} + +impl SessionTracer { + #[must_use] + pub fn new(session_id: impl Into, sink: Arc) -> Self { + Self { + session_id: session_id.into(), + sequence: Arc::new(AtomicU64::new(0)), + sink, + } + } + + #[must_use] + pub fn session_id(&self) -> &str { + &self.session_id + } + + pub fn record(&self, name: impl Into, attributes: Map) { + let record = SessionTraceRecord { + session_id: self.session_id.clone(), + sequence: self.sequence.fetch_add(1, Ordering::Relaxed), + name: name.into(), + timestamp_ms: current_timestamp_ms(), + attributes, + }; + self.sink.record(TelemetryEvent::SessionTrace(record)); + } + + pub fn record_http_request_started( + &self, + attempt: u32, + method: impl Into, + path: impl Into, + attributes: Map, + ) { + let method = method.into(); + let path = path.into(); + self.sink.record(TelemetryEvent::HttpRequestStarted { + session_id: self.session_id.clone(), + attempt, + method: method.clone(), + path: path.clone(), + attributes: attributes.clone(), + }); + self.record( + "http_request_started", + merge_trace_fields(method, path, attempt, attributes), + ); + } + + pub fn record_http_request_succeeded( + &self, + attempt: u32, + method: impl Into, + path: impl Into, + status: u16, + request_id: Option, + attributes: Map, + ) { + let method = method.into(); + let path = path.into(); + self.sink.record(TelemetryEvent::HttpRequestSucceeded { + session_id: self.session_id.clone(), + attempt, + method: method.clone(), + path: path.clone(), + status, + request_id: request_id.clone(), + attributes: attributes.clone(), + }); + let mut trace_attributes = merge_trace_fields(method, path, attempt, attributes); + trace_attributes.insert("status".to_string(), Value::from(status)); + if let Some(request_id) = request_id { + trace_attributes.insert("request_id".to_string(), Value::String(request_id)); + } + self.record("http_request_succeeded", trace_attributes); + } + + pub fn record_http_request_failed( + &self, + attempt: u32, + method: impl Into, + path: impl Into, + error: impl Into, + retryable: bool, + attributes: Map, + ) { + let method = method.into(); + let path = path.into(); + let error = error.into(); + self.sink.record(TelemetryEvent::HttpRequestFailed { + session_id: self.session_id.clone(), + attempt, + method: method.clone(), + path: path.clone(), + error: error.clone(), + retryable, + attributes: attributes.clone(), + }); + let mut trace_attributes = merge_trace_fields(method, path, attempt, attributes); + trace_attributes.insert("error".to_string(), Value::String(error)); + trace_attributes.insert("retryable".to_string(), Value::Bool(retryable)); + self.record("http_request_failed", trace_attributes); + } + + pub fn record_analytics(&self, event: AnalyticsEvent) { + let mut attributes = event.properties.clone(); + attributes.insert( + "namespace".to_string(), + Value::String(event.namespace.clone()), + ); + attributes.insert("action".to_string(), Value::String(event.action.clone())); + self.sink.record(TelemetryEvent::Analytics(event)); + self.record("analytics", attributes); + } +} + +fn merge_trace_fields( + method: String, + path: String, + attempt: u32, + mut attributes: Map, +) -> Map { + attributes.insert("method".to_string(), Value::String(method)); + attributes.insert("path".to_string(), Value::String(path)); + attributes.insert("attempt".to_string(), Value::from(attempt)); + attributes +} + +fn current_timestamp_ms() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis() + .try_into() + .unwrap_or(u64::MAX) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn request_profile_emits_headers_and_merges_body() { + let profile = AnthropicRequestProfile::new( + ClientIdentity::new("claude-code", "1.2.3").with_runtime("rust-cli"), + ) + .with_beta("tools-2026-04-01") + .with_extra_body("metadata", serde_json::json!({"source": "test"})); + + assert_eq!( + profile.header_pairs(), + vec![ + ( + "anthropic-version".to_string(), + DEFAULT_ANTHROPIC_VERSION.to_string() + ), + ("user-agent".to_string(), "claude-code/1.2.3".to_string()), + ( + "anthropic-beta".to_string(), + "claude-code-20250219,prompt-caching-scope-2026-01-05,tools-2026-04-01" + .to_string(), + ), + ] + ); + + let body = profile + .render_json_body(&serde_json::json!({"model": "claude-sonnet"})) + .expect("body should serialize"); + assert_eq!( + body["metadata"]["source"], + Value::String("test".to_string()) + ); + assert_eq!( + body["betas"], + serde_json::json!([ + "claude-code-20250219", + "prompt-caching-scope-2026-01-05", + "tools-2026-04-01" + ]) + ); + } + + #[test] + fn session_tracer_records_structured_events_and_trace_sequence() { + let sink = Arc::new(MemoryTelemetrySink::default()); + let tracer = SessionTracer::new("session-123", sink.clone()); + + tracer.record_http_request_started(1, "POST", "/v1/messages", Map::new()); + tracer.record_analytics( + AnalyticsEvent::new("cli", "prompt_sent") + .with_property("model", Value::String("claude-opus".to_string())), + ); + + let events = sink.events(); + assert!(matches!( + &events[0], + TelemetryEvent::HttpRequestStarted { + session_id, + attempt: 1, + method, + path, + .. + } if session_id == "session-123" && method == "POST" && path == "/v1/messages" + )); + assert!(matches!( + &events[1], + TelemetryEvent::SessionTrace(SessionTraceRecord { sequence: 0, name, .. }) + if name == "http_request_started" + )); + assert!(matches!(&events[2], TelemetryEvent::Analytics(_))); + assert!(matches!( + &events[3], + TelemetryEvent::SessionTrace(SessionTraceRecord { sequence: 1, name, .. }) + if name == "analytics" + )); + } + + #[test] + fn jsonl_sink_persists_events() { + let path = + std::env::temp_dir().join(format!("telemetry-jsonl-{}.log", current_timestamp_ms())); + let sink = JsonlTelemetrySink::new(&path).expect("sink should create file"); + + sink.record(TelemetryEvent::Analytics( + AnalyticsEvent::new("cli", "turn_completed").with_property("ok", Value::Bool(true)), + )); + + let contents = std::fs::read_to_string(&path).expect("telemetry log should be readable"); + assert!(contents.contains("\"type\":\"analytics\"")); + assert!(contents.contains("\"action\":\"turn_completed\"")); + + let _ = std::fs::remove_file(path); + } +} diff --git a/crates/tools/Cargo.toml b/crates/tools/Cargo.toml index 04d738b..86da4e6 100644 --- a/crates/tools/Cargo.toml +++ b/crates/tools/Cargo.toml @@ -7,6 +7,8 @@ publish.workspace = true [dependencies] api = { path = "../api" } +commands = { path = "../commands" } +flate2 = "1" plugins = { path = "../plugins" } runtime = { path = "../runtime" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] } diff --git a/crates/tools/src/lane_completion.rs b/crates/tools/src/lane_completion.rs new file mode 100644 index 0000000..e4eecce --- /dev/null +++ b/crates/tools/src/lane_completion.rs @@ -0,0 +1,181 @@ +//! Lane completion detector — automatically marks lanes as completed when +//! session finishes successfully with green tests and pushed code. +//! +//! This bridges the gap where `LaneContext::completed` was a passive bool +//! that nothing automatically set. Now completion is detected from: +//! - Agent output shows Finished status +//! - No errors/blockers present +//! - Tests passed (green status) +//! - Code pushed (has output file) + +use runtime::{ + evaluate, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, PolicyEngine, PolicyRule, + ReviewStatus, +}; + +use crate::AgentOutput; + +/// Detects if a lane should be automatically marked as completed. +/// +/// Returns `Some(LaneContext)` with `completed = true` if all conditions met, +/// `None` if lane should remain active. +#[allow(dead_code)] +pub(crate) fn detect_lane_completion( + output: &AgentOutput, + test_green: bool, + has_pushed: bool, +) -> Option { + // Must be finished without errors + if output.error.is_some() { + return None; + } + + // Must have finished status + if !output.status.eq_ignore_ascii_case("completed") + && !output.status.eq_ignore_ascii_case("finished") + { + return None; + } + + // Must have no current blocker + if output.current_blocker.is_some() { + return None; + } + + // Must have green tests + if !test_green { + return None; + } + + // Must have pushed code + if !has_pushed { + return None; + } + + // All conditions met — create completed context + Some(LaneContext { + lane_id: output.agent_id.clone(), + green_level: 3, // Workspace green + branch_freshness: std::time::Duration::from_secs(0), + blocker: LaneBlocker::None, + review_status: ReviewStatus::Approved, + diff_scope: runtime::DiffScope::Scoped, + completed: true, + reconciled: false, + }) +} + +/// Evaluates policy actions for a completed lane. +#[allow(dead_code)] +pub(crate) fn evaluate_completed_lane(context: &LaneContext) -> Vec { + let engine = PolicyEngine::new(vec![ + PolicyRule::new( + "closeout-completed-lane", + PolicyCondition::And(vec![ + PolicyCondition::LaneCompleted, + PolicyCondition::GreenAt { level: 3 }, + ]), + PolicyAction::CloseoutLane, + 10, + ), + PolicyRule::new( + "cleanup-completed-session", + PolicyCondition::LaneCompleted, + PolicyAction::CleanupSession, + 5, + ), + ]); + + evaluate(&engine, context) +} + +#[cfg(test)] +mod tests { + use super::*; + use runtime::{DiffScope, LaneBlocker}; + + fn test_output() -> AgentOutput { + AgentOutput { + agent_id: "test-lane-1".to_string(), + name: "Test Agent".to_string(), + description: "Test".to_string(), + subagent_type: None, + model: None, + status: "Finished".to_string(), + output_file: "/tmp/test.output".to_string(), + manifest_file: "/tmp/test.manifest".to_string(), + created_at: "2024-01-01T00:00:00Z".to_string(), + started_at: Some("2024-01-01T00:00:00Z".to_string()), + completed_at: Some("2024-01-01T00:00:00Z".to_string()), + lane_events: vec![], + derived_state: "working".to_string(), + current_blocker: None, + error: None, + } + } + + #[test] + fn detects_completion_when_all_conditions_met() { + let output = test_output(); + let result = detect_lane_completion(&output, true, true); + + assert!(result.is_some()); + let context = result.unwrap(); + assert!(context.completed); + assert_eq!(context.green_level, 3); + assert_eq!(context.blocker, LaneBlocker::None); + } + + #[test] + fn no_completion_when_error_present() { + let mut output = test_output(); + output.error = Some("Build failed".to_string()); + + let result = detect_lane_completion(&output, true, true); + assert!(result.is_none()); + } + + #[test] + fn no_completion_when_not_finished() { + let mut output = test_output(); + output.status = "Running".to_string(); + + let result = detect_lane_completion(&output, true, true); + assert!(result.is_none()); + } + + #[test] + fn no_completion_when_tests_not_green() { + let output = test_output(); + + let result = detect_lane_completion(&output, false, true); + assert!(result.is_none()); + } + + #[test] + fn no_completion_when_not_pushed() { + let output = test_output(); + + let result = detect_lane_completion(&output, true, false); + assert!(result.is_none()); + } + + #[test] + fn evaluate_triggers_closeout_for_completed_lane() { + let context = LaneContext { + lane_id: "completed-lane".to_string(), + green_level: 3, + branch_freshness: std::time::Duration::from_secs(0), + blocker: LaneBlocker::None, + review_status: ReviewStatus::Approved, + diff_scope: DiffScope::Scoped, + completed: true, + reconciled: false, + }; + + let actions = evaluate_completed_lane(&context); + + assert!(actions.contains(&PolicyAction::CloseoutLane)); + assert!(actions.contains(&PolicyAction::CleanupSession)); + } +} diff --git a/crates/tools/src/lib.rs b/crates/tools/src/lib.rs index 5d6c37c..de3c758 100644 --- a/crates/tools/src/lib.rs +++ b/crates/tools/src/lib.rs @@ -4,21 +4,70 @@ use std::process::Command; use std::time::{Duration, Instant}; use api::{ - max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage, - MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, + max_tokens_for_model, resolve_model_alias, ApiError, ContentBlockDelta, InputContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use plugins::PluginTool; use reqwest::blocking::Client; use runtime::{ - edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, - ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, - ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, - RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, + check_freshness, dedupe_superseded_commit_events, edit_file, execute_bash, glob_search, + grep_search, load_system_prompt, + lsp_client::LspRegistry, + mcp_tool_bridge::McpToolRegistry, + permission_enforcer::{EnforcementResult, PermissionEnforcer}, + read_file, + summary_compression::compress_summary_text, + task_registry::TaskRegistry, + team_cron_registry::{CronRegistry, TeamRegistry}, + worker_boot::{WorkerReadySnapshot, WorkerRegistry}, + write_file, ApiClient, ApiRequest, AssistantEvent, BashCommandInput, BashCommandOutput, + BranchFreshness, ConfigLoader, ContentBlock, ConversationMessage, ConversationRuntime, + GrepSearchInput, LaneCommitProvenance, LaneEvent, LaneEventBlocker, LaneEventName, + LaneEventStatus, LaneFailureClass, McpDegradedReport, MessageRole, PermissionMode, + PermissionPolicy, PromptCacheEvent, ProviderFallbackConfig, RuntimeError, Session, TaskPacket, + ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; +/// Global task registry shared across tool invocations within a session. +fn global_lsp_registry() -> &'static LspRegistry { + use std::sync::OnceLock; + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(LspRegistry::new) +} + +fn global_mcp_registry() -> &'static McpToolRegistry { + use std::sync::OnceLock; + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(McpToolRegistry::new) +} + +fn global_team_registry() -> &'static TeamRegistry { + use std::sync::OnceLock; + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(TeamRegistry::new) +} + +fn global_cron_registry() -> &'static CronRegistry { + use std::sync::OnceLock; + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(CronRegistry::new) +} + +fn global_task_registry() -> &'static TaskRegistry { + use std::sync::OnceLock; + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(TaskRegistry::new) +} + +fn global_worker_registry() -> &'static WorkerRegistry { + use std::sync::OnceLock; + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(WorkerRegistry::new) +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ToolManifestEntry { pub name: String, @@ -56,9 +105,19 @@ pub struct ToolSpec { pub required_permission: PermissionMode, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct GlobalToolRegistry { plugin_tools: Vec, + runtime_tools: Vec, + enforcer: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct RuntimeToolDefinition { + pub name: String, + pub description: Option, + pub input_schema: Value, + pub required_permission: PermissionMode, } impl GlobalToolRegistry { @@ -66,6 +125,8 @@ impl GlobalToolRegistry { pub fn builtin() -> Self { Self { plugin_tools: Vec::new(), + runtime_tools: Vec::new(), + enforcer: None, } } @@ -88,10 +149,50 @@ impl GlobalToolRegistry { } } - Ok(Self { plugin_tools }) + Ok(Self { + plugin_tools, + runtime_tools: Vec::new(), + enforcer: None, + }) } - pub fn normalize_allowed_tools(&self, values: &[String]) -> Result>, String> { + pub fn with_runtime_tools( + mut self, + runtime_tools: Vec, + ) -> Result { + let mut seen_names = mvp_tool_specs() + .into_iter() + .map(|spec| spec.name.to_string()) + .chain( + self.plugin_tools + .iter() + .map(|tool| tool.definition().name.clone()), + ) + .collect::>(); + + for tool in &runtime_tools { + if !seen_names.insert(tool.name.clone()) { + return Err(format!( + "runtime tool `{}` conflicts with an existing tool name", + tool.name + )); + } + } + + self.runtime_tools = runtime_tools; + Ok(self) + } + + #[must_use] + pub fn with_enforcer(mut self, enforcer: PermissionEnforcer) -> Self { + self.set_enforcer(enforcer); + self + } + + pub fn normalize_allowed_tools( + &self, + values: &[String], + ) -> Result>, String> { if values.is_empty() { return Ok(None); } @@ -100,7 +201,12 @@ impl GlobalToolRegistry { let canonical_names = builtin_specs .iter() .map(|spec| spec.name.to_string()) - .chain(self.plugin_tools.iter().map(|tool| tool.definition().name.clone())) + .chain( + self.plugin_tools + .iter() + .map(|tool| tool.definition().name.clone()), + ) + .chain(self.runtime_tools.iter().map(|tool| tool.name.clone())) .collect::>(); let mut name_map = canonical_names .iter() @@ -147,47 +253,92 @@ impl GlobalToolRegistry { description: Some(spec.description.to_string()), input_schema: spec.input_schema, }); + let runtime = self + .runtime_tools + .iter() + .filter(|tool| allowed_tools.is_none_or(|allowed| allowed.contains(tool.name.as_str()))) + .map(|tool| ToolDefinition { + name: tool.name.clone(), + description: tool.description.clone(), + input_schema: tool.input_schema.clone(), + }); let plugin = self .plugin_tools .iter() .filter(|tool| { - allowed_tools.is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) + allowed_tools + .is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) }) .map(|tool| ToolDefinition { name: tool.definition().name.clone(), description: tool.definition().description.clone(), input_schema: tool.definition().input_schema.clone(), }); - builtin.chain(plugin).collect() + builtin.chain(runtime).chain(plugin).collect() } - #[must_use] pub fn permission_specs( &self, allowed_tools: Option<&BTreeSet>, - ) -> Vec<(String, PermissionMode)> { + ) -> Result, String> { let builtin = mvp_tool_specs() .into_iter() .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) .map(|spec| (spec.name.to_string(), spec.required_permission)); + let runtime = self + .runtime_tools + .iter() + .filter(|tool| allowed_tools.is_none_or(|allowed| allowed.contains(tool.name.as_str()))) + .map(|tool| (tool.name.clone(), tool.required_permission)); let plugin = self .plugin_tools .iter() .filter(|tool| { - allowed_tools.is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) + allowed_tools + .is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) }) .map(|tool| { - ( - tool.definition().name.clone(), - permission_mode_from_plugin(tool.required_permission()), - ) - }); - builtin.chain(plugin).collect() + permission_mode_from_plugin(tool.required_permission()) + .map(|permission| (tool.definition().name.clone(), permission)) + }) + .collect::, _>>()?; + Ok(builtin.chain(runtime).chain(plugin).collect()) + } + + #[must_use] + pub fn has_runtime_tool(&self, name: &str) -> bool { + self.runtime_tools.iter().any(|tool| tool.name == name) + } + + #[must_use] + pub fn search( + &self, + query: &str, + max_results: usize, + pending_mcp_servers: Option>, + mcp_degraded: Option, + ) -> ToolSearchOutput { + let query = query.trim().to_string(); + let normalized_query = normalize_tool_search_query(&query); + let matches = search_tool_specs(&query, max_results.max(1), &self.searchable_tool_specs()); + + ToolSearchOutput { + matches, + query, + normalized_query, + total_deferred_tools: self.searchable_tool_specs().len(), + pending_mcp_servers, + mcp_degraded, + } + } + + pub fn set_enforcer(&mut self, enforcer: PermissionEnforcer) { + self.enforcer = Some(enforcer); } pub fn execute(&self, name: &str, input: &Value) -> Result { if mvp_tool_specs().iter().any(|spec| spec.name == name) { - return execute_tool(name, input); + return execute_tool_with_enforcer(self.enforcer.as_ref(), name, input); } self.plugin_tools .iter() @@ -196,18 +347,36 @@ impl GlobalToolRegistry { .execute(input) .map_err(|error| error.to_string()) } + + fn searchable_tool_specs(&self) -> Vec { + let builtin = deferred_tool_specs() + .into_iter() + .map(|spec| SearchableToolSpec { + name: spec.name.to_string(), + description: spec.description.to_string(), + }); + let runtime = self.runtime_tools.iter().map(|tool| SearchableToolSpec { + name: tool.name.clone(), + description: tool.description.clone().unwrap_or_default(), + }); + let plugin = self.plugin_tools.iter().map(|tool| SearchableToolSpec { + name: tool.definition().name.clone(), + description: tool.definition().description.clone().unwrap_or_default(), + }); + builtin.chain(runtime).chain(plugin).collect() + } } fn normalize_tool_name(value: &str) -> String { value.trim().replace('-', "_").to_ascii_lowercase() } -fn permission_mode_from_plugin(value: &str) -> PermissionMode { +fn permission_mode_from_plugin(value: &str) -> Result { match value { - "read-only" => PermissionMode::ReadOnly, - "workspace-write" => PermissionMode::WorkspaceWrite, - "danger-full-access" => PermissionMode::DangerFullAccess, - other => panic!("unsupported plugin permission: {other}"), + "read-only" => Ok(PermissionMode::ReadOnly), + "workspace-write" => Ok(PermissionMode::WorkspaceWrite), + "danger-full-access" => Ok(PermissionMode::DangerFullAccess), + other => Err(format!("unsupported plugin permission: {other}")), } } @@ -225,7 +394,11 @@ pub fn mvp_tool_specs() -> Vec { "timeout": { "type": "integer", "minimum": 1 }, "description": { "type": "string" }, "run_in_background": { "type": "boolean" }, - "dangerouslyDisableSandbox": { "type": "boolean" } + "dangerouslyDisableSandbox": { "type": "boolean" }, + "namespaceRestrictions": { "type": "boolean" }, + "isolateNetwork": { "type": "boolean" }, + "filesystemMode": { "type": "string", "enum": ["off", "workspace-only", "allow-list"] }, + "allowedMounts": { "type": "array", "items": { "type": "string" } } }, "required": ["command"], "additionalProperties": false @@ -479,7 +652,7 @@ pub fn mvp_tool_specs() -> Vec { }, ToolSpec { name: "Config", - description: "Get or set Claw Code settings.", + description: "Get or set Claude Code settings.", input_schema: json!({ "type": "object", "properties": { @@ -493,6 +666,26 @@ pub fn mvp_tool_specs() -> Vec { }), required_permission: PermissionMode::WorkspaceWrite, }, + ToolSpec { + name: "EnterPlanMode", + description: "Enable a worktree-local planning mode override and remember the previous local setting for ExitPlanMode.", + input_schema: json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }), + required_permission: PermissionMode::WorkspaceWrite, + }, + ToolSpec { + name: "ExitPlanMode", + description: "Restore or clear the worktree-local planning mode override created by EnterPlanMode.", + input_schema: json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }), + required_permission: PermissionMode::WorkspaceWrite, + }, ToolSpec { name: "StructuredOutput", description: "Return structured output in the requested format.", @@ -533,17 +726,485 @@ pub fn mvp_tool_specs() -> Vec { }), required_permission: PermissionMode::DangerFullAccess, }, + ToolSpec { + name: "AskUserQuestion", + description: "Ask the user a question and wait for their response.", + input_schema: json!({ + "type": "object", + "properties": { + "question": { "type": "string" }, + "options": { + "type": "array", + "items": { "type": "string" } + } + }, + "required": ["question"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "TaskCreate", + description: "Create a background task that runs in a separate subprocess.", + input_schema: json!({ + "type": "object", + "properties": { + "prompt": { "type": "string" }, + "description": { "type": "string" } + }, + "required": ["prompt"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "RunTaskPacket", + description: "Create a background task from a structured task packet.", + input_schema: json!({ + "type": "object", + "properties": { + "objective": { "type": "string" }, + "scope": { "type": "string" }, + "repo": { "type": "string" }, + "branch_policy": { "type": "string" }, + "acceptance_tests": { + "type": "array", + "items": { "type": "string" } + }, + "commit_policy": { "type": "string" }, + "reporting_contract": { "type": "string" }, + "escalation_policy": { "type": "string" } + }, + "required": [ + "objective", + "scope", + "repo", + "branch_policy", + "acceptance_tests", + "commit_policy", + "reporting_contract", + "escalation_policy" + ], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "TaskGet", + description: "Get the status and details of a background task by ID.", + input_schema: json!({ + "type": "object", + "properties": { + "task_id": { "type": "string" } + }, + "required": ["task_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "TaskList", + description: "List all background tasks and their current status.", + input_schema: json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "TaskStop", + description: "Stop a running background task by ID.", + input_schema: json!({ + "type": "object", + "properties": { + "task_id": { "type": "string" } + }, + "required": ["task_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "TaskUpdate", + description: "Send a message or update to a running background task.", + input_schema: json!({ + "type": "object", + "properties": { + "task_id": { "type": "string" }, + "message": { "type": "string" } + }, + "required": ["task_id", "message"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "TaskOutput", + description: "Retrieve the output produced by a background task.", + input_schema: json!({ + "type": "object", + "properties": { + "task_id": { "type": "string" } + }, + "required": ["task_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "WorkerCreate", + description: "Create a coding worker boot session with trust-gate and prompt-delivery guards.", + input_schema: json!({ + "type": "object", + "properties": { + "cwd": { "type": "string" }, + "trusted_roots": { + "type": "array", + "items": { "type": "string" } + }, + "auto_recover_prompt_misdelivery": { "type": "boolean" } + }, + "required": ["cwd"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "WorkerGet", + description: "Fetch the current worker boot state, last error, and event history.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" } + }, + "required": ["worker_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "WorkerObserve", + description: "Feed a terminal snapshot into worker boot detection to resolve trust gates, ready handshakes, and prompt misdelivery.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" }, + "screen_text": { "type": "string" } + }, + "required": ["worker_id", "screen_text"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "WorkerResolveTrust", + description: "Resolve a detected trust prompt so worker boot can continue.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" } + }, + "required": ["worker_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "WorkerAwaitReady", + description: "Return the current ready-handshake verdict for a coding worker.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" } + }, + "required": ["worker_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "WorkerSendPrompt", + description: "Send a task prompt only after the worker reaches ready_for_prompt; can replay a recovered prompt.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" }, + "prompt": { "type": "string" } + }, + "required": ["worker_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "WorkerRestart", + description: "Restart worker boot state after a failed or stale startup.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" } + }, + "required": ["worker_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "WorkerTerminate", + description: "Terminate a worker and mark the lane finished from the control plane.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" } + }, + "required": ["worker_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "WorkerObserveCompletion", + description: "Report session completion to the worker, classifying finish_reason into Finished or Failed (provider-degraded). Use after the opencode session completes to advance the worker to its terminal state.", + input_schema: json!({ + "type": "object", + "properties": { + "worker_id": { "type": "string" }, + "finish_reason": { "type": "string" }, + "tokens_output": { "type": "integer", "minimum": 0 } + }, + "required": ["worker_id", "finish_reason", "tokens_output"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "TeamCreate", + description: "Create a team of sub-agents for parallel task execution.", + input_schema: json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "prompt": { "type": "string" }, + "description": { "type": "string" } + }, + "required": ["prompt"] + } + } + }, + "required": ["name", "tasks"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "TeamDelete", + description: "Delete a team and stop all its running tasks.", + input_schema: json!({ + "type": "object", + "properties": { + "team_id": { "type": "string" } + }, + "required": ["team_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "CronCreate", + description: "Create a scheduled recurring task.", + input_schema: json!({ + "type": "object", + "properties": { + "schedule": { "type": "string" }, + "prompt": { "type": "string" }, + "description": { "type": "string" } + }, + "required": ["schedule", "prompt"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "CronDelete", + description: "Delete a scheduled recurring task by ID.", + input_schema: json!({ + "type": "object", + "properties": { + "cron_id": { "type": "string" } + }, + "required": ["cron_id"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "CronList", + description: "List all scheduled recurring tasks.", + input_schema: json!({ + "type": "object", + "properties": {}, + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "LSP", + description: "Query Language Server Protocol for code intelligence (symbols, references, diagnostics).", + input_schema: json!({ + "type": "object", + "properties": { + "action": { "type": "string", "enum": ["symbols", "references", "diagnostics", "definition", "hover"] }, + "path": { "type": "string" }, + "line": { "type": "integer", "minimum": 0 }, + "character": { "type": "integer", "minimum": 0 }, + "query": { "type": "string" } + }, + "required": ["action"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "ListMcpResources", + description: "List available resources from connected MCP servers.", + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" } + }, + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "ReadMcpResource", + description: "Read a specific resource from an MCP server by URI.", + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" }, + "uri": { "type": "string" } + }, + "required": ["uri"], + "additionalProperties": false + }), + required_permission: PermissionMode::ReadOnly, + }, + ToolSpec { + name: "McpAuth", + description: "Authenticate with an MCP server that requires OAuth or credentials.", + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" } + }, + "required": ["server"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "RemoteTrigger", + description: "Trigger a remote action or webhook endpoint.", + input_schema: json!({ + "type": "object", + "properties": { + "url": { "type": "string" }, + "method": { "type": "string", "enum": ["GET", "POST", "PUT", "DELETE"] }, + "headers": { "type": "object" }, + "body": { "type": "string" } + }, + "required": ["url"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "MCP", + description: "Execute a tool provided by a connected MCP server.", + input_schema: json!({ + "type": "object", + "properties": { + "server": { "type": "string" }, + "tool": { "type": "string" }, + "arguments": { "type": "object" } + }, + "required": ["server", "tool"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, + ToolSpec { + name: "TestingPermission", + description: "Test-only tool for verifying permission enforcement behavior.", + input_schema: json!({ + "type": "object", + "properties": { + "action": { "type": "string" } + }, + "required": ["action"], + "additionalProperties": false + }), + required_permission: PermissionMode::DangerFullAccess, + }, ] } +/// Check permission before executing a tool. Returns Err with denial reason if blocked. +pub fn enforce_permission_check( + enforcer: &PermissionEnforcer, + tool_name: &str, + input: &Value, +) -> Result<(), String> { + let input_str = serde_json::to_string(input).unwrap_or_default(); + let result = enforcer.check(tool_name, &input_str); + + match result { + EnforcementResult::Allowed => Ok(()), + EnforcementResult::Denied { reason, .. } => Err(reason), + } +} + pub fn execute_tool(name: &str, input: &Value) -> Result { + execute_tool_with_enforcer(None, name, input) +} + +fn execute_tool_with_enforcer( + enforcer: Option<&PermissionEnforcer>, + name: &str, + input: &Value, +) -> Result { match name { - "bash" => from_value::(input).and_then(run_bash), - "read_file" => from_value::(input).and_then(run_read_file), - "write_file" => from_value::(input).and_then(run_write_file), - "edit_file" => from_value::(input).and_then(run_edit_file), - "glob_search" => from_value::(input).and_then(run_glob_search), - "grep_search" => from_value::(input).and_then(run_grep_search), + "bash" => { + maybe_enforce_permission_check(enforcer, name, input)?; + from_value::(input).and_then(run_bash) + } + "read_file" => { + maybe_enforce_permission_check(enforcer, name, input)?; + from_value::(input).and_then(run_read_file) + } + "write_file" => { + maybe_enforce_permission_check(enforcer, name, input)?; + from_value::(input).and_then(run_write_file) + } + "edit_file" => { + maybe_enforce_permission_check(enforcer, name, input)?; + from_value::(input).and_then(run_edit_file) + } + "glob_search" => { + maybe_enforce_permission_check(enforcer, name, input)?; + from_value::(input).and_then(run_glob_search) + } + "grep_search" => { + maybe_enforce_permission_check(enforcer, name, input)?; + from_value::(input).and_then(run_grep_search) + } "WebFetch" => from_value::(input).and_then(run_web_fetch), "WebSearch" => from_value::(input).and_then(run_web_search), "TodoWrite" => from_value::(input).and_then(run_todo_write), @@ -554,24 +1215,733 @@ pub fn execute_tool(name: &str, input: &Value) -> Result { "Sleep" => from_value::(input).and_then(run_sleep), "SendUserMessage" | "Brief" => from_value::(input).and_then(run_brief), "Config" => from_value::(input).and_then(run_config), + "EnterPlanMode" => from_value::(input).and_then(run_enter_plan_mode), + "ExitPlanMode" => from_value::(input).and_then(run_exit_plan_mode), "StructuredOutput" => { from_value::(input).and_then(run_structured_output) } "REPL" => from_value::(input).and_then(run_repl), "PowerShell" => from_value::(input).and_then(run_powershell), + "AskUserQuestion" => { + from_value::(input).and_then(run_ask_user_question) + } + "TaskCreate" => from_value::(input).and_then(run_task_create), + "RunTaskPacket" => from_value::(input).and_then(run_task_packet), + "TaskGet" => from_value::(input).and_then(run_task_get), + "TaskList" => run_task_list(input.clone()), + "TaskStop" => from_value::(input).and_then(run_task_stop), + "TaskUpdate" => from_value::(input).and_then(run_task_update), + "TaskOutput" => from_value::(input).and_then(run_task_output), + "WorkerCreate" => from_value::(input).and_then(run_worker_create), + "WorkerGet" => from_value::(input).and_then(run_worker_get), + "WorkerObserve" => from_value::(input).and_then(run_worker_observe), + "WorkerResolveTrust" => { + from_value::(input).and_then(run_worker_resolve_trust) + } + "WorkerAwaitReady" => from_value::(input).and_then(run_worker_await_ready), + "WorkerSendPrompt" => { + from_value::(input).and_then(run_worker_send_prompt) + } + "WorkerRestart" => from_value::(input).and_then(run_worker_restart), + "WorkerTerminate" => from_value::(input).and_then(run_worker_terminate), + "WorkerObserveCompletion" => from_value::(input) + .and_then(run_worker_observe_completion), + "TeamCreate" => from_value::(input).and_then(run_team_create), + "TeamDelete" => from_value::(input).and_then(run_team_delete), + "CronCreate" => from_value::(input).and_then(run_cron_create), + "CronDelete" => from_value::(input).and_then(run_cron_delete), + "CronList" => run_cron_list(input.clone()), + "LSP" => from_value::(input).and_then(run_lsp), + "ListMcpResources" => { + from_value::(input).and_then(run_list_mcp_resources) + } + "ReadMcpResource" => from_value::(input).and_then(run_read_mcp_resource), + "McpAuth" => from_value::(input).and_then(run_mcp_auth), + "RemoteTrigger" => from_value::(input).and_then(run_remote_trigger), + "MCP" => from_value::(input).and_then(run_mcp_tool), + "TestingPermission" => { + from_value::(input).and_then(run_testing_permission) + } _ => Err(format!("unsupported tool: {name}")), } } +fn maybe_enforce_permission_check( + enforcer: Option<&PermissionEnforcer>, + tool_name: &str, + input: &Value, +) -> Result<(), String> { + if let Some(enforcer) = enforcer { + enforce_permission_check(enforcer, tool_name, input)?; + } + Ok(()) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_ask_user_question(input: AskUserQuestionInput) -> Result { + use std::io::{self, BufRead, Write}; + + // Display the question to the user via stdout + let stdout = io::stdout(); + let stdin = io::stdin(); + let mut out = stdout.lock(); + + writeln!(out, "\n[Question] {}", input.question).map_err(|e| e.to_string())?; + + if let Some(ref options) = input.options { + for (i, option) in options.iter().enumerate() { + writeln!(out, " {}. {}", i + 1, option).map_err(|e| e.to_string())?; + } + write!(out, "Enter choice (1-{}): ", options.len()).map_err(|e| e.to_string())?; + } else { + write!(out, "Your answer: ").map_err(|e| e.to_string())?; + } + out.flush().map_err(|e| e.to_string())?; + + // Read user response from stdin + let mut response = String::new(); + stdin + .lock() + .read_line(&mut response) + .map_err(|e| e.to_string())?; + let response = response.trim().to_string(); + + // If options were provided, resolve the numeric choice + let answer = if let Some(ref options) = input.options { + if let Ok(idx) = response.parse::() { + if idx >= 1 && idx <= options.len() { + options[idx - 1].clone() + } else { + response.clone() + } + } else { + response.clone() + } + } else { + response.clone() + }; + + to_pretty_json(json!({ + "question": input.question, + "answer": answer, + "status": "answered" + })) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_task_create(input: TaskCreateInput) -> Result { + let registry = global_task_registry(); + let task = registry.create(&input.prompt, input.description.as_deref()); + to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "prompt": task.prompt, + "description": task.description, + "task_packet": task.task_packet, + "created_at": task.created_at + })) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_task_packet(input: TaskPacket) -> Result { + let registry = global_task_registry(); + let task = registry + .create_from_packet(input) + .map_err(|error| error.to_string())?; + + to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "prompt": task.prompt, + "description": task.description, + "task_packet": task.task_packet, + "created_at": task.created_at + })) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_task_get(input: TaskIdInput) -> Result { + let registry = global_task_registry(); + match registry.get(&input.task_id) { + Some(task) => to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "prompt": task.prompt, + "description": task.description, + "task_packet": task.task_packet, + "created_at": task.created_at, + "updated_at": task.updated_at, + "messages": task.messages, + "team_id": task.team_id + })), + None => Err(format!("task not found: {}", input.task_id)), + } +} + +fn run_task_list(_input: Value) -> Result { + let registry = global_task_registry(); + let tasks: Vec<_> = registry + .list(None) + .into_iter() + .map(|t| { + json!({ + "task_id": t.task_id, + "status": t.status, + "prompt": t.prompt, + "description": t.description, + "task_packet": t.task_packet, + "created_at": t.created_at, + "updated_at": t.updated_at, + "team_id": t.team_id + }) + }) + .collect(); + to_pretty_json(json!({ + "tasks": tasks, + "count": tasks.len() + })) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_task_stop(input: TaskIdInput) -> Result { + let registry = global_task_registry(); + match registry.stop(&input.task_id) { + Ok(task) => to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "message": "Task stopped" + })), + Err(e) => Err(e), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_task_update(input: TaskUpdateInput) -> Result { + let registry = global_task_registry(); + match registry.update(&input.task_id, &input.message) { + Ok(task) => to_pretty_json(json!({ + "task_id": task.task_id, + "status": task.status, + "message_count": task.messages.len(), + "last_message": input.message + })), + Err(e) => Err(e), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_task_output(input: TaskIdInput) -> Result { + let registry = global_task_registry(); + match registry.output(&input.task_id) { + Ok(output) => to_pretty_json(json!({ + "task_id": input.task_id, + "output": output, + "has_output": !output.is_empty() + })), + Err(e) => Err(e), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_create(input: WorkerCreateInput) -> Result { + // Merge config-level trusted_roots with per-call overrides. + // Config provides the default allowlist; per-call roots add on top. + let config_roots: Vec = ConfigLoader::default_for(&input.cwd) + .load() + .ok() + .map(|c| c.trusted_roots().to_vec()) + .unwrap_or_default(); + let merged_roots: Vec = config_roots + .into_iter() + .chain(input.trusted_roots.iter().cloned()) + .collect(); + let worker = global_worker_registry().create( + &input.cwd, + &merged_roots, + input.auto_recover_prompt_misdelivery, + ); + to_pretty_json(worker) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_get(input: WorkerIdInput) -> Result { + global_worker_registry().get(&input.worker_id).map_or_else( + || Err(format!("worker not found: {}", input.worker_id)), + to_pretty_json, + ) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_observe(input: WorkerObserveInput) -> Result { + let worker = global_worker_registry().observe(&input.worker_id, &input.screen_text)?; + to_pretty_json(worker) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_resolve_trust(input: WorkerIdInput) -> Result { + let worker = global_worker_registry().resolve_trust(&input.worker_id)?; + to_pretty_json(worker) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_await_ready(input: WorkerIdInput) -> Result { + let snapshot: WorkerReadySnapshot = global_worker_registry().await_ready(&input.worker_id)?; + to_pretty_json(snapshot) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_send_prompt(input: WorkerSendPromptInput) -> Result { + let worker = global_worker_registry().send_prompt(&input.worker_id, input.prompt.as_deref())?; + to_pretty_json(worker) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_restart(input: WorkerIdInput) -> Result { + let worker = global_worker_registry().restart(&input.worker_id)?; + to_pretty_json(worker) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_terminate(input: WorkerIdInput) -> Result { + let worker = global_worker_registry().terminate(&input.worker_id)?; + to_pretty_json(worker) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_worker_observe_completion(input: WorkerObserveCompletionInput) -> Result { + let worker = global_worker_registry().observe_completion( + &input.worker_id, + &input.finish_reason, + input.tokens_output, + )?; + to_pretty_json(worker) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_team_create(input: TeamCreateInput) -> Result { + let task_ids: Vec = input + .tasks + .iter() + .filter_map(|t| t.get("task_id").and_then(|v| v.as_str()).map(str::to_owned)) + .collect(); + let team = global_team_registry().create(&input.name, task_ids); + // Register team assignment on each task + for task_id in &team.task_ids { + let _ = global_task_registry().assign_team(task_id, &team.team_id); + } + to_pretty_json(json!({ + "team_id": team.team_id, + "name": team.name, + "task_count": team.task_ids.len(), + "task_ids": team.task_ids, + "status": team.status, + "created_at": team.created_at + })) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_team_delete(input: TeamDeleteInput) -> Result { + match global_team_registry().delete(&input.team_id) { + Ok(team) => to_pretty_json(json!({ + "team_id": team.team_id, + "name": team.name, + "status": team.status, + "message": "Team deleted" + })), + Err(e) => Err(e), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_cron_create(input: CronCreateInput) -> Result { + let entry = + global_cron_registry().create(&input.schedule, &input.prompt, input.description.as_deref()); + to_pretty_json(json!({ + "cron_id": entry.cron_id, + "schedule": entry.schedule, + "prompt": entry.prompt, + "description": entry.description, + "enabled": entry.enabled, + "created_at": entry.created_at + })) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_cron_delete(input: CronDeleteInput) -> Result { + match global_cron_registry().delete(&input.cron_id) { + Ok(entry) => to_pretty_json(json!({ + "cron_id": entry.cron_id, + "schedule": entry.schedule, + "status": "deleted", + "message": "Cron entry removed" + })), + Err(e) => Err(e), + } +} + +fn run_cron_list(_input: Value) -> Result { + let entries: Vec<_> = global_cron_registry() + .list(false) + .into_iter() + .map(|e| { + json!({ + "cron_id": e.cron_id, + "schedule": e.schedule, + "prompt": e.prompt, + "description": e.description, + "enabled": e.enabled, + "run_count": e.run_count, + "last_run_at": e.last_run_at, + "created_at": e.created_at + }) + }) + .collect(); + to_pretty_json(json!({ + "crons": entries, + "count": entries.len() + })) +} + +#[allow(clippy::needless_pass_by_value)] +fn run_lsp(input: LspInput) -> Result { + let registry = global_lsp_registry(); + let action = &input.action; + let path = input.path.as_deref(); + let line = input.line; + let character = input.character; + let query = input.query.as_deref(); + + match registry.dispatch(action, path, line, character, query) { + Ok(result) => to_pretty_json(result), + Err(e) => to_pretty_json(json!({ + "action": action, + "error": e, + "status": "error" + })), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_list_mcp_resources(input: McpResourceInput) -> Result { + let registry = global_mcp_registry(); + let server = input.server.as_deref().unwrap_or("default"); + match registry.list_resources(server) { + Ok(resources) => { + let items: Vec<_> = resources + .iter() + .map(|r| { + json!({ + "uri": r.uri, + "name": r.name, + "description": r.description, + "mime_type": r.mime_type, + }) + }) + .collect(); + to_pretty_json(json!({ + "server": server, + "resources": items, + "count": items.len() + })) + } + Err(e) => to_pretty_json(json!({ + "server": server, + "resources": [], + "error": e + })), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_read_mcp_resource(input: McpResourceInput) -> Result { + let registry = global_mcp_registry(); + let uri = input.uri.as_deref().unwrap_or(""); + let server = input.server.as_deref().unwrap_or("default"); + match registry.read_resource(server, uri) { + Ok(resource) => to_pretty_json(json!({ + "server": server, + "uri": resource.uri, + "name": resource.name, + "description": resource.description, + "mime_type": resource.mime_type + })), + Err(e) => to_pretty_json(json!({ + "server": server, + "uri": uri, + "error": e + })), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_mcp_auth(input: McpAuthInput) -> Result { + let registry = global_mcp_registry(); + match registry.get_server(&input.server) { + Some(state) => to_pretty_json(json!({ + "server": input.server, + "status": state.status, + "server_info": state.server_info, + "tool_count": state.tools.len(), + "resource_count": state.resources.len() + })), + None => to_pretty_json(json!({ + "server": input.server, + "status": "disconnected", + "message": "Server not registered. Use MCP tool to connect first." + })), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_remote_trigger(input: RemoteTriggerInput) -> Result { + let method = input.method.unwrap_or_else(|| "GET".to_string()); + let client = Client::new(); + + let mut request = match method.to_uppercase().as_str() { + "GET" => client.get(&input.url), + "POST" => client.post(&input.url), + "PUT" => client.put(&input.url), + "DELETE" => client.delete(&input.url), + "PATCH" => client.patch(&input.url), + "HEAD" => client.head(&input.url), + other => return Err(format!("unsupported HTTP method: {other}")), + }; + + // Apply custom headers + if let Some(ref headers) = input.headers { + if let Some(obj) = headers.as_object() { + for (key, value) in obj { + if let Some(val) = value.as_str() { + request = request.header(key.as_str(), val); + } + } + } + } + + // Apply body + if let Some(ref body) = input.body { + request = request.body(body.clone()); + } + + // Execute with a 30-second timeout + let request = request.timeout(Duration::from_secs(30)); + + match request.send() { + Ok(response) => { + let status = response.status().as_u16(); + let body = response.text().unwrap_or_default(); + let truncated_body = if body.len() > 8192 { + format!( + "{}\n\n[response truncated — {} bytes total]", + &body[..8192], + body.len() + ) + } else { + body + }; + to_pretty_json(json!({ + "url": input.url, + "method": method, + "status_code": status, + "body": truncated_body, + "success": (200..300).contains(&status) + })) + } + Err(e) => to_pretty_json(json!({ + "url": input.url, + "method": method, + "error": e.to_string(), + "success": false + })), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_mcp_tool(input: McpToolInput) -> Result { + let registry = global_mcp_registry(); + let args = input.arguments.unwrap_or(serde_json::json!({})); + match registry.call_tool(&input.server, &input.tool, &args) { + Ok(result) => to_pretty_json(json!({ + "server": input.server, + "tool": input.tool, + "result": result, + "status": "success" + })), + Err(e) => to_pretty_json(json!({ + "server": input.server, + "tool": input.tool, + "error": e, + "status": "error" + })), + } +} + +#[allow(clippy::needless_pass_by_value)] +fn run_testing_permission(input: TestingPermissionInput) -> Result { + to_pretty_json(json!({ + "action": input.action, + "permitted": true, + "message": "Testing permission tool stub" + })) +} fn from_value Deserialize<'de>>(input: &Value) -> Result { serde_json::from_value(input.clone()).map_err(|error| error.to_string()) } fn run_bash(input: BashCommandInput) -> Result { + if let Some(output) = workspace_test_branch_preflight(&input.command) { + return serde_json::to_string_pretty(&output).map_err(|error| error.to_string()); + } serde_json::to_string_pretty(&execute_bash(input).map_err(|error| error.to_string())?) .map_err(|error| error.to_string()) } +fn workspace_test_branch_preflight(command: &str) -> Option { + if !is_workspace_test_command(command) { + return None; + } + + let branch = git_stdout(&["branch", "--show-current"])?; + let main_ref = resolve_main_ref(&branch)?; + let freshness = check_freshness(&branch, &main_ref); + match freshness { + BranchFreshness::Fresh => None, + BranchFreshness::Stale { + commits_behind, + missing_fixes, + } => Some(branch_divergence_output( + command, + &branch, + &main_ref, + commits_behind, + None, + &missing_fixes, + )), + BranchFreshness::Diverged { + ahead, + behind, + missing_fixes, + } => Some(branch_divergence_output( + command, + &branch, + &main_ref, + behind, + Some(ahead), + &missing_fixes, + )), + } +} + +fn is_workspace_test_command(command: &str) -> bool { + let normalized = normalize_shell_command(command); + [ + "cargo test --workspace", + "cargo test --all", + "cargo nextest run --workspace", + "cargo nextest run --all", + ] + .iter() + .any(|needle| normalized.contains(needle)) +} + +fn normalize_shell_command(command: &str) -> String { + command + .split_whitespace() + .collect::>() + .join(" ") + .to_ascii_lowercase() +} + +fn resolve_main_ref(branch: &str) -> Option { + let has_local_main = git_ref_exists("main"); + let has_remote_main = git_ref_exists("origin/main"); + + if branch == "main" && has_remote_main { + Some("origin/main".to_string()) + } else if has_local_main { + Some("main".to_string()) + } else if has_remote_main { + Some("origin/main".to_string()) + } else { + None + } +} + +fn git_ref_exists(reference: &str) -> bool { + Command::new("git") + .args(["rev-parse", "--verify", "--quiet", reference]) + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn git_stdout(args: &[&str]) -> Option { + let output = Command::new("git").args(args).output().ok()?; + if !output.status.success() { + return None; + } + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + (!stdout.is_empty()).then_some(stdout) +} + +fn branch_divergence_output( + command: &str, + branch: &str, + main_ref: &str, + commits_behind: usize, + commits_ahead: Option, + missing_fixes: &[String], +) -> BashCommandOutput { + let relation = commits_ahead.map_or_else( + || format!("is {commits_behind} commit(s) behind"), + |ahead| format!("has diverged ({ahead} ahead, {commits_behind} behind)"), + ); + let missing_summary = if missing_fixes.is_empty() { + "(none surfaced)".to_string() + } else { + missing_fixes.join("; ") + }; + let stderr = format!( + "branch divergence detected before workspace tests: `{branch}` {relation} `{main_ref}`. Missing commits: {missing_summary}. Merge or rebase `{main_ref}` before re-running `{command}`." + ); + + BashCommandOutput { + stdout: String::new(), + stderr: stderr.clone(), + raw_output_path: None, + interrupted: false, + is_image: None, + background_task_id: None, + backgrounded_by_user: None, + assistant_auto_backgrounded: None, + dangerously_disable_sandbox: None, + return_code_interpretation: Some("preflight_blocked:branch_divergence".to_string()), + no_output_expected: Some(false), + structured_content: Some(vec![serde_json::to_value( + LaneEvent::new( + LaneEventName::BranchStaleAgainstMain, + LaneEventStatus::Blocked, + iso8601_now(), + ) + .with_failure_class(LaneFailureClass::BranchDivergence) + .with_detail(stderr.clone()) + .with_data(json!({ + "branch": branch, + "mainRef": main_ref, + "commitsBehind": commits_behind, + "commitsAhead": commits_ahead, + "missingCommits": missing_fixes, + "blockedCommand": command, + "recommendedAction": format!("merge or rebase {main_ref} before workspace tests") + })), + ) + .expect("lane event should serialize")]), + persisted_output_path: None, + persisted_output_size: None, + sandbox_status: None, + } +} + #[allow(clippy::needless_pass_by_value)] fn run_read_file(input: ReadFileInput) -> Result { to_pretty_json(read_file(&input.path, input.offset, input.limit).map_err(io_to_string)?) @@ -636,7 +2006,7 @@ fn run_notebook_edit(input: NotebookEditInput) -> Result { } fn run_sleep(input: SleepInput) -> Result { - to_pretty_json(execute_sleep(input)) + to_pretty_json(execute_sleep(input)?) } fn run_brief(input: BriefInput) -> Result { @@ -647,8 +2017,16 @@ fn run_config(input: ConfigInput) -> Result { to_pretty_json(execute_config(input)?) } +fn run_enter_plan_mode(input: EnterPlanModeInput) -> Result { + to_pretty_json(execute_enter_plan_mode(input)?) +} + +fn run_exit_plan_mode(input: ExitPlanModeInput) -> Result { + to_pretty_json(execute_exit_plan_mode(input)?) +} + fn run_structured_output(input: StructuredOutputInput) -> Result { - to_pretty_json(execute_structured_output(input)) + to_pretty_json(execute_structured_output(input)?) } fn run_repl(input: ReplInput) -> Result { @@ -799,6 +2177,14 @@ struct ConfigInput { value: Option, } +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct EnterPlanModeInput {} + +#[derive(Debug, Default, Deserialize)] +#[serde(default)] +struct ExitPlanModeInput {} + #[derive(Debug, Deserialize)] #[serde(untagged)] enum ConfigValue { @@ -826,6 +2212,143 @@ struct PowerShellInput { run_in_background: Option, } +#[derive(Debug, Deserialize)] +struct AskUserQuestionInput { + question: String, + #[serde(default)] + options: Option>, +} + +#[derive(Debug, Deserialize)] +struct TaskCreateInput { + prompt: String, + #[serde(default)] + description: Option, +} + +#[derive(Debug, Deserialize)] +struct TaskIdInput { + task_id: String, +} + +#[derive(Debug, Deserialize)] +struct TaskUpdateInput { + task_id: String, + message: String, +} + +#[derive(Debug, Deserialize)] +struct WorkerCreateInput { + cwd: String, + #[serde(default)] + trusted_roots: Vec, + #[serde(default = "default_auto_recover_prompt_misdelivery")] + auto_recover_prompt_misdelivery: bool, +} + +#[derive(Debug, Deserialize)] +struct WorkerIdInput { + worker_id: String, +} + +#[derive(Debug, Deserialize)] +struct WorkerObserveCompletionInput { + worker_id: String, + finish_reason: String, + tokens_output: u64, +} + +#[derive(Debug, Deserialize)] +struct WorkerObserveInput { + worker_id: String, + screen_text: String, +} + +#[derive(Debug, Deserialize)] +struct WorkerSendPromptInput { + worker_id: String, + #[serde(default)] + prompt: Option, +} + +const fn default_auto_recover_prompt_misdelivery() -> bool { + true +} + +#[derive(Debug, Deserialize)] +struct TeamCreateInput { + name: String, + tasks: Vec, +} + +#[derive(Debug, Deserialize)] +struct TeamDeleteInput { + team_id: String, +} + +#[derive(Debug, Deserialize)] +struct CronCreateInput { + schedule: String, + prompt: String, + #[serde(default)] + description: Option, +} + +#[derive(Debug, Deserialize)] +struct CronDeleteInput { + cron_id: String, +} + +#[derive(Debug, Deserialize)] +struct LspInput { + action: String, + #[serde(default)] + path: Option, + #[serde(default)] + line: Option, + #[serde(default)] + character: Option, + #[serde(default)] + query: Option, +} + +#[derive(Debug, Deserialize)] +struct McpResourceInput { + #[serde(default)] + server: Option, + #[serde(default)] + uri: Option, +} + +#[derive(Debug, Deserialize)] +struct McpAuthInput { + server: String, +} + +#[derive(Debug, Deserialize)] +struct RemoteTriggerInput { + url: String, + #[serde(default)] + method: Option, + #[serde(default)] + headers: Option, + #[serde(default)] + body: Option, +} + +#[derive(Debug, Deserialize)] +struct McpToolInput { + server: String, + tool: String, + #[serde(default)] + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct TestingPermissionInput { + action: String, +} + #[derive(Debug, Serialize)] struct WebFetchOutput { bytes: usize, @@ -885,6 +2408,12 @@ struct AgentOutput { started_at: Option, #[serde(rename = "completedAt", skip_serializing_if = "Option::is_none")] completed_at: Option, + #[serde(rename = "laneEvents", default, skip_serializing_if = "Vec::is_empty")] + lane_events: Vec, + #[serde(rename = "currentBlocker", skip_serializing_if = "Option::is_none")] + current_blocker: Option, + #[serde(rename = "derivedState")] + derived_state: String, #[serde(skip_serializing_if = "Option::is_none")] error: Option, } @@ -897,8 +2426,8 @@ struct AgentJob { allowed_tools: BTreeSet, } -#[derive(Debug, Serialize)] -struct ToolSearchOutput { +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct ToolSearchOutput { matches: Vec, query: String, normalized_query: String, @@ -906,6 +2435,8 @@ struct ToolSearchOutput { total_deferred_tools: usize, #[serde(rename = "pending_mcp_servers")] pending_mcp_servers: Option>, + #[serde(rename = "mcp_degraded", skip_serializing_if = "Option::is_none")] + mcp_degraded: Option, } #[derive(Debug, Serialize)] @@ -956,6 +2487,39 @@ struct ConfigOutput { error: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +struct PlanModeState { + #[serde(rename = "hadLocalOverride")] + had_local_override: bool, + #[serde(rename = "previousLocalMode")] + previous_local_mode: Option, +} + +#[derive(Debug, Serialize)] +#[allow(clippy::struct_excessive_bools)] +struct PlanModeOutput { + success: bool, + operation: String, + changed: bool, + active: bool, + managed: bool, + message: String, + #[serde(rename = "settingsPath")] + settings_path: String, + #[serde(rename = "statePath")] + state_path: String, + #[serde(rename = "previousLocalMode")] + previous_local_mode: Option, + #[serde(rename = "currentLocalMode")] + current_local_mode: Option, +} + +#[derive(Debug, Clone)] +struct SearchableToolSpec { + name: String, + description: String, +} + #[derive(Debug, Serialize)] struct StructuredOutputResult { data: String, @@ -1081,7 +2645,7 @@ fn build_http_client() -> Result { Client::builder() .timeout(Duration::from_secs(20)) .redirect(reqwest::redirect::Policy::limited(10)) - .user_agent("claw-rust-tools/0.1") + .user_agent("clawd-rust-tools/0.1") .build() .map_err(|error| error.to_string()) } @@ -1102,7 +2666,7 @@ fn normalize_fetch_url(url: &str) -> Result { } fn build_search_url(query: &str) -> Result { - if let Ok(base) = std::env::var("CLAW_WEB_SEARCH_BASE_URL") { + if let Ok(base) = std::env::var("CLAWD_WEB_SEARCH_BASE_URL") { let mut url = reqwest::Url::parse(&base).map_err(|error| error.to_string())?; url.query_pairs_mut().append_pair("q", query); return Ok(url); @@ -1447,57 +3011,274 @@ fn validate_todos(todos: &[TodoItem]) -> Result<(), String> { } fn todo_store_path() -> Result { - if let Ok(path) = std::env::var("CLAW_TODO_STORE") { + if let Ok(path) = std::env::var("CLAWD_TODO_STORE") { return Ok(std::path::PathBuf::from(path)); } let cwd = std::env::current_dir().map_err(|error| error.to_string())?; - Ok(cwd.join(".claw-todos.json")) + Ok(cwd.join(".clawd-todos.json")) } fn resolve_skill_path(skill: &str) -> Result { + let cwd = std::env::current_dir().map_err(|error| error.to_string())?; + match commands::resolve_skill_path(&cwd, skill) { + Ok(path) => Ok(path), + Err(_) => resolve_skill_path_from_compat_roots(skill), + } +} + +fn resolve_skill_path_from_compat_roots(skill: &str) -> Result { let requested = skill.trim().trim_start_matches('/').trim_start_matches('$'); if requested.is_empty() { return Err(String::from("skill must not be empty")); } - let mut candidates = Vec::new(); - if let Ok(codex_home) = std::env::var("CODEX_HOME") { - candidates.push(std::path::PathBuf::from(codex_home).join("skills")); - } - if let Ok(home) = std::env::var("HOME") { - let home = std::path::PathBuf::from(home); - candidates.push(home.join(".agents").join("skills")); - candidates.push(home.join(".config").join("opencode").join("skills")); - candidates.push(home.join(".codex").join("skills")); - } - candidates.push(std::path::PathBuf::from("/home/bellman/.codex/skills")); - - for root in candidates { - let direct = root.join(requested).join("SKILL.md"); - if direct.exists() { - return Ok(direct); - } - - if let Ok(entries) = std::fs::read_dir(&root) { - for entry in entries.flatten() { - let path = entry.path().join("SKILL.md"); - if !path.exists() { - continue; - } - if entry - .file_name() - .to_string_lossy() - .eq_ignore_ascii_case(requested) - { - return Ok(path); - } - } + for root in skill_lookup_roots() { + if let Some(path) = resolve_skill_path_in_root(&root, requested) { + return Ok(path); } } Err(format!("unknown skill: {requested}")) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SkillLookupOrigin { + SkillsDir, + LegacyCommandsDir, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SkillLookupRoot { + path: std::path::PathBuf, + origin: SkillLookupOrigin, +} + +fn skill_lookup_roots() -> Vec { + let mut roots = Vec::new(); + + if let Ok(cwd) = std::env::current_dir() { + push_project_skill_lookup_roots(&mut roots, &cwd); + } + + if let Ok(claw_config_home) = std::env::var("CLAW_CONFIG_HOME") { + push_prefixed_skill_lookup_roots(&mut roots, std::path::Path::new(&claw_config_home)); + } + if let Ok(codex_home) = std::env::var("CODEX_HOME") { + push_prefixed_skill_lookup_roots(&mut roots, std::path::Path::new(&codex_home)); + } + if let Ok(home) = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE")) { + push_home_skill_lookup_roots(&mut roots, std::path::Path::new(&home)); + } + if let Ok(claude_config_dir) = std::env::var("CLAUDE_CONFIG_DIR") { + let claude_config_dir = std::path::PathBuf::from(claude_config_dir); + push_skill_lookup_root( + &mut roots, + claude_config_dir.join("skills"), + SkillLookupOrigin::SkillsDir, + ); + push_skill_lookup_root( + &mut roots, + claude_config_dir.join("skills").join("omc-learned"), + SkillLookupOrigin::SkillsDir, + ); + push_skill_lookup_root( + &mut roots, + claude_config_dir.join("commands"), + SkillLookupOrigin::LegacyCommandsDir, + ); + } + push_skill_lookup_root( + &mut roots, + std::path::PathBuf::from("/home/bellman/.claw/skills"), + SkillLookupOrigin::SkillsDir, + ); + push_skill_lookup_root( + &mut roots, + std::path::PathBuf::from("/home/bellman/.codex/skills"), + SkillLookupOrigin::SkillsDir, + ); + + roots +} + +fn push_project_skill_lookup_roots(roots: &mut Vec, cwd: &std::path::Path) { + for ancestor in cwd.ancestors() { + push_prefixed_skill_lookup_roots(roots, &ancestor.join(".omc")); + push_prefixed_skill_lookup_roots(roots, &ancestor.join(".agents")); + push_prefixed_skill_lookup_roots(roots, &ancestor.join(".claw")); + push_prefixed_skill_lookup_roots(roots, &ancestor.join(".codex")); + push_prefixed_skill_lookup_roots(roots, &ancestor.join(".claude")); + } +} + +fn push_home_skill_lookup_roots(roots: &mut Vec, home: &std::path::Path) { + push_prefixed_skill_lookup_roots(roots, &home.join(".omc")); + push_prefixed_skill_lookup_roots(roots, &home.join(".claw")); + push_prefixed_skill_lookup_roots(roots, &home.join(".codex")); + push_prefixed_skill_lookup_roots(roots, &home.join(".claude")); + push_skill_lookup_root( + roots, + home.join(".agents").join("skills"), + SkillLookupOrigin::SkillsDir, + ); + push_skill_lookup_root( + roots, + home.join(".config").join("opencode").join("skills"), + SkillLookupOrigin::SkillsDir, + ); + push_skill_lookup_root( + roots, + home.join(".claude").join("skills").join("omc-learned"), + SkillLookupOrigin::SkillsDir, + ); +} + +fn push_prefixed_skill_lookup_roots(roots: &mut Vec, prefix: &std::path::Path) { + push_skill_lookup_root(roots, prefix.join("skills"), SkillLookupOrigin::SkillsDir); + push_skill_lookup_root( + roots, + prefix.join("commands"), + SkillLookupOrigin::LegacyCommandsDir, + ); +} + +fn push_skill_lookup_root( + roots: &mut Vec, + path: std::path::PathBuf, + origin: SkillLookupOrigin, +) { + if path.is_dir() && !roots.iter().any(|existing| existing.path == path) { + roots.push(SkillLookupRoot { path, origin }); + } +} + +fn resolve_skill_path_in_root( + root: &SkillLookupRoot, + requested: &str, +) -> Option { + match root.origin { + SkillLookupOrigin::SkillsDir => resolve_skill_path_in_skills_dir(&root.path, requested), + SkillLookupOrigin::LegacyCommandsDir => { + resolve_skill_path_in_legacy_commands_dir(&root.path, requested) + } + } +} + +fn resolve_skill_path_in_skills_dir( + root: &std::path::Path, + requested: &str, +) -> Option { + let direct = root.join(requested).join("SKILL.md"); + if direct.is_file() { + return Some(direct); + } + + let entries = std::fs::read_dir(root).ok()?; + for entry in entries.flatten() { + if !entry.path().is_dir() { + continue; + } + let skill_path = entry.path().join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + if entry + .file_name() + .to_string_lossy() + .eq_ignore_ascii_case(requested) + || skill_frontmatter_name_matches(&skill_path, requested) + { + return Some(skill_path); + } + } + + None +} + +fn resolve_skill_path_in_legacy_commands_dir( + root: &std::path::Path, + requested: &str, +) -> Option { + let direct_dir = root.join(requested).join("SKILL.md"); + if direct_dir.is_file() { + return Some(direct_dir); + } + + let direct_markdown = root.join(format!("{requested}.md")); + if direct_markdown.is_file() { + return Some(direct_markdown); + } + + let entries = std::fs::read_dir(root).ok()?; + for entry in entries.flatten() { + let path = entry.path(); + let candidate_path = if path.is_dir() { + let skill_path = path.join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + skill_path + } else if path + .extension() + .is_some_and(|ext| ext.to_string_lossy().eq_ignore_ascii_case("md")) + { + path + } else { + continue; + }; + + let matches_entry_name = candidate_path + .file_stem() + .is_some_and(|stem| stem.to_string_lossy().eq_ignore_ascii_case(requested)) + || entry + .file_name() + .to_string_lossy() + .trim_end_matches(".md") + .eq_ignore_ascii_case(requested); + if matches_entry_name || skill_frontmatter_name_matches(&candidate_path, requested) { + return Some(candidate_path); + } + } + + None +} + +fn skill_frontmatter_name_matches(path: &std::path::Path, requested: &str) -> bool { + std::fs::read_to_string(path) + .ok() + .and_then(|contents| parse_skill_name(&contents)) + .is_some_and(|name| name.eq_ignore_ascii_case(requested)) +} + +fn parse_skill_name(contents: &str) -> Option { + parse_skill_frontmatter_value(contents, "name") +} + +fn parse_skill_frontmatter_value(contents: &str, key: &str) -> Option { + let mut lines = contents.lines(); + if lines.next().map(str::trim) != Some("---") { + return None; + } + + for line in lines { + let trimmed = line.trim(); + if trimmed == "---" { + break; + } + if let Some(value) = trimmed.strip_prefix(&format!("{key}:")) { + let value = value + .trim() + .trim_matches(|ch| matches!(ch, '"' | '\'')) + .trim(); + if !value.is_empty() { + return Some(value.to_string()); + } + } + } + + None +} + const DEFAULT_AGENT_MODEL: &str = "claude-opus-4-6"; const DEFAULT_AGENT_SYSTEM_DATE: &str = "2026-03-31"; const DEFAULT_AGENT_MAX_ITERATIONS: usize = 32; @@ -1563,6 +3344,9 @@ where created_at: created_at.clone(), started_at: Some(created_at), completed_at: None, + lane_events: vec![LaneEvent::started(iso8601_now())], + current_blocker: None, + derived_state: String::from("working"), error: None, }; write_agent_manifest(&manifest)?; @@ -1584,7 +3368,7 @@ where } fn spawn_agent_job(job: AgentJob) -> Result<(), String> { - let thread_name = format!("claw-agent-{}", job.manifest.agent_id); + let thread_name = format!("clawd-agent-{}", job.manifest.agent_id); std::thread::Builder::new() .name(thread_name) .spawn(move || { @@ -1628,13 +3412,15 @@ fn build_agent_runtime( .clone() .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); let allowed_tools = job.allowed_tools.clone(); - let api_client = ProviderRuntimeClient::new(&model, allowed_tools.clone())?; - let tool_executor = SubagentToolExecutor::new(allowed_tools); + let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?; + let permission_policy = agent_permission_policy(); + let tool_executor = SubagentToolExecutor::new(allowed_tools) + .with_enforcer(PermissionEnforcer::new(permission_policy.clone())); Ok(ConversationRuntime::new( Session::new(), api_client, tool_executor, - agent_permission_policy(), + permission_policy, job.system_prompt.clone(), )) } @@ -1751,9 +3537,11 @@ fn agent_permission_policy() -> PermissionPolicy { } fn write_agent_manifest(manifest: &AgentOutput) -> Result<(), String> { + let mut normalized = manifest.clone(); + normalized.lane_events = dedupe_superseded_commit_events(&normalized.lane_events); std::fs::write( - &manifest.manifest_file, - serde_json::to_string_pretty(manifest).map_err(|error| error.to_string())?, + &normalized.manifest_file, + serde_json::to_string_pretty(&normalized).map_err(|error| error.to_string())?, ) .map_err(|error| error.to_string()) } @@ -1764,17 +3552,119 @@ fn persist_agent_terminal_state( result: Option<&str>, error: Option, ) -> Result<(), String> { + let blocker = error.as_deref().map(classify_lane_blocker); append_agent_output( &manifest.output_file, - &format_agent_terminal_output(status, result, error.as_deref()), + &format_agent_terminal_output(status, result, blocker.as_ref(), error.as_deref()), )?; let mut next_manifest = manifest.clone(); next_manifest.status = status.to_string(); next_manifest.completed_at = Some(iso8601_now()); + next_manifest.current_blocker.clone_from(&blocker); + next_manifest.derived_state = + derive_agent_state(status, result, error.as_deref(), blocker.as_ref()).to_string(); next_manifest.error = error; + if let Some(blocker) = blocker { + next_manifest + .lane_events + .push(LaneEvent::blocked(iso8601_now(), &blocker)); + next_manifest + .lane_events + .push(LaneEvent::failed(iso8601_now(), &blocker)); + } else { + next_manifest.current_blocker = None; + let compressed_detail = result + .filter(|value| !value.trim().is_empty()) + .map(|value| compress_summary_text(value.trim())); + next_manifest + .lane_events + .push(LaneEvent::finished(iso8601_now(), compressed_detail)); + if let Some(provenance) = maybe_commit_provenance(result) { + next_manifest.lane_events.push(LaneEvent::commit_created( + iso8601_now(), + Some(format!("commit {}", provenance.commit)), + provenance, + )); + } + } write_agent_manifest(&next_manifest) } +fn derive_agent_state( + status: &str, + result: Option<&str>, + error: Option<&str>, + blocker: Option<&LaneEventBlocker>, +) -> &'static str { + let normalized_status = status.trim().to_ascii_lowercase(); + let normalized_error = error.unwrap_or_default().to_ascii_lowercase(); + + if normalized_status == "running" { + return "working"; + } + if normalized_status == "completed" { + return if result.is_some_and(|value| !value.trim().is_empty()) { + "finished_cleanable" + } else { + "finished_pending_report" + }; + } + if normalized_error.contains("background") { + return "blocked_background_job"; + } + if normalized_error.contains("merge conflict") || normalized_error.contains("cherry-pick") { + return "blocked_merge_conflict"; + } + if normalized_error.contains("mcp") { + return "degraded_mcp"; + } + if normalized_error.contains("transport") + || normalized_error.contains("broken pipe") + || normalized_error.contains("connection") + || normalized_error.contains("interrupted") + { + return "interrupted_transport"; + } + if blocker.is_some() { + return "truly_idle"; + } + "truly_idle" +} + +fn maybe_commit_provenance(result: Option<&str>) -> Option { + let commit = extract_commit_sha(result?)?; + let branch = current_git_branch().unwrap_or_else(|| "unknown".to_string()); + let worktree = std::env::current_dir() + .ok() + .map(|path| path.display().to_string()); + Some(LaneCommitProvenance { + commit: commit.clone(), + branch, + worktree, + canonical_commit: Some(commit.clone()), + superseded_by: None, + lineage: vec![commit], + }) +} + +fn extract_commit_sha(result: &str) -> Option { + result + .split(|c: char| !c.is_ascii_hexdigit()) + .find(|token| token.len() >= 7 && token.len() <= 40) + .map(str::to_string) +} + +fn current_git_branch() -> Option { + let output = Command::new("git") + .args(["rev-parse", "--abbrev-ref", "HEAD"]) + .output() + .ok()?; + output + .status + .success() + .then(|| String::from_utf8_lossy(&output.stdout).trim().to_string()) +} + fn append_agent_output(path: &str, suffix: &str) -> Result<(), String> { use std::io::Write as _; @@ -1786,8 +3676,22 @@ fn append_agent_output(path: &str, suffix: &str) -> Result<(), String> { .map_err(|error| error.to_string()) } -fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Option<&str>) -> String { +fn format_agent_terminal_output( + status: &str, + result: Option<&str>, + blocker: Option<&LaneEventBlocker>, + error: Option<&str>, +) -> String { let mut sections = vec![format!("\n## Result\n\n- status: {status}\n")]; + if let Some(blocker) = blocker { + sections.push(format!( + "\n### Blocker\n\n- failure_class: {}\n- detail: {}\n", + serde_json::to_string(&blocker.failure_class) + .unwrap_or_else(|_| "\"infra\"".to_string()) + .trim_matches('"'), + blocker.detail.trim() + )); + } if let Some(result) = result.filter(|value| !value.trim().is_empty()) { sections.push(format!("\n### Final response\n\n{}\n", result.trim())); } @@ -1797,28 +3701,117 @@ fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Optio sections.join("") } +fn classify_lane_blocker(error: &str) -> LaneEventBlocker { + let detail = error.trim().to_string(); + LaneEventBlocker { + failure_class: classify_lane_failure(error), + detail, + } +} + +fn classify_lane_failure(error: &str) -> LaneFailureClass { + let normalized = error.to_ascii_lowercase(); + + if normalized.contains("prompt") && normalized.contains("deliver") { + LaneFailureClass::PromptDelivery + } else if normalized.contains("trust") { + LaneFailureClass::TrustGate + } else if normalized.contains("branch") + && (normalized.contains("stale") || normalized.contains("diverg")) + { + LaneFailureClass::BranchDivergence + } else if normalized.contains("gateway") || normalized.contains("routing") { + LaneFailureClass::GatewayRouting + } else if normalized.contains("compile") + || normalized.contains("build failed") + || normalized.contains("cargo check") + { + LaneFailureClass::Compile + } else if normalized.contains("test") { + LaneFailureClass::Test + } else if normalized.contains("tool failed") + || normalized.contains("runtime tool") + || normalized.contains("tool runtime") + { + LaneFailureClass::ToolRuntime + } else if normalized.contains("plugin") { + LaneFailureClass::PluginStartup + } else if normalized.contains("mcp") && normalized.contains("handshake") { + LaneFailureClass::McpHandshake + } else if normalized.contains("mcp") { + LaneFailureClass::McpStartup + } else { + LaneFailureClass::Infra + } +} + +struct ProviderEntry { + model: String, + client: ProviderClient, +} + struct ProviderRuntimeClient { runtime: tokio::runtime::Runtime, - client: ProviderClient, - model: String, + chain: Vec, allowed_tools: BTreeSet, } impl ProviderRuntimeClient { - fn new(model: &str, allowed_tools: BTreeSet) -> Result { - let model = resolve_model_alias(model).clone(); - let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?; + #[allow(clippy::needless_pass_by_value)] + fn new(model: String, allowed_tools: BTreeSet) -> Result { + let fallback_config = load_provider_fallback_config(); + Self::new_with_fallback_config(model, allowed_tools, &fallback_config) + } + + #[allow(clippy::needless_pass_by_value)] + fn new_with_fallback_config( + model: String, + allowed_tools: BTreeSet, + fallback_config: &ProviderFallbackConfig, + ) -> Result { + let primary_model = fallback_config + .primary() + .map(str::to_string) + .unwrap_or(model); + let primary = build_provider_entry(&primary_model)?; + let mut chain = vec![primary]; + for fallback_model in fallback_config.fallbacks() { + match build_provider_entry(fallback_model) { + Ok(entry) => chain.push(entry), + Err(error) => { + eprintln!( + "warning: skipping unavailable fallback provider {fallback_model}: {error}" + ); + } + } + } Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, - client, - model, + chain, allowed_tools, }) } } +fn build_provider_entry(model: &str) -> Result { + let resolved = resolve_model_alias(model).clone(); + let client = ProviderClient::from_model(&resolved).map_err(|error| error.to_string())?; + Ok(ProviderEntry { + model: resolved, + client, + }) +} + +fn load_provider_fallback_config() -> ProviderFallbackConfig { + std::env::current_dir() + .ok() + .and_then(|cwd| ConfigLoader::default_for(cwd).load().ok()) + .map_or_else(ProviderFallbackConfig::default, |config| { + config.provider_fallbacks().clone() + }) +} + impl ApiClient for ProviderRuntimeClient { - #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) .into_iter() @@ -1828,116 +3821,149 @@ impl ApiClient for ProviderRuntimeClient { input_schema: spec.input_schema, }) .collect::>(); - let message_request = MessageRequest { - model: self.model.clone(), - max_tokens: max_tokens_for_model(&self.model), - messages: convert_messages(&request.messages), - system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), - tools: (!tools.is_empty()).then_some(tools), - tool_choice: (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto), - stream: true, - }; + let messages = convert_messages(&request.messages); + let system = + (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")); + let tool_choice = (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto); - self.runtime.block_on(async { - let mut stream = self - .client - .stream_message(&message_request) - .await - .map_err(|error| RuntimeError::new(error.to_string()))?; - let mut events = Vec::new(); - let mut pending_tools: BTreeMap = BTreeMap::new(); - let mut saw_stop = false; + let runtime = &self.runtime; + let chain = &self.chain; + let mut last_error: Option = None; + for (index, entry) in chain.iter().enumerate() { + let message_request = MessageRequest { + model: entry.model.clone(), + max_tokens: max_tokens_for_model(&entry.model), + messages: messages.clone(), + system: system.clone(), + tools: (!tools.is_empty()).then(|| tools.clone()), + tool_choice: tool_choice.clone(), + stream: true, + ..Default::default() + }; - while let Some(event) = stream - .next_event() - .await - .map_err(|error| RuntimeError::new(error.to_string()))? - { - match event { - ApiStreamEvent::MessageStart(start) => { - for block in start.message.content { - push_output_block(block, 0, &mut events, &mut pending_tools, true); - } - } - ApiStreamEvent::ContentBlockStart(start) => { - push_output_block( - start.content_block, - start.index, - &mut events, - &mut pending_tools, - true, - ); - } - ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { - ContentBlockDelta::TextDelta { text } => { - if !text.is_empty() { - events.push(AssistantEvent::TextDelta(text)); - } - } - ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { - input.push_str(&partial_json); - } - } - ContentBlockDelta::ThinkingDelta { .. } - | ContentBlockDelta::SignatureDelta { .. } => {} - }, - ApiStreamEvent::ContentBlockStop(stop) => { - if let Some((id, name, input)) = pending_tools.remove(&stop.index) { - events.push(AssistantEvent::ToolUse { id, name, input }); - } - } - ApiStreamEvent::MessageDelta(delta) => { - events.push(AssistantEvent::Usage(TokenUsage { - input_tokens: delta.usage.input_tokens, - output_tokens: delta.usage.output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - })); - } - ApiStreamEvent::MessageStop(_) => { - saw_stop = true; - events.push(AssistantEvent::MessageStop); - } + let attempt = runtime.block_on(stream_with_provider(&entry.client, &message_request)); + match attempt { + Ok(events) => return Ok(events), + Err(error) if error.is_retryable() && index + 1 < chain.len() => { + eprintln!( + "provider {} failed with retryable error, falling back: {error}", + entry.model + ); + last_error = Some(error); + continue; + } + Err(error) => return Err(RuntimeError::new(error.to_string())), + } + } + + Err(RuntimeError::new( + last_error + .map(|error| error.to_string()) + .unwrap_or_else(|| String::from("provider chain exhausted with no attempts")), + )) + } +} + +#[allow(clippy::too_many_lines)] +async fn stream_with_provider( + client: &ProviderClient, + message_request: &MessageRequest, +) -> Result, ApiError> { + let mut stream = client.stream_message(message_request).await?; + let mut events = Vec::new(); + let mut pending_tools: BTreeMap = BTreeMap::new(); + let mut saw_stop = false; + + while let Some(event) = stream.next_event().await? { + match event { + ApiStreamEvent::MessageStart(start) => { + for block in start.message.content { + push_output_block(block, 0, &mut events, &mut pending_tools, true); } } - - if !saw_stop - && events.iter().any(|event| { - matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) - || matches!(event, AssistantEvent::ToolUse { .. }) - }) - { + ApiStreamEvent::ContentBlockStart(start) => { + push_output_block( + start.content_block, + start.index, + &mut events, + &mut pending_tools, + true, + ); + } + ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { + ContentBlockDelta::TextDelta { text } => { + if !text.is_empty() { + events.push(AssistantEvent::TextDelta(text)); + } + } + ContentBlockDelta::InputJsonDelta { partial_json } => { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { + input.push_str(&partial_json); + } + } + ContentBlockDelta::ThinkingDelta { .. } + | ContentBlockDelta::SignatureDelta { .. } => {} + }, + ApiStreamEvent::ContentBlockStop(stop) => { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + ApiStreamEvent::MessageDelta(delta) => { + events.push(AssistantEvent::Usage(delta.usage.token_usage())); + } + ApiStreamEvent::MessageStop(_) => { + saw_stop = true; events.push(AssistantEvent::MessageStop); } - - if events - .iter() - .any(|event| matches!(event, AssistantEvent::MessageStop)) - { - return Ok(events); - } - - let response = self - .client - .send_message(&MessageRequest { - stream: false, - ..message_request.clone() - }) - .await - .map_err(|error| RuntimeError::new(error.to_string()))?; - Ok(response_to_events(response)) - }) + } } + + push_prompt_cache_record(client, &mut events); + + if !saw_stop + && events.iter().any(|event| { + matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) + || matches!(event, AssistantEvent::ToolUse { .. }) + }) + { + events.push(AssistantEvent::MessageStop); + } + + if events + .iter() + .any(|event| matches!(event, AssistantEvent::MessageStop)) + { + return Ok(events); + } + + let response = client + .send_message(&MessageRequest { + stream: false, + ..message_request.clone() + }) + .await?; + let mut events = response_to_events(response); + push_prompt_cache_record(client, &mut events); + Ok(events) } struct SubagentToolExecutor { allowed_tools: BTreeSet, + enforcer: Option, } impl SubagentToolExecutor { fn new(allowed_tools: BTreeSet) -> Self { - Self { allowed_tools } + Self { + allowed_tools, + enforcer: None, + } + } + + fn with_enforcer(mut self, enforcer: PermissionEnforcer) -> Self { + self.enforcer = Some(enforcer); + self } } @@ -1950,7 +3976,8 @@ impl ToolExecutor for SubagentToolExecutor { } let value = serde_json::from_str(input) .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; - execute_tool(tool_name, &value).map_err(ToolError::new) + execute_tool_with_enforcer(self.enforcer.as_ref(), tool_name, &value) + .map_err(ToolError::new) } } @@ -1992,13 +4019,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { }], is_error: *is_error, }, - ContentBlock::Thinking { thinking, signature } => InputContentBlock::Thinking { - thinking: thinking.clone(), - signature: signature.clone(), - }, - ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking { - data: serde_json::from_str(&data.render()).unwrap_or(serde_json::Value::Null), - }, }) .collect::>(); (!content.is_empty()).then(|| InputMessage { @@ -2049,16 +4069,32 @@ fn response_to_events(response: MessageResponse) -> Vec { } } - events.push(AssistantEvent::Usage(TokenUsage { - input_tokens: response.usage.input_tokens, - output_tokens: response.usage.output_tokens, - cache_creation_input_tokens: response.usage.cache_creation_input_tokens, - cache_read_input_tokens: response.usage.cache_read_input_tokens, - })); + events.push(AssistantEvent::Usage(response.usage.token_usage())); events.push(AssistantEvent::MessageStop); events } +fn push_prompt_cache_record(client: &ProviderClient, events: &mut Vec) { + if let Some(record) = client.take_last_prompt_cache_record() { + if let Some(event) = prompt_cache_record_to_runtime_event(record) { + events.push(AssistantEvent::PromptCache(event)); + } + } +} + +fn prompt_cache_record_to_runtime_event( + record: api::PromptCacheRecord, +) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + fn final_assistant_text(summary: &runtime::TurnSummary) -> String { summary .assistant_messages @@ -2069,10 +4105,7 @@ fn final_assistant_text(summary: &runtime::TurnSummary) -> String { .iter() .filter_map(|block| match block { ContentBlock::Text { text } => Some(text.as_str()), - ContentBlock::Thinking { .. } - | ContentBlock::RedactedThinking { .. } - | ContentBlock::ToolUse { .. } - | ContentBlock::ToolResult { .. } => None, + _ => None, }) .collect::>() .join("") @@ -2082,19 +4115,7 @@ fn final_assistant_text(summary: &runtime::TurnSummary) -> String { #[allow(clippy::needless_pass_by_value)] fn execute_tool_search(input: ToolSearchInput) -> ToolSearchOutput { - let deferred = deferred_tool_specs(); - let max_results = input.max_results.unwrap_or(5).max(1); - let query = input.query.trim().to_string(); - let normalized_query = normalize_tool_search_query(&query); - let matches = search_tool_specs(&query, max_results, &deferred); - - ToolSearchOutput { - matches, - query, - normalized_query, - total_deferred_tools: deferred.len(), - pending_mcp_servers: None, - } + GlobalToolRegistry::builtin().search(&input.query, input.max_results.unwrap_or(5), None, None) } fn deferred_tool_specs() -> Vec { @@ -2109,7 +4130,7 @@ fn deferred_tool_specs() -> Vec { .collect() } -fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec { +fn search_tool_specs(query: &str, max_results: usize, specs: &[SearchableToolSpec]) -> Vec { let lowered = query.to_lowercase(); if let Some(selection) = lowered.strip_prefix("select:") { return selection @@ -2120,8 +4141,8 @@ fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec let wanted = canonical_tool_token(wanted); specs .iter() - .find(|spec| canonical_tool_token(spec.name) == wanted) - .map(|spec| spec.name.to_string()) + .find(|spec| canonical_tool_token(&spec.name) == wanted) + .map(|spec| spec.name.clone()) }) .take(max_results) .collect(); @@ -2148,8 +4169,8 @@ fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec .iter() .filter_map(|spec| { let name = spec.name.to_lowercase(); - let canonical_name = canonical_tool_token(spec.name); - let normalized_description = normalize_tool_search_query(spec.description); + let canonical_name = canonical_tool_token(&spec.name); + let normalized_description = normalize_tool_search_query(&spec.description); let haystack = format!( "{name} {} {canonical_name}", spec.description.to_lowercase() @@ -2182,7 +4203,7 @@ fn search_tool_specs(query: &str, max_results: usize, specs: &[ToolSpec]) -> Vec if score == 0 && !lowered.is_empty() { return None; } - Some((score, spec.name.to_string())) + Some((score, spec.name.clone())) }) .collect::>(); @@ -2217,14 +4238,14 @@ fn canonical_tool_token(value: &str) -> String { } fn agent_store_dir() -> Result { - if let Ok(path) = std::env::var("CLAW_AGENT_STORE") { + if let Ok(path) = std::env::var("CLAWD_AGENT_STORE") { return Ok(std::path::PathBuf::from(path)); } let cwd = std::env::current_dir().map_err(|error| error.to_string())?; if let Some(workspace_root) = cwd.ancestors().nth(2) { - return Ok(workspace_root.join(".claw-agents")); + return Ok(workspace_root.join(".clawd-agents")); } - Ok(cwd.join(".claw-agents")) + Ok(cwd.join(".clawd-agents")) } fn make_agent_id() -> String { @@ -2329,7 +4350,8 @@ fn execute_notebook_edit(input: NotebookEditInput) -> Result { - let resolved_cell_type = resolved_cell_type.expect("insert cell type"); + let resolved_cell_type = resolved_cell_type + .ok_or_else(|| String::from("insert mode requires a cell type"))?; let new_id = make_cell_id(cells.len()); let new_cell = build_notebook_cell(&new_id, resolved_cell_type, &new_source); let insert_at = target_index.map_or(cells.len(), |index| index + 1); @@ -2341,16 +4363,21 @@ fn execute_notebook_edit(input: NotebookEditInput) -> Result { - let removed = cells.remove(target_index.expect("delete target index")); + let idx = target_index + .ok_or_else(|| String::from("delete mode requires a target cell index"))?; + let removed = cells.remove(idx); removed .get("id") .and_then(serde_json::Value::as_str) .map(ToString::to_string) } NotebookEditMode::Replace => { - let resolved_cell_type = resolved_cell_type.expect("replace cell type"); + let resolved_cell_type = resolved_cell_type + .ok_or_else(|| String::from("replace mode requires a cell type"))?; + let idx = target_index + .ok_or_else(|| String::from("replace mode requires a target cell index"))?; let cell = cells - .get_mut(target_index.expect("replace target index")) + .get_mut(idx) .ok_or_else(|| String::from("Cell index out of range"))?; cell["source"] = serde_json::Value::Array(source_lines(&new_source)); cell["cell_type"] = serde_json::Value::String(match resolved_cell_type { @@ -2441,13 +4468,21 @@ fn cell_kind(cell: &serde_json::Value) -> Option { }) } +const MAX_SLEEP_DURATION_MS: u64 = 300_000; + #[allow(clippy::needless_pass_by_value)] -fn execute_sleep(input: SleepInput) -> SleepOutput { +fn execute_sleep(input: SleepInput) -> Result { + if input.duration_ms > MAX_SLEEP_DURATION_MS { + return Err(format!( + "duration_ms {} exceeds maximum allowed sleep of {MAX_SLEEP_DURATION_MS}ms", + input.duration_ms, + )); + } std::thread::sleep(Duration::from_millis(input.duration_ms)); - SleepOutput { + Ok(SleepOutput { duration_ms: input.duration_ms, message: format!("Slept for {}ms", input.duration_ms), - } + }) } fn execute_brief(input: BriefInput) -> Result { @@ -2544,25 +4579,204 @@ fn execute_config(input: ConfigInput) -> Result { } } -fn execute_structured_output(input: StructuredOutputInput) -> StructuredOutputResult { - StructuredOutputResult { +const PERMISSION_DEFAULT_MODE_PATH: &[&str] = &["permissions", "defaultMode"]; + +fn execute_enter_plan_mode(_input: EnterPlanModeInput) -> Result { + let settings_path = config_file_for_scope(ConfigScope::Settings)?; + let state_path = plan_mode_state_file()?; + let mut document = read_json_object(&settings_path)?; + let current_local_mode = get_nested_value(&document, PERMISSION_DEFAULT_MODE_PATH).cloned(); + let current_is_plan = + matches!(current_local_mode.as_ref(), Some(Value::String(value)) if value == "plan"); + + if let Some(state) = read_plan_mode_state(&state_path)? { + if current_is_plan { + return Ok(PlanModeOutput { + success: true, + operation: String::from("enter"), + changed: false, + active: true, + managed: true, + message: String::from("Plan mode override is already active for this worktree."), + settings_path: settings_path.display().to_string(), + state_path: state_path.display().to_string(), + previous_local_mode: state.previous_local_mode, + current_local_mode, + }); + } + clear_plan_mode_state(&state_path)?; + } + + if current_is_plan { + return Ok(PlanModeOutput { + success: true, + operation: String::from("enter"), + changed: false, + active: true, + managed: false, + message: String::from( + "Worktree-local plan mode is already enabled outside EnterPlanMode; leaving it unchanged.", + ), + settings_path: settings_path.display().to_string(), + state_path: state_path.display().to_string(), + previous_local_mode: None, + current_local_mode, + }); + } + + let state = PlanModeState { + had_local_override: current_local_mode.is_some(), + previous_local_mode: current_local_mode.clone(), + }; + write_plan_mode_state(&state_path, &state)?; + set_nested_value( + &mut document, + PERMISSION_DEFAULT_MODE_PATH, + Value::String(String::from("plan")), + ); + write_json_object(&settings_path, &document)?; + + Ok(PlanModeOutput { + success: true, + operation: String::from("enter"), + changed: true, + active: true, + managed: true, + message: String::from("Enabled worktree-local plan mode override."), + settings_path: settings_path.display().to_string(), + state_path: state_path.display().to_string(), + previous_local_mode: state.previous_local_mode, + current_local_mode: get_nested_value(&document, PERMISSION_DEFAULT_MODE_PATH).cloned(), + }) +} + +fn execute_exit_plan_mode(_input: ExitPlanModeInput) -> Result { + let settings_path = config_file_for_scope(ConfigScope::Settings)?; + let state_path = plan_mode_state_file()?; + let mut document = read_json_object(&settings_path)?; + let current_local_mode = get_nested_value(&document, PERMISSION_DEFAULT_MODE_PATH).cloned(); + let current_is_plan = + matches!(current_local_mode.as_ref(), Some(Value::String(value)) if value == "plan"); + + let Some(state) = read_plan_mode_state(&state_path)? else { + return Ok(PlanModeOutput { + success: true, + operation: String::from("exit"), + changed: false, + active: current_is_plan, + managed: false, + message: String::from("No EnterPlanMode override is active for this worktree."), + settings_path: settings_path.display().to_string(), + state_path: state_path.display().to_string(), + previous_local_mode: None, + current_local_mode, + }); + }; + + if !current_is_plan { + clear_plan_mode_state(&state_path)?; + return Ok(PlanModeOutput { + success: true, + operation: String::from("exit"), + changed: false, + active: false, + managed: false, + message: String::from( + "Cleared stale EnterPlanMode state because plan mode was already changed outside the tool.", + ), + settings_path: settings_path.display().to_string(), + state_path: state_path.display().to_string(), + previous_local_mode: state.previous_local_mode, + current_local_mode, + }); + } + + if state.had_local_override { + if let Some(previous_local_mode) = state.previous_local_mode.clone() { + set_nested_value( + &mut document, + PERMISSION_DEFAULT_MODE_PATH, + previous_local_mode, + ); + } else { + remove_nested_value(&mut document, PERMISSION_DEFAULT_MODE_PATH); + } + } else { + remove_nested_value(&mut document, PERMISSION_DEFAULT_MODE_PATH); + } + write_json_object(&settings_path, &document)?; + clear_plan_mode_state(&state_path)?; + + Ok(PlanModeOutput { + success: true, + operation: String::from("exit"), + changed: true, + active: false, + managed: false, + message: String::from("Restored the prior worktree-local plan mode setting."), + settings_path: settings_path.display().to_string(), + state_path: state_path.display().to_string(), + previous_local_mode: state.previous_local_mode, + current_local_mode: get_nested_value(&document, PERMISSION_DEFAULT_MODE_PATH).cloned(), + }) +} + +fn execute_structured_output( + input: StructuredOutputInput, +) -> Result { + if input.0.is_empty() { + return Err(String::from("structured output payload must not be empty")); + } + Ok(StructuredOutputResult { data: String::from("Structured output provided successfully"), structured_output: input.0, - } + }) } fn execute_repl(input: ReplInput) -> Result { if input.code.trim().is_empty() { return Err(String::from("code must not be empty")); } - let _ = input.timeout_ms; let runtime = resolve_repl_runtime(&input.language)?; let started = Instant::now(); - let output = Command::new(runtime.program) + let mut process = Command::new(runtime.program); + process .args(runtime.args) .arg(&input.code) - .output() - .map_err(|error| error.to_string())?; + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()); + + let output = if let Some(timeout_ms) = input.timeout_ms { + let mut child = process.spawn().map_err(|error| error.to_string())?; + loop { + if child + .try_wait() + .map_err(|error| error.to_string())? + .is_some() + { + break child + .wait_with_output() + .map_err(|error| error.to_string())?; + } + if started.elapsed() >= Duration::from_millis(timeout_ms) { + child.kill().map_err(|error| error.to_string())?; + child + .wait_with_output() + .map_err(|error| error.to_string())?; + return Err(format!( + "REPL execution exceeded timeout of {timeout_ms} ms" + )); + } + std::thread::sleep(Duration::from_millis(10)); + } + } else { + process + .spawn() + .map_err(|error| error.to_string())? + .wait_with_output() + .map_err(|error| error.to_string())? + }; Ok(ReplOutput { language: input.language, @@ -2773,7 +4987,14 @@ fn config_home_dir() -> Result { if let Ok(path) = std::env::var("CLAW_CONFIG_HOME") { return Ok(PathBuf::from(path)); } - let home = std::env::var("HOME").map_err(|_| String::from("HOME is not set"))?; + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .map_err(|_| { + String::from( + "HOME is not set (on Windows, set USERPROFILE or HOME, \ + or use CLAW_CONFIG_HOME to point directly at the config directory)", + ) + })?; Ok(PathBuf::from(home).join(".claw")) } @@ -2834,6 +5055,72 @@ fn set_nested_value(root: &mut serde_json::Map, path: &[&str], ne set_nested_value(map, rest, new_value); } +fn remove_nested_value(root: &mut serde_json::Map, path: &[&str]) -> bool { + let Some((first, rest)) = path.split_first() else { + return false; + }; + if rest.is_empty() { + return root.remove(*first).is_some(); + } + + let mut should_remove_parent = false; + let removed = root.get_mut(*first).is_some_and(|entry| { + entry.as_object_mut().is_some_and(|map| { + let removed = remove_nested_value(map, rest); + should_remove_parent = removed && map.is_empty(); + removed + }) + }); + + if should_remove_parent { + root.remove(*first); + } + + removed +} + +fn plan_mode_state_file() -> Result { + Ok(config_file_for_scope(ConfigScope::Settings)? + .parent() + .ok_or_else(|| String::from("settings.local.json has no parent directory"))? + .join("tool-state") + .join("plan-mode.json")) +} + +fn read_plan_mode_state(path: &Path) -> Result, String> { + match std::fs::read_to_string(path) { + Ok(contents) => { + if contents.trim().is_empty() { + return Ok(None); + } + serde_json::from_str(&contents) + .map(Some) + .map_err(|error| error.to_string()) + } + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(error) => Err(error.to_string()), + } +} + +fn write_plan_mode_state(path: &Path, state: &PlanModeState) -> Result<(), String> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|error| error.to_string())?; + } + std::fs::write( + path, + serde_json::to_string_pretty(state).map_err(|error| error.to_string())?, + ) + .map_err(|error| error.to_string()) +} + +fn clear_plan_mode_state(path: &Path) -> Result<(), String> { + match std::fs::remove_file(path) { + Ok(()) => Ok(()), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()), + Err(error) => Err(error.to_string()), + } +} + fn iso8601_timestamp() -> String { if let Ok(output) = Command::new("date") .args(["-u", "+%Y-%m-%dT%H:%M:%SZ"]) @@ -2849,6 +5136,9 @@ fn iso8601_timestamp() -> String { #[allow(clippy::needless_pass_by_value)] fn execute_powershell(input: PowerShellInput) -> std::io::Result { let _ = &input.description; + if let Some(output) = workspace_test_branch_preflight(&input.command) { + return Ok(output); + } let shell = detect_powershell_shell()?; execute_shell_command( shell, @@ -3069,6 +5359,9 @@ fn parse_skill_description(contents: &str) -> Option { None } +pub mod lane_completion; +pub mod pdf_extract; + #[cfg(test)] mod tests { use std::collections::BTreeMap; @@ -3076,18 +5369,26 @@ mod tests { use std::fs; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener}; - use std::path::PathBuf; + use std::path::{Path, PathBuf}; + use std::process::Command; use std::sync::{Arc, Mutex, OnceLock}; use std::thread; use std::time::Duration; use super::{ - agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn, - execute_tool, final_assistant_text, mvp_tool_specs, persist_agent_terminal_state, - push_output_block, AgentInput, AgentJob, SubagentToolExecutor, + agent_permission_policy, allowed_tools_for_subagent, classify_lane_failure, + derive_agent_state, execute_agent_with_spawn, execute_tool, final_assistant_text, + maybe_commit_provenance, mvp_tool_specs, permission_mode_from_plugin, + persist_agent_terminal_state, push_output_block, run_task_packet, AgentInput, AgentJob, + GlobalToolRegistry, LaneEventName, LaneFailureClass, ProviderRuntimeClient, + SubagentToolExecutor, }; use api::OutputContentBlock; - use runtime::{ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session}; + use runtime::ProviderFallbackConfig; + use runtime::{ + permission_enforcer::PermissionEnforcer, ApiRequest, AssistantEvent, ConversationRuntime, + PermissionMode, PermissionPolicy, RuntimeError, Session, TaskPacket, ToolExecutor, + }; use serde_json::json; fn env_lock() -> &'static Mutex<()> { @@ -3100,7 +5401,44 @@ mod tests { .duration_since(std::time::UNIX_EPOCH) .expect("time") .as_nanos(); - std::env::temp_dir().join(format!("claw-tools-{unique}-{name}")) + std::env::temp_dir().join(format!("clawd-tools-{unique}-{name}")) + } + + fn run_git(cwd: &Path, args: &[&str]) { + let status = Command::new("git") + .args(args) + .current_dir(cwd) + .status() + .unwrap_or_else(|error| panic!("git {} failed: {error}", args.join(" "))); + assert!( + status.success(), + "git {} exited with {status}", + args.join(" ") + ); + } + + fn init_git_repo(path: &Path) { + std::fs::create_dir_all(path).expect("create repo"); + run_git(path, &["init", "--quiet", "-b", "main"]); + run_git(path, &["config", "user.email", "tests@example.com"]); + run_git(path, &["config", "user.name", "Tools Tests"]); + std::fs::write(path.join("README.md"), "initial\n").expect("write readme"); + run_git(path, &["add", "README.md"]); + run_git(path, &["commit", "-m", "initial commit", "--quiet"]); + } + + fn commit_file(path: &Path, file: &str, contents: &str, message: &str) { + std::fs::write(path.join(file), contents).expect("write file"); + run_git(path, &["add", file]); + run_git(path, &["commit", "-m", message, "--quiet"]); + } + + fn permission_policy_for_mode(mode: PermissionMode) -> PermissionPolicy { + mvp_tool_specs() + .into_iter() + .fold(PermissionPolicy::new(mode), |policy, spec| { + policy.with_tool_requirement(spec.name, spec.required_permission) + }) } #[test] @@ -3121,9 +5459,15 @@ mod tests { assert!(names.contains(&"Sleep")); assert!(names.contains(&"SendUserMessage")); assert!(names.contains(&"Config")); + assert!(names.contains(&"EnterPlanMode")); + assert!(names.contains(&"ExitPlanMode")); assert!(names.contains(&"StructuredOutput")); assert!(names.contains(&"REPL")); assert!(names.contains(&"PowerShell")); + assert!(names.contains(&"WorkerCreate")); + assert!(names.contains(&"WorkerObserve")); + assert!(names.contains(&"WorkerAwaitReady")); + assert!(names.contains(&"WorkerSendPrompt")); } #[test] @@ -3132,6 +5476,710 @@ mod tests { assert!(error.contains("unsupported tool")); } + #[test] + fn worker_tools_gate_prompt_delivery_until_ready_and_support_auto_trust() { + let created = execute_tool( + "WorkerCreate", + &json!({ + "cwd": "/tmp/worktree/repo", + "trusted_roots": ["/tmp/worktree"] + }), + ) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"] + .as_str() + .expect("worker id") + .to_string(); + assert_eq!(created_output["status"], "spawning"); + assert_eq!(created_output["trust_auto_resolve"], true); + + let gated = execute_tool( + "WorkerSendPrompt", + &json!({ + "worker_id": worker_id, + "prompt": "ship the change" + }), + ) + .expect_err("prompt delivery before ready should fail"); + assert!(gated.contains("not ready for prompt delivery")); + + let observed = execute_tool( + "WorkerObserve", + &json!({ + "worker_id": created_output["worker_id"], + "screen_text": "Do you trust the files in this folder?\n1. Yes, proceed\n2. No" + }), + ) + .expect("WorkerObserve should auto-resolve trust"); + let observed_output: serde_json::Value = serde_json::from_str(&observed).expect("json"); + assert_eq!(observed_output["status"], "spawning"); + assert_eq!(observed_output["trust_gate_cleared"], true); + assert_eq!( + observed_output["events"][1]["payload"]["type"], + "trust_prompt" + ); + assert_eq!( + observed_output["events"][2]["payload"]["resolution"], + "auto_allowlisted" + ); + + let ready = execute_tool( + "WorkerObserve", + &json!({ + "worker_id": created_output["worker_id"], + "screen_text": "Ready for your input\n>" + }), + ) + .expect("WorkerObserve should mark worker ready"); + let ready_output: serde_json::Value = serde_json::from_str(&ready).expect("json"); + assert_eq!(ready_output["status"], "ready_for_prompt"); + + let await_ready = execute_tool( + "WorkerAwaitReady", + &json!({ + "worker_id": created_output["worker_id"] + }), + ) + .expect("WorkerAwaitReady should succeed"); + let await_ready_output: serde_json::Value = + serde_json::from_str(&await_ready).expect("json"); + assert_eq!(await_ready_output["ready"], true); + + let accepted = execute_tool( + "WorkerSendPrompt", + &json!({ + "worker_id": created_output["worker_id"], + "prompt": "ship the change" + }), + ) + .expect("WorkerSendPrompt should succeed after ready"); + let accepted_output: serde_json::Value = serde_json::from_str(&accepted).expect("json"); + assert_eq!(accepted_output["status"], "running"); + assert_eq!(accepted_output["prompt_delivery_attempts"], 1); + assert_eq!(accepted_output["prompt_in_flight"], true); + } + + #[test] + fn worker_create_merges_config_trusted_roots_without_per_call_override() { + use std::fs; + // Write a .claw/settings.json in a temp dir with trustedRoots + let worktree = temp_path("config-trust-worktree"); + let claw_dir = worktree.join(".claw"); + fs::create_dir_all(&claw_dir).expect("create .claw dir"); + // Use the actual OS temp dir so the worktree path matches the allowlist + let tmp_root = std::env::temp_dir().to_str().expect("utf-8").to_string(); + let settings = format!("{{\"trustedRoots\": [\"{tmp_root}\"]}}"); + fs::write(claw_dir.join("settings.json"), settings).expect("write settings"); + + // WorkerCreate with no per-call trusted_roots — config should supply them + let cwd = worktree.to_str().expect("valid utf-8").to_string(); + let created = execute_tool( + "WorkerCreate", + &json!({ + "cwd": cwd + // trusted_roots intentionally omitted + }), + ) + .expect("WorkerCreate should succeed"); + let output: serde_json::Value = serde_json::from_str(&created).expect("json"); + + // worktree is under /tmp, so config roots auto-resolve trust + assert_eq!( + output["trust_auto_resolve"], true, + "config-level trustedRoots should auto-resolve trust without per-call override" + ); + + fs::remove_dir_all(&worktree).ok(); + } + + #[test] + fn worker_terminate_sets_finished_status() { + // Create a worker in running state + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/tmp/terminate-test", "trusted_roots": ["/tmp"]}), + ) + .expect("WorkerCreate should succeed"); + let output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = output["worker_id"].as_str().expect("worker_id").to_string(); + + // Terminate + let terminated = execute_tool("WorkerTerminate", &json!({"worker_id": worker_id})) + .expect("WorkerTerminate should succeed"); + let term_output: serde_json::Value = serde_json::from_str(&terminated).expect("json"); + assert_eq!( + term_output["status"], "finished", + "terminated worker should be finished" + ); + assert_eq!( + term_output["prompt_in_flight"], false, + "prompt_in_flight should be cleared on termination" + ); + } + + #[test] + fn worker_restart_resets_to_spawning() { + // Create and advance worker to ready_for_prompt + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/tmp/restart-test", "trusted_roots": ["/tmp"]}), + ) + .expect("WorkerCreate should succeed"); + let output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = output["worker_id"].as_str().expect("worker_id").to_string(); + + // Advance to ready_for_prompt via observe + execute_tool( + "WorkerObserve", + &json!({"worker_id": worker_id, "screen_text": "Ready for input\n>"}), + ) + .expect("WorkerObserve should succeed"); + + // Restart + let restarted = execute_tool("WorkerRestart", &json!({"worker_id": worker_id})) + .expect("WorkerRestart should succeed"); + let restart_output: serde_json::Value = serde_json::from_str(&restarted).expect("json"); + assert_eq!( + restart_output["status"], "spawning", + "restarted worker should return to spawning" + ); + assert_eq!( + restart_output["prompt_in_flight"], false, + "prompt_in_flight should be cleared on restart" + ); + assert_eq!( + restart_output["trust_gate_cleared"], false, + "trust_gate_cleared should be reset on restart (re-trust required)" + ); + } + + #[test] + fn worker_get_returns_worker_state() { + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/tmp/worker-get-test", "trusted_roots": ["/tmp"]}), + ) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"].as_str().expect("worker_id"); + + let fetched = execute_tool("WorkerGet", &json!({"worker_id": worker_id})) + .expect("WorkerGet should succeed"); + let fetched_output: serde_json::Value = serde_json::from_str(&fetched).expect("json"); + assert_eq!(fetched_output["worker_id"], worker_id); + assert_eq!(fetched_output["status"], "spawning"); + assert_eq!(fetched_output["cwd"], "/tmp/worker-get-test"); + } + + #[test] + fn worker_get_on_unknown_id_returns_error() { + let result = execute_tool( + "WorkerGet", + &json!({"worker_id": "worker_nonexistent_get_00000000"}), + ); + assert!( + result.is_err(), + "WorkerGet on unknown id should return error" + ); + assert!( + result.unwrap_err().contains("worker not found"), + "error should mention worker not found" + ); + } + + #[test] + fn worker_await_ready_on_spawning_worker_returns_not_ready() { + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/tmp/worker-await-not-ready"}), + ) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"].as_str().expect("worker_id"); + + // Worker is still in spawning — await_ready should return not-ready snapshot + let snapshot = execute_tool("WorkerAwaitReady", &json!({"worker_id": worker_id})) + .expect("WorkerAwaitReady should succeed even when not ready"); + let snap_output: serde_json::Value = serde_json::from_str(&snapshot).expect("json"); + assert_eq!( + snap_output["ready"], false, + "WorkerAwaitReady on a spawning worker must return ready=false" + ); + assert_eq!(snap_output["worker_id"], worker_id); + } + + #[test] + fn worker_send_prompt_on_non_ready_worker_returns_error() { + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/tmp/worker-send-not-ready"}), + ) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"].as_str().expect("worker_id"); + + let result = execute_tool( + "WorkerSendPrompt", + &json!({"worker_id": worker_id, "prompt": "too early"}), + ); + assert!( + result.is_err(), + "WorkerSendPrompt on a non-ready worker should fail" + ); + } + + #[test] + fn recovery_loop_state_file_reflects_transitions() { + // End-to-end proof: .claw/worker-state.json reflects every transition + // through the stall-detect -> resolve-trust -> ready loop. + use std::fs; + + // Use a real temp CWD so state file can be written + let worktree = temp_path("recovery-loop-state"); + fs::create_dir_all(&worktree).expect("create worktree"); + let cwd = worktree.to_str().expect("utf-8").to_string(); + let state_path = worktree.join(".claw").join("worker-state.json"); + + // 1. Create worker WITHOUT trusted_roots + let created = execute_tool("WorkerCreate", &json!({"cwd": cwd})) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"] + .as_str() + .expect("worker_id") + .to_string(); + // State file should exist after create + assert!( + state_path.exists(), + "state file should be written after WorkerCreate" + ); + let state: serde_json::Value = + serde_json::from_str(&fs::read_to_string(&state_path).expect("read state")) + .expect("parse state"); + assert_eq!(state["status"], "spawning"); + assert_eq!(state["is_ready"], false); + assert!( + state["seconds_since_update"].is_number(), + "seconds_since_update must be present" + ); + + // 2. Force trust_required via observe + execute_tool( + "WorkerObserve", + &json!({"worker_id": worker_id, "screen_text": "Do you trust the files in this folder?"}), + ) + .expect("WorkerObserve should succeed"); + let state: serde_json::Value = + serde_json::from_str(&fs::read_to_string(&state_path).expect("read state")) + .expect("parse state"); + assert_eq!( + state["status"], "trust_required", + "state file must reflect trust_required stall" + ); + assert_eq!(state["is_ready"], false); + assert_eq!(state["trust_gate_cleared"], false); + assert!(state["seconds_since_update"].is_number()); + + // 3. WorkerResolveTrust -> state file reflects recovery + execute_tool("WorkerResolveTrust", &json!({"worker_id": worker_id})) + .expect("WorkerResolveTrust should succeed"); + let state: serde_json::Value = + serde_json::from_str(&fs::read_to_string(&state_path).expect("read state")) + .expect("parse state"); + assert_eq!( + state["status"], "spawning", + "state file must show spawning after trust resolved" + ); + assert_eq!(state["trust_gate_cleared"], true); + + // 4. Observe ready screen -> state file shows ready_for_prompt + execute_tool( + "WorkerObserve", + &json!({"worker_id": worker_id, "screen_text": "Ready for input\n>"}), + ) + .expect("WorkerObserve ready should succeed"); + let state: serde_json::Value = + serde_json::from_str(&fs::read_to_string(&state_path).expect("read state")) + .expect("parse state"); + assert_eq!( + state["status"], "ready_for_prompt", + "state file must show ready_for_prompt after ready screen" + ); + assert_eq!( + state["is_ready"], true, + "is_ready must be true in state file at ready_for_prompt" + ); + + fs::remove_dir_all(&worktree).ok(); + } + + #[test] + fn stall_detect_and_resolve_trust_end_to_end() { + // 1. Create worker WITHOUT trusted_roots so trust won't auto-resolve + let created = execute_tool("WorkerCreate", &json!({"cwd": "/no/trusted/root/here"})) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"] + .as_str() + .expect("worker_id") + .to_string(); + assert_eq!(created_output["trust_auto_resolve"], false); + + // 2. Observe trust prompt screen text -> worker stalls at trust_required + let stalled = execute_tool( + "WorkerObserve", + &json!({ + "worker_id": worker_id, + "screen_text": "Do you trust the files in this folder?\n[Allow] [Deny]" + }), + ) + .expect("WorkerObserve should succeed"); + let stalled_output: serde_json::Value = serde_json::from_str(&stalled).expect("json"); + assert_eq!( + stalled_output["status"], "trust_required", + "worker should stall at trust_required when trust prompt seen without allowlist" + ); + assert_eq!(stalled_output["trust_gate_cleared"], false); + // 3. Clawhip calls WorkerResolveTrust to unblock + let resolved = execute_tool("WorkerResolveTrust", &json!({"worker_id": worker_id})) + .expect("WorkerResolveTrust should succeed"); + let resolved_output: serde_json::Value = serde_json::from_str(&resolved).expect("json"); + assert_eq!( + resolved_output["status"], "spawning", + "worker should return to spawning after trust resolved" + ); + assert_eq!(resolved_output["trust_gate_cleared"], true); + + // 4. Ready screen text now advances worker normally + let ready = execute_tool( + "WorkerObserve", + &json!({ + "worker_id": worker_id, + "screen_text": "Ready for input\n>" + }), + ) + .expect("WorkerObserve should succeed after trust resolved"); + let ready_output: serde_json::Value = serde_json::from_str(&ready).expect("json"); + assert_eq!( + ready_output["status"], "ready_for_prompt", + "worker should reach ready_for_prompt after trust resolved and ready screen seen" + ); + } + + #[test] + fn stall_detect_and_restart_recovery_end_to_end() { + // Worker stalls at trust_required, clawhip restarts instead of resolving + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/no/trusted/root/restart-test"}), + ) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"] + .as_str() + .expect("worker_id") + .to_string(); + + // Force trust_required + let stalled = execute_tool( + "WorkerObserve", + &json!({ + "worker_id": worker_id, + "screen_text": "trust this folder? [Yes] [No]" + }), + ) + .expect("WorkerObserve should succeed"); + let stalled_output: serde_json::Value = serde_json::from_str(&stalled).expect("json"); + assert_eq!(stalled_output["status"], "trust_required"); + + // WorkerRestart resets the worker + let restarted = execute_tool("WorkerRestart", &json!({"worker_id": worker_id})) + .expect("WorkerRestart should succeed"); + let restarted_output: serde_json::Value = serde_json::from_str(&restarted).expect("json"); + assert_eq!( + restarted_output["status"], "spawning", + "restarted worker should be back at spawning" + ); + assert_eq!( + restarted_output["trust_gate_cleared"], false, + "restart clears trust — next observe loop must re-acquire trust" + ); + } + + #[test] + fn worker_terminate_on_unknown_id_returns_error() { + let result = execute_tool( + "WorkerTerminate", + &json!({"worker_id": "worker_nonexistent_00000000"}), + ); + assert!(result.is_err(), "terminating unknown worker should fail"); + assert!( + result.unwrap_err().contains("worker not found"), + "error should mention worker not found" + ); + } + + #[test] + fn worker_restart_on_unknown_id_returns_error() { + let result = execute_tool( + "WorkerRestart", + &json!({"worker_id": "worker_nonexistent_00000001"}), + ); + assert!(result.is_err(), "restarting unknown worker should fail"); + assert!( + result.unwrap_err().contains("worker not found"), + "error should mention worker not found" + ); + } + + #[test] + fn worker_observe_completion_success_finish_sets_finished_status() { + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/tmp/observe-completion-test", "trusted_roots": ["/tmp"]}), + ) + .expect("WorkerCreate should succeed"); + let output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = output["worker_id"].as_str().expect("worker_id").to_string(); + + let completed = execute_tool( + "WorkerObserveCompletion", + &json!({ + "worker_id": worker_id, + "finish_reason": "end_turn", + "tokens_output": 512 + }), + ) + .expect("WorkerObserveCompletion should succeed"); + let completed_output: serde_json::Value = serde_json::from_str(&completed).expect("json"); + assert_eq!(completed_output["status"], "finished"); + assert_eq!(completed_output["prompt_in_flight"], false); + } + + #[test] + fn worker_observe_completion_degraded_provider_sets_failed_status() { + let created = execute_tool( + "WorkerCreate", + &json!({"cwd": "/tmp/observe-degraded-test", "trusted_roots": ["/tmp"]}), + ) + .expect("WorkerCreate should succeed"); + let output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = output["worker_id"].as_str().expect("worker_id").to_string(); + + // finish=unknown + 0 tokens = degraded provider classification + let failed = execute_tool( + "WorkerObserveCompletion", + &json!({ + "worker_id": worker_id, + "finish_reason": "unknown", + "tokens_output": 0 + }), + ) + .expect("WorkerObserveCompletion should succeed"); + let failed_output: serde_json::Value = serde_json::from_str(&failed).expect("json"); + assert_eq!( + failed_output["status"], "failed", + "finish=unknown + 0 tokens should classify as provider failure" + ); + assert_eq!(failed_output["prompt_in_flight"], false); + // last_error should be set with provider failure message + assert!( + !failed_output["last_error"].is_null(), + "last_error should be populated for provider failure" + ); + } + + #[test] + fn worker_tools_detect_misdelivery_and_arm_prompt_replay() { + let created = execute_tool( + "WorkerCreate", + &json!({ + "cwd": "/tmp/repo/worker-misdelivery" + }), + ) + .expect("WorkerCreate should succeed"); + let created_output: serde_json::Value = serde_json::from_str(&created).expect("json"); + let worker_id = created_output["worker_id"] + .as_str() + .expect("worker id") + .to_string(); + + execute_tool( + "WorkerObserve", + &json!({ + "worker_id": worker_id, + "screen_text": "Ready for input\n>" + }), + ) + .expect("worker should become ready"); + + execute_tool( + "WorkerSendPrompt", + &json!({ + "worker_id": worker_id, + "prompt": "Investigate flaky boot" + }), + ) + .expect("prompt send should succeed"); + + let recovered = execute_tool( + "WorkerObserve", + &json!({ + "worker_id": worker_id, + "screen_text": "% Investigate flaky boot\nzsh: command not found: Investigate" + }), + ) + .expect("misdelivery observe should succeed"); + let recovered_output: serde_json::Value = serde_json::from_str(&recovered).expect("json"); + assert_eq!(recovered_output["status"], "ready_for_prompt"); + assert_eq!(recovered_output["last_error"]["kind"], "prompt_delivery"); + assert_eq!(recovered_output["replay_prompt"], "Investigate flaky boot"); + assert_eq!( + recovered_output["events"][3]["payload"]["observed_target"], + "shell" + ); + assert_eq!( + recovered_output["events"][4]["payload"]["recovery_armed"], + true + ); + + let replayed = execute_tool( + "WorkerSendPrompt", + &json!({ + "worker_id": worker_id + }), + ) + .expect("WorkerSendPrompt should replay recovered prompt"); + let replayed_output: serde_json::Value = serde_json::from_str(&replayed).expect("json"); + assert_eq!(replayed_output["status"], "running"); + assert_eq!(replayed_output["prompt_delivery_attempts"], 2); + assert_eq!(replayed_output["prompt_in_flight"], true); + } + + #[test] + fn global_tool_registry_denies_blocked_tool_before_dispatch() { + // given + let policy = permission_policy_for_mode(PermissionMode::ReadOnly); + let registry = GlobalToolRegistry::builtin().with_enforcer(PermissionEnforcer::new(policy)); + + // when + let error = registry + .execute( + "write_file", + &json!({ + "path": "blocked.txt", + "content": "blocked" + }), + ) + .expect_err("write tool should be denied before dispatch"); + + // then + assert!(error.contains("requires workspace-write permission")); + } + + #[test] + fn subagent_tool_executor_denies_blocked_tool_before_dispatch() { + // given + let policy = permission_policy_for_mode(PermissionMode::ReadOnly); + let mut executor = SubagentToolExecutor::new(BTreeSet::from([String::from("write_file")])) + .with_enforcer(PermissionEnforcer::new(policy)); + + // when + let error = executor + .execute( + "write_file", + &json!({ + "path": "blocked.txt", + "content": "blocked" + }) + .to_string(), + ) + .expect_err("subagent write tool should be denied before dispatch"); + + // then + assert!(error + .to_string() + .contains("requires workspace-write permission")); + } + + #[test] + fn permission_mode_from_plugin_rejects_invalid_inputs() { + let unknown_permission = permission_mode_from_plugin("admin") + .expect_err("unknown plugin permission should fail"); + assert!(unknown_permission.contains("unsupported plugin permission: admin")); + + let empty_permission = + permission_mode_from_plugin("").expect_err("empty plugin permission should fail"); + assert!(empty_permission.contains("unsupported plugin permission: ")); + } + + #[test] + fn runtime_tools_extend_registry_definitions_permissions_and_search() { + let registry = GlobalToolRegistry::builtin() + .with_runtime_tools(vec![super::RuntimeToolDefinition { + name: "mcp__demo__echo".to_string(), + description: Some("Echo text from the demo MCP server".to_string()), + input_schema: json!({ + "type": "object", + "properties": { "text": { "type": "string" } }, + "additionalProperties": false + }), + required_permission: runtime::PermissionMode::ReadOnly, + }]) + .expect("runtime tools should register"); + + let allowed = registry + .normalize_allowed_tools(&["mcp__demo__echo".to_string()]) + .expect("runtime tool should be allow-listable") + .expect("allow-list should be populated"); + assert!(allowed.contains("mcp__demo__echo")); + + let definitions = registry.definitions(Some(&allowed)); + assert_eq!(definitions.len(), 1); + assert_eq!(definitions[0].name, "mcp__demo__echo"); + + let permissions = registry + .permission_specs(Some(&allowed)) + .expect("runtime tool permissions should resolve"); + assert_eq!( + permissions, + vec![( + "mcp__demo__echo".to_string(), + runtime::PermissionMode::ReadOnly + )] + ); + + let search = registry.search( + "demo echo", + 5, + Some(vec!["pending-server".to_string()]), + Some(runtime::McpDegradedReport::new( + vec!["demo".to_string()], + vec![runtime::McpFailedServer { + server_name: "pending-server".to_string(), + phase: runtime::McpLifecyclePhase::ToolDiscovery, + error: runtime::McpErrorSurface::new( + runtime::McpLifecyclePhase::ToolDiscovery, + Some("pending-server".to_string()), + "tool discovery failed", + BTreeMap::new(), + true, + ), + }], + vec!["mcp__demo__echo".to_string()], + vec!["mcp__demo__echo".to_string()], + )), + ); + let output = serde_json::to_value(search).expect("search output should serialize"); + assert_eq!(output["matches"][0], "mcp__demo__echo"); + assert_eq!(output["pending_mcp_servers"][0], "pending-server"); + assert_eq!( + output["mcp_degraded"]["failed_servers"][0]["phase"], + "tool_discovery" + ); + } + #[test] fn web_fetch_returns_prompt_aware_summary() { let server = TestServer::spawn(Arc::new(|request_line: &str| { @@ -3208,6 +6256,14 @@ mod tests { #[test] fn web_search_extracts_and_filters_results() { + // Serialize env-var mutation so this test cannot race with the sibling + // web_search_handles_generic_links_and_invalid_base_url test that also + // sets CLAWD_WEB_SEARCH_BASE_URL. Without the lock, parallel test + // runners can interleave the set/remove calls and cause assertion + // failures on the wrong port. + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); let server = TestServer::spawn(Arc::new(|request_line: &str| { assert!(request_line.contains("GET /search?q=rust+web+search ")); HttpResponse::html( @@ -3223,7 +6279,7 @@ mod tests { })); std::env::set_var( - "CLAW_WEB_SEARCH_BASE_URL", + "CLAWD_WEB_SEARCH_BASE_URL", format!("http://{}/search", server.addr()), ); let result = execute_tool( @@ -3235,7 +6291,7 @@ mod tests { }), ) .expect("WebSearch should succeed"); - std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); assert_eq!(output["query"], "rust web search"); @@ -3271,7 +6327,7 @@ mod tests { })); std::env::set_var( - "CLAW_WEB_SEARCH_BASE_URL", + "CLAWD_WEB_SEARCH_BASE_URL", format!("http://{}/fallback", server.addr()), ); let result = execute_tool( @@ -3281,7 +6337,7 @@ mod tests { }), ) .expect("WebSearch fallback parsing should succeed"); - std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); let results = output["results"].as_array().expect("results array"); @@ -3294,10 +6350,10 @@ mod tests { assert_eq!(content[0]["url"], "https://example.com/one"); assert_eq!(content[1]["url"], "https://docs.rs/tokio"); - std::env::set_var("CLAW_WEB_SEARCH_BASE_URL", "://bad-base-url"); + std::env::set_var("CLAWD_WEB_SEARCH_BASE_URL", "://bad-base-url"); let error = execute_tool("WebSearch", &json!({ "query": "generic links" })) .expect_err("invalid base URL should fail"); - std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); assert!(error.contains("relative URL without a base") || error.contains("empty host")); } @@ -3364,7 +6420,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let path = temp_path("todos.json"); - std::env::set_var("CLAW_TODO_STORE", &path); + std::env::set_var("CLAWD_TODO_STORE", &path); let first = execute_tool( "TodoWrite", @@ -3390,7 +6446,7 @@ mod tests { }), ) .expect("TodoWrite should succeed"); - std::env::remove_var("CLAW_TODO_STORE"); + std::env::remove_var("CLAWD_TODO_STORE"); let _ = std::fs::remove_file(path); let second_output: serde_json::Value = serde_json::from_str(&second).expect("valid json"); @@ -3411,7 +6467,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let path = temp_path("todos-errors.json"); - std::env::set_var("CLAW_TODO_STORE", &path); + std::env::set_var("CLAWD_TODO_STORE", &path); let empty = execute_tool("TodoWrite", &json!({ "todos": [] })) .expect_err("empty todos should fail"); @@ -3451,7 +6507,7 @@ mod tests { }), ) .expect("completed todos should succeed"); - std::env::remove_var("CLAW_TODO_STORE"); + std::env::remove_var("CLAWD_TODO_STORE"); let _ = fs::remove_file(path); let output: serde_json::Value = serde_json::from_str(&nudge).expect("valid json"); @@ -3460,9 +6516,18 @@ mod tests { #[test] fn skill_loads_local_skill_prompt() { - let _guard = env_lock() - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let _guard = env_lock().lock().expect("env lock should acquire"); + let home = temp_path("skills-home"); + let skill_dir = home.join(".agents").join("skills").join("help"); + fs::create_dir_all(&skill_dir).expect("skill dir should exist"); + fs::write( + skill_dir.join("SKILL.md"), + "# help\n\nGuide on using oh-my-codex plugin\n", + ) + .expect("skill file should exist"); + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &home); + let result = execute_tool( "Skill", &json!({ @@ -3497,6 +6562,356 @@ mod tests { .as_str() .expect("path") .ends_with("/help/SKILL.md")); + + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + fs::remove_dir_all(home).expect("temp home should clean up"); + } + + #[test] + fn skill_resolves_project_local_skills_and_legacy_commands() { + let _guard = env_lock().lock().expect("env lock should acquire"); + let root = temp_path("project-skills"); + let skill_dir = root.join(".claw").join("skills").join("plan"); + let command_dir = root.join(".claw").join("commands"); + fs::create_dir_all(&skill_dir).expect("skill dir should exist"); + fs::create_dir_all(&command_dir).expect("command dir should exist"); + fs::write( + skill_dir.join("SKILL.md"), + "---\nname: plan\ndescription: Project planning guidance\n---\n\n# plan\n", + ) + .expect("skill file should exist"); + fs::write( + command_dir.join("handoff.md"), + "---\nname: handoff\ndescription: Legacy handoff guidance\n---\n\n# handoff\n", + ) + .expect("command file should exist"); + + let original_dir = std::env::current_dir().expect("cwd"); + std::env::set_current_dir(&root).expect("set cwd"); + + let skill_result = execute_tool("Skill", &json!({ "skill": "$plan" })) + .expect("project-local skill should resolve"); + let skill_output: serde_json::Value = + serde_json::from_str(&skill_result).expect("valid json"); + assert!(skill_output["path"] + .as_str() + .expect("path") + .ends_with(".claw/skills/plan/SKILL.md")); + + let command_result = execute_tool("Skill", &json!({ "skill": "/handoff" })) + .expect("legacy command should resolve"); + let command_output: serde_json::Value = + serde_json::from_str(&command_result).expect("valid json"); + assert!(command_output["path"] + .as_str() + .expect("path") + .ends_with(".claw/commands/handoff.md")); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + fs::remove_dir_all(root).expect("temp project should clean up"); + } + + #[test] + fn skill_loads_project_local_claude_skill_prompt() { + let _guard = env_lock().lock().expect("env lock should acquire"); + let root = temp_path("project-skills"); + let home = root.join("home"); + let workspace = root.join("workspace"); + let nested = workspace.join("nested"); + let skill_dir = workspace.join(".claude").join("skills").join("trace"); + fs::create_dir_all(&skill_dir).expect("skill dir should exist"); + fs::create_dir_all(&nested).expect("nested cwd should exist"); + fs::write( + skill_dir.join("SKILL.md"), + "---\nname: trace\ndescription: Project-local trace helper\n---\n# trace\n", + ) + .expect("skill file should exist"); + + let original_home = std::env::var("HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_codex_home = std::env::var("CODEX_HOME").ok(); + let original_dir = std::env::current_dir().expect("cwd"); + std::env::set_var("HOME", &home); + std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::remove_var("CODEX_HOME"); + std::env::set_current_dir(&nested).expect("set cwd"); + + let result = execute_tool("Skill", &json!({ "skill": "trace" })) + .expect("project-local skill should resolve"); + + let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); + assert!(output["path"] + .as_str() + .expect("path") + .ends_with(".claude/skills/trace/SKILL.md")); + assert_eq!(output["description"], "Project-local trace helper"); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + match original_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_codex_home { + Some(value) => std::env::set_var("CODEX_HOME", value), + None => std::env::remove_var("CODEX_HOME"), + } + fs::remove_dir_all(root).expect("temp tree should clean up"); + } + + #[test] + fn skill_loads_project_local_omc_and_agents_skill_prompts() { + let _guard = env_lock().lock().expect("env lock should acquire"); + let root = temp_path("project-omc-skills"); + let home = root.join("home"); + let workspace = root.join("workspace"); + let nested = workspace.join("nested"); + let omc_skill_dir = workspace.join(".omc").join("skills").join("hud"); + let agents_skill_dir = workspace.join(".agents").join("skills").join("trace"); + fs::create_dir_all(&omc_skill_dir).expect("omc skill dir should exist"); + fs::create_dir_all(&agents_skill_dir).expect("agents skill dir should exist"); + fs::create_dir_all(&nested).expect("nested cwd should exist"); + fs::write( + omc_skill_dir.join("SKILL.md"), + "---\nname: hud\ndescription: Project-local OMC HUD helper\n---\n# hud\n", + ) + .expect("omc skill file should exist"); + fs::write( + agents_skill_dir.join("SKILL.md"), + "---\nname: trace\ndescription: Project-local agents compatibility helper\n---\n# trace\n", + ) + .expect("agents skill file should exist"); + + let original_home = std::env::var("HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_codex_home = std::env::var("CODEX_HOME").ok(); + let original_dir = std::env::current_dir().expect("cwd"); + std::env::set_var("HOME", &home); + std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::remove_var("CODEX_HOME"); + std::env::set_current_dir(&nested).expect("set cwd"); + + let omc_result = + execute_tool("Skill", &json!({ "skill": "hud" })).expect("omc skill should resolve"); + let agents_result = execute_tool("Skill", &json!({ "skill": "trace" })) + .expect("agents skill should resolve"); + + let omc_output: serde_json::Value = serde_json::from_str(&omc_result).expect("valid json"); + let agents_output: serde_json::Value = + serde_json::from_str(&agents_result).expect("valid json"); + assert!(omc_output["path"] + .as_str() + .expect("path") + .ends_with(".omc/skills/hud/SKILL.md")); + assert_eq!(omc_output["description"], "Project-local OMC HUD helper"); + assert!(agents_output["path"] + .as_str() + .expect("path") + .ends_with(".agents/skills/trace/SKILL.md")); + assert_eq!( + agents_output["description"], + "Project-local agents compatibility helper" + ); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + match original_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_codex_home { + Some(value) => std::env::set_var("CODEX_HOME", value), + None => std::env::remove_var("CODEX_HOME"), + } + fs::remove_dir_all(root).expect("temp tree should clean up"); + } + + #[test] + fn skill_loads_learned_skill_from_claude_config_dir() { + let _guard = env_lock().lock().expect("env lock should acquire"); + let root = temp_path("claude-config-learned-skill"); + let home = root.join("home"); + let claude_config_dir = root.join("claude-config"); + let learned_skill_dir = claude_config_dir + .join("skills") + .join("omc-learned") + .join("learned"); + fs::create_dir_all(&learned_skill_dir).expect("learned skill dir should exist"); + fs::write( + learned_skill_dir.join("SKILL.md"), + "---\nname: learned\ndescription: Learned OMC skill\n---\n# learned\n", + ) + .expect("learned skill file should exist"); + + let original_home = std::env::var("HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_codex_home = std::env::var("CODEX_HOME").ok(); + let original_claude_config_dir = std::env::var("CLAUDE_CONFIG_DIR").ok(); + std::env::set_var("HOME", &home); + std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::remove_var("CODEX_HOME"); + std::env::set_var("CLAUDE_CONFIG_DIR", &claude_config_dir); + + let result = execute_tool("Skill", &json!({ "skill": "learned" })) + .expect("learned skill should resolve"); + + let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); + assert!(output["path"] + .as_str() + .expect("path") + .ends_with("skills/omc-learned/learned/SKILL.md")); + assert_eq!(output["description"], "Learned OMC skill"); + + match original_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_codex_home { + Some(value) => std::env::set_var("CODEX_HOME", value), + None => std::env::remove_var("CODEX_HOME"), + } + match original_claude_config_dir { + Some(value) => std::env::set_var("CLAUDE_CONFIG_DIR", value), + None => std::env::remove_var("CLAUDE_CONFIG_DIR"), + } + fs::remove_dir_all(root).expect("temp tree should clean up"); + } + + #[test] + fn skill_loads_direct_skill_and_legacy_command_from_claude_config_dir() { + let _guard = env_lock().lock().expect("env lock should acquire"); + let root = temp_path("claude-config-direct-skill"); + let home = root.join("home"); + let claude_config_dir = root.join("claude-config"); + let skill_dir = claude_config_dir.join("skills").join("statusline"); + let command_dir = claude_config_dir.join("commands"); + fs::create_dir_all(&skill_dir).expect("direct skill dir should exist"); + fs::create_dir_all(&command_dir).expect("command dir should exist"); + fs::write( + skill_dir.join("SKILL.md"), + "---\nname: statusline\ndescription: Claude config skill\n---\n# statusline\n", + ) + .expect("direct skill file should exist"); + fs::write( + command_dir.join("doctor-check.md"), + "---\nname: doctor-check\ndescription: Claude config command\n---\n# doctor-check\n", + ) + .expect("direct command file should exist"); + + let original_home = std::env::var("HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_codex_home = std::env::var("CODEX_HOME").ok(); + let original_claude_config_dir = std::env::var("CLAUDE_CONFIG_DIR").ok(); + std::env::set_var("HOME", &home); + std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::remove_var("CODEX_HOME"); + std::env::set_var("CLAUDE_CONFIG_DIR", &claude_config_dir); + + let direct_skill = + execute_tool("Skill", &json!({ "skill": "statusline" })).expect("direct skill"); + let direct_skill_output: serde_json::Value = + serde_json::from_str(&direct_skill).expect("valid skill json"); + assert!(direct_skill_output["path"] + .as_str() + .expect("path") + .ends_with("skills/statusline/SKILL.md")); + assert_eq!(direct_skill_output["description"], "Claude config skill"); + + let legacy_command = + execute_tool("Skill", &json!({ "skill": "doctor-check" })).expect("direct command"); + let legacy_command_output: serde_json::Value = + serde_json::from_str(&legacy_command).expect("valid command json"); + assert!(legacy_command_output["path"] + .as_str() + .expect("path") + .ends_with("commands/doctor-check.md")); + assert_eq!( + legacy_command_output["description"], + "Claude config command" + ); + + match original_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_codex_home { + Some(value) => std::env::set_var("CODEX_HOME", value), + None => std::env::remove_var("CODEX_HOME"), + } + match original_claude_config_dir { + Some(value) => std::env::set_var("CLAUDE_CONFIG_DIR", value), + None => std::env::remove_var("CLAUDE_CONFIG_DIR"), + } + fs::remove_dir_all(root).expect("temp tree should clean up"); + } + + #[test] + fn skill_loads_project_local_legacy_command_markdown() { + let _guard = env_lock().lock().expect("env lock should acquire"); + let root = temp_path("project-legacy-command"); + let home = root.join("home"); + let workspace = root.join("workspace"); + let nested = workspace.join("nested"); + let command_dir = workspace.join(".claude").join("commands"); + fs::create_dir_all(&command_dir).expect("legacy command dir should exist"); + fs::create_dir_all(&nested).expect("nested cwd should exist"); + fs::write( + command_dir.join("team.md"), + "---\nname: team\ndescription: Legacy team workflow\n---\n# team\n", + ) + .expect("legacy command file should exist"); + + let original_home = std::env::var("HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_codex_home = std::env::var("CODEX_HOME").ok(); + let original_dir = std::env::current_dir().expect("cwd"); + std::env::set_var("HOME", &home); + std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::remove_var("CODEX_HOME"); + std::env::set_current_dir(&nested).expect("set cwd"); + + let result = execute_tool("Skill", &json!({ "skill": "team" })) + .expect("legacy command markdown should resolve"); + + let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); + assert!(output["path"] + .as_str() + .expect("path") + .ends_with(".claude/commands/team.md")); + assert_eq!(output["description"], "Legacy team workflow"); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + match original_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + match original_codex_home { + Some(value) => std::env::set_var("CODEX_HOME", value), + None => std::env::remove_var("CODEX_HOME"), + } + fs::remove_dir_all(root).expect("temp tree should clean up"); } #[test] @@ -3538,7 +6953,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let dir = temp_path("agent-store"); - std::env::set_var("CLAW_AGENT_STORE", &dir); + std::env::set_var("CLAWD_AGENT_STORE", &dir); let captured = Arc::new(Mutex::new(None::)); let captured_for_spawn = Arc::clone(&captured); @@ -3558,7 +6973,7 @@ mod tests { }, ) .expect("Agent should succeed"); - std::env::remove_var("CLAW_AGENT_STORE"); + std::env::remove_var("CLAWD_AGENT_STORE"); assert_eq!(manifest.name, "ship-audit"); assert_eq!(manifest.subagent_type.as_deref(), Some("Explore")); @@ -3569,10 +6984,15 @@ mod tests { let contents = std::fs::read_to_string(&manifest.output_file).expect("agent file exists"); let manifest_contents = std::fs::read_to_string(&manifest.manifest_file).expect("manifest file exists"); + let manifest_json: serde_json::Value = + serde_json::from_str(&manifest_contents).expect("manifest should be valid json"); assert!(contents.contains("Audit the branch")); assert!(contents.contains("Check tests and outstanding work.")); assert!(manifest_contents.contains("\"subagentType\": \"Explore\"")); assert!(manifest_contents.contains("\"status\": \"running\"")); + assert_eq!(manifest_json["laneEvents"][0]["event"], "lane.started"); + assert_eq!(manifest_json["laneEvents"][0]["status"], "running"); + assert!(manifest_json["currentBlocker"].is_null()); let captured_job = captured .lock() .unwrap_or_else(std::sync::PoisonError::into_inner) @@ -3610,12 +7030,13 @@ mod tests { } #[test] + #[allow(clippy::too_many_lines)] fn agent_fake_runner_can_persist_completion_and_failure() { let _guard = env_lock() .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let dir = temp_path("agent-runner"); - std::env::set_var("CLAW_AGENT_STORE", &dir); + std::env::set_var("CLAWD_AGENT_STORE", &dir); let completed = execute_agent_with_spawn( AgentInput { @@ -3629,7 +7050,7 @@ mod tests { persist_agent_terminal_state( &job.manifest, "completed", - Some("Finished successfully"), + Some("Finished successfully in commit abc1234"), None, ) }, @@ -3638,10 +7059,33 @@ mod tests { let completed_manifest = std::fs::read_to_string(&completed.manifest_file) .expect("completed manifest should exist"); + let completed_manifest_json: serde_json::Value = + serde_json::from_str(&completed_manifest).expect("completed manifest json"); let completed_output = std::fs::read_to_string(&completed.output_file).expect("completed output should exist"); assert!(completed_manifest.contains("\"status\": \"completed\"")); assert!(completed_output.contains("Finished successfully")); + assert_eq!( + completed_manifest_json["laneEvents"][0]["event"], + "lane.started" + ); + assert_eq!( + completed_manifest_json["laneEvents"][1]["event"], + "lane.finished" + ); + assert_eq!( + completed_manifest_json["laneEvents"][2]["event"], + "lane.commit.created" + ); + assert_eq!( + completed_manifest_json["laneEvents"][2]["data"]["commit"], + "abc1234" + ); + assert!(completed_manifest_json["currentBlocker"].is_null()); + assert_eq!( + completed_manifest_json["derivedState"], + "finished_cleanable" + ); let failed = execute_agent_with_spawn( AgentInput { @@ -3656,7 +7100,7 @@ mod tests { &job.manifest, "failed", None, - Some(String::from("simulated failure")), + Some(String::from("tool failed: simulated failure")), ) }, ) @@ -3664,11 +7108,31 @@ mod tests { let failed_manifest = std::fs::read_to_string(&failed.manifest_file).expect("failed manifest should exist"); + let failed_manifest_json: serde_json::Value = + serde_json::from_str(&failed_manifest).expect("failed manifest json"); let failed_output = std::fs::read_to_string(&failed.output_file).expect("failed output should exist"); assert!(failed_manifest.contains("\"status\": \"failed\"")); assert!(failed_manifest.contains("simulated failure")); assert!(failed_output.contains("simulated failure")); + assert!(failed_output.contains("failure_class: tool_runtime")); + assert_eq!( + failed_manifest_json["currentBlocker"]["failureClass"], + "tool_runtime" + ); + assert_eq!( + failed_manifest_json["laneEvents"][1]["event"], + "lane.blocked" + ); + assert_eq!( + failed_manifest_json["laneEvents"][2]["event"], + "lane.failed" + ); + assert_eq!( + failed_manifest_json["laneEvents"][2]["failureClass"], + "tool_runtime" + ); + assert_eq!(failed_manifest_json["derivedState"], "truly_idle"); let spawn_error = execute_agent_with_spawn( AgentInput { @@ -3694,13 +7158,137 @@ mod tests { .then_some(contents) }) .expect("failed manifest should still be written"); + let spawn_error_manifest_json: serde_json::Value = + serde_json::from_str(&spawn_error_manifest).expect("spawn error manifest json"); assert!(spawn_error_manifest.contains("\"status\": \"failed\"")); assert!(spawn_error_manifest.contains("thread creation failed")); + assert_eq!( + spawn_error_manifest_json["currentBlocker"]["failureClass"], + "infra" + ); + assert_eq!(spawn_error_manifest_json["derivedState"], "truly_idle"); - std::env::remove_var("CLAW_AGENT_STORE"); + std::env::remove_var("CLAWD_AGENT_STORE"); let _ = std::fs::remove_dir_all(dir); } + #[test] + fn agent_state_classification_covers_finished_and_specific_blockers() { + assert_eq!(derive_agent_state("running", None, None, None), "working"); + assert_eq!( + derive_agent_state("completed", Some("done"), None, None), + "finished_cleanable" + ); + assert_eq!( + derive_agent_state("completed", None, None, None), + "finished_pending_report" + ); + assert_eq!( + derive_agent_state("failed", None, Some("mcp handshake timed out"), None), + "degraded_mcp" + ); + assert_eq!( + derive_agent_state( + "failed", + None, + Some("background terminal still running"), + None + ), + "blocked_background_job" + ); + assert_eq!( + derive_agent_state("failed", None, Some("merge conflict while rebasing"), None), + "blocked_merge_conflict" + ); + assert_eq!( + derive_agent_state( + "failed", + None, + Some("transport interrupted after partial progress"), + None + ), + "interrupted_transport" + ); + } + + #[test] + fn commit_provenance_is_extracted_from_agent_results() { + let provenance = maybe_commit_provenance(Some("landed as commit deadbee with clean push")) + .expect("commit provenance"); + assert_eq!(provenance.commit, "deadbee"); + assert_eq!(provenance.canonical_commit.as_deref(), Some("deadbee")); + assert_eq!(provenance.lineage, vec!["deadbee".to_string()]); + } + #[test] + fn lane_failure_taxonomy_normalizes_common_blockers() { + let cases = [ + ( + "prompt delivery failed in tmux pane", + LaneFailureClass::PromptDelivery, + ), + ( + "trust prompt is still blocking startup", + LaneFailureClass::TrustGate, + ), + ( + "branch stale against main after divergence", + LaneFailureClass::BranchDivergence, + ), + ( + "compile failed after cargo check", + LaneFailureClass::Compile, + ), + ("targeted tests failed", LaneFailureClass::Test), + ("plugin bootstrap failed", LaneFailureClass::PluginStartup), + ("mcp handshake timed out", LaneFailureClass::McpHandshake), + ( + "mcp startup failed before listing tools", + LaneFailureClass::McpStartup, + ), + ( + "gateway routing rejected the request", + LaneFailureClass::GatewayRouting, + ), + ( + "tool failed: denied tool execution from hook", + LaneFailureClass::ToolRuntime, + ), + ("thread creation failed", LaneFailureClass::Infra), + ]; + + for (message, expected) in cases { + assert_eq!(classify_lane_failure(message), expected, "{message}"); + } + } + + #[test] + fn lane_event_schema_serializes_to_canonical_names() { + let cases = [ + (LaneEventName::Started, "lane.started"), + (LaneEventName::Ready, "lane.ready"), + (LaneEventName::PromptMisdelivery, "lane.prompt_misdelivery"), + (LaneEventName::Blocked, "lane.blocked"), + (LaneEventName::Red, "lane.red"), + (LaneEventName::Green, "lane.green"), + (LaneEventName::CommitCreated, "lane.commit.created"), + (LaneEventName::PrOpened, "lane.pr.opened"), + (LaneEventName::MergeReady, "lane.merge.ready"), + (LaneEventName::Finished, "lane.finished"), + (LaneEventName::Failed, "lane.failed"), + ( + LaneEventName::BranchStaleAgainstMain, + "branch.stale_against_main", + ), + ]; + + for (event, expected) in cases { + assert_eq!( + serde_json::to_value(event).expect("serialize lane event"), + json!(expected) + ); + } + } + #[test] fn agent_tool_subset_mapping_is_expected() { let general = allowed_tools_for_subagent("general-purpose"); @@ -3752,7 +7340,7 @@ mod tests { AssistantEvent::MessageStop, ]) } - _ => panic!("unexpected mock stream call"), + _ => unreachable!("extra mock stream call"), } } } @@ -3982,6 +7570,90 @@ mod tests { assert_eq!(background_output["noOutputExpected"], true); } + #[test] + fn bash_workspace_tests_are_blocked_when_branch_is_behind_main() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let root = temp_path("workspace-test-preflight"); + let original_dir = std::env::current_dir().expect("cwd"); + init_git_repo(&root); + run_git(&root, &["checkout", "-b", "feature/stale-tests"]); + run_git(&root, &["checkout", "main"]); + commit_file( + &root, + "hotfix.txt", + "fix from main\n", + "fix: unblock workspace tests", + ); + run_git(&root, &["checkout", "feature/stale-tests"]); + std::env::set_current_dir(&root).expect("set cwd"); + + let output = execute_tool( + "bash", + &json!({ "command": "cargo test --workspace --all-targets" }), + ) + .expect("preflight should return structured output"); + let output_json: serde_json::Value = serde_json::from_str(&output).expect("json"); + assert_eq!( + output_json["returnCodeInterpretation"], + "preflight_blocked:branch_divergence" + ); + assert!(output_json["stderr"] + .as_str() + .expect("stderr") + .contains("branch divergence detected before workspace tests")); + assert_eq!( + output_json["structuredContent"][0]["event"], + "branch.stale_against_main" + ); + assert_eq!( + output_json["structuredContent"][0]["failureClass"], + "branch_divergence" + ); + assert_eq!( + output_json["structuredContent"][0]["data"]["missingCommits"][0], + "fix: unblock workspace tests" + ); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + let _ = std::fs::remove_dir_all(root); + } + + #[test] + fn bash_targeted_tests_skip_branch_preflight() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let root = temp_path("targeted-test-no-preflight"); + let original_dir = std::env::current_dir().expect("cwd"); + init_git_repo(&root); + run_git(&root, &["checkout", "-b", "feature/targeted-tests"]); + run_git(&root, &["checkout", "main"]); + commit_file( + &root, + "hotfix.txt", + "fix from main\n", + "fix: only broad tests should block", + ); + run_git(&root, &["checkout", "feature/targeted-tests"]); + std::env::set_current_dir(&root).expect("set cwd"); + + let output = execute_tool( + "bash", + &json!({ "command": "printf 'targeted ok'; cargo test -p runtime stale_branch" }), + ) + .expect("targeted commands should still execute"); + let output_json: serde_json::Value = serde_json::from_str(&output).expect("json"); + assert_ne!( + output_json["returnCodeInterpretation"], + "preflight_blocked:branch_divergence" + ); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + let _ = std::fs::remove_dir_all(root); + } + #[test] fn file_tools_cover_read_write_and_edit_behaviors() { let _guard = env_lock() @@ -4180,10 +7852,25 @@ mod tests { assert!(elapsed >= Duration::from_millis(15)); } + #[test] + fn given_excessive_duration_when_sleep_then_rejects_with_error() { + let result = execute_tool("Sleep", &json!({"duration_ms": 999_999_999_u64})); + let error = result.expect_err("excessive sleep should fail"); + assert!(error.contains("exceeds maximum allowed sleep")); + } + + #[test] + fn given_zero_duration_when_sleep_then_succeeds() { + let result = + execute_tool("Sleep", &json!({"duration_ms": 0})).expect("0ms sleep should succeed"); + let output: serde_json::Value = serde_json::from_str(&result).expect("json"); + assert_eq!(output["duration_ms"], 0); + } + #[test] fn brief_returns_sent_message_and_attachment_metadata() { let attachment = std::env::temp_dir().join(format!( - "claw-brief-{}.png", + "clawd-brief-{}.png", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -4214,7 +7901,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let root = std::env::temp_dir().join(format!( - "claw-config-{}", + "clawd-config-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -4274,6 +7961,140 @@ mod tests { let _ = std::fs::remove_dir_all(root); } + #[test] + fn enter_and_exit_plan_mode_round_trip_existing_local_override() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let root = std::env::temp_dir().join(format!( + "clawd-plan-mode-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + let home = root.join("home"); + let cwd = root.join("cwd"); + std::fs::create_dir_all(home.join(".claw")).expect("home dir"); + std::fs::create_dir_all(cwd.join(".claw")).expect("cwd dir"); + std::fs::write( + cwd.join(".claw").join("settings.local.json"), + r#"{"permissions":{"defaultMode":"acceptEdits"}}"#, + ) + .expect("write local settings"); + + let original_home = std::env::var("HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_dir = std::env::current_dir().expect("cwd"); + std::env::set_var("HOME", &home); + std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::set_current_dir(&cwd).expect("set cwd"); + + let enter = execute_tool("EnterPlanMode", &json!({})).expect("enter plan mode"); + let enter_output: serde_json::Value = serde_json::from_str(&enter).expect("json"); + assert_eq!(enter_output["changed"], true); + assert_eq!(enter_output["managed"], true); + assert_eq!(enter_output["previousLocalMode"], "acceptEdits"); + assert_eq!(enter_output["currentLocalMode"], "plan"); + + let local_settings = std::fs::read_to_string(cwd.join(".claw").join("settings.local.json")) + .expect("local settings after enter"); + assert!(local_settings.contains(r#""defaultMode": "plan""#)); + let state = + std::fs::read_to_string(cwd.join(".claw").join("tool-state").join("plan-mode.json")) + .expect("plan mode state"); + assert!(state.contains(r#""hadLocalOverride": true"#)); + assert!(state.contains(r#""previousLocalMode": "acceptEdits""#)); + + let exit = execute_tool("ExitPlanMode", &json!({})).expect("exit plan mode"); + let exit_output: serde_json::Value = serde_json::from_str(&exit).expect("json"); + assert_eq!(exit_output["changed"], true); + assert_eq!(exit_output["managed"], false); + assert_eq!(exit_output["previousLocalMode"], "acceptEdits"); + assert_eq!(exit_output["currentLocalMode"], "acceptEdits"); + + let local_settings = std::fs::read_to_string(cwd.join(".claw").join("settings.local.json")) + .expect("local settings after exit"); + assert!(local_settings.contains(r#""defaultMode": "acceptEdits""#)); + assert!(!cwd + .join(".claw") + .join("tool-state") + .join("plan-mode.json") + .exists()); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + match original_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + let _ = std::fs::remove_dir_all(root); + } + + #[test] + fn exit_plan_mode_clears_override_when_enter_created_it_from_empty_local_state() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let root = std::env::temp_dir().join(format!( + "clawd-plan-mode-empty-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + let home = root.join("home"); + let cwd = root.join("cwd"); + std::fs::create_dir_all(home.join(".claw")).expect("home dir"); + std::fs::create_dir_all(cwd.join(".claw")).expect("cwd dir"); + + let original_home = std::env::var("HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); + let original_dir = std::env::current_dir().expect("cwd"); + std::env::set_var("HOME", &home); + std::env::remove_var("CLAW_CONFIG_HOME"); + std::env::set_current_dir(&cwd).expect("set cwd"); + + let enter = execute_tool("EnterPlanMode", &json!({})).expect("enter plan mode"); + let enter_output: serde_json::Value = serde_json::from_str(&enter).expect("json"); + assert_eq!(enter_output["previousLocalMode"], serde_json::Value::Null); + assert_eq!(enter_output["currentLocalMode"], "plan"); + + let exit = execute_tool("ExitPlanMode", &json!({})).expect("exit plan mode"); + let exit_output: serde_json::Value = serde_json::from_str(&exit).expect("json"); + assert_eq!(exit_output["changed"], true); + assert_eq!(exit_output["currentLocalMode"], serde_json::Value::Null); + + let local_settings = std::fs::read_to_string(cwd.join(".claw").join("settings.local.json")) + .expect("local settings after exit"); + let local_settings_json: serde_json::Value = + serde_json::from_str(&local_settings).expect("valid settings json"); + assert_eq!( + local_settings_json.get("permissions"), + None, + "permissions override should be removed on exit" + ); + assert!(!cwd + .join(".claw") + .join("tool-state") + .join("plan-mode.json") + .exists()); + + std::env::set_current_dir(&original_dir).expect("restore cwd"); + match original_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), + } + let _ = std::fs::remove_dir_all(root); + } + #[test] fn structured_output_echoes_input_payload() { let result = execute_tool("StructuredOutput", &json!({"ok": true, "items": [1, 2, 3]})) @@ -4284,6 +8105,13 @@ mod tests { assert_eq!(output["structured_output"]["items"][1], 2); } + #[test] + fn given_empty_payload_when_structured_output_then_rejects_with_error() { + let result = execute_tool("StructuredOutput", &json!({})); + let error = result.expect_err("empty payload should fail"); + assert!(error.contains("must not be empty")); + } + #[test] fn repl_executes_python_code() { let result = execute_tool( @@ -4297,13 +8125,44 @@ mod tests { assert!(output["stdout"].as_str().expect("stdout").contains('2')); } + #[test] + fn given_empty_code_when_repl_then_rejects_with_error() { + let result = execute_tool("REPL", &json!({"language": "python", "code": " "})); + + let error = result.expect_err("empty REPL code should fail"); + assert!(error.contains("code must not be empty")); + } + + #[test] + fn given_unsupported_language_when_repl_then_rejects_with_error() { + let result = execute_tool("REPL", &json!({"language": "ruby", "code": "puts 1"})); + + let error = result.expect_err("unsupported REPL language should fail"); + assert!(error.contains("unsupported REPL language: ruby")); + } + + #[test] + fn given_timeout_ms_when_repl_blocks_then_returns_timeout_error() { + let result = execute_tool( + "REPL", + &json!({ + "language": "python", + "code": "import time\ntime.sleep(1)", + "timeout_ms": 10 + }), + ); + + let error = result.expect_err("timed out REPL execution should fail"); + assert!(error.contains("REPL execution exceeded timeout of 10 ms")); + } + #[test] fn powershell_runs_via_stub_shell() { let _guard = env_lock() .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let dir = std::env::temp_dir().join(format!( - "claw-pwsh-bin-{}", + "clawd-pwsh-bin-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -4360,7 +8219,7 @@ printf 'pwsh:%s' "$1" .unwrap_or_else(std::sync::PoisonError::into_inner); let original_path = std::env::var("PATH").unwrap_or_default(); let empty_dir = std::env::temp_dir().join(format!( - "claw-empty-bin-{}", + "clawd-empty-bin-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -4378,6 +8237,274 @@ printf 'pwsh:%s' "$1" assert!(err.contains("PowerShell executable not found")); } + fn read_only_registry() -> super::GlobalToolRegistry { + use runtime::permission_enforcer::PermissionEnforcer; + use runtime::PermissionPolicy; + + let policy = mvp_tool_specs().into_iter().fold( + PermissionPolicy::new(runtime::PermissionMode::ReadOnly), + |policy, spec| policy.with_tool_requirement(spec.name, spec.required_permission), + ); + let mut registry = super::GlobalToolRegistry::builtin(); + registry.set_enforcer(PermissionEnforcer::new(policy)); + registry + } + + #[test] + fn given_read_only_enforcer_when_bash_then_denied() { + let registry = read_only_registry(); + let err = registry + .execute("bash", &json!({ "command": "echo hi" })) + .expect_err("bash should be denied in read-only mode"); + assert!( + err.contains("current mode is read-only"), + "should cite active mode: {err}" + ); + } + + #[test] + fn given_read_only_enforcer_when_write_file_then_denied() { + let registry = read_only_registry(); + let err = registry + .execute( + "write_file", + &json!({ "path": "/tmp/x.txt", "content": "x" }), + ) + .expect_err("write_file should be denied in read-only mode"); + assert!( + err.contains("current mode is read-only"), + "should cite active mode: {err}" + ); + } + + #[test] + fn given_read_only_enforcer_when_edit_file_then_denied() { + let registry = read_only_registry(); + let err = registry + .execute( + "edit_file", + &json!({ "path": "/tmp/x.txt", "old_string": "a", "new_string": "b" }), + ) + .expect_err("edit_file should be denied in read-only mode"); + assert!( + err.contains("current mode is read-only"), + "should cite active mode: {err}" + ); + } + + #[test] + fn given_read_only_enforcer_when_read_file_then_not_permission_denied() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let root = temp_path("perm-read"); + fs::create_dir_all(&root).expect("create root"); + let file = root.join("readable.txt"); + fs::write(&file, "content\n").expect("write test file"); + + let registry = read_only_registry(); + let result = registry.execute("read_file", &json!({ "path": file.display().to_string() })); + assert!(result.is_ok(), "read_file should be allowed: {result:?}"); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn given_read_only_enforcer_when_glob_search_then_not_permission_denied() { + let registry = read_only_registry(); + let result = registry.execute("glob_search", &json!({ "pattern": "*.rs" })); + assert!( + result.is_ok(), + "glob_search should be allowed in read-only mode: {result:?}" + ); + } + + #[test] + fn given_no_enforcer_when_bash_then_executes_normally() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let registry = super::GlobalToolRegistry::builtin(); + let result = registry + .execute("bash", &json!({ "command": "printf 'ok'" })) + .expect("bash should succeed without enforcer"); + let output: serde_json::Value = serde_json::from_str(&result).expect("json"); + assert_eq!(output["stdout"], "ok"); + } + + #[test] + fn provider_runtime_client_chain_uses_only_primary_when_no_fallbacks_configured() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + let fallback_config = ProviderFallbackConfig::default(); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-sonnet-4-6".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("primary-only chain should construct"); + + // then + assert_eq!(client.chain.len(), 1); + assert_eq!(client.chain[0].model, "claude-sonnet-4-6"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + } + + #[test] + fn provider_runtime_client_chain_appends_configured_fallbacks_in_order() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + let original_xai = std::env::var_os("XAI_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + std::env::set_var("XAI_API_KEY", "xai-test-key"); + let fallback_config = ProviderFallbackConfig::new( + None, + vec!["grok-3".to_string(), "grok-3-mini".to_string()], + ); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-sonnet-4-6".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("chain with fallbacks should construct"); + + // then + assert_eq!(client.chain.len(), 3); + assert_eq!(client.chain[0].model, "claude-sonnet-4-6"); + assert_eq!(client.chain[1].model, "grok-3"); + assert_eq!(client.chain[2].model, "grok-3-mini"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + match original_xai { + Some(value) => std::env::set_var("XAI_API_KEY", value), + None => std::env::remove_var("XAI_API_KEY"), + } + } + + #[test] + fn provider_runtime_client_chain_primary_override_replaces_constructor_model() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + let original_xai = std::env::var_os("XAI_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + std::env::set_var("XAI_API_KEY", "xai-test-key"); + let fallback_config = ProviderFallbackConfig::new( + Some("grok-3".to_string()), + vec!["claude-sonnet-4-6".to_string()], + ); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-haiku-4-5-20251213".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("chain with primary override should construct"); + + // then + assert_eq!(client.chain.len(), 2); + assert_eq!(client.chain[0].model, "grok-3"); + assert_eq!(client.chain[1].model, "claude-sonnet-4-6"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + match original_xai { + Some(value) => std::env::set_var("XAI_API_KEY", value), + None => std::env::remove_var("XAI_API_KEY"), + } + } + + #[test] + fn provider_runtime_client_chain_skips_fallbacks_missing_credentials() { + // given + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let original_anthropic = std::env::var_os("ANTHROPIC_API_KEY"); + let original_xai = std::env::var_os("XAI_API_KEY"); + std::env::set_var("ANTHROPIC_API_KEY", "anthropic-test-key"); + std::env::remove_var("XAI_API_KEY"); + let fallback_config = ProviderFallbackConfig::new( + None, + vec![ + "grok-3".to_string(), + "claude-haiku-4-5-20251213".to_string(), + ], + ); + + // when + let client = ProviderRuntimeClient::new_with_fallback_config( + "claude-sonnet-4-6".to_string(), + BTreeSet::new(), + &fallback_config, + ) + .expect("chain construction should not fail when only some fallbacks are unavailable"); + + // then + assert_eq!(client.chain.len(), 2); + assert_eq!(client.chain[0].model, "claude-sonnet-4-6"); + assert_eq!(client.chain[1].model, "claude-haiku-4-5-20251213"); + + match original_anthropic { + Some(value) => std::env::set_var("ANTHROPIC_API_KEY", value), + None => std::env::remove_var("ANTHROPIC_API_KEY"), + } + if let Some(value) = original_xai { + std::env::set_var("XAI_API_KEY", value); + } + } + + #[test] + fn run_task_packet_creates_packet_backed_task() { + let result = run_task_packet(TaskPacket { + objective: "Ship packetized runtime task".to_string(), + scope: "runtime/task system".to_string(), + repo: "claw-code-parity".to_string(), + branch_policy: "origin/main only".to_string(), + acceptance_tests: vec![ + "cargo build --workspace".to_string(), + "cargo test --workspace".to_string(), + ], + commit_policy: "single commit".to_string(), + reporting_contract: "print build/test result and sha".to_string(), + escalation_policy: "manual escalation".to_string(), + }) + .expect("task packet should create a task"); + + let output: serde_json::Value = serde_json::from_str(&result).expect("json"); + assert_eq!(output["status"], "created"); + assert_eq!(output["prompt"], "Ship packetized runtime task"); + assert_eq!(output["description"], "runtime/task system"); + assert_eq!(output["task_packet"]["repo"], "claw-code-parity"); + assert_eq!( + output["task_packet"]["acceptance_tests"][1], + "cargo test --workspace" + ); + } + struct TestServer { addr: SocketAddr, shutdown: Option>, diff --git a/crates/tools/src/pdf_extract.rs b/crates/tools/src/pdf_extract.rs new file mode 100644 index 0000000..caa9662 --- /dev/null +++ b/crates/tools/src/pdf_extract.rs @@ -0,0 +1,548 @@ +//! Minimal PDF text extraction. +//! +//! Reads a PDF file, locates `/Contents` stream objects, decompresses with +//! flate2 when the stream uses `/FlateDecode`, and extracts text operators +//! found between `BT` / `ET` markers. + +use std::io::Read as _; +use std::path::Path; + +/// Extract all readable text from a PDF file. +/// +/// Returns the concatenated text found inside BT/ET operators across all +/// content streams. Non-text pages or encrypted PDFs yield an empty string +/// rather than an error. +pub fn extract_text(path: &Path) -> Result { + let data = std::fs::read(path).map_err(|e| format!("failed to read PDF: {e}"))?; + Ok(extract_text_from_bytes(&data)) +} + +/// Core extraction from raw PDF bytes — useful for testing without touching the +/// filesystem. +pub(crate) fn extract_text_from_bytes(data: &[u8]) -> String { + let mut all_text = String::new(); + let mut offset = 0; + + while offset < data.len() { + let Some(stream_start) = find_subsequence(&data[offset..], b"stream") else { + break; + }; + let abs_start = offset + stream_start; + + // Determine the byte offset right after "stream\r\n" or "stream\n". + let content_start = skip_stream_eol(data, abs_start + b"stream".len()); + + let Some(end_rel) = find_subsequence(&data[content_start..], b"endstream") else { + break; + }; + let content_end = content_start + end_rel; + + // Look backwards from "stream" for a FlateDecode hint in the object + // dictionary. We scan at most 512 bytes before the stream keyword. + let dict_window_start = abs_start.saturating_sub(512); + let dict_window = &data[dict_window_start..abs_start]; + let is_flate = find_subsequence(dict_window, b"FlateDecode").is_some(); + + // Only process streams whose parent dictionary references /Contents or + // looks like a page content stream (contains /Length). We intentionally + // keep this loose to cover both inline and referenced content streams. + let raw = &data[content_start..content_end]; + let decompressed; + let stream_bytes: &[u8] = if is_flate { + if let Ok(buf) = inflate(raw) { + decompressed = buf; + &decompressed + } else { + offset = content_end; + continue; + } + } else { + raw + }; + + let text = extract_bt_et_text(stream_bytes); + if !text.is_empty() { + if !all_text.is_empty() { + all_text.push('\n'); + } + all_text.push_str(&text); + } + + offset = content_end; + } + + all_text +} + +/// Inflate (zlib / deflate) compressed data via `flate2`. +fn inflate(data: &[u8]) -> Result, String> { + let mut decoder = flate2::read::ZlibDecoder::new(data); + let mut buf = Vec::new(); + decoder + .read_to_end(&mut buf) + .map_err(|e| format!("flate2 inflate error: {e}"))?; + Ok(buf) +} + +/// Extract text from PDF content-stream operators between BT and ET markers. +/// +/// Handles the common text-showing operators: +/// - `Tj` — show a string +/// - `TJ` — show an array of strings/numbers +/// - `'` — move to next line and show string +/// - `"` — set spacing, move to next line and show string +fn extract_bt_et_text(stream: &[u8]) -> String { + let text = String::from_utf8_lossy(stream); + let mut result = String::new(); + let mut in_bt = false; + + for line in text.lines() { + let trimmed = line.trim(); + if trimmed == "BT" { + in_bt = true; + continue; + } + if trimmed == "ET" { + in_bt = false; + continue; + } + if !in_bt { + continue; + } + + // Tj operator: (text) Tj + if trimmed.ends_with("Tj") { + if let Some(s) = extract_parenthesized_string(trimmed) { + if !result.is_empty() && !result.ends_with('\n') { + result.push(' '); + } + result.push_str(&s); + } + } + // TJ operator: [ (text) 123 (text) ] TJ + else if trimmed.ends_with("TJ") { + let extracted = extract_tj_array(trimmed); + if !extracted.is_empty() { + if !result.is_empty() && !result.ends_with('\n') { + result.push(' '); + } + result.push_str(&extracted); + } + } + // ' operator: (text) ' and " operator: aw ac (text) " + else if is_newline_show_operator(trimmed) { + if let Some(s) = extract_parenthesized_string(trimmed) { + if !result.is_empty() { + result.push('\n'); + } + result.push_str(&s); + } + } + } + + result +} + +/// Returns `true` when `trimmed` looks like a `'` or `"` text-show operator. +fn is_newline_show_operator(trimmed: &str) -> bool { + (trimmed.ends_with('\'') && trimmed.len() > 1) + || (trimmed.ends_with('"') && trimmed.contains('(')) +} + +/// Pull the text from the first `(…)` group, handling escaped parens and +/// common PDF escape sequences. +fn extract_parenthesized_string(input: &str) -> Option { + let open = input.find('(')?; + let bytes = input.as_bytes(); + let mut depth = 0; + let mut result = String::new(); + let mut i = open; + + while i < bytes.len() { + match bytes[i] { + b'(' => { + if depth > 0 { + result.push('('); + } + depth += 1; + } + b')' => { + depth -= 1; + if depth == 0 { + return Some(result); + } + result.push(')'); + } + b'\\' if i + 1 < bytes.len() => { + i += 1; + match bytes[i] { + b'n' => result.push('\n'), + b'r' => result.push('\r'), + b't' => result.push('\t'), + b'\\' => result.push('\\'), + b'(' => result.push('('), + b')' => result.push(')'), + // Octal sequences — up to 3 digits. + d @ b'0'..=b'7' => { + let mut octal = u32::from(d - b'0'); + for _ in 0..2 { + if i + 1 < bytes.len() + && bytes[i + 1].is_ascii_digit() + && bytes[i + 1] <= b'7' + { + i += 1; + octal = octal * 8 + u32::from(bytes[i] - b'0'); + } else { + break; + } + } + if let Some(ch) = char::from_u32(octal) { + result.push(ch); + } + } + other => result.push(char::from(other)), + } + } + ch => result.push(char::from(ch)), + } + i += 1; + } + + None // unbalanced +} + +/// Extract concatenated strings from a TJ array like `[ (Hello) -120 (World) ] TJ`. +fn extract_tj_array(input: &str) -> String { + let mut result = String::new(); + let Some(bracket_start) = input.find('[') else { + return result; + }; + let Some(bracket_end) = input.rfind(']') else { + return result; + }; + let inner = &input[bracket_start + 1..bracket_end]; + + let mut i = 0; + let bytes = inner.as_bytes(); + while i < bytes.len() { + if bytes[i] == b'(' { + // Reconstruct the parenthesized string and extract it. + if let Some(s) = extract_parenthesized_string(&inner[i..]) { + result.push_str(&s); + // Skip past the closing paren. + let mut depth = 0u32; + for &b in &bytes[i..] { + i += 1; + if b == b'(' { + depth += 1; + } else if b == b')' { + depth -= 1; + if depth == 0 { + break; + } + } + } + continue; + } + } + i += 1; + } + + result +} + +/// Skip past the end-of-line marker that immediately follows the `stream` +/// keyword. Per the PDF spec this is either `\r\n` or `\n`. +fn skip_stream_eol(data: &[u8], pos: usize) -> usize { + if pos < data.len() && data[pos] == b'\r' { + if pos + 1 < data.len() && data[pos + 1] == b'\n' { + return pos + 2; + } + return pos + 1; + } + if pos < data.len() && data[pos] == b'\n' { + return pos + 1; + } + pos +} + +/// Simple byte-subsequence search. +fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option { + haystack + .windows(needle.len()) + .position(|window| window == needle) +} + +/// Check if a user-supplied path looks like a PDF file reference. +#[must_use] +pub fn looks_like_pdf_path(text: &str) -> Option<&str> { + for token in text.split_whitespace() { + let cleaned = token.trim_matches(|c: char| c == '\'' || c == '"' || c == '`'); + if let Some(dot_pos) = cleaned.rfind('.') { + if cleaned[dot_pos + 1..].eq_ignore_ascii_case("pdf") && dot_pos > 0 { + return Some(cleaned); + } + } + } + None +} + +/// Auto-extract text from a PDF path mentioned in a user prompt. +/// +/// Returns `Some((path, extracted_text))` when a `.pdf` path is detected and +/// the file exists, otherwise `None`. +#[must_use] +pub fn maybe_extract_pdf_from_prompt(prompt: &str) -> Option<(String, String)> { + let pdf_path = looks_like_pdf_path(prompt)?; + let path = Path::new(pdf_path); + if !path.exists() { + return None; + } + let text = extract_text(path).ok()?; + if text.is_empty() { + return None; + } + Some((pdf_path.to_string(), text)) +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a minimal valid PDF with a single page containing uncompressed + /// text. This is the smallest PDF structure that exercises the BT/ET + /// extraction path. + fn build_simple_pdf(text: &str) -> Vec { + let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET"); + let stream_bytes = content_stream.as_bytes(); + let mut pdf = Vec::new(); + + // Header + pdf.extend_from_slice(b"%PDF-1.4\n"); + + // Object 1 — Catalog + let obj1_offset = pdf.len(); + pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n"); + + // Object 2 — Pages + let obj2_offset = pdf.len(); + pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n"); + + // Object 3 — Page + let obj3_offset = pdf.len(); + pdf.extend_from_slice( + b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n", + ); + + // Object 4 — Content stream (uncompressed) + let obj4_offset = pdf.len(); + let length = stream_bytes.len(); + let header = format!("4 0 obj\n<< /Length {length} >>\nstream\n"); + pdf.extend_from_slice(header.as_bytes()); + pdf.extend_from_slice(stream_bytes); + pdf.extend_from_slice(b"\nendstream\nendobj\n"); + + // Cross-reference table + let xref_offset = pdf.len(); + pdf.extend_from_slice(b"xref\n0 5\n"); + pdf.extend_from_slice(b"0000000000 65535 f \n"); + pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes()); + pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes()); + pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes()); + pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes()); + + // Trailer + pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n"); + pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes()); + + pdf + } + + /// Build a minimal PDF with flate-compressed content stream. + fn build_flate_pdf(text: &str) -> Vec { + use flate2::write::ZlibEncoder; + use flate2::Compression; + use std::io::Write as _; + + let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET"); + let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default()); + encoder + .write_all(content_stream.as_bytes()) + .expect("compress"); + let compressed = encoder.finish().expect("finish"); + + let mut pdf = Vec::new(); + pdf.extend_from_slice(b"%PDF-1.4\n"); + + let obj1_offset = pdf.len(); + pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n"); + + let obj2_offset = pdf.len(); + pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n"); + + let obj3_offset = pdf.len(); + pdf.extend_from_slice( + b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n", + ); + + let obj4_offset = pdf.len(); + let length = compressed.len(); + let header = format!("4 0 obj\n<< /Length {length} /Filter /FlateDecode >>\nstream\n"); + pdf.extend_from_slice(header.as_bytes()); + pdf.extend_from_slice(&compressed); + pdf.extend_from_slice(b"\nendstream\nendobj\n"); + + let xref_offset = pdf.len(); + pdf.extend_from_slice(b"xref\n0 5\n"); + pdf.extend_from_slice(b"0000000000 65535 f \n"); + pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes()); + pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes()); + pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes()); + pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes()); + + pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n"); + pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes()); + + pdf + } + + #[test] + fn extracts_uncompressed_text_from_minimal_pdf() { + // given + let pdf_bytes = build_simple_pdf("Hello World"); + + // when + let text = extract_text_from_bytes(&pdf_bytes); + + // then + assert_eq!(text, "Hello World"); + } + + #[test] + fn extracts_text_from_flate_compressed_stream() { + // given + let pdf_bytes = build_flate_pdf("Compressed PDF Text"); + + // when + let text = extract_text_from_bytes(&pdf_bytes); + + // then + assert_eq!(text, "Compressed PDF Text"); + } + + #[test] + fn handles_tj_array_operator() { + // given + let stream = b"BT\n/F1 12 Tf\n[ (Hello) -120 ( World) ] TJ\nET"; + // Build a raw PDF with TJ array operator instead of simple Tj. + let content_stream = std::str::from_utf8(stream).unwrap(); + let raw = format!( + "%PDF-1.4\n1 0 obj\n<< /Type /Catalog >>\nendobj\n\ + 2 0 obj\n<< /Length {} >>\nstream\n{}\nendstream\nendobj\n%%EOF\n", + content_stream.len(), + content_stream + ); + let pdf_bytes = raw.into_bytes(); + + // when + let text = extract_text_from_bytes(&pdf_bytes); + + // then + assert_eq!(text, "Hello World"); + } + + #[test] + fn handles_escaped_parentheses() { + // given + let content = b"BT\n(Hello \\(World\\)) Tj\nET"; + let raw = format!( + "%PDF-1.4\n1 0 obj\n<< /Length {} >>\nstream\n", + content.len() + ); + let mut pdf_bytes = raw.into_bytes(); + pdf_bytes.extend_from_slice(content); + pdf_bytes.extend_from_slice(b"\nendstream\nendobj\n%%EOF\n"); + + // when + let text = extract_text_from_bytes(&pdf_bytes); + + // then + assert_eq!(text, "Hello (World)"); + } + + #[test] + fn returns_empty_for_non_pdf_data() { + // given + let data = b"This is not a PDF file at all"; + + // when + let text = extract_text_from_bytes(data); + + // then + assert!(text.is_empty()); + } + + #[test] + fn extracts_text_from_file_on_disk() { + // given + let pdf_bytes = build_simple_pdf("Disk Test"); + let dir = std::env::temp_dir().join("clawd-pdf-extract-test"); + std::fs::create_dir_all(&dir).unwrap(); + let pdf_path = dir.join("test.pdf"); + std::fs::write(&pdf_path, &pdf_bytes).unwrap(); + + // when + let text = extract_text(&pdf_path).unwrap(); + + // then + assert_eq!(text, "Disk Test"); + + // cleanup + let _ = std::fs::remove_dir_all(&dir); + } + + #[test] + fn looks_like_pdf_path_detects_pdf_references() { + // given / when / then + assert_eq!( + looks_like_pdf_path("Please read /tmp/report.pdf"), + Some("/tmp/report.pdf") + ); + assert_eq!(looks_like_pdf_path("Check file.PDF now"), Some("file.PDF")); + assert_eq!(looks_like_pdf_path("no pdf here"), None); + } + + #[test] + fn maybe_extract_pdf_from_prompt_returns_none_for_missing_file() { + // given + let prompt = "Read /tmp/nonexistent-abc123.pdf please"; + + // when + let result = maybe_extract_pdf_from_prompt(prompt); + + // then + assert!(result.is_none()); + } + + #[test] + fn maybe_extract_pdf_from_prompt_extracts_existing_file() { + // given + let pdf_bytes = build_simple_pdf("Auto Extracted"); + let dir = std::env::temp_dir().join("clawd-pdf-auto-extract-test"); + std::fs::create_dir_all(&dir).unwrap(); + let pdf_path = dir.join("auto.pdf"); + std::fs::write(&pdf_path, &pdf_bytes).unwrap(); + let prompt = format!("Summarize {}", pdf_path.display()); + + // when + let result = maybe_extract_pdf_from_prompt(&prompt); + + // then + let (path, text) = result.expect("should extract"); + assert_eq!(path, pdf_path.display().to_string()); + assert_eq!(text, "Auto Extracted"); + + // cleanup + let _ = std::fs::remove_dir_all(&dir); + } +} diff --git a/frontend/src/components/ChatView.tsx b/frontend/src/components/ChatView.tsx index d45ac57..f8a8565 100644 --- a/frontend/src/components/ChatView.tsx +++ b/frontend/src/components/ChatView.tsx @@ -5,9 +5,9 @@ import { theme, Skeleton, Spin, Popover } from 'antd'; import { XMarkdown } from '@ant-design/x-markdown'; // 助手气泡 body 撑满可用宽度,避免 Mermaid 等内容宽度受文本行长度影响 -const bubbleStyle = document.createElement('style'); -bubbleStyle.textContent = '.ant-bubble-start > .ant-bubble-body { width: 80%; }'; -document.head.appendChild(bubbleStyle); +// const bubbleStyle = document.createElement('style'); +// bubbleStyle.textContent = '.ant-bubble-start > .ant-bubble-body { width: 100%; }'; +// document.head.appendChild(bubbleStyle); import type { ComponentProps, Token } from '@ant-design/x-markdown'; import Latex from '@ant-design/x-markdown/plugins/latex'; import '@ant-design/x-markdown/themes/light.css'; @@ -18,8 +18,6 @@ import WelcomeScreen from './WelcomeScreen'; // ── XMarkdown 插件配置 ──────────────────────────────────────────────── -// LaTeX 数学公式插件:解析 $...$ / $$...$$ / \(...\) / \[...\] -// 自定义脚注插件:解析 [^1] 语法 → 标签 const footnoteExtension = { name: 'footnote', level: 'inline' as const, @@ -281,14 +279,12 @@ const ChatView: React.FC = ({ variant: 'filled' as const, shape: 'round' as const, avatar: , - // styles: { content: { width: '80%' } }, }, assistant: { placement: 'start' as const, variant: 'borderless' as const, avatar: , streaming: true, - styles: { content: { width: '80%' } }, header: (_content: unknown, { status }: { status?: string }) => { if (status === 'loading' || status === 'updating') { return (