feat: 合并上游 Rust 实现,扩展 API/运行时/工具链能力

将 claw-code/rust/crates 的完整实现合并到主 workspace,涵盖
  9 个 crate 的更新与 2 个新 crate 的引入。

  API 层:
  - 用原生 Anthropic 客户端(anthropic.rs)替换 claw_provider,
    新增 prompt cache 减少重复请求开销
  - 新增 HTTP 客户端构建器统一代理配置,OpenAI 兼容端增加
    DashScope/Qwen 支持与抖动重试
  - MessageRequest 扩展 temperature/top_p 等模型调参字段
  - SSE 解析器增加 provider 上下文感知的错误信息

  运行时(~11,000 行新增):
  - 新增 bash 命令安全校验、分支锁碰撞检测、配置文件校验
  - 新增会话存储与控制面、MCP 生命周期状态机与服务端实现
  - 新增权限执行引擎、策略引擎、插件生命周期管理
  - 新增 worker 启动编排、任务/定时任务注册表、信任解析器
  - 保留 Windows cmd /C fallback

  命令/插件/工具:
  - commands 大幅重写,扩展 sandbox、doctor、plan 等 slash 命令
  - plugins 新增 PostToolUseFailure hook 与宽容加载机制
  - tools 新增 PDF 提取与 lane 补全工具

  新增 crate:mock-anthropic-service(测试)、telemetry(遥测)

  适配 claw-cli/server:ClawApiClient→AnthropicClient 重命名,
  SlashCommand::parse 返回 Result,移除 session 级 Thinking 变体,
  TokenUsage/ConversationMessage 补充序列化支持
This commit is contained in:
fengmengqi 2026-04-13 14:39:17 +08:00
parent 4a04faf926
commit d8d77824f4
94 changed files with 49049 additions and 4429 deletions

56
Cargo.lock generated
View File

@ -34,6 +34,7 @@ dependencies = [
"runtime",
"serde",
"serde_json",
"telemetry",
"tokio",
]
@ -347,12 +348,24 @@ version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "endian-type"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d"
[[package]]
name = "env_home"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe"
[[package]]
name = "equivalent"
version = "1.0.2"
@ -945,6 +958,15 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "mock-anthropic-service"
version = "0.1.0"
dependencies = [
"api",
"serde_json",
"tokio",
]
[[package]]
name = "nibble_vec"
version = "0.1.0"
@ -1340,14 +1362,15 @@ name = "runtime"
version = "0.1.0"
dependencies = [
"glob",
"lsp",
"plugins",
"regex",
"serde",
"serde_json",
"sha2",
"telemetry",
"tokio",
"walkdir",
"which",
]
[[package]]
@ -1431,9 +1454,12 @@ dependencies = [
"commands",
"compat-harness",
"crossterm",
"mock-anthropic-service",
"plugins",
"pulldown-cmark",
"runtime",
"rustyline",
"serde",
"serde_json",
"syntect",
"tokio",
@ -1721,6 +1747,14 @@ dependencies = [
"yaml-rust",
]
[[package]]
name = "telemetry"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "thiserror"
version = "2.0.18"
@ -1852,6 +1886,8 @@ name = "tools"
version = "0.1.0"
dependencies = [
"api",
"commands",
"flate2",
"plugins",
"reqwest",
"runtime",
@ -2129,6 +2165,18 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "7.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762"
dependencies = [
"either",
"env_home",
"rustix 1.1.4",
"winsafe",
]
[[package]]
name = "winapi"
version = "0.3.9"
@ -2384,6 +2432,12 @@ version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
[[package]]
name = "winsafe"
version = "0.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904"
[[package]]
name = "wit-bindgen"
version = "0.51.0"

View File

@ -21,3 +21,35 @@ pedantic = { level = "warn", priority = -1 }
module_name_repetitions = "allow"
missing_panics_doc = "allow"
missing_errors_doc = "allow"
uninlined_format_args = "allow"
map_unwrap_or = "allow"
doc_markdown = "allow"
redundant_pattern_matching = "allow"
items_after_statements = "allow"
must_use_candidate = "allow"
io_other_error = "allow"
implicit_clone = "allow"
redundant_closure_for_method_calls = "allow"
unnecessary_map_or = "allow"
manual_let_else = "allow"
format_push_string = "allow"
match_same_arms = "allow"
similar_names = "allow"
needless_continue = "allow"
assigning_clones = "allow"
cast_possible_truncation = "allow"
cast_possible_wrap = "allow"
cast_sign_loss = "allow"
cmp_owned = "allow"
collapsible_if = "allow"
too_many_lines = "allow"
wildcard_in_or_patterns = "allow"
explicit_counter_loop = "allow"
manual_repeat_n = "allow"
manual_str_repeat = "allow"
needless_borrow = "allow"
needless_pass_by_value = "allow"
single_match_else = "allow"
too_many_arguments = "allow"
unnested_or_patterns = "allow"
unused_self = "allow"

View File

@ -10,6 +10,7 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "rus
runtime = { path = "../runtime" }
serde = { version = "1", features = ["derive"] }
serde_json.workspace = true
telemetry = { path = "../telemetry" }
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
[lints]

View File

@ -1,70 +1,91 @@
use crate::error::ApiError;
use crate::providers::claw_provider::{self, AuthSource, ClawApiClient};
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
use crate::providers::anthropic::{self, AnthropicClient, AuthSource};
use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig};
use crate::providers::{self, Provider, ProviderKind};
use crate::providers::{self, ProviderKind};
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
async fn send_via_provider<P: Provider>(
provider: &P,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
provider.send_message(request).await
}
async fn stream_via_provider<P: Provider>(
provider: &P,
request: &MessageRequest,
) -> Result<P::Stream, ApiError> {
provider.stream_message(request).await
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
pub enum ProviderClient {
ClawApi(ClawApiClient),
Anthropic(AnthropicClient),
Xai(OpenAiCompatClient),
OpenAi(OpenAiCompatClient),
}
impl ProviderClient {
pub fn from_model(model: &str) -> Result<Self, ApiError> {
Self::from_model_with_default_auth(model, None)
Self::from_model_with_anthropic_auth(model, None)
}
pub fn from_model_with_default_auth(
pub fn from_model_with_anthropic_auth(
model: &str,
default_auth: Option<AuthSource>,
anthropic_auth: Option<AuthSource>,
) -> Result<Self, ApiError> {
let resolved_model = providers::resolve_model_alias(model);
match providers::detect_provider_kind(&resolved_model) {
ProviderKind::ClawApi => Ok(Self::ClawApi(match default_auth {
Some(auth) => ClawApiClient::from_auth(auth),
None => ClawApiClient::from_env()?,
ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth {
Some(auth) => AnthropicClient::from_auth(auth),
None => AnthropicClient::from_env()?,
})),
ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
OpenAiCompatConfig::xai(),
)?)),
ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env(
OpenAiCompatConfig::openai(),
)?)),
ProviderKind::OpenAi => {
// DashScope models (qwen-*) also return ProviderKind::OpenAi because they
// speak the OpenAI wire format, but they need the DashScope config which
// reads DASHSCOPE_API_KEY and points at dashscope.aliyuncs.com.
let config = match providers::metadata_for_model(&resolved_model) {
Some(meta) if meta.auth_env == "DASHSCOPE_API_KEY" => {
OpenAiCompatConfig::dashscope()
}
_ => OpenAiCompatConfig::openai(),
};
Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?))
}
}
}
#[must_use]
pub const fn provider_kind(&self) -> ProviderKind {
match self {
Self::ClawApi(_) => ProviderKind::ClawApi,
Self::Anthropic(_) => ProviderKind::Anthropic,
Self::Xai(_) => ProviderKind::Xai,
Self::OpenAi(_) => ProviderKind::OpenAi,
}
}
#[must_use]
pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self {
match self {
Self::Anthropic(client) => Self::Anthropic(client.with_prompt_cache(prompt_cache)),
other => other,
}
}
#[must_use]
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
match self {
Self::Anthropic(client) => client.prompt_cache_stats(),
Self::Xai(_) | Self::OpenAi(_) => None,
}
}
#[must_use]
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
match self {
Self::Anthropic(client) => client.take_last_prompt_cache_record(),
Self::Xai(_) | Self::OpenAi(_) => None,
}
}
pub async fn send_message(
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
match self {
Self::ClawApi(client) => send_via_provider(client, request).await,
Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
Self::Anthropic(client) => client.send_message(request).await,
Self::Xai(client) | Self::OpenAi(client) => client.send_message(request).await,
}
}
@ -73,10 +94,12 @@ impl ProviderClient {
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
match self {
Self::ClawApi(client) => stream_via_provider(client, request)
Self::Anthropic(client) => client
.stream_message(request)
.await
.map(MessageStream::ClawApi),
Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
.map(MessageStream::Anthropic),
Self::Xai(client) | Self::OpenAi(client) => client
.stream_message(request)
.await
.map(MessageStream::OpenAiCompat),
}
@ -85,7 +108,7 @@ impl ProviderClient {
#[derive(Debug)]
pub enum MessageStream {
ClawApi(claw_provider::MessageStream),
Anthropic(anthropic::MessageStream),
OpenAiCompat(openai_compat::MessageStream),
}
@ -93,25 +116,25 @@ impl MessageStream {
#[must_use]
pub fn request_id(&self) -> Option<&str> {
match self {
Self::ClawApi(stream) => stream.request_id(),
Self::Anthropic(stream) => stream.request_id(),
Self::OpenAiCompat(stream) => stream.request_id(),
}
}
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
match self {
Self::ClawApi(stream) => stream.next_event().await,
Self::Anthropic(stream) => stream.next_event().await,
Self::OpenAiCompat(stream) => stream.next_event().await,
}
}
}
pub use claw_provider::{
pub use anthropic::{
oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet,
};
#[must_use]
pub fn read_base_url() -> String {
claw_provider::read_base_url()
anthropic::read_base_url()
}
#[must_use]
@ -121,8 +144,21 @@ pub fn read_xai_base_url() -> String {
#[cfg(test)]
mod tests {
use std::sync::{Mutex, OnceLock};
use super::ProviderClient;
use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind};
/// Serializes every test in this module that mutates process-wide
/// environment variables so concurrent test threads cannot observe
/// each other's partially-applied state.
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[test]
fn resolves_existing_and_grok_aliases() {
assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
@ -135,7 +171,71 @@ mod tests {
assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai);
assert_eq!(
detect_provider_kind("claude-sonnet-4-6"),
ProviderKind::ClawApi
ProviderKind::Anthropic
);
}
/// Snapshot-restore guard for a single environment variable. Mirrors
/// the pattern used in `providers/mod.rs` tests: captures the original
/// value on construction, applies the override, and restores on drop so
/// tests leave the process env untouched even when they panic.
struct EnvVarGuard {
key: &'static str,
original: Option<std::ffi::OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: Option<&str>) -> Self {
let original = std::env::var_os(key);
match value {
Some(value) => std::env::set_var(key, value),
None => std::env::remove_var(key),
}
Self { key, original }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match self.original.take() {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
#[test]
fn dashscope_model_uses_dashscope_config_not_openai() {
// Regression: qwen-plus was being routed to OpenAiCompatConfig::openai()
// which reads OPENAI_API_KEY and points at api.openai.com, when it should
// use OpenAiCompatConfig::dashscope() which reads DASHSCOPE_API_KEY and
// points at dashscope.aliyuncs.com.
let _lock = env_lock();
let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", Some("test-dashscope-key"));
let _openai = EnvVarGuard::set("OPENAI_API_KEY", None);
let client = ProviderClient::from_model("qwen-plus");
// Must succeed (not fail with "missing OPENAI_API_KEY")
assert!(
client.is_ok(),
"qwen-plus with DASHSCOPE_API_KEY set should build successfully, got: {:?}",
client.err()
);
// Verify it's the OpenAi variant pointed at the DashScope base URL.
match client.unwrap() {
ProviderClient::OpenAi(openai_client) => {
assert!(
openai_client.base_url().contains("dashscope.aliyuncs.com"),
"qwen-plus should route to DashScope base URL (contains 'dashscope.aliyuncs.com'), got: {}",
openai_client.base_url()
);
}
other => panic!(
"Expected ProviderClient::OpenAi for qwen-plus, got: {:?}",
other
),
}
}
}

View File

@ -2,22 +2,55 @@ use std::env::VarError;
use std::fmt::{Display, Formatter};
use std::time::Duration;
const GENERIC_FATAL_WRAPPER_MARKERS: &[&str] = &[
"something went wrong while processing your request",
"please try again, or use /new to start a fresh session",
];
const CONTEXT_WINDOW_ERROR_MARKERS: &[&str] = &[
"maximum context length",
"context window",
"context length",
"too many tokens",
"prompt is too long",
"input is too long",
"request is too large",
];
#[derive(Debug)]
pub enum ApiError {
MissingCredentials {
provider: &'static str,
env_vars: &'static [&'static str],
/// Optional, runtime-computed hint appended to the error Display
/// output. Populated when the provider resolver can infer what the
/// user probably intended (e.g. an OpenAI key is set but Anthropic
/// was selected because no Anthropic credentials exist).
hint: Option<String>,
},
ContextWindowExceeded {
model: String,
estimated_input_tokens: u32,
requested_output_tokens: u32,
estimated_total_tokens: u32,
context_window_tokens: u32,
},
ExpiredOAuthToken,
Auth(String),
InvalidApiKeyEnv(VarError),
Http(reqwest::Error),
Io(std::io::Error),
Json(serde_json::Error),
Json {
provider: String,
model: String,
body_snippet: String,
source: serde_json::Error,
},
Api {
status: reqwest::StatusCode,
error_type: Option<String>,
message: Option<String>,
request_id: Option<String>,
body: String,
retryable: bool,
},
@ -38,7 +71,48 @@ impl ApiError {
provider: &'static str,
env_vars: &'static [&'static str],
) -> Self {
Self::MissingCredentials { provider, env_vars }
Self::MissingCredentials {
provider,
env_vars,
hint: None,
}
}
/// Build a `MissingCredentials` error carrying an extra, runtime-computed
/// hint string that the Display impl appends after the canonical "missing
/// <provider> credentials" message. Used by the provider resolver to
/// suggest the likely fix when the user has credentials for a different
/// provider already in the environment.
#[must_use]
pub fn missing_credentials_with_hint(
provider: &'static str,
env_vars: &'static [&'static str],
hint: impl Into<String>,
) -> Self {
Self::MissingCredentials {
provider,
env_vars,
hint: Some(hint.into()),
}
}
/// Build a `Self::Json` enriched with the provider name, the model that
/// was requested, and the first 200 characters of the raw response body so
/// that callers can diagnose deserialization failures without re-running
/// the request.
#[must_use]
pub fn json_deserialize(
provider: impl Into<String>,
model: impl Into<String>,
body: &str,
source: serde_json::Error,
) -> Self {
Self::Json {
provider: provider.into(),
model: model.into(),
body_snippet: truncate_body_snippet(body, 200),
source,
}
}
#[must_use]
@ -48,11 +122,106 @@ impl ApiError {
Self::Api { retryable, .. } => *retryable,
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
Self::MissingCredentials { .. }
| Self::ContextWindowExceeded { .. }
| Self::ExpiredOAuthToken
| Self::Auth(_)
| Self::InvalidApiKeyEnv(_)
| Self::Io(_)
| Self::Json(_)
| Self::Json { .. }
| Self::InvalidSseFrame(_)
| Self::BackoffOverflow { .. } => false,
}
}
#[must_use]
pub fn request_id(&self) -> Option<&str> {
match self {
Self::Api { request_id, .. } => request_id.as_deref(),
Self::RetriesExhausted { last_error, .. } => last_error.request_id(),
Self::MissingCredentials { .. }
| Self::ContextWindowExceeded { .. }
| Self::ExpiredOAuthToken
| Self::Auth(_)
| Self::InvalidApiKeyEnv(_)
| Self::Http(_)
| Self::Io(_)
| Self::Json { .. }
| Self::InvalidSseFrame(_)
| Self::BackoffOverflow { .. } => None,
}
}
#[must_use]
pub fn safe_failure_class(&self) -> &'static str {
match self {
Self::RetriesExhausted { .. } if self.is_context_window_failure() => "context_window",
Self::RetriesExhausted { .. } if self.is_generic_fatal_wrapper() => {
"provider_retry_exhausted"
}
Self::RetriesExhausted { last_error, .. } => last_error.safe_failure_class(),
Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) => {
"provider_auth"
}
Self::Api { status, .. } if matches!(status.as_u16(), 401 | 403) => "provider_auth",
Self::ContextWindowExceeded { .. } => "context_window",
Self::Api { .. } if self.is_context_window_failure() => "context_window",
Self::Api { status, .. } if status.as_u16() == 429 => "provider_rate_limit",
Self::Api { .. } if self.is_generic_fatal_wrapper() => "provider_internal",
Self::Api { .. } => "provider_error",
Self::Http(_) | Self::InvalidSseFrame(_) | Self::BackoffOverflow { .. } => {
"provider_transport"
}
Self::InvalidApiKeyEnv(_) | Self::Io(_) | Self::Json { .. } => "runtime_io",
}
}
#[must_use]
pub fn is_generic_fatal_wrapper(&self) -> bool {
match self {
Self::Api { message, body, .. } => {
message
.as_deref()
.is_some_and(looks_like_generic_fatal_wrapper)
|| looks_like_generic_fatal_wrapper(body)
}
Self::RetriesExhausted { last_error, .. } => last_error.is_generic_fatal_wrapper(),
Self::MissingCredentials { .. }
| Self::ContextWindowExceeded { .. }
| Self::ExpiredOAuthToken
| Self::Auth(_)
| Self::InvalidApiKeyEnv(_)
| Self::Http(_)
| Self::Io(_)
| Self::Json { .. }
| Self::InvalidSseFrame(_)
| Self::BackoffOverflow { .. } => false,
}
}
#[must_use]
pub fn is_context_window_failure(&self) -> bool {
match self {
Self::ContextWindowExceeded { .. } => true,
Self::Api {
status,
message,
body,
..
} => {
matches!(status.as_u16(), 400 | 413 | 422)
&& (message
.as_deref()
.is_some_and(looks_like_context_window_error)
|| looks_like_context_window_error(body))
}
Self::RetriesExhausted { last_error, .. } => last_error.is_context_window_failure(),
Self::MissingCredentials { .. }
| Self::ExpiredOAuthToken
| Self::Auth(_)
| Self::InvalidApiKeyEnv(_)
| Self::Http(_)
| Self::Io(_)
| Self::Json { .. }
| Self::InvalidSseFrame(_)
| Self::BackoffOverflow { .. } => false,
}
@ -62,10 +231,43 @@ impl ApiError {
impl Display for ApiError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingCredentials { provider, env_vars } => write!(
Self::MissingCredentials {
provider,
env_vars,
hint,
} => {
write!(
f,
"missing {provider} credentials; export {} before calling the {provider} API",
env_vars.join(" or ")
)?;
if cfg!(target_os = "windows") {
if let Some(primary) = env_vars.first() {
write!(
f,
" (on Windows, environment variables set in PowerShell only persist for the current session; use `setx {primary} <value>` to make it permanent, then open a new terminal, or place a `.env` file containing `{primary}=<value>` in the current working directory)"
)?;
} else {
write!(
f,
" (on Windows, environment variables set in PowerShell only persist for the current session; use `setx` to make them permanent, then open a new terminal, or place a `.env` file in the current working directory)"
)?;
}
}
if let Some(hint) = hint {
write!(f, " — hint: {hint}")?;
}
Ok(())
}
Self::ContextWindowExceeded {
model,
estimated_input_tokens,
requested_output_tokens,
estimated_total_tokens,
context_window_tokens,
} => write!(
f,
"context_window_blocked for {model}: estimated input {estimated_input_tokens} + requested output {requested_output_tokens} = {estimated_total_tokens} tokens exceeds the {context_window_tokens}-token context window; compact the session or reduce request size before retrying"
),
Self::ExpiredOAuthToken => {
write!(
@ -79,19 +281,37 @@ impl Display for ApiError {
}
Self::Http(error) => write!(f, "http error: {error}"),
Self::Io(error) => write!(f, "io error: {error}"),
Self::Json(error) => write!(f, "json error: {error}"),
Self::Json {
provider,
model,
body_snippet,
source,
} => write!(
f,
"failed to parse {provider} response for model {model}: {source}; first 200 chars of body: {body_snippet}"
),
Self::Api {
status,
error_type,
message,
request_id,
body,
..
} => match (error_type, message) {
(Some(error_type), Some(message)) => {
write!(f, "api returned {status} ({error_type}): {message}")
} => {
if let (Some(error_type), Some(message)) = (error_type, message) {
write!(f, "api returned {status} ({error_type})")?;
if let Some(request_id) = request_id {
write!(f, " [trace {request_id}]")?;
}
write!(f, ": {message}")
} else {
write!(f, "api returned {status}")?;
if let Some(request_id) = request_id {
write!(f, " [trace {request_id}]")?;
}
write!(f, ": {body}")
}
}
_ => write!(f, "api returned {status}: {body}"),
},
Self::RetriesExhausted {
attempts,
last_error,
@ -124,7 +344,12 @@ impl From<std::io::Error> for ApiError {
impl From<serde_json::Error> for ApiError {
fn from(value: serde_json::Error) -> Self {
Self::Json(value)
Self::Json {
provider: "unknown".to_string(),
model: "unknown".to_string(),
body_snippet: String::new(),
source: value,
}
}
}
@ -133,3 +358,215 @@ impl From<VarError> for ApiError {
Self::InvalidApiKeyEnv(value)
}
}
fn looks_like_generic_fatal_wrapper(text: &str) -> bool {
let lowered = text.to_ascii_lowercase();
GENERIC_FATAL_WRAPPER_MARKERS
.iter()
.any(|marker| lowered.contains(marker))
}
fn looks_like_context_window_error(text: &str) -> bool {
let lowered = text.to_ascii_lowercase();
CONTEXT_WINDOW_ERROR_MARKERS
.iter()
.any(|marker| lowered.contains(marker))
}
/// Truncate `body` so the resulting snippet contains at most `max_chars`
/// characters (counted by Unicode scalar values, not bytes), preserving the
/// leading slice of the body that the caller most often needs to inspect.
fn truncate_body_snippet(body: &str, max_chars: usize) -> String {
let mut taken_chars = 0;
let mut byte_end = 0;
for (offset, character) in body.char_indices() {
if taken_chars >= max_chars {
break;
}
taken_chars += 1;
byte_end = offset + character.len_utf8();
}
if taken_chars >= max_chars && byte_end < body.len() {
format!("{}", &body[..byte_end])
} else {
body[..byte_end].to_string()
}
}
#[cfg(test)]
mod tests {
use super::{truncate_body_snippet, ApiError};
#[test]
fn json_deserialize_error_includes_provider_model_and_truncated_body_snippet() {
let raw_body = format!("{}{}", "x".repeat(190), "_TAIL_PAST_200_CHARS_MARKER_");
let source = serde_json::from_str::<serde_json::Value>("{not json")
.expect_err("invalid json should fail to parse");
let error = ApiError::json_deserialize("Anthropic", "claude-opus-4-6", &raw_body, source);
let rendered = error.to_string();
assert!(
rendered.starts_with("failed to parse Anthropic response for model claude-opus-4-6: "),
"rendered error should lead with provider and model: {rendered}"
);
assert!(
rendered.contains("first 200 chars of body: "),
"rendered error should label the body snippet: {rendered}"
);
let snippet = rendered
.split("first 200 chars of body: ")
.nth(1)
.expect("snippet section should be present");
assert!(
snippet.starts_with(&"x".repeat(190)),
"snippet should preserve the leading characters of the body: {snippet}"
);
assert!(
snippet.ends_with('…'),
"snippet should signal truncation with an ellipsis: {snippet}"
);
assert!(
!snippet.contains("_TAIL_PAST_200_CHARS_MARKER_"),
"snippet should drop characters past the 200-char cap: {snippet}"
);
assert_eq!(error.safe_failure_class(), "runtime_io");
assert_eq!(error.request_id(), None);
assert!(!error.is_retryable());
}
#[test]
fn truncate_body_snippet_keeps_short_bodies_intact() {
assert_eq!(truncate_body_snippet("hello", 200), "hello");
assert_eq!(truncate_body_snippet("", 200), "");
}
#[test]
fn truncate_body_snippet_caps_long_bodies_at_max_chars() {
let body = "a".repeat(250);
let snippet = truncate_body_snippet(&body, 200);
assert_eq!(snippet.chars().count(), 201, "200 chars + ellipsis");
assert!(snippet.ends_with('…'));
assert!(snippet.starts_with(&"a".repeat(200)));
}
#[test]
fn truncate_body_snippet_does_not_split_multibyte_characters() {
let body = "한글한글한글한글한글한글";
let snippet = truncate_body_snippet(body, 4);
assert_eq!(snippet, "한글한글…");
}
#[test]
fn detects_generic_fatal_wrapper_and_classifies_it_as_provider_internal() {
let error = ApiError::Api {
status: reqwest::StatusCode::INTERNAL_SERVER_ERROR,
error_type: Some("api_error".to_string()),
message: Some(
"Something went wrong while processing your request. Please try again, or use /new to start a fresh session."
.to_string(),
),
request_id: Some("req_jobdori_123".to_string()),
body: String::new(),
retryable: true,
};
assert!(error.is_generic_fatal_wrapper());
assert_eq!(error.safe_failure_class(), "provider_internal");
assert_eq!(error.request_id(), Some("req_jobdori_123"));
assert!(error.to_string().contains("[trace req_jobdori_123]"));
}
#[test]
fn retries_exhausted_preserves_nested_request_id_and_failure_class() {
let error = ApiError::RetriesExhausted {
attempts: 3,
last_error: Box::new(ApiError::Api {
status: reqwest::StatusCode::BAD_GATEWAY,
error_type: Some("api_error".to_string()),
message: Some(
"Something went wrong while processing your request. Please try again, or use /new to start a fresh session."
.to_string(),
),
request_id: Some("req_nested_456".to_string()),
body: String::new(),
retryable: true,
}),
};
assert!(error.is_generic_fatal_wrapper());
assert_eq!(error.safe_failure_class(), "provider_retry_exhausted");
assert_eq!(error.request_id(), Some("req_nested_456"));
}
#[test]
fn classifies_provider_context_window_errors() {
let error = ApiError::Api {
status: reqwest::StatusCode::BAD_REQUEST,
error_type: Some("invalid_request_error".to_string()),
message: Some(
"This model's maximum context length is 200000 tokens, but your request used 230000 tokens."
.to_string(),
),
request_id: Some("req_ctx_123".to_string()),
body: String::new(),
retryable: false,
};
assert!(error.is_context_window_failure());
assert_eq!(error.safe_failure_class(), "context_window");
assert_eq!(error.request_id(), Some("req_ctx_123"));
}
#[test]
fn missing_credentials_without_hint_renders_the_canonical_message() {
// given
let error = ApiError::missing_credentials(
"Anthropic",
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
);
// when
let rendered = error.to_string();
// then
assert!(
rendered.starts_with(
"missing Anthropic credentials; export ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY before calling the Anthropic API"
),
"rendered error should lead with the canonical missing-credential message: {rendered}"
);
assert!(
!rendered.contains(" — hint: "),
"no hint should be appended when none is supplied: {rendered}"
);
}
#[test]
fn missing_credentials_with_hint_appends_the_hint_after_base_message() {
// given
let error = ApiError::missing_credentials_with_hint(
"Anthropic",
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
"I see OPENAI_API_KEY is set — if you meant to use the OpenAI-compat provider, prefix your model name with `openai/` so prefix routing selects it.",
);
// when
let rendered = error.to_string();
// then
assert!(
rendered.starts_with("missing Anthropic credentials;"),
"hint should be appended, not replace the base message: {rendered}"
);
let hint_marker = " — hint: I see OPENAI_API_KEY is set — if you meant to use the OpenAI-compat provider, prefix your model name with `openai/` so prefix routing selects it.";
assert!(
rendered.ends_with(hint_marker),
"rendered error should end with the hint: {rendered}"
);
// Classification semantics are unaffected by the presence of a hint.
assert_eq!(error.safe_failure_class(), "provider_auth");
assert!(!error.is_retryable());
assert_eq!(error.request_id(), None);
}
}

View File

@ -0,0 +1,344 @@
use crate::error::ApiError;
const HTTP_PROXY_KEYS: [&str; 2] = ["HTTP_PROXY", "http_proxy"];
const HTTPS_PROXY_KEYS: [&str; 2] = ["HTTPS_PROXY", "https_proxy"];
const NO_PROXY_KEYS: [&str; 2] = ["NO_PROXY", "no_proxy"];
/// Snapshot of the proxy-related environment variables that influence the
/// outbound HTTP client. Captured up front so callers can inspect, log, and
/// test the resolved configuration without re-reading the process environment.
///
/// When `proxy_url` is set it acts as a single catch-all proxy for both
/// HTTP and HTTPS traffic, taking precedence over the per-scheme fields.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProxyConfig {
pub http_proxy: Option<String>,
pub https_proxy: Option<String>,
pub no_proxy: Option<String>,
/// Optional unified proxy URL that applies to both HTTP and HTTPS.
/// When set, this takes precedence over `http_proxy` and `https_proxy`.
pub proxy_url: Option<String>,
}
impl ProxyConfig {
/// Read proxy settings from the live process environment, honouring both
/// the upper- and lower-case spellings used by curl, git, and friends.
#[must_use]
pub fn from_env() -> Self {
Self::from_lookup(|key| std::env::var(key).ok())
}
/// Create a proxy configuration from a single URL that applies to both
/// HTTP and HTTPS traffic. This is the config-file alternative to setting
/// `HTTP_PROXY` and `HTTPS_PROXY` environment variables separately.
#[must_use]
pub fn from_proxy_url(url: impl Into<String>) -> Self {
Self {
proxy_url: Some(url.into()),
..Self::default()
}
}
fn from_lookup<F>(mut lookup: F) -> Self
where
F: FnMut(&str) -> Option<String>,
{
Self {
http_proxy: first_non_empty(&HTTP_PROXY_KEYS, &mut lookup),
https_proxy: first_non_empty(&HTTPS_PROXY_KEYS, &mut lookup),
no_proxy: first_non_empty(&NO_PROXY_KEYS, &mut lookup),
proxy_url: None,
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.proxy_url.is_none() && self.http_proxy.is_none() && self.https_proxy.is_none()
}
}
/// Build a `reqwest::Client` that honours the standard `HTTP_PROXY`,
/// `HTTPS_PROXY`, and `NO_PROXY` environment variables. When no proxy is
/// configured the client behaves identically to `reqwest::Client::new()`.
pub fn build_http_client() -> Result<reqwest::Client, ApiError> {
build_http_client_with(&ProxyConfig::from_env())
}
/// Infallible counterpart to [`build_http_client`] for constructors that
/// historically returned `Self` rather than `Result<Self, _>`. When the proxy
/// configuration is malformed we fall back to a default client so that
/// callers retain the previous behaviour and the failure surfaces on the
/// first outbound request instead of at construction time.
#[must_use]
pub fn build_http_client_or_default() -> reqwest::Client {
build_http_client().unwrap_or_else(|_| reqwest::Client::new())
}
/// Build a `reqwest::Client` from an explicit [`ProxyConfig`]. Used by tests
/// and by callers that want to override process-level environment lookups.
///
/// When `config.proxy_url` is set it overrides the per-scheme `http_proxy`
/// and `https_proxy` fields and is registered as both an HTTP and HTTPS
/// proxy so a single value can route every outbound request.
pub fn build_http_client_with(config: &ProxyConfig) -> Result<reqwest::Client, ApiError> {
let mut builder = reqwest::Client::builder().no_proxy();
let no_proxy = config
.no_proxy
.as_deref()
.and_then(reqwest::NoProxy::from_string);
let (http_proxy_url, https_proxy_url) = match config.proxy_url.as_deref() {
Some(unified) => (Some(unified), Some(unified)),
None => (config.http_proxy.as_deref(), config.https_proxy.as_deref()),
};
if let Some(url) = https_proxy_url {
let mut proxy = reqwest::Proxy::https(url)?;
if let Some(filter) = no_proxy.clone() {
proxy = proxy.no_proxy(Some(filter));
}
builder = builder.proxy(proxy);
}
if let Some(url) = http_proxy_url {
let mut proxy = reqwest::Proxy::http(url)?;
if let Some(filter) = no_proxy.clone() {
proxy = proxy.no_proxy(Some(filter));
}
builder = builder.proxy(proxy);
}
Ok(builder.build()?)
}
fn first_non_empty<F>(keys: &[&str], lookup: &mut F) -> Option<String>
where
F: FnMut(&str) -> Option<String>,
{
keys.iter()
.find_map(|key| lookup(key).filter(|value| !value.is_empty()))
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::{build_http_client_with, ProxyConfig};
fn config_from_map(pairs: &[(&str, &str)]) -> ProxyConfig {
let map: HashMap<String, String> = pairs
.iter()
.map(|(key, value)| ((*key).to_string(), (*value).to_string()))
.collect();
ProxyConfig::from_lookup(|key| map.get(key).cloned())
}
#[test]
fn proxy_config_is_empty_when_no_env_vars_are_set() {
// given
let config = config_from_map(&[]);
// when
let empty = config.is_empty();
// then
assert!(empty);
assert_eq!(config, ProxyConfig::default());
}
#[test]
fn proxy_config_reads_uppercase_http_https_and_no_proxy() {
// given
let pairs = [
("HTTP_PROXY", "http://proxy.internal:3128"),
("HTTPS_PROXY", "http://secure.internal:3129"),
("NO_PROXY", "localhost,127.0.0.1,.corp"),
];
// when
let config = config_from_map(&pairs);
// then
assert_eq!(
config.http_proxy.as_deref(),
Some("http://proxy.internal:3128")
);
assert_eq!(
config.https_proxy.as_deref(),
Some("http://secure.internal:3129")
);
assert_eq!(
config.no_proxy.as_deref(),
Some("localhost,127.0.0.1,.corp")
);
assert!(!config.is_empty());
}
#[test]
fn proxy_config_falls_back_to_lowercase_keys() {
// given
let pairs = [
("http_proxy", "http://lower.internal:3128"),
("https_proxy", "http://lower-secure.internal:3129"),
("no_proxy", ".lower"),
];
// when
let config = config_from_map(&pairs);
// then
assert_eq!(
config.http_proxy.as_deref(),
Some("http://lower.internal:3128")
);
assert_eq!(
config.https_proxy.as_deref(),
Some("http://lower-secure.internal:3129")
);
assert_eq!(config.no_proxy.as_deref(), Some(".lower"));
}
#[test]
fn proxy_config_prefers_uppercase_over_lowercase_when_both_set() {
// given
let pairs = [
("HTTP_PROXY", "http://upper.internal:3128"),
("http_proxy", "http://lower.internal:3128"),
];
// when
let config = config_from_map(&pairs);
// then
assert_eq!(
config.http_proxy.as_deref(),
Some("http://upper.internal:3128")
);
}
#[test]
fn proxy_config_treats_empty_strings_as_unset() {
// given
let pairs = [("HTTP_PROXY", ""), ("http_proxy", "")];
// when
let config = config_from_map(&pairs);
// then
assert!(config.http_proxy.is_none());
}
#[test]
fn build_http_client_succeeds_when_no_proxy_is_configured() {
// given
let config = ProxyConfig::default();
// when
let result = build_http_client_with(&config);
// then
assert!(result.is_ok());
}
#[test]
fn build_http_client_succeeds_with_valid_http_and_https_proxies() {
// given
let config = ProxyConfig {
http_proxy: Some("http://proxy.internal:3128".to_string()),
https_proxy: Some("http://secure.internal:3129".to_string()),
no_proxy: Some("localhost,127.0.0.1".to_string()),
proxy_url: None,
};
// when
let result = build_http_client_with(&config);
// then
assert!(result.is_ok());
}
#[test]
fn build_http_client_returns_http_error_for_invalid_proxy_url() {
// given
let config = ProxyConfig {
http_proxy: None,
https_proxy: Some("not a url".to_string()),
no_proxy: None,
proxy_url: None,
};
// when
let result = build_http_client_with(&config);
// then
let error = result.expect_err("invalid proxy URL must be reported as a build failure");
assert!(
matches!(error, crate::error::ApiError::Http(_)),
"expected ApiError::Http for invalid proxy URL, got: {error:?}"
);
}
#[test]
fn from_proxy_url_sets_unified_field_and_leaves_per_scheme_empty() {
// given / when
let config = ProxyConfig::from_proxy_url("http://unified.internal:3128");
// then
assert_eq!(
config.proxy_url.as_deref(),
Some("http://unified.internal:3128")
);
assert!(config.http_proxy.is_none());
assert!(config.https_proxy.is_none());
assert!(!config.is_empty());
}
#[test]
fn build_http_client_succeeds_with_unified_proxy_url() {
// given
let config = ProxyConfig {
proxy_url: Some("http://unified.internal:3128".to_string()),
no_proxy: Some("localhost".to_string()),
..ProxyConfig::default()
};
// when
let result = build_http_client_with(&config);
// then
assert!(result.is_ok());
}
#[test]
fn proxy_url_takes_precedence_over_per_scheme_fields() {
// given both per-scheme and unified are set
let config = ProxyConfig {
http_proxy: Some("http://per-scheme.internal:1111".to_string()),
https_proxy: Some("http://per-scheme.internal:2222".to_string()),
no_proxy: None,
proxy_url: Some("http://unified.internal:3128".to_string()),
};
// when building succeeds (the unified URL is valid)
let result = build_http_client_with(&config);
// then
assert!(result.is_ok());
}
#[test]
fn build_http_client_returns_error_for_invalid_unified_proxy_url() {
// given
let config = ProxyConfig::from_proxy_url("not a url");
// when
let result = build_http_client_with(&config);
// then
assert!(
matches!(result, Err(crate::error::ApiError::Http(_))),
"invalid unified proxy URL should fail: {result:?}"
);
}
}

View File

@ -1,5 +1,7 @@
mod client;
mod error;
mod http_client;
mod prompt_cache;
mod providers;
mod sse;
mod types;
@ -9,10 +11,18 @@ pub use client::{
resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
};
pub use error::ApiError;
pub use providers::claw_provider::{AuthSource, ClawApiClient, ClawApiClient as ApiClient};
pub use http_client::{
build_http_client, build_http_client_or_default, build_http_client_with, ProxyConfig,
};
pub use prompt_cache::{
CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
PromptCacheStats,
};
pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource};
pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
pub use providers::{
detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind,
detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override,
resolve_model_alias, ProviderKind,
};
pub use sse::{parse_frame, SseParser};
pub use types::{
@ -21,3 +31,9 @@ pub use types::{
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
};
pub use telemetry::{
AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, JsonlTelemetrySink,
MemoryTelemetrySink, SessionTraceRecord, SessionTracer, TelemetryEvent, TelemetrySink,
DEFAULT_ANTHROPIC_VERSION,
};

View File

@ -0,0 +1,735 @@
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::types::{MessageRequest, MessageResponse, Usage};
const DEFAULT_COMPLETION_TTL_SECS: u64 = 30;
const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60;
const DEFAULT_BREAK_MIN_DROP: u32 = 2_000;
const MAX_SANITIZED_LENGTH: usize = 80;
const REQUEST_FINGERPRINT_VERSION: u32 = 1;
const REQUEST_FINGERPRINT_PREFIX: &str = "v1";
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
#[derive(Debug, Clone)]
pub struct PromptCacheConfig {
pub session_id: String,
pub completion_ttl: Duration,
pub prompt_ttl: Duration,
pub cache_break_min_drop: u32,
}
impl PromptCacheConfig {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: session_id.into(),
completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS),
prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS),
cache_break_min_drop: DEFAULT_BREAK_MIN_DROP,
}
}
}
impl Default for PromptCacheConfig {
fn default() -> Self {
Self::new("default")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCachePaths {
pub root: PathBuf,
pub session_dir: PathBuf,
pub completion_dir: PathBuf,
pub session_state_path: PathBuf,
pub stats_path: PathBuf,
}
impl PromptCachePaths {
#[must_use]
pub fn for_session(session_id: &str) -> Self {
let root = base_cache_root();
let session_dir = root.join(sanitize_path_segment(session_id));
let completion_dir = session_dir.join("completions");
Self {
root,
session_state_path: session_dir.join("session-state.json"),
stats_path: session_dir.join("stats.json"),
session_dir,
completion_dir,
}
}
#[must_use]
pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf {
self.completion_dir.join(format!("{request_hash}.json"))
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct PromptCacheStats {
pub tracked_requests: u64,
pub completion_cache_hits: u64,
pub completion_cache_misses: u64,
pub completion_cache_writes: u64,
pub expected_invalidations: u64,
pub unexpected_cache_breaks: u64,
pub total_cache_creation_input_tokens: u64,
pub total_cache_read_input_tokens: u64,
pub last_cache_creation_input_tokens: Option<u32>,
pub last_cache_read_input_tokens: Option<u32>,
pub last_request_hash: Option<String>,
pub last_completion_cache_key: Option<String>,
pub last_break_reason: Option<String>,
pub last_cache_source: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CacheBreakEvent {
pub unexpected: bool,
pub reason: String,
pub previous_cache_read_input_tokens: u32,
pub current_cache_read_input_tokens: u32,
pub token_drop: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptCacheRecord {
pub cache_break: Option<CacheBreakEvent>,
pub stats: PromptCacheStats,
}
#[derive(Debug, Clone)]
pub struct PromptCache {
inner: Arc<Mutex<PromptCacheInner>>,
}
impl PromptCache {
#[must_use]
pub fn new(session_id: impl Into<String>) -> Self {
Self::with_config(PromptCacheConfig::new(session_id))
}
#[must_use]
pub fn with_config(config: PromptCacheConfig) -> Self {
let paths = PromptCachePaths::for_session(&config.session_id);
let stats = read_json::<PromptCacheStats>(&paths.stats_path).unwrap_or_default();
let previous = read_json::<TrackedPromptState>(&paths.session_state_path);
Self {
inner: Arc::new(Mutex::new(PromptCacheInner {
config,
paths,
stats,
previous,
})),
}
}
#[must_use]
pub fn paths(&self) -> PromptCachePaths {
self.lock().paths.clone()
}
#[must_use]
pub fn stats(&self) -> PromptCacheStats {
self.lock().stats.clone()
}
#[must_use]
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
let request_hash = request_hash_hex(request);
let (paths, ttl) = {
let inner = self.lock();
(inner.paths.clone(), inner.config.completion_ttl)
};
let entry_path = paths.completion_entry_path(&request_hash);
let entry = read_json::<CompletionCacheEntry>(&entry_path);
let Some(entry) = entry else {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash);
persist_state(&inner);
return None;
};
if entry.fingerprint_version != current_fingerprint_version() {
let mut inner = self.lock();
inner.stats.completion_cache_misses += 1;
inner.stats.last_completion_cache_key = Some(request_hash.clone());
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs();
let mut inner = self.lock();
inner.stats.last_completion_cache_key = Some(request_hash.clone());
if expired {
inner.stats.completion_cache_misses += 1;
let _ = fs::remove_file(entry_path);
persist_state(&inner);
return None;
}
inner.stats.completion_cache_hits += 1;
apply_usage_to_stats(
&mut inner.stats,
&entry.response.usage,
&request_hash,
"completion-cache",
);
inner.previous = Some(TrackedPromptState::from_usage(
request,
&entry.response.usage,
));
persist_state(&inner);
Some(entry.response)
}
#[must_use]
pub fn record_response(
&self,
request: &MessageRequest,
response: &MessageResponse,
) -> PromptCacheRecord {
self.record_usage_internal(request, &response.usage, Some(response))
}
#[must_use]
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
self.record_usage_internal(request, usage, None)
}
fn record_usage_internal(
&self,
request: &MessageRequest,
usage: &Usage,
response: Option<&MessageResponse>,
) -> PromptCacheRecord {
let request_hash = request_hash_hex(request);
let mut inner = self.lock();
let previous = inner.previous.clone();
let current = TrackedPromptState::from_usage(request, usage);
let cache_break = detect_cache_break(&inner.config, previous.as_ref(), &current);
inner.stats.tracked_requests += 1;
apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response");
if let Some(event) = &cache_break {
if event.unexpected {
inner.stats.unexpected_cache_breaks += 1;
} else {
inner.stats.expected_invalidations += 1;
}
inner.stats.last_break_reason = Some(event.reason.clone());
}
inner.previous = Some(current);
if let Some(response) = response {
write_completion_entry(&inner.paths, &request_hash, response);
inner.stats.completion_cache_writes += 1;
}
persist_state(&inner);
PromptCacheRecord {
cache_break,
stats: inner.stats.clone(),
}
}
fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
}
#[derive(Debug)]
struct PromptCacheInner {
config: PromptCacheConfig,
paths: PromptCachePaths,
stats: PromptCacheStats,
previous: Option<TrackedPromptState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CompletionCacheEntry {
cached_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
response: MessageResponse,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct TrackedPromptState {
observed_at_unix_secs: u64,
#[serde(default = "current_fingerprint_version")]
fingerprint_version: u32,
model_hash: u64,
system_hash: u64,
tools_hash: u64,
messages_hash: u64,
cache_read_input_tokens: u32,
}
impl TrackedPromptState {
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
let hashes = RequestFingerprints::from_request(request);
Self {
observed_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
model_hash: hashes.model,
system_hash: hashes.system,
tools_hash: hashes.tools,
messages_hash: hashes.messages,
cache_read_input_tokens: usage.cache_read_input_tokens,
}
}
}
#[derive(Debug, Clone, Copy)]
struct RequestFingerprints {
model: u64,
system: u64,
tools: u64,
messages: u64,
}
impl RequestFingerprints {
fn from_request(request: &MessageRequest) -> Self {
Self {
model: hash_serializable(&request.model),
system: hash_serializable(&request.system),
tools: hash_serializable(&request.tools),
messages: hash_serializable(&request.messages),
}
}
}
fn detect_cache_break(
config: &PromptCacheConfig,
previous: Option<&TrackedPromptState>,
current: &TrackedPromptState,
) -> Option<CacheBreakEvent> {
let previous = previous?;
if previous.fingerprint_version != current.fingerprint_version {
return Some(CacheBreakEvent {
unexpected: false,
reason: format!(
"fingerprint version changed (v{} -> v{})",
previous.fingerprint_version, current.fingerprint_version
),
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop: previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens),
});
}
let token_drop = previous
.cache_read_input_tokens
.saturating_sub(current.cache_read_input_tokens);
if token_drop < config.cache_break_min_drop {
return None;
}
let mut reasons = Vec::new();
if previous.model_hash != current.model_hash {
reasons.push("model changed");
}
if previous.system_hash != current.system_hash {
reasons.push("system prompt changed");
}
if previous.tools_hash != current.tools_hash {
reasons.push("tool definitions changed");
}
if previous.messages_hash != current.messages_hash {
reasons.push("message payload changed");
}
let elapsed = current
.observed_at_unix_secs
.saturating_sub(previous.observed_at_unix_secs);
let (unexpected, reason) = if reasons.is_empty() {
if elapsed > config.prompt_ttl.as_secs() {
(
false,
format!("possible prompt cache TTL expiry after {elapsed}s"),
)
} else {
(
true,
"cache read tokens dropped while prompt fingerprint remained stable".to_string(),
)
}
} else {
(false, reasons.join(", "))
};
Some(CacheBreakEvent {
unexpected,
reason,
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
current_cache_read_input_tokens: current.cache_read_input_tokens,
token_drop,
})
}
fn apply_usage_to_stats(
stats: &mut PromptCacheStats,
usage: &Usage,
request_hash: &str,
source: &str,
) {
stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens);
stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens);
stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens);
stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens);
stats.last_request_hash = Some(request_hash.to_string());
stats.last_cache_source = Some(source.to_string());
}
fn persist_state(inner: &PromptCacheInner) {
let _ = ensure_cache_dirs(&inner.paths);
let _ = write_json(&inner.paths.stats_path, &inner.stats);
if let Some(previous) = &inner.previous {
let _ = write_json(&inner.paths.session_state_path, previous);
}
}
fn write_completion_entry(
paths: &PromptCachePaths,
request_hash: &str,
response: &MessageResponse,
) {
let _ = ensure_cache_dirs(paths);
let entry = CompletionCacheEntry {
cached_at_unix_secs: now_unix_secs(),
fingerprint_version: current_fingerprint_version(),
response: response.clone(),
};
let _ = write_json(&paths.completion_entry_path(request_hash), &entry);
}
fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> {
fs::create_dir_all(&paths.completion_dir)
}
fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> {
let json = serde_json::to_vec_pretty(value)
.map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?;
fs::write(path, json)
}
fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> Option<T> {
let bytes = fs::read(path).ok()?;
serde_json::from_slice(&bytes).ok()
}
fn request_hash_hex(request: &MessageRequest) -> String {
format!(
"{REQUEST_FINGERPRINT_PREFIX}-{:016x}",
hash_serializable(request)
)
}
fn hash_serializable<T: Serialize>(value: &T) -> u64 {
let json = serde_json::to_vec(value).unwrap_or_default();
stable_hash_bytes(&json)
}
fn sanitize_path_segment(value: &str) -> String {
let sanitized: String = value
.chars()
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' })
.collect();
if sanitized.len() <= MAX_SANITIZED_LENGTH {
return sanitized;
}
let suffix = format!("-{:x}", hash_string(value));
format!(
"{}{}",
&sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())],
suffix
)
}
fn hash_string(value: &str) -> u64 {
stable_hash_bytes(value.as_bytes())
}
fn base_cache_root() -> PathBuf {
if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") {
return PathBuf::from(config_home)
.join("cache")
.join("prompt-cache");
}
if let Some(home) = std::env::var_os("HOME") {
return PathBuf::from(home)
.join(".claude")
.join("cache")
.join("prompt-cache");
}
std::env::temp_dir().join("claude-prompt-cache")
}
fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_secs())
}
const fn current_fingerprint_version() -> u32 {
REQUEST_FINGERPRINT_VERSION
}
fn stable_hash_bytes(bytes: &[u8]) -> u64 {
let mut hash = FNV_OFFSET_BASIS;
for byte in bytes {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(FNV_PRIME);
}
hash
}
#[cfg(test)]
mod tests {
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use super::{
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
};
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[test]
fn path_builder_sanitizes_session_identifier() {
let paths = PromptCachePaths::for_session("session:/with spaces");
let session_dir = paths
.session_dir
.file_name()
.and_then(|value| value.to_str())
.expect("session dir name");
assert_eq!(session_dir, "session--with-spaces");
assert!(paths.completion_dir.ends_with("completions"));
assert!(paths.stats_path.ends_with("stats.json"));
assert!(paths.session_state_path.ends_with("session-state.json"));
}
#[test]
fn request_fingerprint_drives_unexpected_break_detection() {
let request = sample_request("same");
let previous = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(event.unexpected);
assert!(event.reason.contains("stable"));
}
#[test]
fn changed_prompt_marks_break_as_expected() {
let previous_request = sample_request("first");
let current_request = sample_request("second");
let previous = TrackedPromptState::from_usage(
&previous_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 6_000,
output_tokens: 0,
},
);
let current = TrackedPromptState::from_usage(
&current_request,
&Usage {
input_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 1_000,
output_tokens: 0,
},
);
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), &current)
.expect("break should be detected");
assert!(!event.unexpected);
assert!(event.reason.contains("message payload changed"));
}
#[test]
fn completion_cache_round_trip_persists_recent_response() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("unit-test-session");
let request = sample_request("cache me");
let response = sample_response(42, 12, "cached");
assert!(cache.lookup_completion(&request).is_none());
let record = cache.record_response(&request, &response);
assert!(record.cache_break.is_none());
let cached = cache
.lookup_completion(&request)
.expect("cached response should load");
assert_eq!(cached.content, response.content);
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 1);
assert_eq!(stats.completion_cache_misses, 1);
assert_eq!(stats.completion_cache_writes, 1);
let persisted = read_json::<super::PromptCacheStats>(&cache.paths().stats_path)
.expect("stats should persist");
assert_eq!(persisted.completion_cache_hits, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn distinct_requests_do_not_collide_in_completion_cache() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-distinct-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::new("distinct-request-session");
let first_request = sample_request("first");
let second_request = sample_request("second");
let response = sample_response(42, 12, "cached");
let _ = cache.record_response(&first_request, &response);
assert!(cache.lookup_completion(&second_request).is_none());
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn expired_completion_entries_are_not_reused() {
let _guard = test_env_lock();
let temp_root = std::env::temp_dir().join(format!(
"prompt-cache-expired-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let cache = PromptCache::with_config(PromptCacheConfig {
session_id: "expired-session".to_string(),
completion_ttl: Duration::ZERO,
..PromptCacheConfig::default()
});
let request = sample_request("expire me");
let response = sample_response(7, 3, "stale");
let _ = cache.record_response(&request, &response);
assert!(cache.lookup_completion(&request).is_none());
let stats = cache.stats();
assert_eq!(stats.completion_cache_hits, 0);
assert_eq!(stats.completion_cache_misses, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[test]
fn sanitize_path_caps_long_values() {
let long_value = "x".repeat(200);
let sanitized = sanitize_path_segment(&long_value);
assert!(sanitized.len() <= 80);
}
#[test]
fn request_hashes_are_versioned_and_stable() {
let request = sample_request("stable");
let first = request_hash_hex(&request);
let second = request_hash_hex(&request);
assert_eq!(first, second);
assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX));
}
fn sample_request(text: &str) -> MessageRequest {
MessageRequest {
model: "claude-3-7-sonnet-latest".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text(text)],
system: Some("system".to_string()),
tools: None,
tool_choice: None,
stream: false,
..Default::default()
}
}
fn sample_response(
cache_read_input_tokens: u32,
output_tokens: u32,
text: &str,
) -> MessageResponse {
MessageResponse {
id: "msg_test".to_string(),
kind: "message".to_string(),
role: "assistant".to_string(),
content: vec![OutputContentBlock::Text {
text: text.to_string(),
}],
model: "claude-3-7-sonnet-latest".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
usage: Usage {
input_tokens: 10,
cache_creation_input_tokens: 5,
cache_read_input_tokens,
output_tokens,
},
request_id: Some("req_test".to_string()),
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,11 @@
use crate::error::ApiError;
use crate::types::StreamEvent;
use serde_json::Value;
use reqwest::StatusCode;
#[derive(Debug, Default)]
pub struct SseParser {
buffer: Vec<u8>,
provider: Option<String>,
model: Option<String>,
}
impl SseParser {
@ -14,12 +14,23 @@ impl SseParser {
Self::default()
}
/// Attach the provider name and model to this parser so that JSON
/// deserialization failures within streamed frames carry enough context
/// for callers to understand which upstream produced the unparseable
/// payload.
#[must_use]
pub fn with_context(mut self, provider: impl Into<String>, model: impl Into<String>) -> Self {
self.provider = Some(provider.into());
self.model = Some(model.into());
self
}
pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<StreamEvent>, ApiError> {
self.buffer.extend_from_slice(chunk);
let mut events = Vec::new();
while let Some(frame) = self.next_frame() {
if let Some(event) = parse_frame(&frame)? {
if let Some(event) = self.parse_frame_with_context(&frame)? {
events.push(event);
}
}
@ -33,12 +44,18 @@ impl SseParser {
}
let trailing = std::mem::take(&mut self.buffer);
match parse_frame(&String::from_utf8_lossy(&trailing))? {
match self.parse_frame_with_context(&String::from_utf8_lossy(&trailing))? {
Some(event) => Ok(vec![event]),
None => Ok(Vec::new()),
}
}
fn parse_frame_with_context(&self, frame: &str) -> Result<Option<StreamEvent>, ApiError> {
let provider = self.provider.as_deref().unwrap_or("unknown");
let model = self.model.as_deref().unwrap_or("unknown");
parse_frame_with_provider(frame, provider, model)
}
fn next_frame(&mut self) -> Option<String> {
let separator = self
.buffer
@ -63,6 +80,14 @@ impl SseParser {
}
pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
parse_frame_with_provider(frame, "unknown", "unknown")
}
pub(crate) fn parse_frame_with_provider(
frame: &str,
provider: &str,
model: &str,
) -> Result<Option<StreamEvent>, ApiError> {
let trimmed = frame.trim();
if trimmed.is_empty() {
return Ok(None);
@ -97,75 +122,9 @@ pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
return Ok(None);
}
if matches!(event_name, Some("error")) {
return Err(parse_error_event(&payload));
}
// Some "Anthropic-compatible" gateways put the event type in the SSE `event:` field,
// and omit the `{ "type": ... }` discriminator from the JSON `data:` payload.
// Our Rust enums are tagged with `#[serde(tag = "type")]`, so we synthesize it here.
match serde_json::from_str::<StreamEvent>(&payload) {
Ok(event) => Ok(Some(event)),
Err(error) => {
// Best-effort: if we have an SSE event name and the payload is a JSON object
// without a `type` field, inject it and retry.
let Some(event_name) = event_name else {
return Err(ApiError::from(error));
};
let Ok(Value::Object(mut object)) = serde_json::from_str::<Value>(&payload) else {
return Err(ApiError::from(error));
};
if object
.get("type")
.and_then(Value::as_str)
.is_some_and(|value| value == "error")
{
return Err(parse_error_object(&object, payload));
}
if object.contains_key("type") {
return Err(ApiError::from(error));
}
object.insert("type".to_string(), Value::String(event_name.to_string()));
serde_json::from_value::<StreamEvent>(Value::Object(object))
serde_json::from_str::<StreamEvent>(&payload)
.map(Some)
.map_err(ApiError::from)
}
}
}
fn parse_error_event(payload: &str) -> ApiError {
match serde_json::from_str::<Value>(payload) {
Ok(Value::Object(object)) => parse_error_object(&object, payload.to_string()),
_ => ApiError::Api {
status: StatusCode::BAD_GATEWAY,
error_type: Some("stream_error".to_string()),
message: Some(payload.to_string()),
body: payload.to_string(),
retryable: false,
},
}
}
fn parse_error_object(object: &serde_json::Map<String, Value>, body: String) -> ApiError {
let nested = object.get("error").and_then(Value::as_object);
let error_type = nested
.and_then(|error| error.get("type"))
.or_else(|| object.get("type"))
.and_then(Value::as_str)
.map(ToOwned::to_owned);
let message = nested
.and_then(|error| error.get("message"))
.or_else(|| object.get("message"))
.and_then(Value::as_str)
.map(ToOwned::to_owned);
ApiError::Api {
status: StatusCode::BAD_GATEWAY,
error_type,
message,
body,
retryable: false,
}
.map_err(|error| ApiError::json_deserialize(provider, model, &payload, error))
}
#[cfg(test)]
@ -263,26 +222,6 @@ mod tests {
assert_eq!(event, None);
}
#[test]
fn parses_event_name_when_payload_omits_type() {
let frame = concat!("event: message_stop\n", "data: {}\n\n");
let event = parse_frame(frame).expect("frame should parse");
assert_eq!(event, Some(StreamEvent::MessageStop(crate::types::MessageStopEvent {})));
}
#[test]
fn surfaces_stream_error_events() {
let frame = concat!(
"event: error\n",
"data: {\"error\":{\"type\":\"invalid_request_error\",\"message\":\"bad input\"}}\n\n"
);
let error = parse_frame(frame).expect_err("error frame should surface");
assert_eq!(
error.to_string(),
"api returned 502 Bad Gateway (invalid_request_error): bad input"
);
}
#[test]
fn parses_split_json_across_data_lines() {
let frame = concat!(
@ -364,4 +303,28 @@ mod tests {
))
);
}
#[test]
fn given_message_delta_frame_with_empty_usage_when_parsed_then_usage_defaults_to_zero() {
// given
let frame = concat!(
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{}}\n\n"
);
// when
let event = parse_frame(frame).expect("frame should parse");
// then
assert_eq!(
event,
Some(StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
delta: MessageDelta {
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
},
usage: Usage::default(),
}))
);
}
}

View File

@ -1,7 +1,8 @@
use runtime::{pricing_for_model, TokenUsage, UsageCostEstimate};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct MessageRequest {
pub model: String,
pub max_tokens: u32,
@ -14,6 +15,22 @@ pub struct MessageRequest {
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
/// OpenAI-compatible tuning parameters. Optional — omitted from payload when None.
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
/// Reasoning effort level for OpenAI-compatible reasoning models (e.g. `o4-mini`).
/// Accepted values: `"low"`, `"medium"`, `"high"`. Omitted when `None`.
/// Silently ignored by backends that do not support it.
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
}
impl MessageRequest {
@ -75,14 +92,6 @@ pub enum InputContentBlock {
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_error: bool,
},
Thinking {
thinking: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
signature: Option<String>,
},
RedactedThinking {
data: Value,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
@ -120,6 +129,7 @@ pub struct MessageResponse {
pub stop_reason: Option<String>,
#[serde(default)]
pub stop_sequence: Option<String>,
#[serde(default)]
pub usage: Usage,
#[serde(default)]
pub request_id: Option<String>,
@ -154,20 +164,44 @@ pub enum OutputContentBlock {
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct Usage {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub cache_creation_input_tokens: u32,
#[serde(default)]
pub cache_read_input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
}
impl Usage {
#[must_use]
pub const fn total_tokens(&self) -> u32 {
self.input_tokens + self.output_tokens
self.input_tokens
+ self.output_tokens
+ self.cache_creation_input_tokens
+ self.cache_read_input_tokens
}
#[must_use]
pub const fn token_usage(&self) -> TokenUsage {
TokenUsage {
input_tokens: self.input_tokens,
output_tokens: self.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens,
}
}
#[must_use]
pub fn estimated_cost_usd(&self, model: &str) -> UsageCostEstimate {
let usage = self.token_usage();
pricing_for_model(model).map_or_else(
|| usage.estimate_cost_usd(),
|pricing| usage.estimate_cost_usd_with_pricing(pricing),
)
}
}
@ -179,6 +213,7 @@ pub struct MessageStartEvent {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageDeltaEvent {
pub delta: MessageDelta,
#[serde(default)]
pub usage: Usage,
}
@ -229,3 +264,47 @@ pub enum StreamEvent {
ContentBlockStop(ContentBlockStopEvent),
MessageStop(MessageStopEvent),
}
#[cfg(test)]
mod tests {
use runtime::format_usd;
use super::{MessageResponse, Usage};
#[test]
fn usage_total_tokens_includes_cache_tokens() {
let usage = Usage {
input_tokens: 10,
cache_creation_input_tokens: 2,
cache_read_input_tokens: 3,
output_tokens: 4,
};
assert_eq!(usage.total_tokens(), 19);
assert_eq!(usage.token_usage().total_tokens(), 19);
}
#[test]
fn message_response_estimates_cost_from_model_usage() {
let response = MessageResponse {
id: "msg_cost".to_string(),
kind: "message".to_string(),
role: "assistant".to_string(),
content: Vec::new(),
model: "claude-sonnet-4-20250514".to_string(),
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
usage: Usage {
input_tokens: 1_000_000,
cache_creation_input_tokens: 100_000,
cache_read_input_tokens: 200_000,
output_tokens: 500_000,
},
request_id: None,
};
let cost = response.usage.estimated_cost_usd(&response.model);
assert_eq!(format_usd(cost.total_cost_usd()), "$54.6750");
assert_eq!(response.total_tokens(), 1_800_000);
}
}

View File

@ -1,17 +1,27 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{Mutex as StdMutex, OnceLock};
use std::time::Duration;
use api::{
ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent,
AnthropicClient, ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent,
ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest,
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
OutputContentBlock, PromptCache, PromptCacheConfig, ProviderClient, StreamEvent, ToolChoice,
ToolDefinition,
};
use serde_json::json;
use telemetry::{ClientIdentity, MemoryTelemetrySink, SessionTracer, TelemetryEvent};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
LOCK.get_or_init(|| StdMutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[tokio::test]
async fn send_message_posts_json_and_parses_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
@ -20,8 +30,8 @@ async fn send_message_posts_json_and_parses_response() {
"\"id\":\"msg_test\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claw\"}],",
"\"model\":\"claude-sonnet-4-6\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
@ -45,10 +55,12 @@ async fn send_message_posts_json_and_parses_response() {
assert_eq!(response.id, "msg_test");
assert_eq!(response.total_tokens(), 16);
assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
assert_eq!(response.usage.cache_creation_input_tokens, 0);
assert_eq!(response.usage.cache_read_input_tokens, 0);
assert_eq!(
response.content,
vec![OutputContentBlock::Text {
text: "Hello from Claw".to_string(),
text: "Hello from Claude".to_string(),
}]
);
@ -64,23 +76,258 @@ async fn send_message_posts_json_and_parses_response() {
request.headers.get("authorization").map(String::as_str),
Some("Bearer proxy-token")
);
assert_eq!(
request.headers.get("anthropic-version").map(String::as_str),
Some("2023-06-01")
);
assert_eq!(
request.headers.get("user-agent").map(String::as_str),
Some("claude-code/0.1.0")
);
assert_eq!(
request.headers.get("anthropic-beta").map(String::as_str),
Some("claude-code-20250219,prompt-caching-scope-2026-01-05")
);
let body: serde_json::Value =
serde_json::from_str(&request.body).expect("request body should be json");
assert_eq!(
body.get("model").and_then(serde_json::Value::as_str),
Some("claude-sonnet-4-6")
Some("claude-3-7-sonnet-latest")
);
assert!(body.get("stream").is_none());
assert_eq!(body["tools"][0]["name"], json!("get_weather"));
assert_eq!(body["tool_choice"]["type"], json!("auto"));
assert!(
body.get("betas").is_none(),
"betas must travel via the anthropic-beta header, not the request body"
);
}
#[tokio::test]
async fn send_message_blocks_oversized_requests_before_the_http_call() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response("200 OK", "application/json", "{}")],
)
.await;
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
let error = client
.send_message(&MessageRequest {
model: "claude-sonnet-4-6".to_string(),
max_tokens: 64_000,
messages: vec![InputMessage {
role: "user".to_string(),
content: vec![InputContentBlock::Text {
text: "x".repeat(600_000),
}],
}],
system: Some("Keep the answer short.".to_string()),
tools: None,
tool_choice: None,
stream: false,
..Default::default()
})
.await
.expect_err("oversized request should fail local context-window preflight");
assert!(matches!(error, ApiError::ContextWindowExceeded { .. }));
assert!(
state.lock().await.is_empty(),
"preflight failure should avoid any upstream HTTP request"
);
}
#[tokio::test]
async fn send_message_applies_request_profile_and_records_telemetry() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response_with_headers(
"200 OK",
"application/json",
concat!(
"{",
"\"id\":\"msg_profile\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"ok\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":1,\"cache_creation_input_tokens\":2,\"cache_read_input_tokens\":3,\"output_tokens\":1}",
"}"
),
&[("request-id", "req_profile_123")],
)],
)
.await;
let sink = Arc::new(MemoryTelemetrySink::default());
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_client_identity(ClientIdentity::new("claude-code", "9.9.9").with_runtime("rust-cli"))
.with_beta("tools-2026-04-01")
.with_extra_body_param("metadata", json!({"source": "clawd-code"}))
.with_session_tracer(SessionTracer::new("session-telemetry", sink.clone()));
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.request_id.as_deref(), Some("req_profile_123"));
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert_eq!(
request.headers.get("anthropic-beta").map(String::as_str),
Some("claude-code-20250219,prompt-caching-scope-2026-01-05,tools-2026-04-01")
);
assert_eq!(
request.headers.get("user-agent").map(String::as_str),
Some("claude-code/9.9.9")
);
let body: serde_json::Value =
serde_json::from_str(&request.body).expect("request body should be json");
assert_eq!(body["metadata"]["source"], json!("clawd-code"));
assert!(
body.get("betas").is_none(),
"betas must travel via the anthropic-beta header, not the request body"
);
let events = sink.events();
assert_eq!(events.len(), 6);
assert!(matches!(
&events[0],
TelemetryEvent::HttpRequestStarted {
session_id,
attempt: 1,
method,
path,
..
} if session_id == "session-telemetry" && method == "POST" && path == "/v1/messages"
));
assert!(matches!(
&events[1],
TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_started"
));
assert!(matches!(
&events[2],
TelemetryEvent::HttpRequestSucceeded {
request_id,
status: 200,
..
} if request_id.as_deref() == Some("req_profile_123")
));
assert!(matches!(
&events[3],
TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_succeeded"
));
assert!(matches!(
&events[4],
TelemetryEvent::Analytics(event)
if event.namespace == "api"
&& event.action == "message_usage"
&& event.properties.get("request_id") == Some(&json!("req_profile_123"))
&& event.properties.get("total_tokens") == Some(&json!(7))
&& event.properties.get("estimated_cost_usd") == Some(&json!("$0.0001"))
));
assert!(matches!(
&events[5],
TelemetryEvent::SessionTrace(trace) if trace.name == "analytics"
));
}
#[tokio::test]
async fn send_message_parses_prompt_cache_token_usage_from_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"msg_cache_tokens\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}",
"}"
);
let server = spawn_server(
state,
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.usage.input_tokens, 12);
assert_eq!(response.usage.cache_creation_input_tokens, 321);
assert_eq!(response.usage.cache_read_input_tokens, 654);
assert_eq!(response.usage.output_tokens, 4);
}
#[tokio::test]
async fn given_empty_usage_object_when_send_message_parses_response_then_usage_defaults_to_zero() {
// given
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"msg_empty_usage\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{}",
"}"
);
let server = spawn_server(
state,
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
// when
let response = client
.send_message(&sample_request(false))
.await
.expect("response with empty usage object should still parse");
// then
assert_eq!(response.id, "msg_empty_usage");
assert_eq!(response.total_tokens(), 0);
assert_eq!(response.usage.input_tokens, 0);
assert_eq!(response.usage.cache_creation_input_tokens, 0);
assert_eq!(response.usage.cache_read_input_tokens, 0);
assert_eq!(response.usage.output_tokens, 0);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn stream_message_parses_sse_events_with_tool_use() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-stream-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!(
"event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n",
"event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
"event: content_block_delta\n",
@ -88,7 +335,7 @@ async fn stream_message_parses_sse_events_with_tool_use() {
"event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n"
@ -106,7 +353,8 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let client = ApiClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string()))
.with_base_url(server.base_url());
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::new("stream-session"));
let mut stream = client
.stream_message(&sample_request(false))
.await
@ -160,6 +408,20 @@ async fn stream_message_parses_sse_events_with_tool_use() {
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true"));
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.tracked_requests, 1);
assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34));
assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55));
assert_eq!(
cache_stats.last_cache_source.as_deref(),
Some("api-response")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
@ -176,7 +438,7 @@ async fn retries_retryable_failures_before_succeeding() {
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
),
],
)
@ -196,28 +458,28 @@ async fn retries_retryable_failures_before_succeeding() {
}
#[tokio::test]
async fn provider_client_dispatches_api_requests() {
async fn provider_client_dispatches_anthropic_requests() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
"{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
)],
)
.await;
let client = ProviderClient::from_model_with_default_auth(
let client = ProviderClient::from_model_with_anthropic_auth(
"claude-sonnet-4-6",
Some(AuthSource::ApiKey("test-key".to_string())),
)
.expect("api provider client should be constructed");
.expect("anthropic provider client should be constructed");
let client = match client {
ProviderClient::ClawApi(client) => {
ProviderClient::ClawApi(client.with_base_url(server.base_url()))
ProviderClient::Anthropic(client) => {
ProviderClient::Anthropic(client.with_base_url(server.base_url()))
}
other => panic!("expected default provider, got {other:?}"),
other => panic!("expected anthropic provider, got {other:?}"),
};
let response = client
@ -284,13 +546,194 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
}
}
#[tokio::test]
async fn retries_multiple_retryable_failures_with_exponential_backoff_and_jitter() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![
http_response(
"429 Too Many Requests",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
),
http_response(
"500 Internal Server Error",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"api_error\",\"message\":\"boom\"}}",
),
http_response(
"503 Service Unavailable",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
),
http_response(
"429 Too Many Requests",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down again\"}}",
),
http_response(
"503 Service Unavailable",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
),
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_exp_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered after 5\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
),
],
)
.await;
let client = ApiClient::new("test-key")
.with_base_url(server.base_url())
.with_retry_policy(8, Duration::from_millis(1), Duration::from_millis(4));
let started_at = std::time::Instant::now();
let response = client
.send_message(&sample_request(false))
.await
.expect("8-retry policy should absorb 5 retryable failures");
let elapsed = started_at.elapsed();
assert_eq!(response.total_tokens(), 5);
assert_eq!(
state.lock().await.len(),
6,
"client should issue 1 original + 5 retry requests before the 200"
);
// Jittered sleeps are bounded by 2 * max_backoff per retry (base + jitter),
// so 5 sleeps fit comfortably below this upper bound with generous slack.
assert!(
elapsed < Duration::from_secs(5),
"retries should complete promptly, took {elapsed:?}"
);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_reuses_recent_completion_cache_entries() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-cache-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}",
)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::new("integration-session"));
let first = client
.send_message(&sample_request(false))
.await
.expect("first request should succeed");
let second = client
.send_message(&sample_request(false))
.await
.expect("second request should reuse cache");
assert_eq!(first.content, second.content);
assert_eq!(state.lock().await.len(), 1);
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.completion_cache_hits, 1);
assert_eq!(cache_stats.completion_cache_misses, 1);
assert_eq!(cache_stats.completion_cache_writes, 1);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn send_message_tracks_unexpected_prompt_cache_breaks() {
let _guard = env_lock();
let temp_root = std::env::temp_dir().join(format!(
"api-prompt-break-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time")
.as_nanos()
));
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state,
vec![
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}",
),
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}",
),
],
)
.await;
let request = sample_request(false);
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_prompt_cache(PromptCache::with_config(PromptCacheConfig {
session_id: "break-session".to_string(),
completion_ttl: Duration::from_secs(0),
..PromptCacheConfig::default()
}));
client
.send_message(&request)
.await
.expect("first response should succeed");
client
.send_message(&request)
.await
.expect("second response should succeed");
let cache_stats = client
.prompt_cache_stats()
.expect("prompt cache stats should exist");
assert_eq!(cache_stats.unexpected_cache_breaks, 1);
assert_eq!(
cache_stats.last_break_reason.as_deref(),
Some("cache read tokens dropped while prompt fingerprint remained stable")
);
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
std::env::remove_var("CLAUDE_CONFIG_HOME");
}
#[tokio::test]
#[ignore = "requires ANTHROPIC_API_KEY and network access"]
async fn live_stream_smoke_test() {
let client = ApiClient::from_env().expect("ANTHROPIC_API_KEY must be set");
let mut stream = client
.stream_message(&MessageRequest {
model: std::env::var("CLAW_MODEL").unwrap_or_else(|_| "claude-sonnet-4-6".to_string()),
model: std::env::var("ANTHROPIC_MODEL")
.unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
max_tokens: 32,
messages: vec![InputMessage::user_text(
"Reply with exactly: hello from rust",
@ -299,6 +742,7 @@ async fn live_stream_smoke_test() {
tools: None,
tool_choice: None,
stream: false,
..Default::default()
})
.await
.expect("live stream should start");
@ -450,7 +894,7 @@ fn http_response_with_headers(
fn sample_request(stream: bool) -> MessageRequest {
MessageRequest {
model: "claude-sonnet-4-6".to_string(),
model: "claude-3-7-sonnet-latest".to_string(),
max_tokens: 64,
messages: vec![InputMessage {
role: "user".to_string(),
@ -479,5 +923,6 @@ fn sample_request(stream: bool) -> MessageRequest {
}]),
tool_choice: Some(ToolChoice::Auto),
stream,
..Default::default()
}
}

View File

@ -4,9 +4,10 @@ use std::sync::Arc;
use std::sync::{Mutex as StdMutex, OnceLock};
use api::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest,
OpenAiCompatClient, OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent,
ToolChoice, ToolDefinition,
};
use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -62,6 +63,43 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() {
assert_eq!(body["tools"][0]["type"], json!("function"));
}
#[tokio::test]
async fn send_message_blocks_oversized_xai_requests_before_the_http_call() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![http_response("200 OK", "application/json", "{}")],
)
.await;
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
.with_base_url(server.base_url());
let error = client
.send_message(&MessageRequest {
model: "grok-3".to_string(),
max_tokens: 64_000,
messages: vec![InputMessage {
role: "user".to_string(),
content: vec![InputContentBlock::Text {
text: "x".repeat(300_000),
}],
}],
system: Some("Keep the answer short.".to_string()),
tools: None,
tool_choice: None,
stream: false,
..Default::default()
})
.await
.expect_err("oversized request should fail local context-window preflight");
assert!(matches!(error, ApiError::ContextWindowExceeded { .. }));
assert!(
state.lock().await.is_empty(),
"preflight failure should avoid any upstream HTTP request"
);
}
#[tokio::test]
async fn send_message_accepts_full_chat_completions_endpoint_override() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
@ -195,6 +233,83 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() {
assert!(request.body.contains("\"stream\":true"));
}
#[allow(clippy::await_holding_lock)]
#[tokio::test]
async fn openai_streaming_requests_opt_into_usage_chunks() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!(
"data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n",
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n",
"data: [DONE]\n\n"
);
let server = spawn_server(
state.clone(),
vec![http_response_with_headers(
"200 OK",
"text/event-stream",
sse,
&[("x-request-id", "req_openai_stream")],
)],
)
.await;
let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai())
.with_base_url(server.base_url());
let mut stream = client
.stream_message(&sample_request(false))
.await
.expect("stream should start");
assert_eq!(stream.request_id(), Some("req_openai_stream"));
let mut events = Vec::new();
while let Some(event) = stream.next_event().await.expect("event should parse") {
events.push(event);
}
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
assert!(matches!(
events[1],
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
content_block: OutputContentBlock::Text { .. },
..
})
));
assert!(matches!(
events[2],
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
delta: ContentBlockDelta::TextDelta { .. },
..
})
));
assert!(matches!(
events[3],
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
));
assert!(matches!(
events[4],
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
));
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
match &events[4] {
StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
assert_eq!(usage.input_tokens, 9);
assert_eq!(usage.output_tokens, 4);
}
other => panic!("expected message delta, got {other:?}"),
}
let captured = state.lock().await;
let request = captured.first().expect("captured request");
assert_eq!(request.path, "/chat/completions");
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
assert_eq!(body["stream"], json!(true));
assert_eq!(body["stream_options"], json!({"include_usage": true}));
}
#[allow(clippy::await_holding_lock)]
#[tokio::test]
async fn provider_client_dispatches_xai_requests_from_env() {
let _lock = env_lock();
@ -382,6 +497,7 @@ fn sample_request(stream: bool) -> MessageRequest {
}]),
tool_choice: Some(ToolChoice::Auto),
stream,
..Default::default()
}
}
@ -389,7 +505,7 @@ fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
LOCK.get_or_init(|| StdMutex::new(()))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
struct ScopedEnvVar {

View File

@ -22,7 +22,9 @@ fn provider_client_reports_missing_xai_credentials_for_grok_models() {
.expect_err("grok requests without XAI_API_KEY should fail fast");
match error {
ApiError::MissingCredentials { provider, env_vars } => {
ApiError::MissingCredentials {
provider, env_vars, ..
} => {
assert_eq!(provider, "xAI");
assert_eq!(env_vars, &["XAI_API_KEY"]);
}
@ -31,18 +33,18 @@ fn provider_client_reports_missing_xai_credentials_for_grok_models() {
}
#[test]
fn provider_client_uses_explicit_auth_without_env_lookup() {
fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() {
let _lock = env_lock();
let _api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None);
let _auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None);
let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None);
let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None);
let client = ProviderClient::from_model_with_default_auth(
let client = ProviderClient::from_model_with_anthropic_auth(
"claude-sonnet-4-6",
Some(AuthSource::ApiKey("claw-test-key".to_string())),
Some(AuthSource::ApiKey("anthropic-test-key".to_string())),
)
.expect("explicit auth should avoid env lookup");
.expect("explicit anthropic auth should avoid env lookup");
assert_eq!(client.provider_kind(), ProviderKind::ClawApi);
assert_eq!(client.provider_kind(), ProviderKind::Anthropic);
}
#[test]
@ -57,7 +59,7 @@ fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
struct EnvVarGuard {

View File

@ -0,0 +1,173 @@
use std::ffi::OsString;
use std::sync::{Mutex, OnceLock};
use api::{build_http_client_with, ProxyConfig};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
struct EnvVarGuard {
key: &'static str,
original: Option<OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: Option<&str>) -> Self {
let original = std::env::var_os(key);
match value {
Some(value) => std::env::set_var(key, value),
None => std::env::remove_var(key),
}
Self { key, original }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
match &self.original {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
#[test]
fn proxy_config_from_env_reads_uppercase_proxy_vars() {
// given
let _lock = env_lock();
let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128"));
let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129"));
let _no = EnvVarGuard::set("NO_PROXY", Some("localhost,127.0.0.1"));
let _http_lower = EnvVarGuard::set("http_proxy", None);
let _https_lower = EnvVarGuard::set("https_proxy", None);
let _no_lower = EnvVarGuard::set("no_proxy", None);
// when
let config = ProxyConfig::from_env();
// then
assert_eq!(config.http_proxy.as_deref(), Some("http://proxy.corp:3128"));
assert_eq!(
config.https_proxy.as_deref(),
Some("http://secure.corp:3129")
);
assert_eq!(config.no_proxy.as_deref(), Some("localhost,127.0.0.1"));
assert!(config.proxy_url.is_none());
assert!(!config.is_empty());
}
#[test]
fn proxy_config_from_env_reads_lowercase_proxy_vars() {
// given
let _lock = env_lock();
let _http = EnvVarGuard::set("HTTP_PROXY", None);
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
let _no = EnvVarGuard::set("NO_PROXY", None);
let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128"));
let _https_lower = EnvVarGuard::set("https_proxy", Some("http://lower-secure.corp:3129"));
let _no_lower = EnvVarGuard::set("no_proxy", Some(".internal"));
// when
let config = ProxyConfig::from_env();
// then
assert_eq!(config.http_proxy.as_deref(), Some("http://lower.corp:3128"));
assert_eq!(
config.https_proxy.as_deref(),
Some("http://lower-secure.corp:3129")
);
assert_eq!(config.no_proxy.as_deref(), Some(".internal"));
assert!(!config.is_empty());
}
#[test]
fn proxy_config_from_env_is_empty_when_no_vars_set() {
// given
let _lock = env_lock();
let _http = EnvVarGuard::set("HTTP_PROXY", None);
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
let _no = EnvVarGuard::set("NO_PROXY", None);
let _http_lower = EnvVarGuard::set("http_proxy", None);
let _https_lower = EnvVarGuard::set("https_proxy", None);
let _no_lower = EnvVarGuard::set("no_proxy", None);
// when
let config = ProxyConfig::from_env();
// then
assert!(config.is_empty());
assert!(config.http_proxy.is_none());
assert!(config.https_proxy.is_none());
assert!(config.no_proxy.is_none());
}
#[test]
fn proxy_config_from_env_treats_empty_values_as_unset() {
// given
let _lock = env_lock();
let _http = EnvVarGuard::set("HTTP_PROXY", Some(""));
let _https = EnvVarGuard::set("HTTPS_PROXY", Some(""));
let _http_lower = EnvVarGuard::set("http_proxy", Some(""));
let _https_lower = EnvVarGuard::set("https_proxy", Some(""));
let _no = EnvVarGuard::set("NO_PROXY", Some(""));
let _no_lower = EnvVarGuard::set("no_proxy", Some(""));
// when
let config = ProxyConfig::from_env();
// then
assert!(config.is_empty());
}
#[test]
fn build_client_with_env_proxy_config_succeeds() {
// given
let _lock = env_lock();
let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128"));
let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129"));
let _no = EnvVarGuard::set("NO_PROXY", Some("localhost"));
let _http_lower = EnvVarGuard::set("http_proxy", None);
let _https_lower = EnvVarGuard::set("https_proxy", None);
let _no_lower = EnvVarGuard::set("no_proxy", None);
let config = ProxyConfig::from_env();
// when
let result = build_http_client_with(&config);
// then
assert!(result.is_ok());
}
#[test]
fn build_client_with_proxy_url_config_succeeds() {
// given
let config = ProxyConfig::from_proxy_url("http://unified.corp:3128");
// when
let result = build_http_client_with(&config);
// then
assert!(result.is_ok());
}
#[test]
fn proxy_config_from_env_prefers_uppercase_over_lowercase() {
// given
let _lock = env_lock();
let _http_upper = EnvVarGuard::set("HTTP_PROXY", Some("http://upper.corp:3128"));
let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128"));
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
let _https_lower = EnvVarGuard::set("https_proxy", None);
let _no = EnvVarGuard::set("NO_PROXY", None);
let _no_lower = EnvVarGuard::set("no_proxy", None);
// when
let config = ProxyConfig::from_env();
// then
assert_eq!(config.http_proxy.as_deref(), Some("http://upper.corp:3128"));
}

View File

@ -16,7 +16,7 @@ use std::thread;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use api::{
resolve_startup_auth_source, AuthSource, ClawApiClient, ContentBlockDelta, InputContentBlock,
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
};
@ -329,7 +329,7 @@ fn join_optional_args(args: &[String]) -> Option<String> {
fn parse_direct_slash_cli_action(rest: &[String]) -> Result<CliAction, String> {
let raw = rest.join(" ");
match SlashCommand::parse(&raw) {
match SlashCommand::parse(&raw).map_err(|e| e.to_string())? {
Some(SlashCommand::Help) => Ok(CliAction::Help),
Some(SlashCommand::Agents { args }) => Ok(CliAction::Agents { args }),
Some(SlashCommand::Skills { args }) => Ok(CliAction::Skills { args }),
@ -484,7 +484,7 @@ fn dump_manifests() {
}
fn print_bootstrap_plan() {
for phase in runtime::BootstrapPlan::claw_default().phases() {
for phase in runtime::BootstrapPlan::claude_code_default().phases() {
println!("- {phase:?}");
}
}
@ -541,7 +541,7 @@ fn run_login() -> Result<(), Box<dyn std::error::Error>> {
return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into());
}
let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(api::read_base_url());
let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(api::read_base_url());
let exchange_request =
OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri);
let runtime = tokio::runtime::Runtime::new()?;
@ -650,7 +650,7 @@ fn resume_session(session_path: &Path, commands: &[String]) {
let mut session = session;
for raw_command in commands {
let Some(command) = SlashCommand::parse(raw_command) else {
let Ok(Some(command)) = SlashCommand::parse(raw_command) else {
eprintln!("unsupported resumed command: {raw_command}");
std::process::exit(2);
};
@ -987,8 +987,6 @@ fn run_resume_command(
}
SlashCommand::Bughunter { .. }
| SlashCommand::Branch { .. }
| SlashCommand::Worktree { .. }
| SlashCommand::CommitPushPr { .. }
| SlashCommand::Commit
| SlashCommand::Pr { .. }
| SlashCommand::Issue { .. }
@ -1000,7 +998,7 @@ fn run_resume_command(
| SlashCommand::Permissions { .. }
| SlashCommand::Session { .. }
| SlashCommand::Plugins { .. }
| SlashCommand::Unknown(_) => Err("unsupported resumed slash command".into()),
| _ => Err("unsupported resumed slash command".into()),
}
}
@ -1024,7 +1022,7 @@ fn run_repl(
cli.persist_session()?;
break;
}
if let Some(command) = SlashCommand::parse(trimmed) {
if let Ok(Some(command)) = SlashCommand::parse(trimmed) {
if cli.handle_repl_command(command)? {
cli.persist_session()?;
}
@ -1336,24 +1334,14 @@ impl LiveCli {
);
false
}
SlashCommand::Worktree { .. } => {
eprintln!(
"{}",
render_mode_unavailable("worktree", "git worktree commands")
);
false
}
SlashCommand::CommitPushPr { .. } => {
eprintln!(
"{}",
render_mode_unavailable("commit-push-pr", "commit + push + PR automation")
);
false
}
SlashCommand::Unknown(name) => {
eprintln!("{}", render_unknown_repl_command(&name));
false
}
_ => {
eprintln!("command not available in this mode");
false
}
})
}
@ -2505,12 +2493,6 @@ fn render_export_text(session: &Session) -> String {
for block in &message.blocks {
match block {
ContentBlock::Text { text } => lines.push(text.clone()),
ContentBlock::Thinking { thinking, .. } => {
lines.push(format!("[thinking] {thinking}"));
}
ContentBlock::RedactedThinking { .. } => {
lines.push("[thinking] <redacted>".to_string());
}
ContentBlock::ToolUse { id, name, input } => {
lines.push(format!("[tool_use id={id} name={name}] {input}"));
}
@ -2995,7 +2977,7 @@ fn build_runtime(
CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()),
permission_policy(permission_mode, &tool_registry),
system_prompt,
feature_config,
&feature_config,
))
}
@ -3047,7 +3029,7 @@ impl runtime::PermissionPrompter for CliPermissionPrompter {
struct DefaultRuntimeClient {
runtime: tokio::runtime::Runtime,
client: ClawApiClient,
client: AnthropicClient,
model: String,
enable_tools: bool,
emit_output: bool,
@ -3067,7 +3049,7 @@ impl DefaultRuntimeClient {
) -> Result<Self, Box<dyn std::error::Error>> {
Ok(Self {
runtime: tokio::runtime::Runtime::new()?,
client: ClawApiClient::from_auth(resolve_cli_auth_source()?)
client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
.with_base_url(api::read_base_url()),
model,
enable_tools,
@ -3105,6 +3087,12 @@ impl ApiClient for DefaultRuntimeClient {
.then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())),
tool_choice: self.enable_tools.then_some(ToolChoice::Auto),
stream: true,
temperature: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
reasoning_effort: None,
};
self.runtime.block_on(async {
@ -3173,7 +3161,6 @@ impl ApiClient for DefaultRuntimeClient {
.and_then(|()| out.flush())
.map_err(|error| RuntimeError::new(error.to_string()))?;
}
events.push(AssistantEvent::ThinkingDelta(thinking));
}
}
ContentBlockDelta::SignatureDelta { .. } => {}
@ -3254,9 +3241,7 @@ fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
.iter()
.filter_map(|block| match block {
ContentBlock::Text { text } => Some(text.as_str()),
ContentBlock::Thinking { thinking, .. } => Some(thinking.as_str()),
ContentBlock::RedactedThinking { .. }
| ContentBlock::ToolUse { .. }
ContentBlock::ToolUse { .. }
| ContentBlock::ToolResult { .. } => None,
})
.collect::<Vec<_>>()
@ -3276,9 +3261,7 @@ fn collect_tool_uses(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> {
"name": name,
"input": input,
})),
ContentBlock::Thinking { .. }
| ContentBlock::RedactedThinking { .. }
| ContentBlock::Text { .. }
ContentBlock::Text { .. }
| ContentBlock::ToolResult { .. } => None,
})
.collect()
@ -3301,9 +3284,7 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value
"output": output,
"is_error": is_error,
})),
ContentBlock::Thinking { .. }
| ContentBlock::RedactedThinking { .. }
| ContentBlock::Text { .. }
ContentBlock::Text { .. }
| ContentBlock::ToolUse { .. } => None,
})
.collect()
@ -3851,7 +3832,6 @@ fn push_output_block(
write!(out, "\x1b[2m{thinking}\x1b[0m")
.and_then(|()| out.flush())
.map_err(|error| RuntimeError::new(error.to_string()))?;
events.push(AssistantEvent::ThinkingDelta(thinking));
}
}
OutputContentBlock::RedactedThinking { .. } => {}
@ -3942,7 +3922,7 @@ impl ToolExecutor for CliToolExecutor {
}
fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy {
tool_registry.permission_specs(None).into_iter().fold(
tool_registry.permission_specs(None).unwrap_or_default().into_iter().fold(
PermissionPolicy::new(mode),
|policy, (name, required_permission)| {
policy.with_tool_requirement(name, required_permission)
@ -3963,16 +3943,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
.iter()
.map(|block| match block {
ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() },
ContentBlock::Thinking {
thinking,
signature,
} => InputContentBlock::Thinking {
thinking: thinking.clone(),
signature: signature.clone(),
},
ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking {
data: serde_json::from_str(&data.render()).unwrap_or(serde_json::Value::Null),
},
ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse {
id: id.clone(),
name: name.clone(),
@ -4735,39 +4705,39 @@ mod tests {
#[test]
fn clear_command_requires_explicit_confirmation_flag() {
assert_eq!(
SlashCommand::parse("/clear"),
Some(SlashCommand::Clear { confirm: false })
SlashCommand::parse("/clear").map_err(|e| e.to_string()),
Ok(Some(SlashCommand::Clear { confirm: false }))
);
assert_eq!(
SlashCommand::parse("/clear --confirm"),
Some(SlashCommand::Clear { confirm: true })
SlashCommand::parse("/clear --confirm").map_err(|e| e.to_string()),
Ok(Some(SlashCommand::Clear { confirm: true }))
);
}
#[test]
fn parses_resume_and_config_slash_commands() {
assert_eq!(
SlashCommand::parse("/resume saved-session.json"),
Some(SlashCommand::Resume {
SlashCommand::parse("/resume saved-session.json").map_err(|e| e.to_string()),
Ok(Some(SlashCommand::Resume {
session_path: Some("saved-session.json".to_string())
})
}))
);
assert_eq!(
SlashCommand::parse("/clear --confirm"),
Some(SlashCommand::Clear { confirm: true })
SlashCommand::parse("/clear --confirm").map_err(|e| e.to_string()),
Ok(Some(SlashCommand::Clear { confirm: true }))
);
assert_eq!(
SlashCommand::parse("/config"),
Some(SlashCommand::Config { section: None })
SlashCommand::parse("/config").map_err(|e| e.to_string()),
Ok(Some(SlashCommand::Config { section: None }))
);
assert_eq!(
SlashCommand::parse("/config env"),
Some(SlashCommand::Config {
SlashCommand::parse("/config env").map_err(|e| e.to_string()),
Ok(Some(SlashCommand::Config {
section: Some("env".to_string())
})
}))
);
assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory));
assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init));
assert_eq!(SlashCommand::parse("/memory").map_err(|e| e.to_string()), Ok(Some(SlashCommand::Memory)));
assert_eq!(SlashCommand::parse("/init").map_err(|e| e.to_string()), Ok(Some(SlashCommand::Init)));
}
#[test]

File diff suppressed because it is too large Load Diff

View File

@ -70,16 +70,12 @@ fn upstream_repo_candidates(primary_repo_root: &Path) -> Vec<PathBuf> {
}
for ancestor in primary_repo_root.ancestors().take(4) {
candidates.push(ancestor.join("claude-code"));
candidates.push(ancestor.join("claw-code"));
candidates.push(ancestor.join("clawd-code"));
}
candidates.push(
primary_repo_root
.join("reference-source")
.join("claude-code"),
);
candidates.push(primary_repo_root.join("vendor").join("claude-code"));
candidates.push(primary_repo_root.join("reference-source").join("claw-code"));
candidates.push(primary_repo_root.join("vendor").join("claw-code"));
let mut deduped = Vec::new();
for candidate in candidates {

View File

@ -41,6 +41,7 @@ mod tests {
})
}
#[allow(clippy::too_many_lines)]
fn write_mock_server_script(root: &std::path::Path) -> PathBuf {
let script_path = root.join("mock_lsp_server.py");
fs::write(

View File

@ -0,0 +1,18 @@
[package]
name = "mock-anthropic-service"
version.workspace = true
edition.workspace = true
license.workspace = true
publish.workspace = true
[[bin]]
name = "mock-anthropic-service"
path = "src/main.rs"
[dependencies]
api = { path = "../api" }
serde_json.workspace = true
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "signal", "sync"] }
[lints]
workspace = true

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,34 @@
use std::env;
use mock_anthropic_service::MockAnthropicService;
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut bind_addr = String::from("127.0.0.1:0");
let mut args = env::args().skip(1);
while let Some(arg) = args.next() {
match arg.as_str() {
"--bind" => {
bind_addr = args
.next()
.ok_or_else(|| "missing value for --bind".to_string())?;
}
flag if flag.starts_with("--bind=") => {
bind_addr = flag[7..].to_string();
}
"--help" | "-h" => {
println!("Usage: mock-anthropic-service [--bind HOST:PORT]");
return Ok(());
}
other => {
return Err(format!("unsupported argument: {other}").into());
}
}
}
let server = MockAnthropicService::spawn_on(&bind_addr).await?;
println!("MOCK_ANTHROPIC_BASE_URL={}", server.base_url());
tokio::signal::ctrl_c().await?;
drop(server);
Ok(())
}

View File

@ -1,6 +1,4 @@
use std::ffi::OsStr;
#[cfg(not(windows))]
use std::path::Path;
use std::process::Command;
use serde_json::json;
@ -11,6 +9,7 @@ use crate::{PluginError, PluginHooks, PluginRegistry};
pub enum HookEvent {
PreToolUse,
PostToolUse,
PostToolUseFailure,
}
impl HookEvent {
@ -18,6 +17,7 @@ impl HookEvent {
match self {
Self::PreToolUse => "PreToolUse",
Self::PostToolUse => "PostToolUse",
Self::PostToolUseFailure => "PostToolUseFailure",
}
}
}
@ -25,6 +25,7 @@ impl HookEvent {
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HookRunResult {
denied: bool,
failed: bool,
messages: Vec<String>,
}
@ -33,6 +34,7 @@ impl HookRunResult {
pub fn allow(messages: Vec<String>) -> Self {
Self {
denied: false,
failed: false,
messages,
}
}
@ -42,6 +44,11 @@ impl HookRunResult {
self.denied
}
#[must_use]
pub fn is_failed(&self) -> bool {
self.failed
}
#[must_use]
pub fn messages(&self) -> &[String] {
&self.messages
@ -65,7 +72,7 @@ impl HookRunner {
#[must_use]
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
self.run_commands(
Self::run_commands(
HookEvent::PreToolUse,
&self.hooks.pre_tool_use,
tool_name,
@ -83,7 +90,7 @@ impl HookRunner {
tool_output: &str,
is_error: bool,
) -> HookRunResult {
self.run_commands(
Self::run_commands(
HookEvent::PostToolUse,
&self.hooks.post_tool_use,
tool_name,
@ -93,8 +100,24 @@ impl HookRunner {
)
}
fn run_commands(
#[must_use]
pub fn run_post_tool_use_failure(
&self,
tool_name: &str,
tool_input: &str,
tool_error: &str,
) -> HookRunResult {
Self::run_commands(
HookEvent::PostToolUseFailure,
&self.hooks.post_tool_use_failure,
tool_name,
tool_input,
Some(tool_error),
true,
)
}
fn run_commands(
event: HookEvent,
commands: &[String],
tool_name: &str,
@ -106,20 +129,12 @@ impl HookRunner {
return HookRunResult::allow(Vec::new());
}
let payload = json!({
"hook_event_name": event.as_str(),
"tool_name": tool_name,
"tool_input": parse_tool_input(tool_input),
"tool_input_json": tool_input,
"tool_output": tool_output,
"tool_result_is_error": is_error,
})
.to_string();
let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string();
let mut messages = Vec::new();
for command in commands {
match self.run_command(
match Self::run_command(
command,
event,
tool_name,
@ -139,19 +154,26 @@ impl HookRunner {
}));
return HookRunResult {
denied: true,
failed: false,
messages,
};
}
HookCommandOutcome::Failed { message } => {
messages.push(message);
return HookRunResult {
denied: false,
failed: true,
messages,
};
}
HookCommandOutcome::Warn { message } => messages.push(message),
}
}
HookRunResult::allow(messages)
}
#[allow(clippy::too_many_arguments, clippy::unused_self)]
#[allow(clippy::too_many_arguments)]
fn run_command(
&self,
command: &str,
event: HookEvent,
tool_name: &str,
@ -180,7 +202,7 @@ impl HookRunner {
match output.status.code() {
Some(0) => HookCommandOutcome::Allow { message },
Some(2) => HookCommandOutcome::Deny { message },
Some(code) => HookCommandOutcome::Warn {
Some(code) => HookCommandOutcome::Failed {
message: format_hook_warning(
command,
code,
@ -188,7 +210,7 @@ impl HookRunner {
stderr.as_str(),
),
},
None => HookCommandOutcome::Warn {
None => HookCommandOutcome::Failed {
message: format!(
"{} hook `{command}` terminated by signal while handling `{tool_name}`",
event.as_str()
@ -196,7 +218,7 @@ impl HookRunner {
},
}
}
Err(error) => HookCommandOutcome::Warn {
Err(error) => HookCommandOutcome::Failed {
message: format!(
"{} hook `{command}` failed to start for `{tool_name}`: {error}",
event.as_str()
@ -209,7 +231,34 @@ impl HookRunner {
enum HookCommandOutcome {
Allow { message: Option<String> },
Deny { message: Option<String> },
Warn { message: String },
Failed { message: String },
}
fn hook_payload(
event: HookEvent,
tool_name: &str,
tool_input: &str,
tool_output: Option<&str>,
is_error: bool,
) -> serde_json::Value {
match event {
HookEvent::PostToolUseFailure => json!({
"hook_event_name": event.as_str(),
"tool_name": tool_name,
"tool_input": parse_tool_input(tool_input),
"tool_input_json": tool_input,
"tool_error": tool_output,
"tool_result_is_error": true,
}),
_ => json!({
"hook_event_name": event.as_str(),
"tool_name": tool_name,
"tool_input": parse_tool_input(tool_input),
"tool_input_json": tool_input,
"tool_output": tool_output,
"tool_result_is_error": is_error,
}),
}
}
fn parse_tool_input(tool_input: &str) -> serde_json::Value {
@ -217,8 +266,7 @@ fn parse_tool_input(tool_input: &str) -> serde_json::Value {
}
fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
let mut message =
format!("Hook `{command}` exited with status {code}; allowing tool execution to continue");
let mut message = format!("Hook `{command}` exited with status {code}");
if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
message.push_str(": ");
message.push_str(stdout);
@ -288,7 +336,28 @@ impl CommandWithStdin {
let mut child = self.command.spawn()?;
if let Some(mut child_stdin) = child.stdin.take() {
use std::io::Write as _;
child_stdin.write_all(stdin)?;
// Tolerate BrokenPipe: a hook script that runs to completion
// (or exits early without reading stdin) closes its stdin
// before the parent finishes writing the JSON payload, and
// the kernel raises EPIPE on the parent's write_all. That is
// not a hook failure — the child still exited cleanly and we
// still need to wait_with_output() to capture stdout/stderr
// and the real exit code. Other write errors (e.g. EIO,
// permission, OOM) still propagate.
//
// This was the root cause of the Linux CI flake on
// hooks::tests::collects_and_runs_hooks_from_enabled_plugins
// (ROADMAP #25, runs 24120271422 / 24120538408 / 24121392171
// / 24121776826): the test hook scripts run in microseconds
// and the parent's stdin write races against child exit.
// macOS pipes happen to buffer the small payload before the
// child exits; Linux pipes do not, so the race shows up
// deterministically on ubuntu runners.
match child_stdin.write_all(stdin) {
Ok(()) => {}
Err(error) if error.kind() == std::io::ErrorKind::BrokenPipe => {}
Err(error) => return Err(error),
}
}
child.wait_with_output()
}
@ -310,23 +379,55 @@ mod tests {
std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}"))
}
fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) {
fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir");
fn make_executable(path: &Path) {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = fs::Permissions::from_mode(0o755);
fs::set_permissions(path, perms)
.unwrap_or_else(|e| panic!("chmod +x {}: {e}", path.display()));
}
#[cfg(not(unix))]
let _ = path;
}
fn write_hook_plugin(
root: &Path,
name: &str,
pre_message: &str,
post_message: &str,
failure_message: &str,
) {
fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir");
fs::create_dir_all(root.join("hooks")).expect("hooks dir");
let pre_path = root.join("hooks").join("pre.sh");
fs::write(
root.join("hooks").join("pre.sh"),
&pre_path,
format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"),
)
.expect("write pre hook");
make_executable(&pre_path);
let post_path = root.join("hooks").join("post.sh");
fs::write(
root.join("hooks").join("post.sh"),
&post_path,
format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"),
)
.expect("write post hook");
make_executable(&post_path);
let failure_path = root.join("hooks").join("failure.sh");
fs::write(
root.join(".claw-plugin").join("plugin.json"),
&failure_path,
format!("#!/bin/sh\nprintf '%s\\n' '{failure_message}'\n"),
)
.expect("write failure hook");
make_executable(&failure_path);
fs::write(
root.join(".claude-plugin").join("plugin.json"),
format!(
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}"
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"],\n \"PostToolUseFailure\": [\"./hooks/failure.sh\"]\n }}\n}}"
),
)
.expect("write plugin manifest");
@ -334,6 +435,7 @@ mod tests {
#[test]
fn collects_and_runs_hooks_from_enabled_plugins() {
// given
let config_home = temp_dir("config");
let first_source_root = temp_dir("source-a");
let second_source_root = temp_dir("source-b");
@ -342,12 +444,14 @@ mod tests {
"first",
"plugin pre one",
"plugin post one",
"plugin failure one",
);
write_hook_plugin(
&second_source_root,
"second",
"plugin pre two",
"plugin post two",
"plugin failure two",
);
let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
@ -359,8 +463,10 @@ mod tests {
.expect("second plugin install should succeed");
let registry = manager.plugin_registry().expect("registry should build");
// when
let runner = HookRunner::from_registry(&registry).expect("plugin hooks should load");
// then
assert_eq!(
runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#),
HookRunResult::allow(vec![
@ -375,6 +481,13 @@ mod tests {
"plugin post two".to_string(),
])
);
assert_eq!(
runner.run_post_tool_use_failure("Read", r#"{"path":"README.md"}"#, "tool failed",),
HookRunResult::allow(vec![
"plugin failure one".to_string(),
"plugin failure two".to_string(),
])
);
let _ = fs::remove_dir_all(config_home);
let _ = fs::remove_dir_all(first_source_root);
@ -383,14 +496,68 @@ mod tests {
#[test]
fn pre_tool_use_denies_when_plugin_hook_exits_two() {
// given
let runner = HookRunner::new(crate::PluginHooks {
pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()],
post_tool_use: Vec::new(),
post_tool_use_failure: Vec::new(),
});
// when
let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
// then
assert!(result.is_denied());
assert_eq!(result.messages(), &["blocked by plugin".to_string()]);
}
#[test]
fn propagates_plugin_hook_failures() {
// given
let runner = HookRunner::new(crate::PluginHooks {
pre_tool_use: vec![
"printf 'broken plugin hook'; exit 1".to_string(),
"printf 'later plugin hook'".to_string(),
],
post_tool_use: Vec::new(),
post_tool_use_failure: Vec::new(),
});
// when
let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
// then
assert!(result.is_failed());
assert!(result
.messages()
.iter()
.any(|message| message.contains("broken plugin hook")));
assert!(!result
.messages()
.iter()
.any(|message| message == "later plugin hook"));
}
#[test]
#[cfg(unix)]
fn generated_hook_scripts_are_executable() {
use std::os::unix::fs::PermissionsExt;
// given
let root = temp_dir("exec-guard");
write_hook_plugin(&root, "exec-check", "pre", "post", "fail");
// then
for script in ["pre.sh", "post.sh", "failure.sh"] {
let path = root.join("hooks").join(script);
let mode = fs::metadata(&path)
.unwrap_or_else(|e| panic!("{script} metadata: {e}"))
.permissions()
.mode();
assert!(
mode & 0o111 != 0,
"{script} must have at least one execute bit set, got mode {mode:#o}"
);
}
}
}

View File

@ -18,7 +18,7 @@ const BUNDLED_MARKETPLACE: &str = "bundled";
const SETTINGS_FILE_NAME: &str = "settings.json";
const REGISTRY_FILE_NAME: &str = "installed.json";
const MANIFEST_FILE_NAME: &str = "plugin.json";
const MANIFEST_RELATIVE_PATH: &str = ".claw-plugin/plugin.json";
const MANIFEST_RELATIVE_PATH: &str = ".claude-plugin/plugin.json";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
@ -67,12 +67,16 @@ pub struct PluginHooks {
pub pre_tool_use: Vec<String>,
#[serde(rename = "PostToolUse", default)]
pub post_tool_use: Vec<String>,
#[serde(rename = "PostToolUseFailure", default)]
pub post_tool_use_failure: Vec<String>,
}
impl PluginHooks {
#[must_use]
pub fn is_empty(&self) -> bool {
self.pre_tool_use.is_empty() && self.post_tool_use.is_empty()
self.pre_tool_use.is_empty()
&& self.post_tool_use.is_empty()
&& self.post_tool_use_failure.is_empty()
}
#[must_use]
@ -85,6 +89,9 @@ impl PluginHooks {
.post_tool_use
.extend(other.post_tool_use.iter().cloned());
merged
.post_tool_use_failure
.extend(other.post_tool_use_failure.iter().cloned());
merged
}
}
@ -302,14 +309,14 @@ impl PluginTool {
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.env("CLAW_PLUGIN_ID", &self.plugin_id)
.env("CLAW_PLUGIN_NAME", &self.plugin_name)
.env("CLAW_TOOL_NAME", &self.definition.name)
.env("CLAW_TOOL_INPUT", &input_json);
.env("CLAWD_PLUGIN_ID", &self.plugin_id)
.env("CLAWD_PLUGIN_NAME", &self.plugin_name)
.env("CLAWD_TOOL_NAME", &self.definition.name)
.env("CLAWD_TOOL_INPUT", &input_json);
if let Some(root) = &self.root {
process
.current_dir(root)
.env("CLAW_PLUGIN_ROOT", root.display().to_string());
.env("CLAWD_PLUGIN_ROOT", root.display().to_string());
}
let mut child = process.spawn()?;
@ -648,6 +655,106 @@ pub struct PluginSummary {
pub enabled: bool,
}
#[derive(Debug)]
pub struct PluginLoadFailure {
pub plugin_root: PathBuf,
pub kind: PluginKind,
pub source: String,
error: Box<PluginError>,
}
impl PluginLoadFailure {
#[must_use]
pub fn new(plugin_root: PathBuf, kind: PluginKind, source: String, error: PluginError) -> Self {
Self {
plugin_root,
kind,
source,
error: Box::new(error),
}
}
#[must_use]
pub fn error(&self) -> &PluginError {
self.error.as_ref()
}
}
impl Display for PluginLoadFailure {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"failed to load {} plugin from `{}` (source: {}): {}",
self.kind,
self.plugin_root.display(),
self.source,
self.error()
)
}
}
#[derive(Debug)]
pub struct PluginRegistryReport {
registry: PluginRegistry,
failures: Vec<PluginLoadFailure>,
}
impl PluginRegistryReport {
#[must_use]
pub fn new(registry: PluginRegistry, failures: Vec<PluginLoadFailure>) -> Self {
Self { registry, failures }
}
#[must_use]
pub fn registry(&self) -> &PluginRegistry {
&self.registry
}
#[must_use]
pub fn failures(&self) -> &[PluginLoadFailure] {
&self.failures
}
#[must_use]
pub fn has_failures(&self) -> bool {
!self.failures.is_empty()
}
#[must_use]
pub fn summaries(&self) -> Vec<PluginSummary> {
self.registry.summaries()
}
pub fn into_registry(self) -> Result<PluginRegistry, PluginError> {
if self.failures.is_empty() {
Ok(self.registry)
} else {
Err(PluginError::LoadFailures(self.failures))
}
}
}
#[derive(Debug, Default)]
struct PluginDiscovery {
plugins: Vec<PluginDefinition>,
failures: Vec<PluginLoadFailure>,
}
impl PluginDiscovery {
fn push_plugin(&mut self, plugin: PluginDefinition) {
self.plugins.push(plugin);
}
fn push_failure(&mut self, failure: PluginLoadFailure) {
self.failures.push(failure);
}
fn extend(&mut self, other: Self) {
self.plugins.extend(other.plugins);
self.failures.extend(other.failures);
}
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct PluginRegistry {
plugins: Vec<RegisteredPlugin>,
@ -802,6 +909,10 @@ pub enum PluginManifestValidationError {
kind: &'static str,
path: PathBuf,
},
PathIsDirectory {
kind: &'static str,
path: PathBuf,
},
InvalidToolInputSchema {
tool_name: String,
},
@ -809,6 +920,9 @@ pub enum PluginManifestValidationError {
tool_name: String,
permission: String,
},
UnsupportedManifestContract {
detail: String,
},
}
impl Display for PluginManifestValidationError {
@ -838,6 +952,9 @@ impl Display for PluginManifestValidationError {
Self::MissingPath { kind, path } => {
write!(f, "{kind} path `{}` does not exist", path.display())
}
Self::PathIsDirectory { kind, path } => {
write!(f, "{kind} path `{}` must point to a file", path.display())
}
Self::InvalidToolInputSchema { tool_name } => {
write!(
f,
@ -851,6 +968,7 @@ impl Display for PluginManifestValidationError {
f,
"plugin tool `{tool_name}` requiredPermission `{permission}` must be read-only, workspace-write, or danger-full-access"
),
Self::UnsupportedManifestContract { detail } => f.write_str(detail),
}
}
}
@ -860,6 +978,7 @@ pub enum PluginError {
Io(std::io::Error),
Json(serde_json::Error),
ManifestValidation(Vec<PluginManifestValidationError>),
LoadFailures(Vec<PluginLoadFailure>),
InvalidManifest(String),
NotFound(String),
CommandFailed(String),
@ -879,6 +998,15 @@ impl Display for PluginError {
}
Ok(())
}
Self::LoadFailures(failures) => {
for (index, failure) in failures.iter().enumerate() {
if index > 0 {
write!(f, "; ")?;
}
write!(f, "{failure}")?;
}
Ok(())
}
Self::InvalidManifest(message)
| Self::NotFound(message)
| Self::CommandFailed(message) => write!(f, "{message}"),
@ -935,15 +1063,23 @@ impl PluginManager {
}
pub fn plugin_registry(&self) -> Result<PluginRegistry, PluginError> {
Ok(PluginRegistry::new(
self.discover_plugins()?
.into_iter()
.map(|plugin| {
let enabled = self.is_enabled(plugin.metadata());
RegisteredPlugin::new(plugin, enabled)
})
.collect(),
))
self.plugin_registry_report()?.into_registry()
}
pub fn plugin_registry_report(&self) -> Result<PluginRegistryReport, PluginError> {
self.sync_bundled_plugins()?;
let mut discovery = PluginDiscovery::default();
discovery.plugins.extend(builtin_plugins());
let installed = self.discover_installed_plugins_with_failures()?;
discovery.extend(installed);
let external =
self.discover_external_directory_plugins_with_failures(&discovery.plugins)?;
discovery.extend(external);
Ok(self.build_registry_report(discovery))
}
pub fn list_plugins(&self) -> Result<Vec<PluginSummary>, PluginError> {
@ -955,11 +1091,12 @@ impl PluginManager {
}
pub fn discover_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> {
self.sync_bundled_plugins()?;
let mut plugins = builtin_plugins();
plugins.extend(self.discover_installed_plugins()?);
plugins.extend(self.discover_external_directory_plugins(&plugins)?);
Ok(plugins)
Ok(self
.plugin_registry()?
.plugins
.into_iter()
.map(|plugin| plugin.definition)
.collect())
}
pub fn aggregated_hooks(&self) -> Result<PluginHooks, PluginError> {
@ -1094,9 +1231,9 @@ impl PluginManager {
})
}
fn discover_installed_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> {
fn discover_installed_plugins_with_failures(&self) -> Result<PluginDiscovery, PluginError> {
let mut registry = self.load_registry()?;
let mut plugins = Vec::new();
let mut discovery = PluginDiscovery::default();
let mut seen_ids = BTreeSet::<String>::new();
let mut seen_paths = BTreeSet::<PathBuf>::new();
let mut stale_registry_ids = Vec::new();
@ -1111,10 +1248,21 @@ impl PluginManager {
|| install_path.display().to_string(),
|record| describe_install_source(&record.source),
);
let plugin = load_plugin_definition(&install_path, kind, source, kind.marketplace())?;
match load_plugin_definition(&install_path, kind, source.clone(), kind.marketplace()) {
Ok(plugin) => {
if seen_ids.insert(plugin.metadata().id.clone()) {
seen_paths.insert(install_path);
plugins.push(plugin);
discovery.push_plugin(plugin);
}
}
Err(error) => {
discovery.push_failure(PluginLoadFailure::new(
install_path,
kind,
source,
error,
));
}
}
}
@ -1127,15 +1275,27 @@ impl PluginManager {
stale_registry_ids.push(record.id.clone());
continue;
}
let plugin = load_plugin_definition(
let source = describe_install_source(&record.source);
match load_plugin_definition(
&record.install_path,
record.kind,
describe_install_source(&record.source),
source.clone(),
record.kind.marketplace(),
)?;
) {
Ok(plugin) => {
if seen_ids.insert(plugin.metadata().id.clone()) {
seen_paths.insert(record.install_path.clone());
plugins.push(plugin);
discovery.push_plugin(plugin);
}
}
Err(error) => {
discovery.push_failure(PluginLoadFailure::new(
record.install_path.clone(),
record.kind,
source,
error,
));
}
}
}
@ -1146,47 +1306,51 @@ impl PluginManager {
self.store_registry(&registry)?;
}
Ok(plugins)
Ok(discovery)
}
fn discover_external_directory_plugins(
fn discover_external_directory_plugins_with_failures(
&self,
existing_plugins: &[PluginDefinition],
) -> Result<Vec<PluginDefinition>, PluginError> {
let mut plugins = Vec::new();
) -> Result<PluginDiscovery, PluginError> {
let mut discovery = PluginDiscovery::default();
for directory in &self.config.external_dirs {
for root in discover_plugin_dirs(directory)? {
let plugin = load_plugin_definition(
let source = root.display().to_string();
match load_plugin_definition(
&root,
PluginKind::External,
root.display().to_string(),
source.clone(),
EXTERNAL_MARKETPLACE,
)?;
) {
Ok(plugin) => {
if existing_plugins
.iter()
.chain(plugins.iter())
.chain(discovery.plugins.iter())
.all(|existing| existing.metadata().id != plugin.metadata().id)
{
plugins.push(plugin);
discovery.push_plugin(plugin);
}
}
Err(error) => {
discovery.push_failure(PluginLoadFailure::new(
root,
PluginKind::External,
source,
error,
));
}
}
}
}
Ok(plugins)
Ok(discovery)
}
fn installed_plugin_registry(&self) -> Result<PluginRegistry, PluginError> {
pub fn installed_plugin_registry_report(&self) -> Result<PluginRegistryReport, PluginError> {
self.sync_bundled_plugins()?;
Ok(PluginRegistry::new(
self.discover_installed_plugins()?
.into_iter()
.map(|plugin| {
let enabled = self.is_enabled(plugin.metadata());
RegisteredPlugin::new(plugin, enabled)
})
.collect(),
))
Ok(self.build_registry_report(self.discover_installed_plugins_with_failures()?))
}
fn sync_bundled_plugins(&self) -> Result<(), PluginError> {
@ -1332,6 +1496,26 @@ impl PluginManager {
}
})
}
fn installed_plugin_registry(&self) -> Result<PluginRegistry, PluginError> {
self.installed_plugin_registry_report()?.into_registry()
}
fn build_registry_report(&self, discovery: PluginDiscovery) -> PluginRegistryReport {
PluginRegistryReport::new(
PluginRegistry::new(
discovery
.plugins
.into_iter()
.map(|plugin| {
let enabled = self.is_enabled(plugin.metadata());
RegisteredPlugin::new(plugin, enabled)
})
.collect(),
),
discovery.failures,
)
}
}
#[must_use]
@ -1414,10 +1598,73 @@ fn load_manifest_from_path(
manifest_path.display()
))
})?;
let raw_manifest: RawPluginManifest = serde_json::from_str(&contents)?;
let raw_json: Value = serde_json::from_str(&contents)?;
let compatibility_errors = detect_claude_code_manifest_contract_gaps(&raw_json);
if !compatibility_errors.is_empty() {
return Err(PluginError::ManifestValidation(compatibility_errors));
}
let raw_manifest: RawPluginManifest = serde_json::from_value(raw_json)?;
build_plugin_manifest(root, raw_manifest)
}
fn detect_claude_code_manifest_contract_gaps(
raw_manifest: &Value,
) -> Vec<PluginManifestValidationError> {
let Some(root) = raw_manifest.as_object() else {
return Vec::new();
};
let mut errors = Vec::new();
for (field, detail) in [
(
"skills",
"plugin manifest field `skills` uses the Claude Code plugin contract; `claw` does not load plugin-managed skills and instead discovers skills from local roots such as `.claw/skills`, `.omc/skills`, `.agents/skills`, `~/.omc/skills`, and `~/.claude/skills/omc-learned`.",
),
(
"mcpServers",
"plugin manifest field `mcpServers` uses the Claude Code plugin contract; `claw` does not import MCP servers from plugin manifests.",
),
(
"agents",
"plugin manifest field `agents` uses the Claude Code plugin contract; `claw` does not load plugin-managed agent markdown catalogs from plugin manifests.",
),
] {
if root.contains_key(field) {
errors.push(PluginManifestValidationError::UnsupportedManifestContract {
detail: detail.to_string(),
});
}
}
if root
.get("commands")
.and_then(Value::as_array)
.is_some_and(|commands| commands.iter().any(Value::is_string))
{
errors.push(PluginManifestValidationError::UnsupportedManifestContract {
detail: "plugin manifest field `commands` uses Claude Code-style directory globs; `claw` slash dispatch is still built-in and does not load plugin slash command markdown files.".to_string(),
});
}
if let Some(hooks) = root.get("hooks").and_then(Value::as_object) {
for hook_name in hooks.keys() {
if !matches!(
hook_name.as_str(),
"PreToolUse" | "PostToolUse" | "PostToolUseFailure"
) {
errors.push(PluginManifestValidationError::UnsupportedManifestContract {
detail: format!(
"plugin hook `{hook_name}` uses the Claude Code lifecycle contract; `claw` plugins currently support only PreToolUse, PostToolUse, and PostToolUseFailure."
),
});
}
}
}
errors
}
fn plugin_manifest_path(root: &Path) -> Result<PathBuf, PluginError> {
let direct_path = root.join(MANIFEST_FILE_NAME);
if direct_path.exists() {
@ -1449,6 +1696,12 @@ fn build_plugin_manifest(
let permissions = build_manifest_permissions(&raw.permissions, &mut errors);
validate_command_entries(root, raw.hooks.pre_tool_use.iter(), "hook", &mut errors);
validate_command_entries(root, raw.hooks.post_tool_use.iter(), "hook", &mut errors);
validate_command_entries(
root,
raw.hooks.post_tool_use_failure.iter(),
"hook",
&mut errors,
);
validate_command_entries(
root,
raw.lifecycle.init.iter(),
@ -1676,6 +1929,8 @@ fn validate_command_entry(
};
if !path.exists() {
errors.push(PluginManifestValidationError::MissingPath { kind, path });
} else if !path.is_file() {
errors.push(PluginManifestValidationError::PathIsDirectory { kind, path });
}
}
@ -1691,6 +1946,11 @@ fn resolve_hooks(root: &Path, hooks: &PluginHooks) -> PluginHooks {
.iter()
.map(|entry| resolve_hook_entry(root, entry))
.collect(),
post_tool_use_failure: hooks
.post_tool_use_failure
.iter()
.map(|entry| resolve_hook_entry(root, entry))
.collect(),
}
}
@ -1739,7 +1999,12 @@ fn validate_hook_paths(root: Option<&Path>, hooks: &PluginHooks) -> Result<(), P
let Some(root) = root else {
return Ok(());
};
for entry in hooks.pre_tool_use.iter().chain(hooks.post_tool_use.iter()) {
for entry in hooks
.pre_tool_use
.iter()
.chain(hooks.post_tool_use.iter())
.chain(hooks.post_tool_use_failure.iter())
{
validate_command_path(root, entry, "hook")?;
}
Ok(())
@ -1783,6 +2048,12 @@ fn validate_command_path(root: &Path, entry: &str, kind: &str) -> Result<(), Plu
path.display()
)));
}
if !path.is_file() {
return Err(PluginError::InvalidManifest(format!(
"{kind} path `{}` must point to a file",
path.display()
)));
}
Ok(())
}
@ -2094,6 +2365,30 @@ mod tests {
);
}
fn write_directory_path_plugin(root: &Path, name: &str) {
fs::create_dir_all(root.join("hooks").join("pre-dir")).expect("hook dir");
fs::create_dir_all(root.join("tools").join("tool-dir")).expect("tool dir");
fs::create_dir_all(root.join("commands").join("sync-dir")).expect("command dir");
fs::create_dir_all(root.join("lifecycle").join("init-dir")).expect("lifecycle dir");
write_file(
root.join(MANIFEST_FILE_NAME).as_path(),
format!(
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"directory path plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre-dir\"]\n }},\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init-dir\"]\n }},\n \"tools\": [\n {{\n \"name\": \"dir_tool\",\n \"description\": \"Directory tool\",\n \"inputSchema\": {{\"type\": \"object\"}},\n \"command\": \"./tools/tool-dir\"\n }}\n ],\n \"commands\": [\n {{\n \"name\": \"sync\",\n \"description\": \"Directory command\",\n \"command\": \"./commands/sync-dir\"\n }}\n ]\n}}"
)
.as_str(),
);
}
fn write_broken_failure_hook_plugin(root: &Path, name: &str) {
write_file(
root.join(MANIFEST_RELATIVE_PATH).as_path(),
format!(
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"broken plugin\",\n \"hooks\": {{\n \"PostToolUseFailure\": [\"./hooks/missing-failure.sh\"]\n }}\n}}"
)
.as_str(),
);
}
fn write_lifecycle_plugin(root: &Path, name: &str, version: &str) -> PathBuf {
let log_path = root.join("lifecycle.log");
write_file(
@ -2122,7 +2417,7 @@ mod tests {
let script_path = root.join("tools").join("echo-json.sh");
write_file(
&script_path,
"#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAW_PLUGIN_ID\" \"$CLAW_TOOL_NAME\" \"$INPUT\"\n",
"#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_PLUGIN_ID\" \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n",
);
#[cfg(unix)]
{
@ -2289,6 +2584,37 @@ mod tests {
let _ = fs::remove_dir_all(root);
}
#[test]
fn load_plugin_from_directory_rejects_claude_code_manifest_contracts_with_guidance() {
let root = temp_dir("manifest-claude-code-contract");
write_file(
root.join(MANIFEST_FILE_NAME).as_path(),
r#"{
"name": "oh-my-claudecode",
"version": "4.10.2",
"description": "Claude Code plugin manifest",
"hooks": {
"SessionStart": ["scripts/session-start.mjs"]
},
"agents": ["agents/*.md"],
"commands": ["commands/**/*.md"],
"skills": "./skills/",
"mcpServers": "./.mcp.json"
}"#,
);
let error = load_plugin_from_directory(&root)
.expect_err("Claude Code plugin manifest should fail with guidance");
let rendered = error.to_string();
assert!(rendered.contains("field `skills` uses the Claude Code plugin contract"));
assert!(rendered.contains("field `mcpServers` uses the Claude Code plugin contract"));
assert!(rendered.contains("field `agents` uses the Claude Code plugin contract"));
assert!(rendered.contains("field `commands` uses Claude Code-style directory globs"));
assert!(rendered.contains("hook `SessionStart` uses the Claude Code lifecycle contract"));
let _ = fs::remove_dir_all(root);
}
#[test]
fn load_plugin_from_directory_rejects_missing_tool_or_command_paths() {
let root = temp_dir("manifest-paths");
@ -2315,6 +2641,90 @@ mod tests {
let _ = fs::remove_dir_all(root);
}
#[test]
fn load_plugin_from_directory_rejects_missing_lifecycle_paths() {
// given
let root = temp_dir("manifest-lifecycle-paths");
write_file(
root.join(MANIFEST_FILE_NAME).as_path(),
r#"{
"name": "missing-lifecycle-paths",
"version": "1.0.0",
"description": "Missing lifecycle path validation",
"lifecycle": {
"Init": ["./lifecycle/init.sh"],
"Shutdown": ["./lifecycle/shutdown.sh"]
}
}"#,
);
// when
let error =
load_plugin_from_directory(&root).expect_err("missing lifecycle paths should fail");
// then
match error {
PluginError::ManifestValidation(errors) => {
assert!(errors.iter().any(|error| matches!(
error,
PluginManifestValidationError::MissingPath { kind, path }
if *kind == "lifecycle command"
&& path.ends_with(Path::new("lifecycle/init.sh"))
)));
assert!(errors.iter().any(|error| matches!(
error,
PluginManifestValidationError::MissingPath { kind, path }
if *kind == "lifecycle command"
&& path.ends_with(Path::new("lifecycle/shutdown.sh"))
)));
}
other => panic!("expected manifest validation errors, got {other}"),
}
let _ = fs::remove_dir_all(root);
}
#[test]
fn load_plugin_from_directory_rejects_directory_command_paths() {
// given
let root = temp_dir("manifest-directory-paths");
write_directory_path_plugin(&root, "directory-paths");
// when
let error =
load_plugin_from_directory(&root).expect_err("directory command paths should fail");
// then
match error {
PluginError::ManifestValidation(errors) => {
assert!(errors.iter().any(|error| matches!(
error,
PluginManifestValidationError::PathIsDirectory { kind, path }
if *kind == "hook" && path.ends_with(Path::new("hooks/pre-dir"))
)));
assert!(errors.iter().any(|error| matches!(
error,
PluginManifestValidationError::PathIsDirectory { kind, path }
if *kind == "lifecycle command"
&& path.ends_with(Path::new("lifecycle/init-dir"))
)));
assert!(errors.iter().any(|error| matches!(
error,
PluginManifestValidationError::PathIsDirectory { kind, path }
if *kind == "tool" && path.ends_with(Path::new("tools/tool-dir"))
)));
assert!(errors.iter().any(|error| matches!(
error,
PluginManifestValidationError::PathIsDirectory { kind, path }
if *kind == "command" && path.ends_with(Path::new("commands/sync-dir"))
)));
}
other => panic!("expected manifest validation errors, got {other}"),
}
let _ = fs::remove_dir_all(root);
}
#[test]
fn load_plugin_from_directory_rejects_invalid_permissions() {
let root = temp_dir("manifest-invalid-permissions");
@ -2806,16 +3216,95 @@ mod tests {
let _ = fs::remove_dir_all(source_root);
}
#[test]
fn plugin_registry_report_collects_load_failures_without_dropping_valid_plugins() {
// given
let config_home = temp_dir("report-home");
let external_root = temp_dir("report-external");
write_external_plugin(&external_root.join("valid"), "valid-report", "1.0.0");
write_broken_plugin(&external_root.join("broken"), "broken-report");
let mut config = PluginManagerConfig::new(&config_home);
config.external_dirs = vec![external_root.clone()];
let manager = PluginManager::new(config);
// when
let report = manager
.plugin_registry_report()
.expect("report should tolerate invalid external plugins");
// then
assert!(report.registry().contains("valid-report@external"));
assert_eq!(report.failures().len(), 1);
assert_eq!(report.failures()[0].kind, PluginKind::External);
assert!(report.failures()[0]
.plugin_root
.ends_with(Path::new("broken")));
assert!(report.failures()[0]
.error()
.to_string()
.contains("does not exist"));
let error = manager
.plugin_registry()
.expect_err("strict registry should surface load failures");
match error {
PluginError::LoadFailures(failures) => {
assert_eq!(failures.len(), 1);
assert!(failures[0].plugin_root.ends_with(Path::new("broken")));
}
other => panic!("expected load failures, got {other}"),
}
let _ = fs::remove_dir_all(config_home);
let _ = fs::remove_dir_all(external_root);
}
#[test]
fn installed_plugin_registry_report_collects_load_failures_from_install_root() {
// given
let config_home = temp_dir("installed-report-home");
let bundled_root = temp_dir("installed-report-bundled");
let install_root = config_home.join("plugins").join("installed");
write_external_plugin(&install_root.join("valid"), "installed-valid", "1.0.0");
write_broken_plugin(&install_root.join("broken"), "installed-broken");
let mut config = PluginManagerConfig::new(&config_home);
config.bundled_root = Some(bundled_root.clone());
config.install_root = Some(install_root);
let manager = PluginManager::new(config);
// when
let report = manager
.installed_plugin_registry_report()
.expect("installed report should tolerate invalid installed plugins");
// then
assert!(report.registry().contains("installed-valid@external"));
assert_eq!(report.failures().len(), 1);
assert!(report.failures()[0]
.plugin_root
.ends_with(Path::new("broken")));
let _ = fs::remove_dir_all(config_home);
let _ = fs::remove_dir_all(bundled_root);
}
#[test]
fn rejects_plugin_sources_with_missing_hook_paths() {
// given
let config_home = temp_dir("broken-home");
let source_root = temp_dir("broken-source");
write_broken_plugin(&source_root, "broken");
let manager = PluginManager::new(PluginManagerConfig::new(&config_home));
// when
let error = manager
.validate_plugin_source(source_root.to_str().expect("utf8 path"))
.expect_err("missing hook file should fail validation");
// then
assert!(error.to_string().contains("does not exist"));
let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
@ -2828,6 +3317,33 @@ mod tests {
let _ = fs::remove_dir_all(source_root);
}
#[test]
fn rejects_plugin_sources_with_missing_failure_hook_paths() {
// given
let config_home = temp_dir("broken-failure-home");
let source_root = temp_dir("broken-failure-source");
write_broken_failure_hook_plugin(&source_root, "broken-failure");
let manager = PluginManager::new(PluginManagerConfig::new(&config_home));
// when
let error = manager
.validate_plugin_source(source_root.to_str().expect("utf8 path"))
.expect_err("missing failure hook file should fail validation");
// then
assert!(error.to_string().contains("does not exist"));
let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
let install_error = manager
.install(source_root.to_str().expect("utf8 path"))
.expect_err("install should reject invalid failure hook paths");
assert!(install_error.to_string().contains("does not exist"));
let _ = fs::remove_dir_all(config_home);
let _ = fs::remove_dir_all(source_root);
}
#[test]
fn plugin_registry_runs_initialize_and_shutdown_for_enabled_plugins() {
let config_home = temp_dir("lifecycle-home");

View File

@ -7,13 +7,14 @@ publish.workspace = true
[dependencies]
sha2 = "0.10"
which = "7"
glob = "0.3"
lsp = { path = "../lsp" }
plugins = { path = "../plugins" }
regex = "1"
serde = { version = "1", features = ["derive"] }
serde_json.workspace = true
tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
telemetry = { path = "../telemetry" }
tokio = { version = "1", features = ["io-std", "io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
walkdir = "2"
[lints]

View File

@ -14,6 +14,7 @@ use crate::sandbox::{
};
use crate::ConfigLoader;
/// Input schema for the built-in bash execution tool.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct BashCommandInput {
pub command: String,
@ -33,6 +34,7 @@ pub struct BashCommandInput {
pub allowed_mounts: Option<Vec<String>>,
}
/// Output returned from a bash tool invocation.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct BashCommandOutput {
pub stdout: String,
@ -64,6 +66,7 @@ pub struct BashCommandOutput {
pub sandbox_status: Option<SandboxStatus>,
}
/// Executes a shell command with the requested sandbox settings.
pub fn execute_bash(input: BashCommandInput) -> io::Result<BashCommandOutput> {
let cwd = env::current_dir()?;
let sandbox_status = sandbox_status_for_input(&input, &cwd);
@ -134,8 +137,8 @@ async fn execute_bash_async(
};
let (output, interrupted) = output_result;
let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
let stdout = truncate_output(&String::from_utf8_lossy(&output.stdout));
let stderr = truncate_output(&String::from_utf8_lossy(&output.stderr));
let no_output_expected = Some(stdout.trim().is_empty() && stderr.trim().is_empty());
let return_code_interpretation = output.status.code().and_then(|code| {
if code == 0 {
@ -197,36 +200,31 @@ fn prepare_command(
return prepared;
}
let mut prepared = if cfg!(target_os = "windows") && !sh_exists() {
let prepared = if cfg!(target_os = "windows") && !sh_exists() {
let mut p = Command::new("cmd");
p.arg("/C").arg(command);
p.arg("/C").arg(command).current_dir(cwd);
p
} else {
let mut p = Command::new("sh");
p.arg("-lc").arg(command);
p.arg("-lc").arg(command).current_dir(cwd);
if sandbox_status.filesystem_active {
p.env("HOME", cwd.join(".sandbox-home"));
p.env("TMPDIR", cwd.join(".sandbox-tmp"));
}
p
};
prepared.current_dir(cwd);
if sandbox_status.filesystem_active {
prepared.env("HOME", cwd.join(".sandbox-home"));
prepared.env("TMPDIR", cwd.join(".sandbox-tmp"));
}
prepared
}
fn sh_exists() -> bool {
env::var_os("PATH").is_some_and(|paths| {
env::split_paths(&paths).any(|path| {
#[cfg(windows)]
{
path.join("sh.exe").exists() || path.join("sh.bat").exists() || path.join("sh").exists()
which::which("sh").is_ok()
}
#[cfg(not(windows))]
{
path.join("sh").exists()
true
}
})
})
}
fn prepare_tokio_command(
@ -247,20 +245,19 @@ fn prepare_tokio_command(
return prepared;
}
let mut prepared = if cfg!(target_os = "windows") && !sh_exists() {
let prepared = if cfg!(target_os = "windows") && !sh_exists() {
let mut p = TokioCommand::new("cmd");
p.arg("/C").arg(command);
p.arg("/C").arg(command).current_dir(cwd);
p
} else {
let mut p = TokioCommand::new("sh");
p.arg("-lc").arg(command);
p.arg("-lc").arg(command).current_dir(cwd);
if sandbox_status.filesystem_active {
p.env("HOME", cwd.join(".sandbox-home"));
p.env("TMPDIR", cwd.join(".sandbox-tmp"));
}
p
};
prepared.current_dir(cwd);
if sandbox_status.filesystem_active {
prepared.env("HOME", cwd.join(".sandbox-home"));
prepared.env("TMPDIR", cwd.join(".sandbox-tmp"));
}
prepared
}
@ -312,3 +309,53 @@ mod tests {
assert!(!output.sandbox_status.expect("sandbox status").enabled);
}
}
/// Maximum output bytes before truncation (16 KiB, matching upstream).
const MAX_OUTPUT_BYTES: usize = 16_384;
/// Truncate output to `MAX_OUTPUT_BYTES`, appending a marker when trimmed.
fn truncate_output(s: &str) -> String {
if s.len() <= MAX_OUTPUT_BYTES {
return s.to_string();
}
// Find the last valid UTF-8 boundary at or before MAX_OUTPUT_BYTES
let mut end = MAX_OUTPUT_BYTES;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
let mut truncated = s[..end].to_string();
truncated.push_str("\n\n[output truncated — exceeded 16384 bytes]");
truncated
}
#[cfg(test)]
mod truncation_tests {
use super::*;
#[test]
fn short_output_unchanged() {
let s = "hello world";
assert_eq!(truncate_output(s), s);
}
#[test]
fn long_output_truncated() {
let s = "x".repeat(20_000);
let result = truncate_output(&s);
assert!(result.len() < 20_000);
assert!(result.ends_with("[output truncated — exceeded 16384 bytes]"));
}
#[test]
fn exact_boundary_unchanged() {
let s = "a".repeat(MAX_OUTPUT_BYTES);
assert_eq!(truncate_output(&s), s);
}
#[test]
fn one_over_boundary_truncated() {
let s = "a".repeat(MAX_OUTPUT_BYTES + 1);
let result = truncate_output(&s);
assert!(result.contains("[output truncated"));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,7 @@ pub struct BootstrapPlan {
impl BootstrapPlan {
#[must_use]
pub fn claw_default() -> Self {
pub fn claude_code_default() -> Self {
Self::from_phases(vec![
BootstrapPhase::CliEntry,
BootstrapPhase::FastPathVersion,
@ -54,3 +54,58 @@ impl BootstrapPlan {
&self.phases
}
}
#[cfg(test)]
mod tests {
use super::{BootstrapPhase, BootstrapPlan};
#[test]
fn from_phases_deduplicates_while_preserving_order() {
// given
let phases = vec![
BootstrapPhase::CliEntry,
BootstrapPhase::FastPathVersion,
BootstrapPhase::CliEntry,
BootstrapPhase::MainRuntime,
BootstrapPhase::FastPathVersion,
];
// when
let plan = BootstrapPlan::from_phases(phases);
// then
assert_eq!(
plan.phases(),
&[
BootstrapPhase::CliEntry,
BootstrapPhase::FastPathVersion,
BootstrapPhase::MainRuntime,
]
);
}
#[test]
fn claude_code_default_covers_each_phase_once() {
// given
let expected = [
BootstrapPhase::CliEntry,
BootstrapPhase::FastPathVersion,
BootstrapPhase::StartupProfiler,
BootstrapPhase::SystemPromptFastPath,
BootstrapPhase::ChromeMcpFastPath,
BootstrapPhase::DaemonWorkerFastPath,
BootstrapPhase::BridgeFastPath,
BootstrapPhase::DaemonFastPath,
BootstrapPhase::BackgroundSessionFastPath,
BootstrapPhase::TemplateFastPath,
BootstrapPhase::EnvironmentRunnerFastPath,
BootstrapPhase::MainRuntime,
];
// when
let plan = BootstrapPlan::claude_code_default();
// then
assert_eq!(plan.phases(), &expected);
}
}

View File

@ -0,0 +1,144 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BranchLockIntent {
#[serde(rename = "laneId")]
pub lane_id: String,
pub branch: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub worktree: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub modules: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct BranchLockCollision {
pub branch: String,
pub module: String,
#[serde(rename = "laneIds")]
pub lane_ids: Vec<String>,
}
#[must_use]
pub fn detect_branch_lock_collisions(intents: &[BranchLockIntent]) -> Vec<BranchLockCollision> {
let mut collisions = Vec::new();
for (index, left) in intents.iter().enumerate() {
for right in &intents[index + 1..] {
if left.branch != right.branch {
continue;
}
for module in overlapping_modules(&left.modules, &right.modules) {
collisions.push(BranchLockCollision {
branch: left.branch.clone(),
module,
lane_ids: vec![left.lane_id.clone(), right.lane_id.clone()],
});
}
}
}
collisions.sort_by(|a, b| {
a.branch
.cmp(&b.branch)
.then(a.module.cmp(&b.module))
.then(a.lane_ids.cmp(&b.lane_ids))
});
collisions.dedup();
collisions
}
fn overlapping_modules(left: &[String], right: &[String]) -> Vec<String> {
let mut overlaps = Vec::new();
for left_module in left {
for right_module in right {
if modules_overlap(left_module, right_module) {
overlaps.push(shared_scope(left_module, right_module));
}
}
}
overlaps.sort();
overlaps.dedup();
overlaps
}
fn modules_overlap(left: &str, right: &str) -> bool {
left == right
|| left.starts_with(&format!("{right}/"))
|| right.starts_with(&format!("{left}/"))
}
fn shared_scope(left: &str, right: &str) -> String {
if left.starts_with(&format!("{right}/")) || left == right {
right.to_string()
} else {
left.to_string()
}
}
#[cfg(test)]
mod tests {
use super::{detect_branch_lock_collisions, BranchLockIntent};
#[test]
fn detects_same_branch_same_module_collisions() {
let collisions = detect_branch_lock_collisions(&[
BranchLockIntent {
lane_id: "lane-a".to_string(),
branch: "feature/lock".to_string(),
worktree: Some("wt-a".to_string()),
modules: vec!["runtime/mcp".to_string()],
},
BranchLockIntent {
lane_id: "lane-b".to_string(),
branch: "feature/lock".to_string(),
worktree: Some("wt-b".to_string()),
modules: vec!["runtime/mcp".to_string()],
},
]);
assert_eq!(collisions.len(), 1);
assert_eq!(collisions[0].branch, "feature/lock");
assert_eq!(collisions[0].module, "runtime/mcp");
}
#[test]
fn detects_nested_module_scope_collisions() {
let collisions = detect_branch_lock_collisions(&[
BranchLockIntent {
lane_id: "lane-a".to_string(),
branch: "feature/lock".to_string(),
worktree: None,
modules: vec!["runtime".to_string()],
},
BranchLockIntent {
lane_id: "lane-b".to_string(),
branch: "feature/lock".to_string(),
worktree: None,
modules: vec!["runtime/mcp".to_string()],
},
]);
assert_eq!(collisions[0].module, "runtime");
}
#[test]
fn ignores_different_branches() {
let collisions = detect_branch_lock_collisions(&[
BranchLockIntent {
lane_id: "lane-a".to_string(),
branch: "feature/a".to_string(),
worktree: None,
modules: vec!["runtime/mcp".to_string()],
},
BranchLockIntent {
lane_id: "lane-b".to_string(),
branch: "feature/b".to_string(),
worktree: None,
modules: vec!["runtime/mcp".to_string()],
},
]);
assert!(collisions.is_empty());
}
}

View File

@ -5,6 +5,7 @@ const COMPACT_CONTINUATION_PREAMBLE: &str =
const COMPACT_RECENT_MESSAGES_NOTE: &str = "Recent messages are preserved verbatim.";
const COMPACT_DIRECT_RESUME_INSTRUCTION: &str = "Continue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text.";
/// Thresholds controlling when and how a session is compacted.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CompactionConfig {
pub preserve_recent_messages: usize,
@ -20,6 +21,7 @@ impl Default for CompactionConfig {
}
}
/// Result of compacting a session into a summary plus preserved tail messages.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompactionResult {
pub summary: String,
@ -28,11 +30,13 @@ pub struct CompactionResult {
pub removed_message_count: usize,
}
/// Roughly estimates the token footprint of the current session transcript.
#[must_use]
pub fn estimate_session_tokens(session: &Session) -> usize {
session.messages.iter().map(estimate_message_tokens).sum()
}
/// Returns `true` when the session exceeds the configured compaction budget.
#[must_use]
pub fn should_compact(session: &Session, config: CompactionConfig) -> bool {
let start = compacted_summary_prefix_len(session);
@ -46,6 +50,7 @@ pub fn should_compact(session: &Session, config: CompactionConfig) -> bool {
>= config.max_estimated_tokens
}
/// Normalizes a compaction summary into user-facing continuation text.
#[must_use]
pub fn format_compact_summary(summary: &str) -> String {
let without_analysis = strip_tag_block(summary, "analysis");
@ -61,6 +66,7 @@ pub fn format_compact_summary(summary: &str) -> String {
collapse_blank_lines(&formatted).trim().to_string()
}
/// Builds the synthetic system message used after session compaction.
#[must_use]
pub fn get_compact_continuation_message(
summary: &str,
@ -85,6 +91,7 @@ pub fn get_compact_continuation_message(
base
}
/// Compacts a session by summarizing older messages and preserving the recent tail.
#[must_use]
pub fn compact_session(session: &Session, config: CompactionConfig) -> CompactionResult {
if !should_compact(session, config) {
@ -101,10 +108,55 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
.first()
.and_then(extract_existing_compacted_summary);
let compacted_prefix_len = usize::from(existing_summary.is_some());
let keep_from = session
let raw_keep_from = session
.messages
.len()
.saturating_sub(config.preserve_recent_messages);
// Ensure we do not split a tool-use / tool-result pair at the compaction
// boundary. If the first preserved message is a user message whose first
// block is a ToolResult, the assistant message with the matching ToolUse
// was slated for removal — that produces an orphaned tool role message on
// the OpenAI-compat path (400: tool message must follow assistant with
// tool_calls). Walk the boundary back until we start at a safe point.
let keep_from = {
let mut k = raw_keep_from;
// If the first preserved message is a tool-result turn, ensure its
// paired assistant tool-use turn is preserved too. Without this fix,
// the OpenAI-compat adapter sends an orphaned 'tool' role message
// with no preceding assistant 'tool_calls', which providers reject
// with a 400. We walk back only if the immediately preceding message
// is NOT an assistant message that contains a ToolUse block (i.e. the
// pair is actually broken at the boundary).
loop {
if k == 0 || k <= compacted_prefix_len {
break;
}
let first_preserved = &session.messages[k];
let starts_with_tool_result = first_preserved
.blocks
.first()
.map(|b| matches!(b, ContentBlock::ToolResult { .. }))
.unwrap_or(false);
if !starts_with_tool_result {
break;
}
// Check the message just before the current boundary.
let preceding = &session.messages[k - 1];
let preceding_has_tool_use = preceding
.blocks
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
if preceding_has_tool_use {
// Pair is intact — walk back one more to include the assistant turn.
k = k.saturating_sub(1);
break;
}
// Preceding message has no ToolUse but we have a ToolResult —
// this is already an orphaned pair; walk back to try to fix it.
k = k.saturating_sub(1);
}
k
};
let removed = &session.messages[compacted_prefix_len..keep_from];
let preserved = session.messages[keep_from..].to_vec();
let summary =
@ -119,13 +171,14 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
}];
compacted_messages.extend(preserved);
let mut compacted_session = session.clone();
compacted_session.messages = compacted_messages;
compacted_session.record_compaction(summary.clone(), removed.len());
CompactionResult {
summary,
formatted_summary,
compacted_session: Session {
version: session.version,
messages: compacted_messages,
},
compacted_session,
removed_message_count: removed.len(),
}
}
@ -160,9 +213,7 @@ fn summarize_messages(messages: &[ConversationMessage]) -> String {
.filter_map(|block| match block {
ContentBlock::ToolUse { name, .. } => Some(name.as_str()),
ContentBlock::ToolResult { tool_name, .. } => Some(tool_name.as_str()),
ContentBlock::Text { .. }
| ContentBlock::Thinking { .. }
| ContentBlock::RedactedThinking { .. } => None,
ContentBlock::Text { .. } => None,
})
.collect::<Vec<_>>();
tool_names.sort_unstable();
@ -277,8 +328,6 @@ fn summarize_block(block: &ContentBlock) -> String {
"tool_result {tool_name}: {}{output}",
if *is_error { "error " } else { "" }
),
ContentBlock::Thinking { thinking, .. } => format!("thinking: {thinking}"),
ContentBlock::RedactedThinking { .. } => "thinking: <redacted>".to_string(),
};
truncate_summary(&raw, 160)
}
@ -328,8 +377,6 @@ fn collect_key_files(messages: &[ConversationMessage]) -> Vec<String> {
.flat_map(|message| message.blocks.iter())
.map(|block| match block {
ContentBlock::Text { text } => text.as_str(),
ContentBlock::Thinking { thinking, .. } => thinking.as_str(),
ContentBlock::RedactedThinking { .. } => "",
ContentBlock::ToolUse { input, .. } => input.as_str(),
ContentBlock::ToolResult { output, .. } => output.as_str(),
})
@ -354,9 +401,7 @@ fn first_text_block(message: &ConversationMessage) -> Option<&str> {
ContentBlock::Text { text } if !text.trim().is_empty() => Some(text.as_str()),
ContentBlock::ToolUse { .. }
| ContentBlock::ToolResult { .. }
| ContentBlock::Text { .. }
| ContentBlock::Thinking { .. }
| ContentBlock::RedactedThinking { .. } => None,
| ContentBlock::Text { .. } => None,
})
}
@ -402,8 +447,6 @@ fn estimate_message_tokens(message: &ConversationMessage) -> usize {
.iter()
.map(|block| match block {
ContentBlock::Text { text } => text.len() / 4 + 1,
ContentBlock::Thinking { thinking, .. } => thinking.len() / 4 + 1,
ContentBlock::RedactedThinking { .. } => 1,
ContentBlock::ToolUse { name, input, .. } => (name.len() + input.len()) / 4 + 1,
ContentBlock::ToolResult {
tool_name, output, ..
@ -512,7 +555,7 @@ fn extract_summary_timeline(summary: &str) -> Vec<String> {
#[cfg(test)]
mod tests {
use super::{
collect_key_files, compact_session, estimate_session_tokens, format_compact_summary,
collect_key_files, compact_session, format_compact_summary,
get_compact_continuation_message, infer_pending_work, should_compact, CompactionConfig,
};
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
@ -525,10 +568,8 @@ mod tests {
#[test]
fn leaves_small_sessions_unchanged() {
let session = Session {
version: 1,
messages: vec![ConversationMessage::user_text("hello")],
};
let mut session = Session::new();
session.messages = vec![ConversationMessage::user_text("hello")];
let result = compact_session(&session, CompactionConfig::default());
assert_eq!(result.removed_message_count, 0);
@ -539,9 +580,8 @@ mod tests {
#[test]
fn compacts_older_messages_into_a_system_summary() {
let session = Session {
version: 1,
messages: vec![
let mut session = Session::new();
session.messages = vec![
ConversationMessage::user_text("one ".repeat(200)),
ConversationMessage::assistant(vec![ContentBlock::Text {
text: "two ".repeat(200),
@ -554,8 +594,7 @@ mod tests {
}],
usage: None,
},
],
};
];
let result = compact_session(
&session,
@ -565,7 +604,14 @@ mod tests {
},
);
assert_eq!(result.removed_message_count, 2);
// With the tool-use/tool-result boundary fix, the compaction preserves
// one extra message to avoid an orphaned tool result at the boundary.
// messages[1] (assistant) must be kept along with messages[2] (tool result).
assert!(
result.removed_message_count <= 2,
"expected at most 2 removed, got {}",
result.removed_message_count
);
assert_eq!(
result.compacted_session.messages[0].role,
MessageRole::System
@ -583,28 +629,29 @@ mod tests {
max_estimated_tokens: 1,
}
));
// Note: with the tool-use/tool-result boundary guard the compacted session
// may preserve one extra message at the boundary, so token reduction is
// not guaranteed for small sessions. The invariant that matters is that
// the removed_message_count is non-zero (something was compacted).
assert!(
estimate_session_tokens(&result.compacted_session) < estimate_session_tokens(&session)
result.removed_message_count > 0,
"compaction must remove at least one message"
);
}
#[test]
fn keeps_previous_compacted_context_when_compacting_again() {
let initial_session = Session {
version: 1,
messages: vec![
let mut initial_session = Session::new();
initial_session.messages = vec![
ConversationMessage::user_text("Investigate rust/crates/runtime/src/compact.rs"),
ConversationMessage::assistant(vec![ContentBlock::Text {
text: "I will inspect the compact flow.".to_string(),
}]),
ConversationMessage::user_text(
"Also update rust/crates/runtime/src/conversation.rs",
),
ConversationMessage::user_text("Also update rust/crates/runtime/src/conversation.rs"),
ConversationMessage::assistant(vec![ContentBlock::Text {
text: "Next: preserve prior summary context during auto compact.".to_string(),
}]),
],
};
];
let config = CompactionConfig {
preserve_recent_messages: 2,
max_estimated_tokens: 1,
@ -619,13 +666,9 @@ mod tests {
}]),
]);
let second = compact_session(
&Session {
version: 1,
messages: follow_up_messages,
},
config,
);
let mut second_session = Session::new();
second_session.messages = follow_up_messages;
let second = compact_session(&second_session, config);
assert!(second
.formatted_summary
@ -654,9 +697,8 @@ mod tests {
#[test]
fn ignores_existing_compacted_summary_when_deciding_to_recompact() {
let summary = "<summary>Conversation summary:\n- Scope: earlier work preserved.\n- Key timeline:\n - user: large preserved context\n</summary>";
let session = Session {
version: 1,
messages: vec![
let mut session = Session::new();
session.messages = vec![
ConversationMessage {
role: MessageRole::System,
blocks: vec![ContentBlock::Text {
@ -668,8 +710,7 @@ mod tests {
ConversationMessage::assistant(vec![ContentBlock::Text {
text: "recent".to_string(),
}]),
],
};
];
assert!(!should_compact(
&session,
@ -692,10 +733,84 @@ mod tests {
#[test]
fn extracts_key_files_from_message_content() {
let files = collect_key_files(&[ConversationMessage::user_text(
"Update rust/crates/runtime/src/compact.rs and rust/crates/tools/src/lib.rs next.",
"Update rust/crates/runtime/src/compact.rs and rust/crates/rusty-claude-cli/src/main.rs next.",
)]);
assert!(files.contains(&"rust/crates/runtime/src/compact.rs".to_string()));
assert!(files.contains(&"rust/crates/tools/src/lib.rs".to_string()));
assert!(files.contains(&"rust/crates/rusty-claude-cli/src/main.rs".to_string()));
}
/// Regression: compaction must not split an assistant(ToolUse) /
/// user(ToolResult) pair at the boundary. An orphaned tool-result message
/// without the preceding assistant tool_calls causes a 400 on the
/// OpenAI-compat path (gaebal-gajae repro 2026-04-09).
#[test]
fn compaction_does_not_split_tool_use_tool_result_pair() {
use crate::session::{ContentBlock, Session};
let tool_id = "call_abc";
let mut session = Session::default();
// Turn 1: user prompt
session
.push_message(ConversationMessage::user_text("Search for files"))
.unwrap();
// Turn 2: assistant calls a tool
session
.push_message(ConversationMessage::assistant(vec![
ContentBlock::ToolUse {
id: tool_id.to_string(),
name: "search".to_string(),
input: "{\"q\":\"*.rs\"}".to_string(),
},
]))
.unwrap();
// Turn 3: tool result
session
.push_message(ConversationMessage::tool_result(
tool_id,
"search",
"found 5 files",
false,
))
.unwrap();
// Turn 4: assistant final response
session
.push_message(ConversationMessage::assistant(vec![ContentBlock::Text {
text: "Done.".to_string(),
}]))
.unwrap();
// Compact preserving only 1 recent message — without the fix this
// would cut the boundary so that the tool result (turn 3) is first,
// without its preceding assistant tool_calls (turn 2).
let config = CompactionConfig {
preserve_recent_messages: 1,
..CompactionConfig::default()
};
let result = compact_session(&session, config);
// After compaction, no two consecutive messages should have the pattern
// tool_result immediately following a non-assistant message (i.e. an
// orphaned tool result without a preceding assistant ToolUse).
let messages = &result.compacted_session.messages;
for i in 1..messages.len() {
let curr_is_tool_result = messages[i]
.blocks
.first()
.map(|b| matches!(b, ContentBlock::ToolResult { .. }))
.unwrap_or(false);
if curr_is_tool_result {
let prev_has_tool_use = messages[i - 1]
.blocks
.iter()
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
assert!(
prev_has_tool_use,
"message[{}] is a ToolResult but message[{}] has no ToolUse: {:?}",
i,
i - 1,
&messages[i - 1].blocks
);
}
}
}
#[test]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,901 @@
use std::collections::BTreeMap;
use std::path::Path;
use crate::config::ConfigError;
use crate::json::JsonValue;
/// Diagnostic emitted when a config file contains a suspect field.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ConfigDiagnostic {
pub path: String,
pub field: String,
pub line: Option<usize>,
pub kind: DiagnosticKind,
}
/// Classification of the diagnostic.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DiagnosticKind {
UnknownKey {
suggestion: Option<String>,
},
WrongType {
expected: &'static str,
got: &'static str,
},
Deprecated {
replacement: &'static str,
},
}
impl std::fmt::Display for ConfigDiagnostic {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let location = self
.line
.map_or_else(String::new, |line| format!(" (line {line})"));
match &self.kind {
DiagnosticKind::UnknownKey { suggestion: None } => {
write!(f, "{}: unknown key \"{}\"{location}", self.path, self.field)
}
DiagnosticKind::UnknownKey {
suggestion: Some(hint),
} => {
write!(
f,
"{}: unknown key \"{}\"{location}. Did you mean \"{}\"?",
self.path, self.field, hint
)
}
DiagnosticKind::WrongType { expected, got } => {
write!(
f,
"{}: field \"{}\" must be {expected}, got {got}{location}",
self.path, self.field
)
}
DiagnosticKind::Deprecated { replacement } => {
write!(
f,
"{}: field \"{}\" is deprecated{location}. Use \"{replacement}\" instead",
self.path, self.field
)
}
}
}
}
/// Result of validating a single config file.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ValidationResult {
pub errors: Vec<ConfigDiagnostic>,
pub warnings: Vec<ConfigDiagnostic>,
}
impl ValidationResult {
#[must_use]
pub fn is_ok(&self) -> bool {
self.errors.is_empty()
}
fn merge(&mut self, other: Self) {
self.errors.extend(other.errors);
self.warnings.extend(other.warnings);
}
}
// ---- known-key schema ----
/// Expected type for a config field.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum FieldType {
String,
Bool,
Object,
StringArray,
Number,
}
impl FieldType {
fn label(self) -> &'static str {
match self {
Self::String => "a string",
Self::Bool => "a boolean",
Self::Object => "an object",
Self::StringArray => "an array of strings",
Self::Number => "a number",
}
}
fn matches(self, value: &JsonValue) -> bool {
match self {
Self::String => value.as_str().is_some(),
Self::Bool => value.as_bool().is_some(),
Self::Object => value.as_object().is_some(),
Self::StringArray => value
.as_array()
.is_some_and(|arr| arr.iter().all(|v| v.as_str().is_some())),
Self::Number => value.as_i64().is_some(),
}
}
}
fn json_type_label(value: &JsonValue) -> &'static str {
match value {
JsonValue::Null => "null",
JsonValue::Bool(_) => "a boolean",
JsonValue::Number(_) => "a number",
JsonValue::String(_) => "a string",
JsonValue::Array(_) => "an array",
JsonValue::Object(_) => "an object",
}
}
struct FieldSpec {
name: &'static str,
expected: FieldType,
}
struct DeprecatedField {
name: &'static str,
replacement: &'static str,
}
const TOP_LEVEL_FIELDS: &[FieldSpec] = &[
FieldSpec {
name: "$schema",
expected: FieldType::String,
},
FieldSpec {
name: "model",
expected: FieldType::String,
},
FieldSpec {
name: "hooks",
expected: FieldType::Object,
},
FieldSpec {
name: "permissions",
expected: FieldType::Object,
},
FieldSpec {
name: "permissionMode",
expected: FieldType::String,
},
FieldSpec {
name: "mcpServers",
expected: FieldType::Object,
},
FieldSpec {
name: "oauth",
expected: FieldType::Object,
},
FieldSpec {
name: "enabledPlugins",
expected: FieldType::Object,
},
FieldSpec {
name: "plugins",
expected: FieldType::Object,
},
FieldSpec {
name: "sandbox",
expected: FieldType::Object,
},
FieldSpec {
name: "env",
expected: FieldType::Object,
},
FieldSpec {
name: "aliases",
expected: FieldType::Object,
},
FieldSpec {
name: "providerFallbacks",
expected: FieldType::Object,
},
FieldSpec {
name: "trustedRoots",
expected: FieldType::StringArray,
},
];
const HOOKS_FIELDS: &[FieldSpec] = &[
FieldSpec {
name: "PreToolUse",
expected: FieldType::StringArray,
},
FieldSpec {
name: "PostToolUse",
expected: FieldType::StringArray,
},
FieldSpec {
name: "PostToolUseFailure",
expected: FieldType::StringArray,
},
];
const PERMISSIONS_FIELDS: &[FieldSpec] = &[
FieldSpec {
name: "defaultMode",
expected: FieldType::String,
},
FieldSpec {
name: "allow",
expected: FieldType::StringArray,
},
FieldSpec {
name: "deny",
expected: FieldType::StringArray,
},
FieldSpec {
name: "ask",
expected: FieldType::StringArray,
},
];
const PLUGINS_FIELDS: &[FieldSpec] = &[
FieldSpec {
name: "enabled",
expected: FieldType::Object,
},
FieldSpec {
name: "externalDirectories",
expected: FieldType::StringArray,
},
FieldSpec {
name: "installRoot",
expected: FieldType::String,
},
FieldSpec {
name: "registryPath",
expected: FieldType::String,
},
FieldSpec {
name: "bundledRoot",
expected: FieldType::String,
},
FieldSpec {
name: "maxOutputTokens",
expected: FieldType::Number,
},
];
const SANDBOX_FIELDS: &[FieldSpec] = &[
FieldSpec {
name: "enabled",
expected: FieldType::Bool,
},
FieldSpec {
name: "namespaceRestrictions",
expected: FieldType::Bool,
},
FieldSpec {
name: "networkIsolation",
expected: FieldType::Bool,
},
FieldSpec {
name: "filesystemMode",
expected: FieldType::String,
},
FieldSpec {
name: "allowedMounts",
expected: FieldType::StringArray,
},
];
const OAUTH_FIELDS: &[FieldSpec] = &[
FieldSpec {
name: "clientId",
expected: FieldType::String,
},
FieldSpec {
name: "authorizeUrl",
expected: FieldType::String,
},
FieldSpec {
name: "tokenUrl",
expected: FieldType::String,
},
FieldSpec {
name: "callbackPort",
expected: FieldType::Number,
},
FieldSpec {
name: "manualRedirectUrl",
expected: FieldType::String,
},
FieldSpec {
name: "scopes",
expected: FieldType::StringArray,
},
];
const DEPRECATED_FIELDS: &[DeprecatedField] = &[
DeprecatedField {
name: "permissionMode",
replacement: "permissions.defaultMode",
},
DeprecatedField {
name: "enabledPlugins",
replacement: "plugins.enabled",
},
];
// ---- line-number resolution ----
/// Find the 1-based line number where a JSON key first appears in the raw source.
fn find_key_line(source: &str, key: &str) -> Option<usize> {
// Search for `"key"` followed by optional whitespace and a colon.
let needle = format!("\"{key}\"");
let mut search_start = 0;
while let Some(offset) = source[search_start..].find(&needle) {
let absolute = search_start + offset;
let after = absolute + needle.len();
// Verify the next non-whitespace char is `:` to confirm this is a key, not a value.
if source[after..].chars().find(|ch| !ch.is_ascii_whitespace()) == Some(':') {
return Some(source[..absolute].chars().filter(|&ch| ch == '\n').count() + 1);
}
search_start = after;
}
None
}
// ---- core validation ----
fn validate_object_keys(
object: &BTreeMap<String, JsonValue>,
known_fields: &[FieldSpec],
prefix: &str,
source: &str,
path_display: &str,
) -> ValidationResult {
let mut result = ValidationResult {
errors: Vec::new(),
warnings: Vec::new(),
};
let known_names: Vec<&str> = known_fields.iter().map(|f| f.name).collect();
for (key, value) in object {
let field_path = if prefix.is_empty() {
key.clone()
} else {
format!("{prefix}.{key}")
};
if let Some(spec) = known_fields.iter().find(|f| f.name == key) {
// Type check.
if !spec.expected.matches(value) {
result.errors.push(ConfigDiagnostic {
path: path_display.to_string(),
field: field_path,
line: find_key_line(source, key),
kind: DiagnosticKind::WrongType {
expected: spec.expected.label(),
got: json_type_label(value),
},
});
}
} else if DEPRECATED_FIELDS.iter().any(|d| d.name == key) {
// Deprecated key — handled separately, not an unknown-key error.
} else {
// Unknown key.
let suggestion = suggest_field(key, &known_names);
result.errors.push(ConfigDiagnostic {
path: path_display.to_string(),
field: field_path,
line: find_key_line(source, key),
kind: DiagnosticKind::UnknownKey { suggestion },
});
}
}
result
}
fn suggest_field(input: &str, candidates: &[&str]) -> Option<String> {
let input_lower = input.to_ascii_lowercase();
candidates
.iter()
.filter_map(|candidate| {
let distance = simple_edit_distance(&input_lower, &candidate.to_ascii_lowercase());
(distance <= 3).then_some((distance, *candidate))
})
.min_by_key(|(distance, _)| *distance)
.map(|(_, name)| name.to_string())
}
fn simple_edit_distance(left: &str, right: &str) -> usize {
if left.is_empty() {
return right.len();
}
if right.is_empty() {
return left.len();
}
let right_chars: Vec<char> = right.chars().collect();
let mut previous: Vec<usize> = (0..=right_chars.len()).collect();
let mut current = vec![0; right_chars.len() + 1];
for (left_index, left_char) in left.chars().enumerate() {
current[0] = left_index + 1;
for (right_index, right_char) in right_chars.iter().enumerate() {
let cost = usize::from(left_char != *right_char);
current[right_index + 1] = (previous[right_index + 1] + 1)
.min(current[right_index] + 1)
.min(previous[right_index] + cost);
}
previous.clone_from(&current);
}
previous[right_chars.len()]
}
/// Validate a parsed config file's keys and types against the known schema.
///
/// Returns diagnostics (errors and deprecation warnings) without blocking the load.
pub fn validate_config_file(
object: &BTreeMap<String, JsonValue>,
source: &str,
file_path: &Path,
) -> ValidationResult {
let path_display = file_path.display().to_string();
let mut result = validate_object_keys(object, TOP_LEVEL_FIELDS, "", source, &path_display);
// Check deprecated fields.
for deprecated in DEPRECATED_FIELDS {
if object.contains_key(deprecated.name) {
result.warnings.push(ConfigDiagnostic {
path: path_display.clone(),
field: deprecated.name.to_string(),
line: find_key_line(source, deprecated.name),
kind: DiagnosticKind::Deprecated {
replacement: deprecated.replacement,
},
});
}
}
// Validate known nested objects.
if let Some(hooks) = object.get("hooks").and_then(JsonValue::as_object) {
result.merge(validate_object_keys(
hooks,
HOOKS_FIELDS,
"hooks",
source,
&path_display,
));
}
if let Some(permissions) = object.get("permissions").and_then(JsonValue::as_object) {
result.merge(validate_object_keys(
permissions,
PERMISSIONS_FIELDS,
"permissions",
source,
&path_display,
));
}
if let Some(plugins) = object.get("plugins").and_then(JsonValue::as_object) {
result.merge(validate_object_keys(
plugins,
PLUGINS_FIELDS,
"plugins",
source,
&path_display,
));
}
if let Some(sandbox) = object.get("sandbox").and_then(JsonValue::as_object) {
result.merge(validate_object_keys(
sandbox,
SANDBOX_FIELDS,
"sandbox",
source,
&path_display,
));
}
if let Some(oauth) = object.get("oauth").and_then(JsonValue::as_object) {
result.merge(validate_object_keys(
oauth,
OAUTH_FIELDS,
"oauth",
source,
&path_display,
));
}
result
}
/// Check whether a file path uses an unsupported config format (e.g. TOML).
pub fn check_unsupported_format(file_path: &Path) -> Result<(), ConfigError> {
if let Some(ext) = file_path.extension().and_then(|e| e.to_str()) {
if ext.eq_ignore_ascii_case("toml") {
return Err(ConfigError::Parse(format!(
"{}: TOML config files are not supported. Use JSON (settings.json) instead",
file_path.display()
)));
}
}
Ok(())
}
/// Format all diagnostics into a human-readable report.
#[must_use]
pub fn format_diagnostics(result: &ValidationResult) -> String {
let mut lines = Vec::new();
for warning in &result.warnings {
lines.push(format!("warning: {warning}"));
}
for error in &result.errors {
lines.push(format!("error: {error}"));
}
lines.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn test_path() -> PathBuf {
PathBuf::from("/test/settings.json")
}
#[test]
fn detects_unknown_top_level_key() {
// given
let source = r#"{"model": "opus", "unknownField": true}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "unknownField");
assert!(matches!(
result.errors[0].kind,
DiagnosticKind::UnknownKey { .. }
));
}
#[test]
fn detects_wrong_type_for_model() {
// given
let source = r#"{"model": 123}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "model");
assert!(matches!(
result.errors[0].kind,
DiagnosticKind::WrongType {
expected: "a string",
got: "a number"
}
));
}
#[test]
fn detects_deprecated_permission_mode() {
// given
let source = r#"{"permissionMode": "plan"}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.warnings.len(), 1);
assert_eq!(result.warnings[0].field, "permissionMode");
assert!(matches!(
result.warnings[0].kind,
DiagnosticKind::Deprecated {
replacement: "permissions.defaultMode"
}
));
}
#[test]
fn detects_deprecated_enabled_plugins() {
// given
let source = r#"{"enabledPlugins": {"tool-guard@builtin": true}}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.warnings.len(), 1);
assert_eq!(result.warnings[0].field, "enabledPlugins");
assert!(matches!(
result.warnings[0].kind,
DiagnosticKind::Deprecated {
replacement: "plugins.enabled"
}
));
}
#[test]
fn reports_line_number_for_unknown_key() {
// given
let source = "{\n \"model\": \"opus\",\n \"badKey\": true\n}";
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].line, Some(3));
assert_eq!(result.errors[0].field, "badKey");
}
#[test]
fn reports_line_number_for_wrong_type() {
// given
let source = "{\n \"model\": 42\n}";
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].line, Some(2));
}
#[test]
fn validates_nested_hooks_keys() {
// given
let source = r#"{"hooks": {"PreToolUse": ["cmd"], "BadHook": ["x"]}}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "hooks.BadHook");
}
#[test]
fn validates_nested_permissions_keys() {
// given
let source = r#"{"permissions": {"allow": ["Read"], "denyAll": true}}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "permissions.denyAll");
}
#[test]
fn validates_nested_sandbox_keys() {
// given
let source = r#"{"sandbox": {"enabled": true, "containerMode": "strict"}}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "sandbox.containerMode");
}
#[test]
fn validates_nested_plugins_keys() {
// given
let source = r#"{"plugins": {"installRoot": "/tmp", "autoUpdate": true}}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "plugins.autoUpdate");
}
#[test]
fn validates_nested_oauth_keys() {
// given
let source = r#"{"oauth": {"clientId": "abc", "secret": "hidden"}}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "oauth.secret");
}
#[test]
fn valid_config_produces_no_diagnostics() {
// given
let source = r#"{
"model": "opus",
"hooks": {"PreToolUse": ["guard"]},
"permissions": {"defaultMode": "plan", "allow": ["Read"]},
"mcpServers": {},
"sandbox": {"enabled": false}
}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert!(result.is_ok());
assert!(result.warnings.is_empty());
}
#[test]
fn suggests_close_field_name() {
// given
let source = r#"{"modle": "opus"}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
match &result.errors[0].kind {
DiagnosticKind::UnknownKey {
suggestion: Some(s),
} => assert_eq!(s, "model"),
other => panic!("expected suggestion, got {other:?}"),
}
}
#[test]
fn format_diagnostics_includes_all_entries() {
// given
let source = r#"{"permissionMode": "plan", "badKey": 1}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
let result = validate_config_file(object, source, &test_path());
// when
let output = format_diagnostics(&result);
// then
assert!(output.contains("warning:"));
assert!(output.contains("error:"));
assert!(output.contains("badKey"));
assert!(output.contains("permissionMode"));
}
#[test]
fn check_unsupported_format_rejects_toml() {
// given
let path = PathBuf::from("/home/.claw/settings.toml");
// when
let result = check_unsupported_format(&path);
// then
assert!(result.is_err());
let message = result.unwrap_err().to_string();
assert!(message.contains("TOML"));
assert!(message.contains("settings.toml"));
}
#[test]
fn check_unsupported_format_allows_json() {
// given
let path = PathBuf::from("/home/.claw/settings.json");
// when / then
assert!(check_unsupported_format(&path).is_ok());
}
#[test]
fn wrong_type_in_nested_sandbox_field() {
// given
let source = r#"{"sandbox": {"enabled": "yes"}}"#;
let parsed = JsonValue::parse(source).expect("valid json");
let object = parsed.as_object().expect("object");
// when
let result = validate_config_file(object, source, &test_path());
// then
assert_eq!(result.errors.len(), 1);
assert_eq!(result.errors[0].field, "sandbox.enabled");
assert!(matches!(
result.errors[0].kind,
DiagnosticKind::WrongType {
expected: "a boolean",
got: "a string"
}
));
}
#[test]
fn display_format_unknown_key_with_line() {
// given
let diag = ConfigDiagnostic {
path: "/test/settings.json".to_string(),
field: "badKey".to_string(),
line: Some(5),
kind: DiagnosticKind::UnknownKey { suggestion: None },
};
// when
let output = diag.to_string();
// then
assert_eq!(
output,
r#"/test/settings.json: unknown key "badKey" (line 5)"#
);
}
#[test]
fn display_format_wrong_type_with_line() {
// given
let diag = ConfigDiagnostic {
path: "/test/settings.json".to_string(),
field: "model".to_string(),
line: Some(2),
kind: DiagnosticKind::WrongType {
expected: "a string",
got: "a number",
},
};
// when
let output = diag.to_string();
// then
assert_eq!(
output,
r#"/test/settings.json: field "model" must be a string, got a number (line 2)"#
);
}
#[test]
fn display_format_deprecated_with_line() {
// given
let diag = ConfigDiagnostic {
path: "/test/settings.json".to_string(),
field: "permissionMode".to_string(),
line: Some(3),
kind: DiagnosticKind::Deprecated {
replacement: "permissions.defaultMode",
},
};
// when
let output = diag.to_string();
// then
assert_eq!(
output,
r#"/test/settings.json: field "permissionMode" is deprecated (line 3). Use "permissions.defaultMode" instead"#
);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -9,6 +9,41 @@ use regex::RegexBuilder;
use serde::{Deserialize, Serialize};
use walkdir::WalkDir;
/// Maximum file size that can be read (10 MB).
const MAX_READ_SIZE: u64 = 10 * 1024 * 1024;
/// Maximum file size that can be written (10 MB).
const MAX_WRITE_SIZE: usize = 10 * 1024 * 1024;
/// Check whether a file appears to contain binary content by examining
/// the first chunk for NUL bytes.
fn is_binary_file(path: &Path) -> io::Result<bool> {
use std::io::Read;
let mut file = fs::File::open(path)?;
let mut buffer = [0u8; 8192];
let bytes_read = file.read(&mut buffer)?;
Ok(buffer[..bytes_read].contains(&0))
}
/// Validate that a resolved path stays within the given workspace root.
/// Returns the canonical path on success, or an error if the path escapes
/// the workspace boundary (e.g. via `../` traversal or symlink).
#[allow(dead_code)]
fn validate_workspace_boundary(resolved: &Path, workspace_root: &Path) -> io::Result<()> {
if !resolved.starts_with(workspace_root) {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!(
"path {} escapes workspace boundary {}",
resolved.display(),
workspace_root.display()
),
));
}
Ok(())
}
/// Text payload returned by file-reading operations.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TextFilePayload {
#[serde(rename = "filePath")]
@ -22,6 +57,7 @@ pub struct TextFilePayload {
pub total_lines: usize,
}
/// Output envelope for the `read_file` tool.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ReadFileOutput {
#[serde(rename = "type")]
@ -29,6 +65,7 @@ pub struct ReadFileOutput {
pub file: TextFilePayload,
}
/// Structured patch hunk emitted by write and edit operations.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct StructuredPatchHunk {
#[serde(rename = "oldStart")]
@ -42,6 +79,7 @@ pub struct StructuredPatchHunk {
pub lines: Vec<String>,
}
/// Output envelope for full-file write operations.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct WriteFileOutput {
#[serde(rename = "type")]
@ -57,6 +95,7 @@ pub struct WriteFileOutput {
pub git_diff: Option<serde_json::Value>,
}
/// Output envelope for targeted string-replacement edits.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct EditFileOutput {
#[serde(rename = "filePath")]
@ -77,6 +116,7 @@ pub struct EditFileOutput {
pub git_diff: Option<serde_json::Value>,
}
/// Result of a glob-based filename search.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GlobSearchOutput {
#[serde(rename = "durationMs")]
@ -87,6 +127,7 @@ pub struct GlobSearchOutput {
pub truncated: bool,
}
/// Parameters accepted by the grep-style search tool.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct GrepSearchInput {
pub pattern: String,
@ -112,6 +153,7 @@ pub struct GrepSearchInput {
pub multiline: Option<bool>,
}
/// Result payload returned by the grep-style search tool.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct GrepSearchOutput {
pub mode: Option<String>,
@ -129,12 +171,35 @@ pub struct GrepSearchOutput {
pub applied_offset: Option<usize>,
}
/// Reads a text file and returns a line-windowed payload.
pub fn read_file(
path: &str,
offset: Option<usize>,
limit: Option<usize>,
) -> io::Result<ReadFileOutput> {
let absolute_path = normalize_path(path)?;
// Check file size before reading
let metadata = fs::metadata(&absolute_path)?;
if metadata.len() > MAX_READ_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"file is too large ({} bytes, max {} bytes)",
metadata.len(),
MAX_READ_SIZE
),
));
}
// Detect binary files
if is_binary_file(&absolute_path)? {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"file appears to be binary",
));
}
let content = fs::read_to_string(&absolute_path)?;
let lines: Vec<&str> = content.lines().collect();
let start_index = offset.unwrap_or(0).min(lines.len());
@ -155,7 +220,19 @@ pub fn read_file(
})
}
/// Replaces a file's contents and returns patch metadata.
pub fn write_file(path: &str, content: &str) -> io::Result<WriteFileOutput> {
if content.len() > MAX_WRITE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"content is too large ({} bytes, max {} bytes)",
content.len(),
MAX_WRITE_SIZE
),
));
}
let absolute_path = normalize_path_allow_missing(path)?;
let original_file = fs::read_to_string(&absolute_path).ok();
if let Some(parent) = absolute_path.parent() {
@ -177,6 +254,7 @@ pub fn write_file(path: &str, content: &str) -> io::Result<WriteFileOutput> {
})
}
/// Performs an in-file string replacement and returns patch metadata.
pub fn edit_file(
path: &str,
old_string: &str,
@ -217,6 +295,7 @@ pub fn edit_file(
})
}
/// Expands a glob pattern and returns matching filenames.
pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result<GlobSearchOutput> {
let started = Instant::now();
let base_dir = path
@ -229,14 +308,22 @@ pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result<GlobSearchOu
base_dir.join(pattern).to_string_lossy().into_owned()
};
// The `glob` crate does not support brace expansion ({a,b,c}).
// Expand braces into multiple patterns so patterns like
// `Assets/**/*.{cs,uxml,uss}` work correctly.
let expanded = expand_braces(&search_pattern);
let mut seen = std::collections::HashSet::new();
let mut matches = Vec::new();
let entries = glob::glob(&search_pattern)
for pat in &expanded {
let entries = glob::glob(pat)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error.to_string()))?;
for entry in entries.flatten() {
if entry.is_file() {
if entry.is_file() && seen.insert(entry.clone()) {
matches.push(entry);
}
}
}
matches.sort_by_key(|path| {
fs::metadata(path)
@ -260,6 +347,7 @@ pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result<GlobSearchOu
})
}
/// Runs a regex search over workspace files with optional context lines.
pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
let base_path = input
.path
@ -477,18 +565,105 @@ fn normalize_path_allow_missing(path: &str) -> io::Result<PathBuf> {
Ok(candidate)
}
/// Read a file with workspace boundary enforcement.
#[allow(dead_code)]
pub fn read_file_in_workspace(
path: &str,
offset: Option<usize>,
limit: Option<usize>,
workspace_root: &Path,
) -> io::Result<ReadFileOutput> {
let absolute_path = normalize_path(path)?;
let canonical_root = workspace_root
.canonicalize()
.unwrap_or_else(|_| workspace_root.to_path_buf());
validate_workspace_boundary(&absolute_path, &canonical_root)?;
read_file(path, offset, limit)
}
/// Write a file with workspace boundary enforcement.
#[allow(dead_code)]
pub fn write_file_in_workspace(
path: &str,
content: &str,
workspace_root: &Path,
) -> io::Result<WriteFileOutput> {
let absolute_path = normalize_path_allow_missing(path)?;
let canonical_root = workspace_root
.canonicalize()
.unwrap_or_else(|_| workspace_root.to_path_buf());
validate_workspace_boundary(&absolute_path, &canonical_root)?;
write_file(path, content)
}
/// Edit a file with workspace boundary enforcement.
#[allow(dead_code)]
pub fn edit_file_in_workspace(
path: &str,
old_string: &str,
new_string: &str,
replace_all: bool,
workspace_root: &Path,
) -> io::Result<EditFileOutput> {
let absolute_path = normalize_path(path)?;
let canonical_root = workspace_root
.canonicalize()
.unwrap_or_else(|_| workspace_root.to_path_buf());
validate_workspace_boundary(&absolute_path, &canonical_root)?;
edit_file(path, old_string, new_string, replace_all)
}
/// Check whether a path is a symlink that resolves outside the workspace.
#[allow(dead_code)]
pub fn is_symlink_escape(path: &Path, workspace_root: &Path) -> io::Result<bool> {
let metadata = fs::symlink_metadata(path)?;
if !metadata.is_symlink() {
return Ok(false);
}
let resolved = path.canonicalize()?;
let canonical_root = workspace_root
.canonicalize()
.unwrap_or_else(|_| workspace_root.to_path_buf());
Ok(!resolved.starts_with(&canonical_root))
}
/// Expand shell-style brace groups in a glob pattern.
///
/// Handles one level of braces: `foo.{a,b,c}` → `["foo.a", "foo.b", "foo.c"]`.
/// Nested braces are not expanded (uncommon in practice).
/// Patterns without braces pass through unchanged.
fn expand_braces(pattern: &str) -> Vec<String> {
let Some(open) = pattern.find('{') else {
return vec![pattern.to_owned()];
};
let Some(close) = pattern[open..].find('}').map(|i| open + i) else {
// Unmatched brace — treat as literal.
return vec![pattern.to_owned()];
};
let prefix = &pattern[..open];
let suffix = &pattern[close + 1..];
let alternatives = &pattern[open + 1..close];
alternatives
.split(',')
.flat_map(|alt| expand_braces(&format!("{prefix}{alt}{suffix}")))
.collect()
}
#[cfg(test)]
mod tests {
use std::time::{SystemTime, UNIX_EPOCH};
use super::{edit_file, glob_search, grep_search, read_file, write_file, GrepSearchInput};
use super::{
edit_file, expand_braces, glob_search, grep_search, is_symlink_escape, read_file,
read_file_in_workspace, write_file, GrepSearchInput, MAX_WRITE_SIZE,
};
fn temp_path(name: &str) -> std::path::PathBuf {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should move forward")
.as_nanos();
std::env::temp_dir().join(format!("claw-native-{name}-{unique}"))
std::env::temp_dir().join(format!("clawd-native-{name}-{unique}"))
}
#[test]
@ -513,6 +688,73 @@ mod tests {
assert!(output.replace_all);
}
#[test]
fn rejects_binary_files() {
let path = temp_path("binary-test.bin");
std::fs::write(&path, b"\x00\x01\x02\x03binary content").expect("write should succeed");
let result = read_file(path.to_string_lossy().as_ref(), None, None);
assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error.kind(), std::io::ErrorKind::InvalidData);
assert!(error.to_string().contains("binary"));
}
#[test]
fn rejects_oversized_writes() {
let path = temp_path("oversize-write.txt");
let huge = "x".repeat(MAX_WRITE_SIZE + 1);
let result = write_file(path.to_string_lossy().as_ref(), &huge);
assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error.kind(), std::io::ErrorKind::InvalidData);
assert!(error.to_string().contains("too large"));
}
#[test]
fn enforces_workspace_boundary() {
let workspace = temp_path("workspace-boundary");
std::fs::create_dir_all(&workspace).expect("workspace dir should be created");
let inside = workspace.join("inside.txt");
write_file(inside.to_string_lossy().as_ref(), "safe content")
.expect("write inside workspace should succeed");
// Reading inside workspace should succeed
let result =
read_file_in_workspace(inside.to_string_lossy().as_ref(), None, None, &workspace);
assert!(result.is_ok());
// Reading outside workspace should fail
let outside = temp_path("outside-boundary.txt");
write_file(outside.to_string_lossy().as_ref(), "unsafe content")
.expect("write outside should succeed");
let result =
read_file_in_workspace(outside.to_string_lossy().as_ref(), None, None, &workspace);
assert!(result.is_err());
let error = result.unwrap_err();
assert_eq!(error.kind(), std::io::ErrorKind::PermissionDenied);
assert!(error.to_string().contains("escapes workspace"));
}
#[test]
fn detects_symlink_escape() {
let workspace = temp_path("symlink-workspace");
std::fs::create_dir_all(&workspace).expect("workspace dir should be created");
let outside = temp_path("symlink-target.txt");
std::fs::write(&outside, "target content").expect("target should write");
let _link_path = workspace.join("escape-link.txt");
#[cfg(unix)]
{
std::os::unix::fs::symlink(&outside, &link_path).expect("symlink should create");
assert!(is_symlink_escape(&link_path, &workspace).expect("check should succeed"));
}
// Non-symlink file should not be an escape
let normal = workspace.join("normal.txt");
std::fs::write(&normal, "normal content").expect("normal file should write");
assert!(!is_symlink_escape(&normal, &workspace).expect("check should succeed"));
}
#[test]
fn globs_and_greps_directory() {
let dir = temp_path("search-dir");
@ -547,4 +789,51 @@ mod tests {
.expect("grep should succeed");
assert!(grep_output.content.unwrap_or_default().contains("hello"));
}
#[test]
fn expand_braces_no_braces() {
assert_eq!(expand_braces("*.rs"), vec!["*.rs"]);
}
#[test]
fn expand_braces_single_group() {
let mut result = expand_braces("Assets/**/*.{cs,uxml,uss}");
result.sort();
assert_eq!(
result,
vec!["Assets/**/*.cs", "Assets/**/*.uss", "Assets/**/*.uxml",]
);
}
#[test]
fn expand_braces_nested() {
let mut result = expand_braces("src/{a,b}.{rs,toml}");
result.sort();
assert_eq!(
result,
vec!["src/a.rs", "src/a.toml", "src/b.rs", "src/b.toml"]
);
}
#[test]
fn expand_braces_unmatched() {
assert_eq!(expand_braces("foo.{bar"), vec!["foo.{bar"]);
}
#[test]
fn glob_search_with_braces_finds_files() {
let dir = temp_path("glob-braces");
std::fs::create_dir_all(&dir).unwrap();
std::fs::write(dir.join("a.rs"), "fn main() {}").unwrap();
std::fs::write(dir.join("b.toml"), "[package]").unwrap();
std::fs::write(dir.join("c.txt"), "hello").unwrap();
let result =
glob_search("*.{rs,toml}", Some(dir.to_str().unwrap())).expect("glob should succeed");
assert_eq!(
result.num_files, 2,
"should match .rs and .toml but not .txt"
);
let _ = std::fs::remove_dir_all(&dir);
}
}

View File

@ -0,0 +1,324 @@
use std::path::Path;
use std::process::Command;
/// A single git commit entry from the log.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GitCommitEntry {
pub hash: String,
pub subject: String,
}
/// Git-aware context gathered at startup for injection into the system prompt.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GitContext {
pub branch: Option<String>,
pub recent_commits: Vec<GitCommitEntry>,
pub staged_files: Vec<String>,
}
const MAX_RECENT_COMMITS: usize = 5;
impl GitContext {
/// Detect the git context from the given working directory.
///
/// Returns `None` when the directory is not inside a git repository.
#[must_use]
pub fn detect(cwd: &Path) -> Option<Self> {
// Quick gate: is this a git repo at all?
let rev_parse = Command::new("git")
.args(["rev-parse", "--is-inside-work-tree"])
.current_dir(cwd)
.output()
.ok()?;
if !rev_parse.status.success() {
return None;
}
Some(Self {
branch: read_branch(cwd),
recent_commits: read_recent_commits(cwd),
staged_files: read_staged_files(cwd),
})
}
/// Render a human-readable summary suitable for system-prompt injection.
#[must_use]
pub fn render(&self) -> String {
let mut lines = Vec::new();
if let Some(branch) = &self.branch {
lines.push(format!("Git branch: {branch}"));
}
if !self.recent_commits.is_empty() {
lines.push(String::new());
lines.push("Recent commits:".to_string());
for entry in &self.recent_commits {
lines.push(format!(" {} {}", entry.hash, entry.subject));
}
}
if !self.staged_files.is_empty() {
lines.push(String::new());
lines.push("Staged files:".to_string());
for file in &self.staged_files {
lines.push(format!(" {file}"));
}
}
lines.join("\n")
}
}
fn read_branch(cwd: &Path) -> Option<String> {
let output = Command::new("git")
.args(["rev-parse", "--abbrev-ref", "HEAD"])
.current_dir(cwd)
.output()
.ok()?;
if !output.status.success() {
return None;
}
let branch = String::from_utf8(output.stdout).ok()?;
let trimmed = branch.trim();
if trimmed.is_empty() || trimmed == "HEAD" {
None
} else {
Some(trimmed.to_string())
}
}
fn read_recent_commits(cwd: &Path) -> Vec<GitCommitEntry> {
let output = Command::new("git")
.args([
"--no-optional-locks",
"log",
"--oneline",
"-n",
&MAX_RECENT_COMMITS.to_string(),
"--no-decorate",
])
.current_dir(cwd)
.output()
.ok();
let Some(output) = output else {
return Vec::new();
};
if !output.status.success() {
return Vec::new();
}
let stdout = String::from_utf8(output.stdout).unwrap_or_default();
stdout
.lines()
.filter_map(|line| {
let line = line.trim();
if line.is_empty() {
return None;
}
let (hash, subject) = line.split_once(' ')?;
Some(GitCommitEntry {
hash: hash.to_string(),
subject: subject.to_string(),
})
})
.collect()
}
fn read_staged_files(cwd: &Path) -> Vec<String> {
let output = Command::new("git")
.args(["--no-optional-locks", "diff", "--cached", "--name-only"])
.current_dir(cwd)
.output()
.ok();
let Some(output) = output else {
return Vec::new();
};
if !output.status.success() {
return Vec::new();
}
let stdout = String::from_utf8(output.stdout).unwrap_or_default();
stdout
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| line.trim().to_string())
.collect()
}
#[cfg(test)]
mod tests {
use super::{GitCommitEntry, GitContext};
use std::fs;
use std::process::Command;
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_dir(label: &str) -> std::path::PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("runtime-git-context-{label}-{nanos}"))
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
crate::test_env_lock()
}
fn ensure_valid_cwd() {
if std::env::current_dir().is_err() {
std::env::set_current_dir(env!("CARGO_MANIFEST_DIR"))
.expect("test cwd should be recoverable");
}
}
#[test]
fn returns_none_for_non_git_directory() {
// given
let _guard = env_lock();
ensure_valid_cwd();
let root = temp_dir("non-git");
fs::create_dir_all(&root).expect("create dir");
// when
let context = GitContext::detect(&root);
// then
assert!(context.is_none());
fs::remove_dir_all(root).expect("cleanup");
}
#[test]
fn detects_branch_name_and_commits() {
// given
let _guard = env_lock();
ensure_valid_cwd();
let root = temp_dir("branch-commits");
fs::create_dir_all(&root).expect("create dir");
git(&root, &["init", "--quiet", "--initial-branch=main"]);
git(&root, &["config", "user.email", "tests@example.com"]);
git(&root, &["config", "user.name", "Git Context Tests"]);
fs::write(root.join("a.txt"), "a\n").expect("write a");
git(&root, &["add", "a.txt"]);
git(&root, &["commit", "-m", "first commit", "--quiet"]);
fs::write(root.join("b.txt"), "b\n").expect("write b");
git(&root, &["add", "b.txt"]);
git(&root, &["commit", "-m", "second commit", "--quiet"]);
// when
let context = GitContext::detect(&root).expect("should detect git repo");
// then
assert_eq!(context.branch.as_deref(), Some("main"));
assert_eq!(context.recent_commits.len(), 2);
assert_eq!(context.recent_commits[0].subject, "second commit");
assert_eq!(context.recent_commits[1].subject, "first commit");
assert!(context.staged_files.is_empty());
fs::remove_dir_all(root).expect("cleanup");
}
#[test]
fn detects_staged_files() {
// given
let _guard = env_lock();
ensure_valid_cwd();
let root = temp_dir("staged");
fs::create_dir_all(&root).expect("create dir");
git(&root, &["init", "--quiet", "--initial-branch=main"]);
git(&root, &["config", "user.email", "tests@example.com"]);
git(&root, &["config", "user.name", "Git Context Tests"]);
fs::write(root.join("init.txt"), "init\n").expect("write init");
git(&root, &["add", "init.txt"]);
git(&root, &["commit", "-m", "initial", "--quiet"]);
fs::write(root.join("staged.txt"), "staged\n").expect("write staged");
git(&root, &["add", "staged.txt"]);
// when
let context = GitContext::detect(&root).expect("should detect git repo");
// then
assert_eq!(context.staged_files, vec!["staged.txt"]);
fs::remove_dir_all(root).expect("cleanup");
}
#[test]
fn render_formats_all_sections() {
// given
let context = GitContext {
branch: Some("feat/test".to_string()),
recent_commits: vec![
GitCommitEntry {
hash: "abc1234".to_string(),
subject: "add feature".to_string(),
},
GitCommitEntry {
hash: "def5678".to_string(),
subject: "fix bug".to_string(),
},
],
staged_files: vec!["src/main.rs".to_string()],
};
// when
let rendered = context.render();
// then
assert!(rendered.contains("Git branch: feat/test"));
assert!(rendered.contains("abc1234 add feature"));
assert!(rendered.contains("def5678 fix bug"));
assert!(rendered.contains("src/main.rs"));
}
#[test]
fn render_omits_empty_sections() {
// given
let context = GitContext {
branch: Some("main".to_string()),
recent_commits: Vec::new(),
staged_files: Vec::new(),
};
// when
let rendered = context.render();
// then
assert!(rendered.contains("Git branch: main"));
assert!(!rendered.contains("Recent commits:"));
assert!(!rendered.contains("Staged files:"));
}
#[test]
fn limits_to_five_recent_commits() {
// given
let _guard = env_lock();
ensure_valid_cwd();
let root = temp_dir("five-commits");
fs::create_dir_all(&root).expect("create dir");
git(&root, &["init", "--quiet", "--initial-branch=main"]);
git(&root, &["config", "user.email", "tests@example.com"]);
git(&root, &["config", "user.name", "Git Context Tests"]);
for i in 1..=8 {
let name = format!("file{i}.txt");
fs::write(root.join(&name), format!("{i}\n")).expect("write file");
git(&root, &["add", &name]);
git(&root, &["commit", "-m", &format!("commit {i}"), "--quiet"]);
}
// when
let context = GitContext::detect(&root).expect("should detect git repo");
// then
assert_eq!(context.recent_commits.len(), 5);
assert_eq!(context.recent_commits[0].subject, "commit 8");
assert_eq!(context.recent_commits[4].subject, "commit 4");
fs::remove_dir_all(root).expect("cleanup");
}
fn git(cwd: &std::path::Path, args: &[&str]) {
let status = Command::new("git")
.args(args)
.current_dir(cwd)
.output()
.unwrap_or_else(|_| panic!("git {args:?} should run"))
.status;
assert!(status.success(), "git {args:?} failed");
}
}

View File

@ -0,0 +1,152 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GreenLevel {
TargetedTests,
Package,
Workspace,
MergeReady,
}
impl GreenLevel {
#[must_use]
pub fn as_str(self) -> &'static str {
match self {
Self::TargetedTests => "targeted_tests",
Self::Package => "package",
Self::Workspace => "workspace",
Self::MergeReady => "merge_ready",
}
}
}
impl std::fmt::Display for GreenLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct GreenContract {
pub required_level: GreenLevel,
}
impl GreenContract {
#[must_use]
pub fn new(required_level: GreenLevel) -> Self {
Self { required_level }
}
#[must_use]
pub fn evaluate(self, observed_level: Option<GreenLevel>) -> GreenContractOutcome {
match observed_level {
Some(level) if level >= self.required_level => GreenContractOutcome::Satisfied {
required_level: self.required_level,
observed_level: level,
},
_ => GreenContractOutcome::Unsatisfied {
required_level: self.required_level,
observed_level,
},
}
}
#[must_use]
pub fn is_satisfied_by(self, observed_level: GreenLevel) -> bool {
observed_level >= self.required_level
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "outcome", rename_all = "snake_case")]
pub enum GreenContractOutcome {
Satisfied {
required_level: GreenLevel,
observed_level: GreenLevel,
},
Unsatisfied {
required_level: GreenLevel,
observed_level: Option<GreenLevel>,
},
}
impl GreenContractOutcome {
#[must_use]
pub fn is_satisfied(&self) -> bool {
matches!(self, Self::Satisfied { .. })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn given_matching_level_when_evaluating_contract_then_it_is_satisfied() {
// given
let contract = GreenContract::new(GreenLevel::Package);
// when
let outcome = contract.evaluate(Some(GreenLevel::Package));
// then
assert_eq!(
outcome,
GreenContractOutcome::Satisfied {
required_level: GreenLevel::Package,
observed_level: GreenLevel::Package,
}
);
assert!(outcome.is_satisfied());
}
#[test]
fn given_higher_level_when_checking_requirement_then_it_still_satisfies_contract() {
// given
let contract = GreenContract::new(GreenLevel::TargetedTests);
// when
let is_satisfied = contract.is_satisfied_by(GreenLevel::Workspace);
// then
assert!(is_satisfied);
}
#[test]
fn given_lower_level_when_evaluating_contract_then_it_is_unsatisfied() {
// given
let contract = GreenContract::new(GreenLevel::Workspace);
// when
let outcome = contract.evaluate(Some(GreenLevel::Package));
// then
assert_eq!(
outcome,
GreenContractOutcome::Unsatisfied {
required_level: GreenLevel::Workspace,
observed_level: Some(GreenLevel::Package),
}
);
assert!(!outcome.is_satisfied());
}
#[test]
fn given_no_green_level_when_evaluating_contract_then_contract_is_unsatisfied() {
// given
let contract = GreenContract::new(GreenLevel::MergeReady);
// when
let outcome = contract.evaluate(None);
// then
assert_eq!(
outcome,
GreenContractOutcome::Unsatisfied {
required_level: GreenLevel::MergeReady,
observed_level: None,
}
);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,7 @@
use std::collections::BTreeMap;
use std::fmt::{Display, Formatter};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum JsonValue {
Null,
Bool(bool),

View File

@ -0,0 +1,383 @@
#![allow(clippy::similar_names)]
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum LaneEventName {
#[serde(rename = "lane.started")]
Started,
#[serde(rename = "lane.ready")]
Ready,
#[serde(rename = "lane.prompt_misdelivery")]
PromptMisdelivery,
#[serde(rename = "lane.blocked")]
Blocked,
#[serde(rename = "lane.red")]
Red,
#[serde(rename = "lane.green")]
Green,
#[serde(rename = "lane.commit.created")]
CommitCreated,
#[serde(rename = "lane.pr.opened")]
PrOpened,
#[serde(rename = "lane.merge.ready")]
MergeReady,
#[serde(rename = "lane.finished")]
Finished,
#[serde(rename = "lane.failed")]
Failed,
#[serde(rename = "lane.reconciled")]
Reconciled,
#[serde(rename = "lane.merged")]
Merged,
#[serde(rename = "lane.superseded")]
Superseded,
#[serde(rename = "lane.closed")]
Closed,
#[serde(rename = "branch.stale_against_main")]
BranchStaleAgainstMain,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LaneEventStatus {
Running,
Ready,
Blocked,
Red,
Green,
Completed,
Failed,
Reconciled,
Merged,
Superseded,
Closed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LaneFailureClass {
PromptDelivery,
TrustGate,
BranchDivergence,
Compile,
Test,
PluginStartup,
McpStartup,
McpHandshake,
GatewayRouting,
ToolRuntime,
Infra,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LaneEventBlocker {
#[serde(rename = "failureClass")]
pub failure_class: LaneFailureClass,
pub detail: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LaneCommitProvenance {
pub commit: String,
pub branch: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub worktree: Option<String>,
#[serde(rename = "canonicalCommit", skip_serializing_if = "Option::is_none")]
pub canonical_commit: Option<String>,
#[serde(rename = "supersededBy", skip_serializing_if = "Option::is_none")]
pub superseded_by: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub lineage: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LaneEvent {
pub event: LaneEventName,
pub status: LaneEventStatus,
#[serde(rename = "emittedAt")]
pub emitted_at: String,
#[serde(rename = "failureClass", skip_serializing_if = "Option::is_none")]
pub failure_class: Option<LaneFailureClass>,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
impl LaneEvent {
#[must_use]
pub fn new(
event: LaneEventName,
status: LaneEventStatus,
emitted_at: impl Into<String>,
) -> Self {
Self {
event,
status,
emitted_at: emitted_at.into(),
failure_class: None,
detail: None,
data: None,
}
}
#[must_use]
pub fn started(emitted_at: impl Into<String>) -> Self {
Self::new(LaneEventName::Started, LaneEventStatus::Running, emitted_at)
}
#[must_use]
pub fn finished(emitted_at: impl Into<String>, detail: Option<String>) -> Self {
Self::new(
LaneEventName::Finished,
LaneEventStatus::Completed,
emitted_at,
)
.with_optional_detail(detail)
}
#[must_use]
pub fn commit_created(
emitted_at: impl Into<String>,
detail: Option<String>,
provenance: LaneCommitProvenance,
) -> Self {
Self::new(
LaneEventName::CommitCreated,
LaneEventStatus::Completed,
emitted_at,
)
.with_optional_detail(detail)
.with_data(serde_json::to_value(provenance).expect("commit provenance should serialize"))
}
#[must_use]
pub fn superseded(
emitted_at: impl Into<String>,
detail: Option<String>,
provenance: LaneCommitProvenance,
) -> Self {
Self::new(
LaneEventName::Superseded,
LaneEventStatus::Superseded,
emitted_at,
)
.with_optional_detail(detail)
.with_data(serde_json::to_value(provenance).expect("commit provenance should serialize"))
}
#[must_use]
pub fn blocked(emitted_at: impl Into<String>, blocker: &LaneEventBlocker) -> Self {
Self::new(LaneEventName::Blocked, LaneEventStatus::Blocked, emitted_at)
.with_failure_class(blocker.failure_class)
.with_detail(blocker.detail.clone())
}
#[must_use]
pub fn failed(emitted_at: impl Into<String>, blocker: &LaneEventBlocker) -> Self {
Self::new(LaneEventName::Failed, LaneEventStatus::Failed, emitted_at)
.with_failure_class(blocker.failure_class)
.with_detail(blocker.detail.clone())
}
#[must_use]
pub fn with_failure_class(mut self, failure_class: LaneFailureClass) -> Self {
self.failure_class = Some(failure_class);
self
}
#[must_use]
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
self.detail = Some(detail.into());
self
}
#[must_use]
pub fn with_optional_detail(mut self, detail: Option<String>) -> Self {
self.detail = detail;
self
}
#[must_use]
pub fn with_data(mut self, data: Value) -> Self {
self.data = Some(data);
self
}
}
#[must_use]
pub fn dedupe_superseded_commit_events(events: &[LaneEvent]) -> Vec<LaneEvent> {
let mut keep = vec![true; events.len()];
let mut latest_by_key = std::collections::BTreeMap::<String, usize>::new();
for (index, event) in events.iter().enumerate() {
if event.event != LaneEventName::CommitCreated {
continue;
}
let Some(data) = event.data.as_ref() else {
continue;
};
let key = data
.get("canonicalCommit")
.or_else(|| data.get("commit"))
.and_then(serde_json::Value::as_str)
.map(str::to_string);
let superseded = data
.get("supersededBy")
.and_then(serde_json::Value::as_str)
.is_some();
if superseded {
keep[index] = false;
continue;
}
if let Some(key) = key {
if let Some(previous) = latest_by_key.insert(key, index) {
keep[previous] = false;
}
}
}
events
.iter()
.cloned()
.zip(keep)
.filter_map(|(event, retain)| retain.then_some(event))
.collect()
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{
dedupe_superseded_commit_events, LaneCommitProvenance, LaneEvent, LaneEventBlocker,
LaneEventName, LaneEventStatus, LaneFailureClass,
};
#[test]
fn canonical_lane_event_names_serialize_to_expected_wire_values() {
let cases = [
(LaneEventName::Started, "lane.started"),
(LaneEventName::Ready, "lane.ready"),
(LaneEventName::PromptMisdelivery, "lane.prompt_misdelivery"),
(LaneEventName::Blocked, "lane.blocked"),
(LaneEventName::Red, "lane.red"),
(LaneEventName::Green, "lane.green"),
(LaneEventName::CommitCreated, "lane.commit.created"),
(LaneEventName::PrOpened, "lane.pr.opened"),
(LaneEventName::MergeReady, "lane.merge.ready"),
(LaneEventName::Finished, "lane.finished"),
(LaneEventName::Failed, "lane.failed"),
(LaneEventName::Reconciled, "lane.reconciled"),
(LaneEventName::Merged, "lane.merged"),
(LaneEventName::Superseded, "lane.superseded"),
(LaneEventName::Closed, "lane.closed"),
(
LaneEventName::BranchStaleAgainstMain,
"branch.stale_against_main",
),
];
for (event, expected) in cases {
assert_eq!(
serde_json::to_value(event).expect("serialize event"),
json!(expected)
);
}
}
#[test]
fn failure_classes_cover_canonical_taxonomy_wire_values() {
let cases = [
(LaneFailureClass::PromptDelivery, "prompt_delivery"),
(LaneFailureClass::TrustGate, "trust_gate"),
(LaneFailureClass::BranchDivergence, "branch_divergence"),
(LaneFailureClass::Compile, "compile"),
(LaneFailureClass::Test, "test"),
(LaneFailureClass::PluginStartup, "plugin_startup"),
(LaneFailureClass::McpStartup, "mcp_startup"),
(LaneFailureClass::McpHandshake, "mcp_handshake"),
(LaneFailureClass::GatewayRouting, "gateway_routing"),
(LaneFailureClass::ToolRuntime, "tool_runtime"),
(LaneFailureClass::Infra, "infra"),
];
for (failure_class, expected) in cases {
assert_eq!(
serde_json::to_value(failure_class).expect("serialize failure class"),
json!(expected)
);
}
}
#[test]
fn blocked_and_failed_events_reuse_blocker_details() {
let blocker = LaneEventBlocker {
failure_class: LaneFailureClass::McpStartup,
detail: "broken server".to_string(),
};
let blocked = LaneEvent::blocked("2026-04-04T00:00:00Z", &blocker);
let failed = LaneEvent::failed("2026-04-04T00:00:01Z", &blocker);
assert_eq!(blocked.event, LaneEventName::Blocked);
assert_eq!(blocked.status, LaneEventStatus::Blocked);
assert_eq!(blocked.failure_class, Some(LaneFailureClass::McpStartup));
assert_eq!(failed.event, LaneEventName::Failed);
assert_eq!(failed.status, LaneEventStatus::Failed);
assert_eq!(failed.detail.as_deref(), Some("broken server"));
}
#[test]
fn commit_events_can_carry_worktree_and_supersession_metadata() {
let event = LaneEvent::commit_created(
"2026-04-04T00:00:00Z",
Some("commit created".to_string()),
LaneCommitProvenance {
commit: "abc123".to_string(),
branch: "feature/provenance".to_string(),
worktree: Some("wt-a".to_string()),
canonical_commit: Some("abc123".to_string()),
superseded_by: None,
lineage: vec!["abc123".to_string()],
},
);
let event_json = serde_json::to_value(&event).expect("lane event should serialize");
assert_eq!(event_json["event"], "lane.commit.created");
assert_eq!(event_json["data"]["branch"], "feature/provenance");
assert_eq!(event_json["data"]["worktree"], "wt-a");
}
#[test]
fn dedupes_superseded_commit_events_by_canonical_commit() {
let retained = dedupe_superseded_commit_events(&[
LaneEvent::commit_created(
"2026-04-04T00:00:00Z",
Some("old".to_string()),
LaneCommitProvenance {
commit: "old123".to_string(),
branch: "feature/provenance".to_string(),
worktree: Some("wt-a".to_string()),
canonical_commit: Some("canon123".to_string()),
superseded_by: Some("new123".to_string()),
lineage: vec!["old123".to_string(), "new123".to_string()],
},
),
LaneEvent::commit_created(
"2026-04-04T00:00:01Z",
Some("new".to_string()),
LaneCommitProvenance {
commit: "new123".to_string(),
branch: "feature/provenance".to_string(),
worktree: Some("wt-b".to_string()),
canonical_commit: Some("canon123".to_string()),
superseded_by: None,
lineage: vec!["old123".to_string(), "new123".to_string()],
},
),
]);
assert_eq!(retained.len(), 1);
assert_eq!(retained[0].detail.as_deref(), Some("new"));
}
}

View File

@ -1,64 +1,112 @@
//! Core runtime primitives for the `claw` CLI and supporting crates.
//!
//! This crate owns session persistence, permission evaluation, prompt assembly,
//! MCP plumbing, tool-facing file operations, and the core conversation loop
//! that drives interactive and one-shot turns.
mod bash;
pub mod bash_validation;
mod bootstrap;
pub mod branch_lock;
mod compact;
mod config;
pub mod config_validate;
mod conversation;
mod file_ops;
mod git_context;
pub mod green_contract;
mod hooks;
mod json;
mod lane_events;
pub mod lsp_client;
mod mcp;
mod mcp_client;
pub mod mcp_lifecycle_hardened;
pub mod mcp_server;
mod mcp_stdio;
pub mod mcp_tool_bridge;
mod oauth;
pub mod permission_enforcer;
mod permissions;
pub mod plugin_lifecycle;
mod policy_engine;
mod prompt;
pub mod recovery_recipes;
mod remote;
pub mod sandbox;
mod session;
pub mod session_control;
pub use session_control::SessionStore;
mod sse;
pub mod stale_base;
pub mod stale_branch;
pub mod summary_compression;
pub mod task_packet;
pub mod task_registry;
pub mod team_cron_registry;
#[cfg(test)]
mod trust_resolver;
mod usage;
pub mod worker_boot;
pub use lsp::{
FileDiagnostics, LspContextEnrichment, LspError, LspManager, LspServerConfig,
SymbolLocation, WorkspaceDiagnostics,
};
pub use bash::{execute_bash, BashCommandInput, BashCommandOutput};
pub use bootstrap::{BootstrapPhase, BootstrapPlan};
pub use branch_lock::{detect_branch_lock_collisions, BranchLockCollision, BranchLockIntent};
pub use compact::{
compact_session, estimate_session_tokens, format_compact_summary,
get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult,
};
pub use config::{
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpManagedProxyServerConfig,
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpConfigCollection,
McpManagedProxyServerConfig, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig,
RuntimePluginConfig, ScopedMcpServerConfig, CLAW_SETTINGS_SCHEMA_NAME,
ProviderFallbackConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig,
RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig,
CLAW_SETTINGS_SCHEMA_NAME,
};
pub use config_validate::{
check_unsupported_format, format_diagnostics, validate_config_file, ConfigDiagnostic,
DiagnosticKind, ValidationResult,
};
pub use conversation::{
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
ToolError, ToolExecutor, TurnSummary,
auto_compaction_threshold_from_env, ApiClient, ApiRequest, AssistantEvent, AutoCompactionEvent,
ConversationRuntime, PromptCacheEvent, RuntimeError, StaticToolExecutor, ToolError,
ToolExecutor, TurnSummary,
};
pub use file_ops::{
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
WriteFileOutput,
};
pub use hooks::{HookEvent, HookRunResult, HookRunner};
pub use git_context::{GitCommitEntry, GitContext};
pub use hooks::{
HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, HookRunner,
};
pub use lane_events::{
dedupe_superseded_commit_events, LaneCommitProvenance, LaneEvent, LaneEventBlocker,
LaneEventName, LaneEventStatus, LaneFailureClass,
};
pub use mcp::{
mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp,
scoped_mcp_config_hash, unwrap_ccr_proxy_url,
};
pub use mcp_client::{
McpManagedProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport,
McpClientAuth, McpClientBootstrap, McpClientTransport, McpManagedProxyTransport,
McpRemoteTransport, McpSdkTransport, McpStdioTransport,
};
pub use mcp_lifecycle_hardened::{
McpDegradedReport, McpErrorSurface, McpFailedServer, McpLifecyclePhase, McpLifecycleState,
McpLifecycleValidator, McpPhaseResult,
};
pub use mcp_server::{McpServer, McpServerSpec, ToolCallHandler, MCP_SERVER_PROTOCOL_VERSION};
pub use mcp_stdio::{
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult,
McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource,
McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool,
McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer,
ManagedMcpTool, McpDiscoveryFailure, McpInitializeClientInfo, McpInitializeParams,
McpInitializeResult, McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult,
McpListToolsParams, McpListToolsResult, McpReadResourceParams, McpReadResourceResult,
McpResource, McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess,
McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult, McpToolDiscoveryReport,
UnsupportedMcpServer,
};
pub use oauth::{
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
@ -68,22 +116,59 @@ pub use oauth::{
PkceChallengeMethod, PkceCodePair,
};
pub use permissions::{
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
PermissionPrompter, PermissionRequest,
PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
PermissionPromptDecision, PermissionPrompter, PermissionRequest,
};
pub use plugin_lifecycle::{
DegradedMode, DiscoveryResult, PluginHealthcheck, PluginLifecycle, PluginLifecycleEvent,
PluginState, ResourceInfo, ServerHealth, ServerStatus, ToolInfo,
};
pub use policy_engine::{
evaluate, DiffScope, GreenLevel, LaneBlocker, LaneContext, PolicyAction, PolicyCondition,
PolicyEngine, PolicyRule, ReconcileReason, ReviewStatus,
};
pub use prompt::{
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
};
pub use recovery_recipes::{
attempt_recovery, recipe_for, EscalationPolicy, FailureScenario, RecoveryContext,
RecoveryEvent, RecoveryRecipe, RecoveryResult, RecoveryStep,
};
pub use remote::{
inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url,
RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL,
DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS,
};
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
pub use sandbox::{
build_linux_sandbox_command, detect_container_environment, detect_container_environment_from,
resolve_sandbox_status, resolve_sandbox_status_for_request, ContainerEnvironment,
FilesystemIsolationMode, LinuxSandboxCommand, SandboxConfig, SandboxDetectionInputs,
SandboxRequest, SandboxStatus,
};
pub use session::{
ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError,
SessionFork, SessionPromptEntry,
};
pub use sse::{IncrementalSseParser, SseEvent};
pub use stale_base::{
check_base_commit, format_stale_base_warning, read_claw_base_file, resolve_expected_base,
BaseCommitSource, BaseCommitState,
};
pub use stale_branch::{
apply_policy, check_freshness, BranchFreshness, StaleBranchAction, StaleBranchEvent,
StaleBranchPolicy,
};
pub use task_packet::{validate_packet, TaskPacket, TaskPacketValidationError, ValidatedPacket};
#[cfg(test)]
pub use trust_resolver::{TrustConfig, TrustDecision, TrustEvent, TrustPolicy, TrustResolver};
pub use usage::{
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
};
pub use worker_boot::{
Worker, WorkerEvent, WorkerEventKind, WorkerEventPayload, WorkerFailure, WorkerFailureKind,
WorkerPromptTarget, WorkerReadySnapshot, WorkerRegistry, WorkerStatus, WorkerTrustResolution,
};
#[cfg(test)]
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {

View File

@ -0,0 +1,747 @@
#![allow(clippy::should_implement_trait, clippy::must_use_candidate)]
//! LSP (Language Server Protocol) client registry for tool dispatch.
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
/// Supported LSP actions.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LspAction {
Diagnostics,
Hover,
Definition,
References,
Completion,
Symbols,
Format,
}
impl LspAction {
pub fn from_str(s: &str) -> Option<Self> {
match s {
"diagnostics" => Some(Self::Diagnostics),
"hover" => Some(Self::Hover),
"definition" | "goto_definition" => Some(Self::Definition),
"references" | "find_references" => Some(Self::References),
"completion" | "completions" => Some(Self::Completion),
"symbols" | "document_symbols" => Some(Self::Symbols),
"format" | "formatting" => Some(Self::Format),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LspDiagnostic {
pub path: String,
pub line: u32,
pub character: u32,
pub severity: String,
pub message: String,
pub source: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LspLocation {
pub path: String,
pub line: u32,
pub character: u32,
pub end_line: Option<u32>,
pub end_character: Option<u32>,
pub preview: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LspHoverResult {
pub content: String,
pub language: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LspCompletionItem {
pub label: String,
pub kind: Option<String>,
pub detail: Option<String>,
pub insert_text: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LspSymbol {
pub name: String,
pub kind: String,
pub path: String,
pub line: u32,
pub character: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LspServerStatus {
Connected,
Disconnected,
Starting,
Error,
}
impl std::fmt::Display for LspServerStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connected => write!(f, "connected"),
Self::Disconnected => write!(f, "disconnected"),
Self::Starting => write!(f, "starting"),
Self::Error => write!(f, "error"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LspServerState {
pub language: String,
pub status: LspServerStatus,
pub root_path: Option<String>,
pub capabilities: Vec<String>,
pub diagnostics: Vec<LspDiagnostic>,
}
#[derive(Debug, Clone, Default)]
pub struct LspRegistry {
inner: Arc<Mutex<RegistryInner>>,
}
#[derive(Debug, Default)]
struct RegistryInner {
servers: HashMap<String, LspServerState>,
}
impl LspRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register(
&self,
language: &str,
status: LspServerStatus,
root_path: Option<&str>,
capabilities: Vec<String>,
) {
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
inner.servers.insert(
language.to_owned(),
LspServerState {
language: language.to_owned(),
status,
root_path: root_path.map(str::to_owned),
capabilities,
diagnostics: Vec::new(),
},
);
}
pub fn get(&self, language: &str) -> Option<LspServerState> {
let inner = self.inner.lock().expect("lsp registry lock poisoned");
inner.servers.get(language).cloned()
}
/// Find the appropriate server for a file path based on extension.
pub fn find_server_for_path(&self, path: &str) -> Option<LspServerState> {
let ext = std::path::Path::new(path)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
let language = match ext {
"rs" => "rust",
"ts" | "tsx" => "typescript",
"js" | "jsx" => "javascript",
"py" => "python",
"go" => "go",
"java" => "java",
"c" | "h" => "c",
"cpp" | "hpp" | "cc" => "cpp",
"rb" => "ruby",
"lua" => "lua",
_ => return None,
};
self.get(language)
}
/// List all registered servers.
pub fn list_servers(&self) -> Vec<LspServerState> {
let inner = self.inner.lock().expect("lsp registry lock poisoned");
inner.servers.values().cloned().collect()
}
/// Add diagnostics to a server.
pub fn add_diagnostics(
&self,
language: &str,
diagnostics: Vec<LspDiagnostic>,
) -> Result<(), String> {
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
let server = inner
.servers
.get_mut(language)
.ok_or_else(|| format!("LSP server not found for language: {language}"))?;
server.diagnostics.extend(diagnostics);
Ok(())
}
/// Get diagnostics for a specific file path.
pub fn get_diagnostics(&self, path: &str) -> Vec<LspDiagnostic> {
let inner = self.inner.lock().expect("lsp registry lock poisoned");
inner
.servers
.values()
.flat_map(|s| &s.diagnostics)
.filter(|d| d.path == path)
.cloned()
.collect()
}
/// Clear diagnostics for a language server.
pub fn clear_diagnostics(&self, language: &str) -> Result<(), String> {
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
let server = inner
.servers
.get_mut(language)
.ok_or_else(|| format!("LSP server not found for language: {language}"))?;
server.diagnostics.clear();
Ok(())
}
/// Disconnect a server.
pub fn disconnect(&self, language: &str) -> Option<LspServerState> {
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
inner.servers.remove(language)
}
#[must_use]
pub fn len(&self) -> usize {
let inner = self.inner.lock().expect("lsp registry lock poisoned");
inner.servers.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Dispatch an LSP action and return a structured result.
pub fn dispatch(
&self,
action: &str,
path: Option<&str>,
line: Option<u32>,
character: Option<u32>,
_query: Option<&str>,
) -> Result<serde_json::Value, String> {
let lsp_action =
LspAction::from_str(action).ok_or_else(|| format!("unknown LSP action: {action}"))?;
// For diagnostics, we can check existing cached diagnostics
if lsp_action == LspAction::Diagnostics {
if let Some(path) = path {
let diags = self.get_diagnostics(path);
return Ok(serde_json::json!({
"action": "diagnostics",
"path": path,
"diagnostics": diags,
"count": diags.len()
}));
}
// All diagnostics across all servers
let inner = self.inner.lock().expect("lsp registry lock poisoned");
let all_diags: Vec<_> = inner
.servers
.values()
.flat_map(|s| &s.diagnostics)
.collect();
return Ok(serde_json::json!({
"action": "diagnostics",
"diagnostics": all_diags,
"count": all_diags.len()
}));
}
// For other actions, we need a connected server for the given file
let path = path.ok_or("path is required for this LSP action")?;
let server = self
.find_server_for_path(path)
.ok_or_else(|| format!("no LSP server available for path: {path}"))?;
if server.status != LspServerStatus::Connected {
return Err(format!(
"LSP server for '{}' is not connected (status: {})",
server.language, server.status
));
}
// Return structured placeholder — actual LSP JSON-RPC calls would
// go through the real LSP process here.
Ok(serde_json::json!({
"action": action,
"path": path,
"line": line,
"character": character,
"language": server.language,
"status": "dispatched",
"message": format!("LSP {} dispatched to {} server", action, server.language)
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registers_and_retrieves_server() {
let registry = LspRegistry::new();
registry.register(
"rust",
LspServerStatus::Connected,
Some("/workspace"),
vec!["hover".into(), "completion".into()],
);
let server = registry.get("rust").expect("should exist");
assert_eq!(server.language, "rust");
assert_eq!(server.status, LspServerStatus::Connected);
assert_eq!(server.capabilities.len(), 2);
}
#[test]
fn finds_server_by_file_extension() {
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
registry.register("typescript", LspServerStatus::Connected, None, vec![]);
let rs_server = registry.find_server_for_path("src/main.rs").unwrap();
assert_eq!(rs_server.language, "rust");
let ts_server = registry.find_server_for_path("src/index.ts").unwrap();
assert_eq!(ts_server.language, "typescript");
assert!(registry.find_server_for_path("data.csv").is_none());
}
#[test]
fn manages_diagnostics() {
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
registry
.add_diagnostics(
"rust",
vec![LspDiagnostic {
path: "src/main.rs".into(),
line: 10,
character: 5,
severity: "error".into(),
message: "mismatched types".into(),
source: Some("rust-analyzer".into()),
}],
)
.unwrap();
let diags = registry.get_diagnostics("src/main.rs");
assert_eq!(diags.len(), 1);
assert_eq!(diags[0].message, "mismatched types");
registry.clear_diagnostics("rust").unwrap();
assert!(registry.get_diagnostics("src/main.rs").is_empty());
}
#[test]
fn dispatches_diagnostics_action() {
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
registry
.add_diagnostics(
"rust",
vec![LspDiagnostic {
path: "src/lib.rs".into(),
line: 1,
character: 0,
severity: "warning".into(),
message: "unused import".into(),
source: None,
}],
)
.unwrap();
let result = registry
.dispatch("diagnostics", Some("src/lib.rs"), None, None, None)
.unwrap();
assert_eq!(result["count"], 1);
}
#[test]
fn dispatches_hover_action() {
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
let result = registry
.dispatch("hover", Some("src/main.rs"), Some(10), Some(5), None)
.unwrap();
assert_eq!(result["action"], "hover");
assert_eq!(result["language"], "rust");
}
#[test]
fn rejects_action_on_disconnected_server() {
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Disconnected, None, vec![]);
assert!(registry
.dispatch("hover", Some("src/main.rs"), Some(1), Some(0), None)
.is_err());
}
#[test]
fn rejects_unknown_action() {
let registry = LspRegistry::new();
assert!(registry
.dispatch("unknown_action", Some("file.rs"), None, None, None)
.is_err());
}
#[test]
fn disconnects_server() {
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
assert_eq!(registry.len(), 1);
let removed = registry.disconnect("rust");
assert!(removed.is_some());
assert!(registry.is_empty());
}
#[test]
fn lsp_action_from_str_all_aliases() {
// given
let cases = [
("diagnostics", Some(LspAction::Diagnostics)),
("hover", Some(LspAction::Hover)),
("definition", Some(LspAction::Definition)),
("goto_definition", Some(LspAction::Definition)),
("references", Some(LspAction::References)),
("find_references", Some(LspAction::References)),
("completion", Some(LspAction::Completion)),
("completions", Some(LspAction::Completion)),
("symbols", Some(LspAction::Symbols)),
("document_symbols", Some(LspAction::Symbols)),
("format", Some(LspAction::Format)),
("formatting", Some(LspAction::Format)),
("unknown", None),
];
// when
let resolved: Vec<_> = cases
.into_iter()
.map(|(input, expected)| (input, LspAction::from_str(input), expected))
.collect();
// then
for (input, actual, expected) in resolved {
assert_eq!(actual, expected, "unexpected action resolution for {input}");
}
}
#[test]
fn lsp_server_status_display_all_variants() {
// given
let cases = [
(LspServerStatus::Connected, "connected"),
(LspServerStatus::Disconnected, "disconnected"),
(LspServerStatus::Starting, "starting"),
(LspServerStatus::Error, "error"),
];
// when
let rendered: Vec<_> = cases
.into_iter()
.map(|(status, expected)| (status.to_string(), expected))
.collect();
// then
assert_eq!(
rendered,
vec![
("connected".to_string(), "connected"),
("disconnected".to_string(), "disconnected"),
("starting".to_string(), "starting"),
("error".to_string(), "error"),
]
);
}
#[test]
fn dispatch_diagnostics_without_path_aggregates() {
// given
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
registry.register("python", LspServerStatus::Connected, None, vec![]);
registry
.add_diagnostics(
"rust",
vec![LspDiagnostic {
path: "src/lib.rs".into(),
line: 1,
character: 0,
severity: "warning".into(),
message: "unused import".into(),
source: Some("rust-analyzer".into()),
}],
)
.expect("rust diagnostics should add");
registry
.add_diagnostics(
"python",
vec![LspDiagnostic {
path: "script.py".into(),
line: 2,
character: 4,
severity: "error".into(),
message: "undefined name".into(),
source: Some("pyright".into()),
}],
)
.expect("python diagnostics should add");
// when
let result = registry
.dispatch("diagnostics", None, None, None, None)
.expect("aggregate diagnostics should work");
// then
assert_eq!(result["action"], "diagnostics");
assert_eq!(result["count"], 2);
assert_eq!(result["diagnostics"].as_array().map(Vec::len), Some(2));
}
#[test]
fn dispatch_non_diagnostics_requires_path() {
// given
let registry = LspRegistry::new();
// when
let result = registry.dispatch("hover", None, Some(1), Some(0), None);
// then
assert_eq!(
result.expect_err("path should be required"),
"path is required for this LSP action"
);
}
#[test]
fn dispatch_no_server_for_path_errors() {
// given
let registry = LspRegistry::new();
// when
let result = registry.dispatch("hover", Some("notes.md"), Some(1), Some(0), None);
// then
let error = result.expect_err("missing server should fail");
assert!(error.contains("no LSP server available for path: notes.md"));
}
#[test]
fn dispatch_disconnected_server_error_payload() {
// given
let registry = LspRegistry::new();
registry.register("typescript", LspServerStatus::Disconnected, None, vec![]);
// when
let result = registry.dispatch("hover", Some("src/index.ts"), Some(3), Some(2), None);
// then
let error = result.expect_err("disconnected server should fail");
assert!(error.contains("typescript"));
assert!(error.contains("disconnected"));
}
#[test]
fn find_server_for_all_extensions() {
// given
let registry = LspRegistry::new();
for language in [
"rust",
"typescript",
"javascript",
"python",
"go",
"java",
"c",
"cpp",
"ruby",
"lua",
] {
registry.register(language, LspServerStatus::Connected, None, vec![]);
}
let cases = [
("src/main.rs", "rust"),
("src/index.ts", "typescript"),
("src/view.tsx", "typescript"),
("src/app.js", "javascript"),
("src/app.jsx", "javascript"),
("script.py", "python"),
("main.go", "go"),
("Main.java", "java"),
("native.c", "c"),
("native.h", "c"),
("native.cpp", "cpp"),
("native.hpp", "cpp"),
("native.cc", "cpp"),
("script.rb", "ruby"),
("script.lua", "lua"),
];
// when
let resolved: Vec<_> = cases
.into_iter()
.map(|(path, expected)| {
(
path,
registry
.find_server_for_path(path)
.map(|server| server.language),
expected,
)
})
.collect();
// then
for (path, actual, expected) in resolved {
assert_eq!(
actual.as_deref(),
Some(expected),
"unexpected mapping for {path}"
);
}
}
#[test]
fn find_server_for_path_no_extension() {
// given
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
// when
let result = registry.find_server_for_path("Makefile");
// then
assert!(result.is_none());
}
#[test]
fn list_servers_with_multiple() {
// given
let registry = LspRegistry::new();
registry.register("rust", LspServerStatus::Connected, None, vec![]);
registry.register("typescript", LspServerStatus::Starting, None, vec![]);
registry.register("python", LspServerStatus::Error, None, vec![]);
// when
let servers = registry.list_servers();
// then
assert_eq!(servers.len(), 3);
assert!(servers.iter().any(|server| server.language == "rust"));
assert!(servers.iter().any(|server| server.language == "typescript"));
assert!(servers.iter().any(|server| server.language == "python"));
}
#[test]
fn get_missing_server_returns_none() {
// given
let registry = LspRegistry::new();
// when
let server = registry.get("missing");
// then
assert!(server.is_none());
}
#[test]
fn add_diagnostics_missing_language_errors() {
// given
let registry = LspRegistry::new();
// when
let result = registry.add_diagnostics("missing", vec![]);
// then
let error = result.expect_err("missing language should fail");
assert!(error.contains("LSP server not found for language: missing"));
}
#[test]
fn get_diagnostics_across_servers() {
// given
let registry = LspRegistry::new();
let shared_path = "shared/file.txt";
registry.register("rust", LspServerStatus::Connected, None, vec![]);
registry.register("python", LspServerStatus::Connected, None, vec![]);
registry
.add_diagnostics(
"rust",
vec![LspDiagnostic {
path: shared_path.into(),
line: 4,
character: 1,
severity: "warning".into(),
message: "warn".into(),
source: None,
}],
)
.expect("rust diagnostics should add");
registry
.add_diagnostics(
"python",
vec![LspDiagnostic {
path: shared_path.into(),
line: 8,
character: 3,
severity: "error".into(),
message: "err".into(),
source: None,
}],
)
.expect("python diagnostics should add");
// when
let diagnostics = registry.get_diagnostics(shared_path);
// then
assert_eq!(diagnostics.len(), 2);
assert!(diagnostics
.iter()
.any(|diagnostic| diagnostic.message == "warn"));
assert!(diagnostics
.iter()
.any(|diagnostic| diagnostic.message == "err"));
}
#[test]
fn clear_diagnostics_missing_language_errors() {
// given
let registry = LspRegistry::new();
// when
let result = registry.clear_diagnostics("missing");
// then
let error = result.expect_err("missing language should fail");
assert!(error.contains("LSP server not found for language: missing"));
}
}

View File

@ -84,10 +84,13 @@ pub fn mcp_server_signature(config: &McpServerConfig) -> Option<String> {
pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String {
let rendered = match &config.config {
McpServerConfig::Stdio(stdio) => format!(
"stdio|{}|{}|{}",
"stdio|{}|{}|{}|{}",
stdio.command,
render_command_signature(&stdio.args),
render_env_signature(&stdio.env)
render_env_signature(&stdio.env),
stdio
.tool_call_timeout_ms
.map_or_else(String::new, |timeout_ms| timeout_ms.to_string())
),
McpServerConfig::Sse(remote) => format!(
"sse|{}|{}|{}|{}",
@ -245,6 +248,7 @@ mod tests {
command: "uvx".to_string(),
args: vec!["mcp-server".to_string()],
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
tool_call_timeout_ms: None,
});
assert_eq!(
mcp_server_signature(&stdio),

View File

@ -3,6 +3,8 @@ use std::collections::BTreeMap;
use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
pub const DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS: u64 = 60_000;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum McpClientTransport {
Stdio(McpStdioTransport),
@ -18,6 +20,7 @@ pub struct McpStdioTransport {
pub command: String,
pub args: Vec<String>,
pub env: BTreeMap<String, String>,
pub tool_call_timeout_ms: Option<u64>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
@ -75,6 +78,7 @@ impl McpClientTransport {
command: config.command.clone(),
args: config.args.clone(),
env: config.env.clone(),
tool_call_timeout_ms: config.tool_call_timeout_ms,
}),
McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
url: config.url.clone(),
@ -105,6 +109,14 @@ impl McpClientTransport {
}
}
impl McpStdioTransport {
#[must_use]
pub fn resolved_tool_call_timeout_ms(&self) -> u64 {
self.tool_call_timeout_ms
.unwrap_or(DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS)
}
}
impl McpClientAuth {
#[must_use]
pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
@ -136,6 +148,7 @@ mod tests {
command: "uvx".to_string(),
args: vec!["mcp-server".to_string()],
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
tool_call_timeout_ms: Some(15_000),
}),
};
@ -154,6 +167,7 @@ mod tests {
transport.env.get("TOKEN").map(String::as_str),
Some("secret")
);
assert_eq!(transport.tool_call_timeout_ms, Some(15_000));
}
other => panic!("expected stdio transport, got {other:?}"),
}

View File

@ -0,0 +1,843 @@
#![allow(clippy::unnested_or_patterns, clippy::map_unwrap_or)]
use std::collections::{BTreeMap, BTreeSet};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum McpLifecyclePhase {
ConfigLoad,
ServerRegistration,
SpawnConnect,
InitializeHandshake,
ToolDiscovery,
ResourceDiscovery,
Ready,
Invocation,
ErrorSurfacing,
Shutdown,
Cleanup,
}
impl McpLifecyclePhase {
#[must_use]
pub fn all() -> [Self; 11] {
[
Self::ConfigLoad,
Self::ServerRegistration,
Self::SpawnConnect,
Self::InitializeHandshake,
Self::ToolDiscovery,
Self::ResourceDiscovery,
Self::Ready,
Self::Invocation,
Self::ErrorSurfacing,
Self::Shutdown,
Self::Cleanup,
]
}
}
impl std::fmt::Display for McpLifecyclePhase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ConfigLoad => write!(f, "config_load"),
Self::ServerRegistration => write!(f, "server_registration"),
Self::SpawnConnect => write!(f, "spawn_connect"),
Self::InitializeHandshake => write!(f, "initialize_handshake"),
Self::ToolDiscovery => write!(f, "tool_discovery"),
Self::ResourceDiscovery => write!(f, "resource_discovery"),
Self::Ready => write!(f, "ready"),
Self::Invocation => write!(f, "invocation"),
Self::ErrorSurfacing => write!(f, "error_surfacing"),
Self::Shutdown => write!(f, "shutdown"),
Self::Cleanup => write!(f, "cleanup"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct McpErrorSurface {
pub phase: McpLifecyclePhase,
pub server_name: Option<String>,
pub message: String,
pub context: BTreeMap<String, String>,
pub recoverable: bool,
pub timestamp: u64,
}
impl McpErrorSurface {
#[must_use]
pub fn new(
phase: McpLifecyclePhase,
server_name: Option<String>,
message: impl Into<String>,
context: BTreeMap<String, String>,
recoverable: bool,
) -> Self {
Self {
phase,
server_name,
message: message.into(),
context,
recoverable,
timestamp: now_secs(),
}
}
}
impl std::fmt::Display for McpErrorSurface {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"MCP lifecycle error during {}: {}",
self.phase, self.message
)?;
if let Some(server_name) = &self.server_name {
write!(f, " (server: {server_name})")?;
}
if !self.context.is_empty() {
write!(f, " with context {:?}", self.context)?;
}
if self.recoverable {
write!(f, " [recoverable]")?;
}
Ok(())
}
}
impl std::error::Error for McpErrorSurface {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum McpPhaseResult {
Success {
phase: McpLifecyclePhase,
duration: Duration,
},
Failure {
phase: McpLifecyclePhase,
error: McpErrorSurface,
},
Timeout {
phase: McpLifecyclePhase,
waited: Duration,
error: McpErrorSurface,
},
}
impl McpPhaseResult {
#[must_use]
pub fn phase(&self) -> McpLifecyclePhase {
match self {
Self::Success { phase, .. }
| Self::Failure { phase, .. }
| Self::Timeout { phase, .. } => *phase,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct McpLifecycleState {
current_phase: Option<McpLifecyclePhase>,
phase_errors: BTreeMap<McpLifecyclePhase, Vec<McpErrorSurface>>,
phase_timestamps: BTreeMap<McpLifecyclePhase, u64>,
phase_results: Vec<McpPhaseResult>,
}
impl McpLifecycleState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn current_phase(&self) -> Option<McpLifecyclePhase> {
self.current_phase
}
#[must_use]
pub fn errors_for_phase(&self, phase: McpLifecyclePhase) -> &[McpErrorSurface] {
self.phase_errors
.get(&phase)
.map(Vec::as_slice)
.unwrap_or(&[])
}
#[must_use]
pub fn results(&self) -> &[McpPhaseResult] {
&self.phase_results
}
#[must_use]
pub fn phase_timestamps(&self) -> &BTreeMap<McpLifecyclePhase, u64> {
&self.phase_timestamps
}
#[must_use]
pub fn phase_timestamp(&self, phase: McpLifecyclePhase) -> Option<u64> {
self.phase_timestamps.get(&phase).copied()
}
fn record_phase(&mut self, phase: McpLifecyclePhase) {
self.current_phase = Some(phase);
self.phase_timestamps.insert(phase, now_secs());
}
fn record_error(&mut self, error: McpErrorSurface) {
self.phase_errors
.entry(error.phase)
.or_default()
.push(error);
}
fn record_result(&mut self, result: McpPhaseResult) {
self.phase_results.push(result);
}
fn can_resume_after_error(&self) -> bool {
match self.phase_results.last() {
Some(McpPhaseResult::Failure { error, .. } | McpPhaseResult::Timeout { error, .. }) => {
error.recoverable
}
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct McpFailedServer {
pub server_name: String,
pub phase: McpLifecyclePhase,
pub error: McpErrorSurface,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct McpDegradedReport {
pub working_servers: Vec<String>,
pub failed_servers: Vec<McpFailedServer>,
pub available_tools: Vec<String>,
pub missing_tools: Vec<String>,
}
impl McpDegradedReport {
#[must_use]
pub fn new(
working_servers: Vec<String>,
failed_servers: Vec<McpFailedServer>,
available_tools: Vec<String>,
expected_tools: Vec<String>,
) -> Self {
let working_servers = dedupe_sorted(working_servers);
let available_tools = dedupe_sorted(available_tools);
let available_tool_set: BTreeSet<_> = available_tools.iter().cloned().collect();
let expected_tools = dedupe_sorted(expected_tools);
let missing_tools = expected_tools
.into_iter()
.filter(|tool| !available_tool_set.contains(tool))
.collect();
Self {
working_servers,
failed_servers,
available_tools,
missing_tools,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct McpLifecycleValidator {
state: McpLifecycleState,
}
impl McpLifecycleValidator {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn state(&self) -> &McpLifecycleState {
&self.state
}
#[must_use]
pub fn validate_phase_transition(from: McpLifecyclePhase, to: McpLifecyclePhase) -> bool {
match (from, to) {
(McpLifecyclePhase::ConfigLoad, McpLifecyclePhase::ServerRegistration)
| (McpLifecyclePhase::ServerRegistration, McpLifecyclePhase::SpawnConnect)
| (McpLifecyclePhase::SpawnConnect, McpLifecyclePhase::InitializeHandshake)
| (McpLifecyclePhase::InitializeHandshake, McpLifecyclePhase::ToolDiscovery)
| (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::ResourceDiscovery)
| (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready)
| (McpLifecyclePhase::ResourceDiscovery, McpLifecyclePhase::Ready)
| (McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation)
| (McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready)
| (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Ready)
| (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Shutdown)
| (McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup) => true,
(_, McpLifecyclePhase::Shutdown) => from != McpLifecyclePhase::Cleanup,
(_, McpLifecyclePhase::ErrorSurfacing) => {
from != McpLifecyclePhase::Cleanup && from != McpLifecyclePhase::Shutdown
}
_ => false,
}
}
pub fn run_phase(&mut self, phase: McpLifecyclePhase) -> McpPhaseResult {
let started = Instant::now();
if let Some(current_phase) = self.state.current_phase() {
if current_phase == McpLifecyclePhase::ErrorSurfacing
&& phase == McpLifecyclePhase::Ready
&& !self.state.can_resume_after_error()
{
return self.record_failure(McpErrorSurface::new(
phase,
None,
"cannot return to ready after a non-recoverable MCP lifecycle failure",
BTreeMap::from([
("from".to_string(), current_phase.to_string()),
("to".to_string(), phase.to_string()),
]),
false,
));
}
if !Self::validate_phase_transition(current_phase, phase) {
return self.record_failure(McpErrorSurface::new(
phase,
None,
format!("invalid MCP lifecycle transition from {current_phase} to {phase}"),
BTreeMap::from([
("from".to_string(), current_phase.to_string()),
("to".to_string(), phase.to_string()),
]),
false,
));
}
} else if phase != McpLifecyclePhase::ConfigLoad {
return self.record_failure(McpErrorSurface::new(
phase,
None,
format!("invalid initial MCP lifecycle phase {phase}"),
BTreeMap::from([("phase".to_string(), phase.to_string())]),
false,
));
}
self.state.record_phase(phase);
let result = McpPhaseResult::Success {
phase,
duration: started.elapsed(),
};
self.state.record_result(result.clone());
result
}
pub fn record_failure(&mut self, error: McpErrorSurface) -> McpPhaseResult {
let phase = error.phase;
self.state.record_error(error.clone());
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
let result = McpPhaseResult::Failure { phase, error };
self.state.record_result(result.clone());
result
}
pub fn record_timeout(
&mut self,
phase: McpLifecyclePhase,
waited: Duration,
server_name: Option<String>,
mut context: BTreeMap<String, String>,
) -> McpPhaseResult {
context.insert("waited_ms".to_string(), waited.as_millis().to_string());
let error = McpErrorSurface::new(
phase,
server_name,
format!(
"MCP lifecycle phase {phase} timed out after {} ms",
waited.as_millis()
),
context,
true,
);
self.state.record_error(error.clone());
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
let result = McpPhaseResult::Timeout {
phase,
waited,
error,
};
self.state.record_result(result.clone());
result
}
}
fn dedupe_sorted(mut values: Vec<String>) -> Vec<String> {
values.sort();
values.dedup();
values
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn phase_display_matches_serde_name() {
// given
let phases = McpLifecyclePhase::all();
// when
let serialized = phases
.into_iter()
.map(|phase| {
(
phase.to_string(),
serde_json::to_value(phase).expect("serialize phase"),
)
})
.collect::<Vec<_>>();
// then
for (display, json_value) in serialized {
assert_eq!(json_value, json!(display));
}
}
#[test]
fn given_startup_path_when_running_to_cleanup_then_each_control_transition_succeeds() {
// given
let mut validator = McpLifecycleValidator::new();
let phases = [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
McpLifecyclePhase::ResourceDiscovery,
McpLifecyclePhase::Ready,
McpLifecyclePhase::Invocation,
McpLifecyclePhase::Ready,
McpLifecyclePhase::Shutdown,
McpLifecyclePhase::Cleanup,
];
// when
let results = phases
.into_iter()
.map(|phase| validator.run_phase(phase))
.collect::<Vec<_>>();
// then
assert!(results
.iter()
.all(|result| matches!(result, McpPhaseResult::Success { .. })));
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::Cleanup)
);
for phase in [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
McpLifecyclePhase::ResourceDiscovery,
McpLifecyclePhase::Ready,
McpLifecyclePhase::Invocation,
McpLifecyclePhase::Shutdown,
McpLifecyclePhase::Cleanup,
] {
assert!(validator.state().phase_timestamp(phase).is_some());
}
}
#[test]
fn given_tool_discovery_when_resource_discovery_is_skipped_then_ready_is_still_allowed() {
// given
let mut validator = McpLifecycleValidator::new();
for phase in [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
] {
let result = validator.run_phase(phase);
assert!(matches!(result, McpPhaseResult::Success { .. }));
}
// when
let result = validator.run_phase(McpLifecyclePhase::Ready);
// then
assert!(matches!(result, McpPhaseResult::Success { .. }));
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::Ready)
);
}
#[test]
fn validates_expected_phase_transitions() {
// given
let valid_transitions = [
(
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
),
(
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
),
(
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
),
(
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
),
(
McpLifecyclePhase::ToolDiscovery,
McpLifecyclePhase::ResourceDiscovery,
),
(McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready),
(
McpLifecyclePhase::ResourceDiscovery,
McpLifecyclePhase::Ready,
),
(McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation),
(McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready),
(McpLifecyclePhase::Ready, McpLifecyclePhase::Shutdown),
(
McpLifecyclePhase::Invocation,
McpLifecyclePhase::ErrorSurfacing,
),
(
McpLifecyclePhase::ErrorSurfacing,
McpLifecyclePhase::Shutdown,
),
(McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup),
];
// when / then
for (from, to) in valid_transitions {
assert!(McpLifecycleValidator::validate_phase_transition(from, to));
}
assert!(!McpLifecycleValidator::validate_phase_transition(
McpLifecyclePhase::Ready,
McpLifecyclePhase::ConfigLoad,
));
assert!(!McpLifecycleValidator::validate_phase_transition(
McpLifecyclePhase::Cleanup,
McpLifecyclePhase::Ready,
));
}
#[test]
fn given_invalid_transition_when_running_phase_then_structured_failure_is_recorded() {
// given
let mut validator = McpLifecycleValidator::new();
let _ = validator.run_phase(McpLifecyclePhase::ConfigLoad);
let _ = validator.run_phase(McpLifecyclePhase::ServerRegistration);
// when
let result = validator.run_phase(McpLifecyclePhase::Ready);
// then
match result {
McpPhaseResult::Failure { phase, error } => {
assert_eq!(phase, McpLifecyclePhase::Ready);
assert!(!error.recoverable);
assert_eq!(error.phase, McpLifecyclePhase::Ready);
assert_eq!(
error.context.get("from").map(String::as_str),
Some("server_registration")
);
assert_eq!(error.context.get("to").map(String::as_str), Some("ready"));
}
other => panic!("expected failure result, got {other:?}"),
}
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::ErrorSurfacing)
);
assert_eq!(
validator
.state()
.errors_for_phase(McpLifecyclePhase::Ready)
.len(),
1
);
}
#[test]
fn given_each_phase_when_failure_is_recorded_then_error_is_tracked_per_phase() {
// given
let mut validator = McpLifecycleValidator::new();
// when / then
for phase in McpLifecyclePhase::all() {
let result = validator.record_failure(McpErrorSurface::new(
phase,
Some("alpha".to_string()),
format!("failure at {phase}"),
BTreeMap::from([("server".to_string(), "alpha".to_string())]),
phase == McpLifecyclePhase::ResourceDiscovery,
));
match result {
McpPhaseResult::Failure {
phase: failed_phase,
error,
} => {
assert_eq!(failed_phase, phase);
assert_eq!(error.phase, phase);
assert_eq!(
error.recoverable,
phase == McpLifecyclePhase::ResourceDiscovery
);
}
other => panic!("expected failure result, got {other:?}"),
}
assert_eq!(validator.state().errors_for_phase(phase).len(), 1);
}
}
#[test]
fn given_spawn_connect_timeout_when_recorded_then_waited_duration_is_preserved() {
// given
let mut validator = McpLifecycleValidator::new();
let waited = Duration::from_millis(250);
// when
let result = validator.record_timeout(
McpLifecyclePhase::SpawnConnect,
waited,
Some("alpha".to_string()),
BTreeMap::from([("attempt".to_string(), "1".to_string())]),
);
// then
match result {
McpPhaseResult::Timeout {
phase,
waited: actual,
error,
} => {
assert_eq!(phase, McpLifecyclePhase::SpawnConnect);
assert_eq!(actual, waited);
assert!(error.recoverable);
assert_eq!(error.server_name.as_deref(), Some("alpha"));
}
other => panic!("expected timeout result, got {other:?}"),
}
let errors = validator
.state()
.errors_for_phase(McpLifecyclePhase::SpawnConnect);
assert_eq!(errors.len(), 1);
assert_eq!(
errors[0].context.get("waited_ms").map(String::as_str),
Some("250")
);
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::ErrorSurfacing)
);
}
#[test]
fn given_partial_server_health_when_building_degraded_report_then_missing_tools_are_reported() {
// given
let failed = vec![McpFailedServer {
server_name: "broken".to_string(),
phase: McpLifecyclePhase::InitializeHandshake,
error: McpErrorSurface::new(
McpLifecyclePhase::InitializeHandshake,
Some("broken".to_string()),
"initialize failed",
BTreeMap::from([("reason".to_string(), "broken pipe".to_string())]),
false,
),
}];
// when
let report = McpDegradedReport::new(
vec!["alpha".to_string(), "beta".to_string(), "alpha".to_string()],
failed,
vec![
"alpha.echo".to_string(),
"beta.search".to_string(),
"alpha.echo".to_string(),
],
vec![
"alpha.echo".to_string(),
"beta.search".to_string(),
"broken.fetch".to_string(),
],
);
// then
assert_eq!(
report.working_servers,
vec!["alpha".to_string(), "beta".to_string()]
);
assert_eq!(report.failed_servers.len(), 1);
assert_eq!(report.failed_servers[0].server_name, "broken");
assert_eq!(
report.available_tools,
vec!["alpha.echo".to_string(), "beta.search".to_string()]
);
assert_eq!(report.missing_tools, vec!["broken.fetch".to_string()]);
}
#[test]
fn given_failure_during_resource_discovery_when_shutting_down_then_cleanup_still_succeeds() {
// given
let mut validator = McpLifecycleValidator::new();
for phase in [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
] {
let result = validator.run_phase(phase);
assert!(matches!(result, McpPhaseResult::Success { .. }));
}
let _ = validator.record_failure(McpErrorSurface::new(
McpLifecyclePhase::ResourceDiscovery,
Some("alpha".to_string()),
"resource listing failed",
BTreeMap::from([("reason".to_string(), "timeout".to_string())]),
true,
));
// when
let shutdown = validator.run_phase(McpLifecyclePhase::Shutdown);
let cleanup = validator.run_phase(McpLifecyclePhase::Cleanup);
// then
assert!(matches!(shutdown, McpPhaseResult::Success { .. }));
assert!(matches!(cleanup, McpPhaseResult::Success { .. }));
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::Cleanup)
);
assert!(validator
.state()
.phase_timestamp(McpLifecyclePhase::ErrorSurfacing)
.is_some());
}
#[test]
fn error_surface_display_includes_phase_server_and_recoverable_flag() {
// given
let error = McpErrorSurface::new(
McpLifecyclePhase::SpawnConnect,
Some("alpha".to_string()),
"process exited early",
BTreeMap::from([("exit_code".to_string(), "1".to_string())]),
true,
);
// when
let rendered = error.to_string();
// then
assert!(rendered.contains("spawn_connect"));
assert!(rendered.contains("process exited early"));
assert!(rendered.contains("server: alpha"));
assert!(rendered.contains("recoverable"));
let trait_object: &dyn std::error::Error = &error;
assert_eq!(trait_object.to_string(), rendered);
}
#[test]
fn given_nonrecoverable_failure_when_returning_to_ready_then_validator_rejects_resume() {
// given
let mut validator = McpLifecycleValidator::new();
for phase in [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
McpLifecyclePhase::Ready,
] {
let result = validator.run_phase(phase);
assert!(matches!(result, McpPhaseResult::Success { .. }));
}
let _ = validator.record_failure(McpErrorSurface::new(
McpLifecyclePhase::Invocation,
Some("alpha".to_string()),
"tool call corrupted the session",
BTreeMap::from([("reason".to_string(), "invalid frame".to_string())]),
false,
));
// when
let result = validator.run_phase(McpLifecyclePhase::Ready);
// then
match result {
McpPhaseResult::Failure { phase, error } => {
assert_eq!(phase, McpLifecyclePhase::Ready);
assert!(!error.recoverable);
assert!(error.message.contains("non-recoverable"));
}
other => panic!("expected failure result, got {other:?}"),
}
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::ErrorSurfacing)
);
}
#[test]
fn given_recoverable_failure_when_returning_to_ready_then_validator_allows_resume() {
// given
let mut validator = McpLifecycleValidator::new();
for phase in [
McpLifecyclePhase::ConfigLoad,
McpLifecyclePhase::ServerRegistration,
McpLifecyclePhase::SpawnConnect,
McpLifecyclePhase::InitializeHandshake,
McpLifecyclePhase::ToolDiscovery,
McpLifecyclePhase::Ready,
] {
let result = validator.run_phase(phase);
assert!(matches!(result, McpPhaseResult::Success { .. }));
}
let _ = validator.record_failure(McpErrorSurface::new(
McpLifecyclePhase::Invocation,
Some("alpha".to_string()),
"tool call failed but can be retried",
BTreeMap::from([("reason".to_string(), "upstream timeout".to_string())]),
true,
));
// when
let result = validator.run_phase(McpLifecyclePhase::Ready);
// then
assert!(matches!(result, McpPhaseResult::Success { .. }));
assert_eq!(
validator.state().current_phase(),
Some(McpLifecyclePhase::Ready)
);
}
}

View File

@ -0,0 +1,440 @@
//! Minimal Model Context Protocol (MCP) server.
//!
//! Implements a newline-safe, LSP-framed JSON-RPC server over stdio that
//! answers `initialize`, `tools/list`, and `tools/call` requests. The framing
//! matches the client transport implemented in [`crate::mcp_stdio`] so this
//! server can be driven by either an external MCP client (e.g. Claude
//! Desktop) or `claw`'s own [`McpServerManager`](crate::McpServerManager).
//!
//! The server is intentionally small: it exposes a list of pre-built
//! [`McpTool`] descriptors and delegates `tools/call` to a caller-supplied
//! handler. Tool execution itself lives in the `tools` crate; this module is
//! purely the transport + dispatch loop.
//!
//! [`McpTool`]: crate::mcp_stdio::McpTool
use std::io;
use serde_json::{json, Value as JsonValue};
use tokio::io::{
stdin, stdout, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, Stdin, Stdout,
};
use crate::mcp_stdio::{
JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, McpInitializeResult,
McpInitializeServerInfo, McpListToolsResult, McpTool, McpToolCallContent, McpToolCallParams,
McpToolCallResult,
};
/// Protocol version the server advertises during `initialize`.
///
/// Matches the version used by the built-in client in
/// [`crate::mcp_stdio`], so the two stay in lockstep.
pub const MCP_SERVER_PROTOCOL_VERSION: &str = "2025-03-26";
/// Synchronous handler invoked for every `tools/call` request.
///
/// Returning `Ok(text)` yields a single `text` content block and
/// `isError: false`. Returning `Err(message)` yields a `text` block with the
/// error and `isError: true`, mirroring the error-surfacing convention used
/// elsewhere in claw.
pub type ToolCallHandler =
Box<dyn Fn(&str, &JsonValue) -> Result<String, String> + Send + Sync + 'static>;
/// Configuration for an [`McpServer`] instance.
///
/// Named `McpServerSpec` rather than `McpServerConfig` to avoid colliding
/// with the existing client-side [`crate::config::McpServerConfig`] that
/// describes *remote* MCP servers the runtime connects to.
pub struct McpServerSpec {
/// Name advertised in the `serverInfo` field of the `initialize` response.
pub server_name: String,
/// Version advertised in the `serverInfo` field of the `initialize`
/// response.
pub server_version: String,
/// Tool descriptors returned for `tools/list`.
pub tools: Vec<McpTool>,
/// Handler invoked for `tools/call`.
pub tool_handler: ToolCallHandler,
}
/// Minimal MCP stdio server.
///
/// The server runs a blocking read/dispatch/write loop over the current
/// process's stdin/stdout, terminating cleanly when the peer closes the
/// stream.
pub struct McpServer {
spec: McpServerSpec,
stdin: BufReader<Stdin>,
stdout: Stdout,
}
impl McpServer {
#[must_use]
pub fn new(spec: McpServerSpec) -> Self {
Self {
spec,
stdin: BufReader::new(stdin()),
stdout: stdout(),
}
}
/// Runs the server until the client closes stdin.
///
/// Returns `Ok(())` on clean EOF; any other I/O error is propagated so
/// callers can log and exit non-zero.
pub async fn run(&mut self) -> io::Result<()> {
loop {
let Some(payload) = read_frame(&mut self.stdin).await? else {
return Ok(());
};
// Requests and notifications share a wire format; the absence of
// `id` distinguishes notifications, which must never receive a
// response.
let message: JsonValue = match serde_json::from_slice(&payload) {
Ok(value) => value,
Err(error) => {
// Parse error with null id per JSON-RPC 2.0 §4.2.
let response = JsonRpcResponse::<JsonValue> {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::Null,
result: None,
error: Some(JsonRpcError {
code: -32700,
message: format!("parse error: {error}"),
data: None,
}),
};
write_response(&mut self.stdout, &response).await?;
continue;
}
};
if message.get("id").is_none() {
// Notification: dispatch for side effects only (e.g. log),
// but send no reply.
continue;
}
let request: JsonRpcRequest<JsonValue> = match serde_json::from_value(message) {
Ok(request) => request,
Err(error) => {
let response = JsonRpcResponse::<JsonValue> {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::Null,
result: None,
error: Some(JsonRpcError {
code: -32600,
message: format!("invalid request: {error}"),
data: None,
}),
};
write_response(&mut self.stdout, &response).await?;
continue;
}
};
let response = self.dispatch(request);
write_response(&mut self.stdout, &response).await?;
}
}
fn dispatch(&self, request: JsonRpcRequest<JsonValue>) -> JsonRpcResponse<JsonValue> {
let id = request.id.clone();
match request.method.as_str() {
"initialize" => self.handle_initialize(id),
"tools/list" => self.handle_tools_list(id),
"tools/call" => self.handle_tools_call(id, request.params),
other => JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code: -32601,
message: format!("method not found: {other}"),
data: None,
}),
},
}
}
fn handle_initialize(&self, id: JsonRpcId) -> JsonRpcResponse<JsonValue> {
let result = McpInitializeResult {
protocol_version: MCP_SERVER_PROTOCOL_VERSION.to_string(),
capabilities: json!({ "tools": {} }),
server_info: McpInitializeServerInfo {
name: self.spec.server_name.clone(),
version: self.spec.server_version.clone(),
},
};
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: serde_json::to_value(result).ok(),
error: None,
}
}
fn handle_tools_list(&self, id: JsonRpcId) -> JsonRpcResponse<JsonValue> {
let result = McpListToolsResult {
tools: self.spec.tools.clone(),
next_cursor: None,
};
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: serde_json::to_value(result).ok(),
error: None,
}
}
fn handle_tools_call(
&self,
id: JsonRpcId,
params: Option<JsonValue>,
) -> JsonRpcResponse<JsonValue> {
let Some(params) = params else {
return invalid_params_response(id, "missing params for tools/call");
};
let call: McpToolCallParams = match serde_json::from_value(params) {
Ok(value) => value,
Err(error) => {
return invalid_params_response(id, &format!("invalid tools/call params: {error}"));
}
};
let arguments = call.arguments.unwrap_or_else(|| json!({}));
let tool_result = (self.spec.tool_handler)(&call.name, &arguments);
let (text, is_error) = match tool_result {
Ok(text) => (text, false),
Err(message) => (message, true),
};
let mut data = std::collections::BTreeMap::new();
data.insert("text".to_string(), JsonValue::String(text));
let call_result = McpToolCallResult {
content: vec![McpToolCallContent {
kind: "text".to_string(),
data,
}],
structured_content: None,
is_error: Some(is_error),
meta: None,
};
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: serde_json::to_value(call_result).ok(),
error: None,
}
}
}
fn invalid_params_response(id: JsonRpcId, message: &str) -> JsonRpcResponse<JsonValue> {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: message.to_string(),
data: None,
}),
}
}
/// Reads a single LSP-framed JSON-RPC payload from `reader`.
///
/// Returns `Ok(None)` on clean EOF before any header bytes have been read,
/// matching how [`crate::mcp_stdio::McpStdioProcess`] treats stream closure.
async fn read_frame(reader: &mut BufReader<Stdin>) -> io::Result<Option<Vec<u8>>> {
let mut content_length: Option<usize> = None;
let mut first_header = true;
loop {
let mut line = String::new();
let bytes_read = reader.read_line(&mut line).await?;
if bytes_read == 0 {
if first_header {
return Ok(None);
}
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"MCP stdio stream closed while reading headers",
));
}
first_header = false;
if line == "\r\n" || line == "\n" {
break;
}
let header = line.trim_end_matches(['\r', '\n']);
if let Some((name, value)) = header.split_once(':') {
if name.trim().eq_ignore_ascii_case("Content-Length") {
let parsed = value
.trim()
.parse::<usize>()
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
content_length = Some(parsed);
}
}
}
let content_length = content_length.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
})?;
let mut payload = vec![0_u8; content_length];
reader.read_exact(&mut payload).await?;
Ok(Some(payload))
}
async fn write_response(
stdout: &mut Stdout,
response: &JsonRpcResponse<JsonValue>,
) -> io::Result<()> {
let body = serde_json::to_vec(response)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
let header = format!("Content-Length: {}\r\n\r\n", body.len());
stdout.write_all(header.as_bytes()).await?;
stdout.write_all(&body).await?;
stdout.flush().await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dispatch_initialize_returns_server_info() {
let server = McpServer {
spec: McpServerSpec {
server_name: "test".to_string(),
server_version: "9.9.9".to_string(),
tools: Vec::new(),
tool_handler: Box::new(|_, _| Ok(String::new())),
},
stdin: BufReader::new(stdin()),
stdout: stdout(),
};
let request = JsonRpcRequest::<JsonValue> {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::Number(1),
method: "initialize".to_string(),
params: None,
};
let response = server.dispatch(request);
assert_eq!(response.id, JsonRpcId::Number(1));
assert!(response.error.is_none());
let result = response.result.expect("initialize result");
assert_eq!(result["protocolVersion"], MCP_SERVER_PROTOCOL_VERSION);
assert_eq!(result["serverInfo"]["name"], "test");
assert_eq!(result["serverInfo"]["version"], "9.9.9");
}
#[test]
fn dispatch_tools_list_returns_registered_tools() {
let tool = McpTool {
name: "echo".to_string(),
description: Some("Echo".to_string()),
input_schema: Some(json!({"type": "object"})),
annotations: None,
meta: None,
};
let server = McpServer {
spec: McpServerSpec {
server_name: "test".to_string(),
server_version: "0.0.0".to_string(),
tools: vec![tool.clone()],
tool_handler: Box::new(|_, _| Ok(String::new())),
},
stdin: BufReader::new(stdin()),
stdout: stdout(),
};
let request = JsonRpcRequest::<JsonValue> {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::Number(2),
method: "tools/list".to_string(),
params: None,
};
let response = server.dispatch(request);
assert!(response.error.is_none());
let result = response.result.expect("tools/list result");
assert_eq!(result["tools"][0]["name"], "echo");
}
#[test]
fn dispatch_tools_call_wraps_handler_output() {
let server = McpServer {
spec: McpServerSpec {
server_name: "test".to_string(),
server_version: "0.0.0".to_string(),
tools: Vec::new(),
tool_handler: Box::new(|name, args| Ok(format!("called {name} with {args}"))),
},
stdin: BufReader::new(stdin()),
stdout: stdout(),
};
let request = JsonRpcRequest::<JsonValue> {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::Number(3),
method: "tools/call".to_string(),
params: Some(json!({
"name": "echo",
"arguments": {"text": "hi"}
})),
};
let response = server.dispatch(request);
assert!(response.error.is_none());
let result = response.result.expect("tools/call result");
assert_eq!(result["isError"], false);
assert_eq!(result["content"][0]["type"], "text");
assert!(result["content"][0]["text"]
.as_str()
.unwrap()
.starts_with("called echo"));
}
#[test]
fn dispatch_tools_call_surfaces_handler_error() {
let server = McpServer {
spec: McpServerSpec {
server_name: "test".to_string(),
server_version: "0.0.0".to_string(),
tools: Vec::new(),
tool_handler: Box::new(|_, _| Err("boom".to_string())),
},
stdin: BufReader::new(stdin()),
stdout: stdout(),
};
let request = JsonRpcRequest::<JsonValue> {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::Number(4),
method: "tools/call".to_string(),
params: Some(json!({"name": "broken"})),
};
let response = server.dispatch(request);
let result = response.result.expect("tools/call result");
assert_eq!(result["isError"], true);
assert_eq!(result["content"][0]["text"], "boom");
}
#[test]
fn dispatch_unknown_method_returns_method_not_found() {
let server = McpServer {
spec: McpServerSpec {
server_name: "test".to_string(),
server_version: "0.0.0".to_string(),
tools: Vec::new(),
tool_handler: Box::new(|_, _| Ok(String::new())),
},
stdin: BufReader::new(stdin()),
stdout: stdout(),
};
let request = JsonRpcRequest::<JsonValue> {
jsonrpc: "2.0".to_string(),
id: JsonRpcId::Number(5),
method: "nonsense".to_string(),
params: None,
};
let response = server.dispatch(request);
let error = response.error.expect("error payload");
assert_eq!(error.code, -32601);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,920 @@
#![allow(
clippy::await_holding_lock,
clippy::doc_markdown,
clippy::match_same_arms,
clippy::must_use_candidate,
clippy::uninlined_format_args,
clippy::unnested_or_patterns
)]
//! Bridge between MCP tool surface (ListMcpResources, ReadMcpResource, McpAuth, MCP)
//! and the existing McpServerManager runtime.
//!
//! Provides a stateful client registry that tool handlers can use to
//! connect to MCP servers and invoke their capabilities.
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use crate::mcp::mcp_tool_name;
use crate::mcp_stdio::McpServerManager;
use serde::{Deserialize, Serialize};
/// Status of a managed MCP server connection.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum McpConnectionStatus {
Disconnected,
Connecting,
Connected,
AuthRequired,
Error,
}
impl std::fmt::Display for McpConnectionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disconnected => write!(f, "disconnected"),
Self::Connecting => write!(f, "connecting"),
Self::Connected => write!(f, "connected"),
Self::AuthRequired => write!(f, "auth_required"),
Self::Error => write!(f, "error"),
}
}
}
/// Metadata about an MCP resource.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpResourceInfo {
pub uri: String,
pub name: String,
pub description: Option<String>,
pub mime_type: Option<String>,
}
/// Metadata about an MCP tool exposed by a server.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpToolInfo {
pub name: String,
pub description: Option<String>,
pub input_schema: Option<serde_json::Value>,
}
/// Tracked state of an MCP server connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerState {
pub server_name: String,
pub status: McpConnectionStatus,
pub tools: Vec<McpToolInfo>,
pub resources: Vec<McpResourceInfo>,
pub server_info: Option<String>,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct McpToolRegistry {
inner: Arc<Mutex<HashMap<String, McpServerState>>>,
manager: Arc<OnceLock<Arc<Mutex<McpServerManager>>>>,
}
impl McpToolRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_manager(
&self,
manager: Arc<Mutex<McpServerManager>>,
) -> Result<(), Arc<Mutex<McpServerManager>>> {
self.manager.set(manager)
}
pub fn register_server(
&self,
server_name: &str,
status: McpConnectionStatus,
tools: Vec<McpToolInfo>,
resources: Vec<McpResourceInfo>,
server_info: Option<String>,
) {
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
inner.insert(
server_name.to_owned(),
McpServerState {
server_name: server_name.to_owned(),
status,
tools,
resources,
server_info,
error_message: None,
},
);
}
pub fn get_server(&self, server_name: &str) -> Option<McpServerState> {
let inner = self.inner.lock().expect("mcp registry lock poisoned");
inner.get(server_name).cloned()
}
pub fn list_servers(&self) -> Vec<McpServerState> {
let inner = self.inner.lock().expect("mcp registry lock poisoned");
inner.values().cloned().collect()
}
pub fn list_resources(&self, server_name: &str) -> Result<Vec<McpResourceInfo>, String> {
let inner = self.inner.lock().expect("mcp registry lock poisoned");
match inner.get(server_name) {
Some(state) => {
if state.status != McpConnectionStatus::Connected {
return Err(format!(
"server '{}' is not connected (status: {})",
server_name, state.status
));
}
Ok(state.resources.clone())
}
None => Err(format!("server '{}' not found", server_name)),
}
}
pub fn read_resource(&self, server_name: &str, uri: &str) -> Result<McpResourceInfo, String> {
let inner = self.inner.lock().expect("mcp registry lock poisoned");
let state = inner
.get(server_name)
.ok_or_else(|| format!("server '{}' not found", server_name))?;
if state.status != McpConnectionStatus::Connected {
return Err(format!(
"server '{}' is not connected (status: {})",
server_name, state.status
));
}
state
.resources
.iter()
.find(|r| r.uri == uri)
.cloned()
.ok_or_else(|| format!("resource '{}' not found on server '{}'", uri, server_name))
}
pub fn list_tools(&self, server_name: &str) -> Result<Vec<McpToolInfo>, String> {
let inner = self.inner.lock().expect("mcp registry lock poisoned");
match inner.get(server_name) {
Some(state) => {
if state.status != McpConnectionStatus::Connected {
return Err(format!(
"server '{}' is not connected (status: {})",
server_name, state.status
));
}
Ok(state.tools.clone())
}
None => Err(format!("server '{}' not found", server_name)),
}
}
fn spawn_tool_call(
manager: Arc<Mutex<McpServerManager>>,
qualified_tool_name: String,
arguments: Option<serde_json::Value>,
) -> Result<serde_json::Value, String> {
let join_handle = std::thread::Builder::new()
.name(format!("mcp-tool-call-{qualified_tool_name}"))
.spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|error| format!("failed to create MCP tool runtime: {error}"))?;
runtime.block_on(async move {
let response = {
let mut manager = manager
.lock()
.map_err(|_| "mcp server manager lock poisoned".to_string())?;
manager
.discover_tools()
.await
.map_err(|error| error.to_string())?;
let response = manager
.call_tool(&qualified_tool_name, arguments)
.await
.map_err(|error| error.to_string());
let shutdown = manager.shutdown().await.map_err(|error| error.to_string());
match (response, shutdown) {
(Ok(response), Ok(())) => Ok(response),
(Err(error), Ok(())) | (Err(error), Err(_)) => Err(error),
(Ok(_), Err(error)) => Err(error),
}
}?;
if let Some(error) = response.error {
return Err(format!(
"MCP server returned JSON-RPC error for tools/call: {} ({})",
error.message, error.code
));
}
let result = response.result.ok_or_else(|| {
"MCP server returned no result for tools/call".to_string()
})?;
serde_json::to_value(result)
.map_err(|error| format!("failed to serialize MCP tool result: {error}"))
})
})
.map_err(|error| format!("failed to spawn MCP tool call thread: {error}"))?;
join_handle.join().map_err(|panic_payload| {
if let Some(message) = panic_payload.downcast_ref::<&str>() {
format!("MCP tool call thread panicked: {message}")
} else if let Some(message) = panic_payload.downcast_ref::<String>() {
format!("MCP tool call thread panicked: {message}")
} else {
"MCP tool call thread panicked".to_string()
}
})?
}
pub fn call_tool(
&self,
server_name: &str,
tool_name: &str,
arguments: &serde_json::Value,
) -> Result<serde_json::Value, String> {
let inner = self.inner.lock().expect("mcp registry lock poisoned");
let state = inner
.get(server_name)
.ok_or_else(|| format!("server '{}' not found", server_name))?;
if state.status != McpConnectionStatus::Connected {
return Err(format!(
"server '{}' is not connected (status: {})",
server_name, state.status
));
}
if !state.tools.iter().any(|t| t.name == tool_name) {
return Err(format!(
"tool '{}' not found on server '{}'",
tool_name, server_name
));
}
drop(inner);
let manager = self
.manager
.get()
.cloned()
.ok_or_else(|| "MCP server manager is not configured".to_string())?;
Self::spawn_tool_call(
manager,
mcp_tool_name(server_name, tool_name),
(!arguments.is_null()).then(|| arguments.clone()),
)
}
/// Set auth status for a server.
pub fn set_auth_status(
&self,
server_name: &str,
status: McpConnectionStatus,
) -> Result<(), String> {
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
let state = inner
.get_mut(server_name)
.ok_or_else(|| format!("server '{}' not found", server_name))?;
state.status = status;
Ok(())
}
/// Disconnect / remove a server.
pub fn disconnect(&self, server_name: &str) -> Option<McpServerState> {
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
inner.remove(server_name)
}
/// Number of registered servers.
#[must_use]
pub fn len(&self) -> usize {
let inner = self.inner.lock().expect("mcp registry lock poisoned");
inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(all(test, unix))]
mod tests {
use std::collections::BTreeMap;
use std::fs;
use std::os::unix::fs::PermissionsExt;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use super::*;
use crate::config::{
ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
};
fn temp_dir() -> PathBuf {
static NEXT_TEMP_DIR_ID: AtomicU64 = AtomicU64::new(0);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after epoch")
.as_nanos();
let unique_id = NEXT_TEMP_DIR_ID.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!("runtime-mcp-tool-bridge-{nanos}-{unique_id}"))
}
fn cleanup_script(script_path: &Path) {
if let Some(root) = script_path.parent() {
let _ = fs::remove_dir_all(root);
}
}
fn write_bridge_mcp_server_script() -> PathBuf {
let root = temp_dir();
fs::create_dir_all(&root).expect("temp dir");
let script_path = root.join("bridge-mcp-server.py");
let script = [
"#!/usr/bin/env python3",
"import json, os, sys",
"LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
"LOG_PATH = os.environ.get('MCP_LOG_PATH')",
"",
"def log(method):",
" if LOG_PATH:",
" with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
" handle.write(f'{method}\\n')",
"",
"def read_message():",
" header = b''",
r" while not header.endswith(b'\r\n\r\n'):",
" chunk = sys.stdin.buffer.read(1)",
" if not chunk:",
" return None",
" header += chunk",
" length = 0",
r" for line in header.decode().split('\r\n'):",
r" if line.lower().startswith('content-length:'):",
r" length = int(line.split(':', 1)[1].strip())",
" payload = sys.stdin.buffer.read(length)",
" return json.loads(payload.decode())",
"",
"def send_message(message):",
" payload = json.dumps(message).encode()",
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
" sys.stdout.buffer.flush()",
"",
"while True:",
" request = read_message()",
" if request is None:",
" break",
" method = request['method']",
" log(method)",
" if method == 'initialize':",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'result': {",
" 'protocolVersion': request['params']['protocolVersion'],",
" 'capabilities': {'tools': {}},",
" 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
" }",
" })",
" elif method == 'tools/list':",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'result': {",
" 'tools': [",
" {",
" 'name': 'echo',",
" 'description': f'Echo tool for {LABEL}',",
" 'inputSchema': {",
" 'type': 'object',",
" 'properties': {'text': {'type': 'string'}},",
" 'required': ['text']",
" }",
" }",
" ]",
" }",
" })",
" elif method == 'tools/call':",
" args = request['params'].get('arguments') or {}",
" text = args.get('text', '')",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'result': {",
" 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
" 'structuredContent': {'server': LABEL, 'echoed': text},",
" 'isError': False",
" }",
" })",
" else:",
" send_message({",
" 'jsonrpc': '2.0',",
" 'id': request['id'],",
" 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
" })",
"",
]
.join("\n");
fs::write(&script_path, script).expect("write script");
let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
permissions.set_mode(0o755);
fs::set_permissions(&script_path, permissions).expect("chmod");
script_path
}
fn manager_server_config(
script_path: &Path,
server_name: &str,
log_path: &Path,
) -> ScopedMcpServerConfig {
ScopedMcpServerConfig {
scope: ConfigSource::Local,
config: McpServerConfig::Stdio(McpStdioServerConfig {
command: "python3".to_string(),
args: vec![script_path.to_string_lossy().into_owned()],
env: BTreeMap::from([
("MCP_SERVER_LABEL".to_string(), server_name.to_string()),
(
"MCP_LOG_PATH".to_string(),
log_path.to_string_lossy().into_owned(),
),
]),
tool_call_timeout_ms: Some(1_000),
}),
}
}
#[test]
fn registers_and_retrieves_server() {
let registry = McpToolRegistry::new();
registry.register_server(
"test-server",
McpConnectionStatus::Connected,
vec![McpToolInfo {
name: "greet".into(),
description: Some("Greet someone".into()),
input_schema: None,
}],
vec![McpResourceInfo {
uri: "res://data".into(),
name: "Data".into(),
description: None,
mime_type: Some("application/json".into()),
}],
Some("TestServer v1.0".into()),
);
let server = registry.get_server("test-server").expect("should exist");
assert_eq!(server.status, McpConnectionStatus::Connected);
assert_eq!(server.tools.len(), 1);
assert_eq!(server.resources.len(), 1);
}
#[test]
fn lists_resources_from_connected_server() {
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::Connected,
vec![],
vec![McpResourceInfo {
uri: "res://alpha".into(),
name: "Alpha".into(),
description: None,
mime_type: None,
}],
None,
);
let resources = registry.list_resources("srv").expect("should succeed");
assert_eq!(resources.len(), 1);
assert_eq!(resources[0].uri, "res://alpha");
}
#[test]
fn rejects_resource_listing_for_disconnected_server() {
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::Disconnected,
vec![],
vec![],
None,
);
assert!(registry.list_resources("srv").is_err());
}
#[test]
fn reads_specific_resource() {
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::Connected,
vec![],
vec![McpResourceInfo {
uri: "res://data".into(),
name: "Data".into(),
description: Some("Test data".into()),
mime_type: Some("text/plain".into()),
}],
None,
);
let resource = registry
.read_resource("srv", "res://data")
.expect("should find");
assert_eq!(resource.name, "Data");
assert!(registry.read_resource("srv", "res://missing").is_err());
}
#[test]
fn given_connected_server_without_manager_when_calling_tool_then_it_errors() {
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::Connected,
vec![McpToolInfo {
name: "greet".into(),
description: None,
input_schema: None,
}],
vec![],
None,
);
let error = registry
.call_tool("srv", "greet", &serde_json::json!({"name": "world"}))
.expect_err("should require a configured manager");
assert!(error.contains("MCP server manager is not configured"));
// Unknown tool should fail
assert!(registry
.call_tool("srv", "missing", &serde_json::json!({}))
.is_err());
}
#[test]
fn given_connected_server_with_manager_when_calling_tool_then_it_returns_live_result() {
let script_path = write_bridge_mcp_server_script();
let root = script_path.parent().expect("script parent");
let log_path = root.join("bridge.log");
let servers = BTreeMap::from([(
"alpha".to_string(),
manager_server_config(&script_path, "alpha", &log_path),
)]);
let manager = Arc::new(Mutex::new(McpServerManager::from_servers(&servers)));
let registry = McpToolRegistry::new();
registry.register_server(
"alpha",
McpConnectionStatus::Connected,
vec![McpToolInfo {
name: "echo".into(),
description: Some("Echo tool for alpha".into()),
input_schema: Some(serde_json::json!({
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"]
})),
}],
vec![],
Some("bridge test server".into()),
);
registry
.set_manager(Arc::clone(&manager))
.expect("manager should only be set once");
let result = registry
.call_tool("alpha", "echo", &serde_json::json!({"text": "hello"}))
.expect("should return live MCP result");
assert_eq!(
result["structuredContent"]["server"],
serde_json::json!("alpha")
);
assert_eq!(
result["structuredContent"]["echoed"],
serde_json::json!("hello")
);
assert_eq!(
result["content"][0]["text"],
serde_json::json!("alpha:hello")
);
let log = fs::read_to_string(&log_path).expect("read log");
assert_eq!(
log.lines().collect::<Vec<_>>(),
vec!["initialize", "tools/list", "tools/call"]
);
cleanup_script(&script_path);
}
#[test]
fn rejects_tool_call_on_disconnected_server() {
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::AuthRequired,
vec![McpToolInfo {
name: "greet".into(),
description: None,
input_schema: None,
}],
vec![],
None,
);
assert!(registry
.call_tool("srv", "greet", &serde_json::json!({}))
.is_err());
}
#[test]
fn sets_auth_and_disconnects() {
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::AuthRequired,
vec![],
vec![],
None,
);
registry
.set_auth_status("srv", McpConnectionStatus::Connected)
.expect("should succeed");
let state = registry.get_server("srv").unwrap();
assert_eq!(state.status, McpConnectionStatus::Connected);
let removed = registry.disconnect("srv");
assert!(removed.is_some());
assert!(registry.is_empty());
}
#[test]
fn rejects_operations_on_missing_server() {
let registry = McpToolRegistry::new();
assert!(registry.list_resources("missing").is_err());
assert!(registry.read_resource("missing", "uri").is_err());
assert!(registry.list_tools("missing").is_err());
assert!(registry
.call_tool("missing", "tool", &serde_json::json!({}))
.is_err());
assert!(registry
.set_auth_status("missing", McpConnectionStatus::Connected)
.is_err());
}
#[test]
fn mcp_connection_status_display_all_variants() {
// given
let cases = [
(McpConnectionStatus::Disconnected, "disconnected"),
(McpConnectionStatus::Connecting, "connecting"),
(McpConnectionStatus::Connected, "connected"),
(McpConnectionStatus::AuthRequired, "auth_required"),
(McpConnectionStatus::Error, "error"),
];
// when
let rendered: Vec<_> = cases
.into_iter()
.map(|(status, expected)| (status.to_string(), expected))
.collect();
// then
assert_eq!(
rendered,
vec![
("disconnected".to_string(), "disconnected"),
("connecting".to_string(), "connecting"),
("connected".to_string(), "connected"),
("auth_required".to_string(), "auth_required"),
("error".to_string(), "error"),
]
);
}
#[test]
fn list_servers_returns_all_registered() {
// given
let registry = McpToolRegistry::new();
registry.register_server(
"alpha",
McpConnectionStatus::Connected,
vec![],
vec![],
None,
);
registry.register_server(
"beta",
McpConnectionStatus::Connecting,
vec![],
vec![],
None,
);
// when
let servers = registry.list_servers();
// then
assert_eq!(servers.len(), 2);
assert!(servers.iter().any(|server| server.server_name == "alpha"));
assert!(servers.iter().any(|server| server.server_name == "beta"));
}
#[test]
fn list_tools_from_connected_server() {
// given
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::Connected,
vec![McpToolInfo {
name: "inspect".into(),
description: Some("Inspect data".into()),
input_schema: Some(serde_json::json!({"type": "object"})),
}],
vec![],
None,
);
// when
let tools = registry.list_tools("srv").expect("tools should list");
// then
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "inspect");
}
#[test]
fn list_tools_rejects_disconnected_server() {
// given
let registry = McpToolRegistry::new();
registry.register_server(
"srv",
McpConnectionStatus::AuthRequired,
vec![],
vec![],
None,
);
// when
let result = registry.list_tools("srv");
// then
let error = result.expect_err("non-connected server should fail");
assert!(error.contains("not connected"));
assert!(error.contains("auth_required"));
}
#[test]
fn list_tools_rejects_missing_server() {
// given
let registry = McpToolRegistry::new();
// when
let result = registry.list_tools("missing");
// then
assert_eq!(
result.expect_err("missing server should fail"),
"server 'missing' not found"
);
}
#[test]
fn get_server_returns_none_for_missing() {
// given
let registry = McpToolRegistry::new();
// when
let server = registry.get_server("missing");
// then
assert!(server.is_none());
}
#[test]
fn call_tool_payload_structure() {
let script_path = write_bridge_mcp_server_script();
let root = script_path.parent().expect("script parent");
let log_path = root.join("payload.log");
let servers = BTreeMap::from([(
"srv".to_string(),
manager_server_config(&script_path, "srv", &log_path),
)]);
let registry = McpToolRegistry::new();
let arguments = serde_json::json!({"text": "world"});
registry.register_server(
"srv",
McpConnectionStatus::Connected,
vec![McpToolInfo {
name: "echo".into(),
description: Some("Echo tool for srv".into()),
input_schema: Some(serde_json::json!({
"type": "object",
"properties": {"text": {"type": "string"}},
"required": ["text"]
})),
}],
vec![],
None,
);
registry
.set_manager(Arc::new(Mutex::new(McpServerManager::from_servers(
&servers,
))))
.expect("manager should only be set once");
let result = registry
.call_tool("srv", "echo", &arguments)
.expect("tool should return live payload");
assert_eq!(result["structuredContent"]["server"], "srv");
assert_eq!(result["structuredContent"]["echoed"], "world");
assert_eq!(result["content"][0]["text"], "srv:world");
cleanup_script(&script_path);
}
#[test]
fn upsert_overwrites_existing_server() {
// given
let registry = McpToolRegistry::new();
registry.register_server("srv", McpConnectionStatus::Connecting, vec![], vec![], None);
// when
registry.register_server(
"srv",
McpConnectionStatus::Connected,
vec![McpToolInfo {
name: "inspect".into(),
description: None,
input_schema: None,
}],
vec![],
Some("Inspector".into()),
);
let state = registry.get_server("srv").expect("server should exist");
// then
assert_eq!(state.status, McpConnectionStatus::Connected);
assert_eq!(state.tools.len(), 1);
assert_eq!(state.server_info.as_deref(), Some("Inspector"));
}
#[test]
fn disconnect_missing_returns_none() {
// given
let registry = McpToolRegistry::new();
// when
let removed = registry.disconnect("missing");
// then
assert!(removed.is_none());
}
#[test]
fn len_and_is_empty_transitions() {
// given
let registry = McpToolRegistry::new();
// when
registry.register_server(
"alpha",
McpConnectionStatus::Connected,
vec![],
vec![],
None,
);
registry.register_server("beta", McpConnectionStatus::Connected, vec![], vec![], None);
let after_create = registry.len();
registry.disconnect("alpha");
let after_first_remove = registry.len();
registry.disconnect("beta");
// then
assert_eq!(after_create, 2);
assert_eq!(after_first_remove, 1);
assert_eq!(registry.len(), 0);
assert!(registry.is_empty());
}
}

View File

@ -9,6 +9,7 @@ use sha2::{Digest, Sha256};
use crate::config::OAuthConfig;
/// Persisted OAuth access token bundle used by the CLI.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OAuthTokenSet {
pub access_token: String,
@ -17,6 +18,7 @@ pub struct OAuthTokenSet {
pub scopes: Vec<String>,
}
/// PKCE verifier/challenge pair generated for an OAuth authorization flow.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PkceCodePair {
pub verifier: String,
@ -24,6 +26,7 @@ pub struct PkceCodePair {
pub challenge_method: PkceChallengeMethod,
}
/// Challenge algorithms supported by the local PKCE helpers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PkceChallengeMethod {
S256,
@ -38,6 +41,7 @@ impl PkceChallengeMethod {
}
}
/// Parameters needed to build an authorization URL for browser-based login.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthAuthorizationRequest {
pub authorize_url: String,
@ -50,6 +54,7 @@ pub struct OAuthAuthorizationRequest {
pub extra_params: BTreeMap<String, String>,
}
/// Request body for exchanging an OAuth authorization code for tokens.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthTokenExchangeRequest {
pub grant_type: &'static str,
@ -60,6 +65,7 @@ pub struct OAuthTokenExchangeRequest {
pub state: String,
}
/// Request body for refreshing an existing OAuth token set.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthRefreshRequest {
pub grant_type: &'static str,
@ -68,6 +74,7 @@ pub struct OAuthRefreshRequest {
pub scopes: Vec<String>,
}
/// Parsed query parameters returned to the local OAuth callback endpoint.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OAuthCallbackParams {
pub code: Option<String>,
@ -327,15 +334,16 @@ fn credentials_home_dir() -> io::Result<PathBuf> {
if let Some(path) = std::env::var_os("CLAW_CONFIG_HOME") {
return Ok(PathBuf::from(path));
}
if let Some(path) = std::env::var_os("HOME") {
return Ok(PathBuf::from(path).join(".claw"));
}
if cfg!(target_os = "windows") {
if let Some(path) = std::env::var_os("USERPROFILE") {
return Ok(PathBuf::from(path).join(".claw"));
}
}
Err(io::Error::new(io::ErrorKind::NotFound, "HOME or USERPROFILE is not set"))
let home = std::env::var_os("HOME")
.or_else(|| std::env::var_os("USERPROFILE"))
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"HOME is not set (on Windows, set USERPROFILE or HOME, \
or use CLAW_CONFIG_HOME to point directly at the config directory)",
)
})?;
Ok(PathBuf::from(home).join(".claw"))
}
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
@ -448,7 +456,7 @@ fn decode_hex(byte: u8) -> Result<u8, String> {
b'0'..=b'9' => Ok(byte - b'0'),
b'a'..=b'f' => Ok(byte - b'a' + 10),
b'A'..=b'F' => Ok(byte - b'A' + 10),
_ => Err(format!("invalid percent-encoding byte: {byte}")),
_ => Err(format!("invalid percent byte: {byte}")),
}
}

View File

@ -0,0 +1,551 @@
#![allow(
clippy::match_wildcard_for_single_variants,
clippy::must_use_candidate,
clippy::uninlined_format_args
)]
//! Permission enforcement layer that gates tool execution based on the
//! active `PermissionPolicy`.
use crate::permissions::{PermissionMode, PermissionOutcome, PermissionPolicy};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "outcome")]
pub enum EnforcementResult {
/// Tool execution is allowed.
Allowed,
/// Tool execution was denied due to insufficient permissions.
Denied {
tool: String,
active_mode: String,
required_mode: String,
reason: String,
},
}
#[derive(Debug, Clone, PartialEq)]
pub struct PermissionEnforcer {
policy: PermissionPolicy,
}
impl PermissionEnforcer {
#[must_use]
pub fn new(policy: PermissionPolicy) -> Self {
Self { policy }
}
/// Check whether a tool can be executed under the current permission policy.
/// Auto-denies when prompting is required but no prompter is provided.
pub fn check(&self, tool_name: &str, input: &str) -> EnforcementResult {
// When the active mode is Prompt, defer to the caller's interactive
// prompt flow rather than hard-denying (the enforcer has no prompter).
if self.policy.active_mode() == PermissionMode::Prompt {
return EnforcementResult::Allowed;
}
let outcome = self.policy.authorize(tool_name, input, None);
match outcome {
PermissionOutcome::Allow => EnforcementResult::Allowed,
PermissionOutcome::Deny { reason } => {
let active_mode = self.policy.active_mode();
let required_mode = self.policy.required_mode_for(tool_name);
EnforcementResult::Denied {
tool: tool_name.to_owned(),
active_mode: active_mode.as_str().to_owned(),
required_mode: required_mode.as_str().to_owned(),
reason,
}
}
}
}
#[must_use]
pub fn is_allowed(&self, tool_name: &str, input: &str) -> bool {
matches!(self.check(tool_name, input), EnforcementResult::Allowed)
}
#[must_use]
pub fn active_mode(&self) -> PermissionMode {
self.policy.active_mode()
}
/// Classify a file operation against workspace boundaries.
pub fn check_file_write(&self, path: &str, workspace_root: &str) -> EnforcementResult {
let mode = self.policy.active_mode();
match mode {
PermissionMode::ReadOnly => EnforcementResult::Denied {
tool: "write_file".to_owned(),
active_mode: mode.as_str().to_owned(),
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
reason: format!("file writes are not allowed in '{}' mode", mode.as_str()),
},
PermissionMode::WorkspaceWrite => {
if is_within_workspace(path, workspace_root) {
EnforcementResult::Allowed
} else {
EnforcementResult::Denied {
tool: "write_file".to_owned(),
active_mode: mode.as_str().to_owned(),
required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(),
reason: format!(
"path '{}' is outside workspace root '{}'",
path, workspace_root
),
}
}
}
// Allow and DangerFullAccess permit all writes
PermissionMode::Allow | PermissionMode::DangerFullAccess => EnforcementResult::Allowed,
PermissionMode::Prompt => EnforcementResult::Denied {
tool: "write_file".to_owned(),
active_mode: mode.as_str().to_owned(),
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
reason: "file write requires confirmation in prompt mode".to_owned(),
},
}
}
/// Check if a bash command should be allowed based on current mode.
pub fn check_bash(&self, command: &str) -> EnforcementResult {
let mode = self.policy.active_mode();
match mode {
PermissionMode::ReadOnly => {
if is_read_only_command(command) {
EnforcementResult::Allowed
} else {
EnforcementResult::Denied {
tool: "bash".to_owned(),
active_mode: mode.as_str().to_owned(),
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
reason: format!(
"command may modify state; not allowed in '{}' mode",
mode.as_str()
),
}
}
}
PermissionMode::Prompt => EnforcementResult::Denied {
tool: "bash".to_owned(),
active_mode: mode.as_str().to_owned(),
required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(),
reason: "bash requires confirmation in prompt mode".to_owned(),
},
// WorkspaceWrite, Allow, DangerFullAccess: permit bash
_ => EnforcementResult::Allowed,
}
}
}
/// Simple workspace boundary check via string prefix.
fn is_within_workspace(path: &str, workspace_root: &str) -> bool {
let normalized = if path.starts_with('/') {
path.to_owned()
} else {
format!("{workspace_root}/{path}")
};
let root = if workspace_root.ends_with('/') {
workspace_root.to_owned()
} else {
format!("{workspace_root}/")
};
normalized.starts_with(&root) || normalized == workspace_root.trim_end_matches('/')
}
/// Conservative heuristic: is this bash command read-only?
fn is_read_only_command(command: &str) -> bool {
let first_token = command
.split_whitespace()
.next()
.unwrap_or("")
.rsplit('/')
.next()
.unwrap_or("");
matches!(
first_token,
"cat"
| "head"
| "tail"
| "less"
| "more"
| "wc"
| "ls"
| "find"
| "grep"
| "rg"
| "awk"
| "sed"
| "echo"
| "printf"
| "which"
| "where"
| "whoami"
| "pwd"
| "env"
| "printenv"
| "date"
| "cal"
| "df"
| "du"
| "free"
| "uptime"
| "uname"
| "file"
| "stat"
| "diff"
| "sort"
| "uniq"
| "tr"
| "cut"
| "paste"
| "tee"
| "xargs"
| "test"
| "true"
| "false"
| "type"
| "readlink"
| "realpath"
| "basename"
| "dirname"
| "sha256sum"
| "md5sum"
| "b3sum"
| "xxd"
| "hexdump"
| "od"
| "strings"
| "tree"
| "jq"
| "yq"
| "python3"
| "python"
| "node"
| "ruby"
| "cargo"
| "rustc"
| "git"
| "gh"
) && !command.contains("-i ")
&& !command.contains("--in-place")
&& !command.contains(" > ")
&& !command.contains(" >> ")
}
#[cfg(test)]
mod tests {
use super::*;
fn make_enforcer(mode: PermissionMode) -> PermissionEnforcer {
let policy = PermissionPolicy::new(mode);
PermissionEnforcer::new(policy)
}
#[test]
fn allow_mode_permits_everything() {
let enforcer = make_enforcer(PermissionMode::Allow);
assert!(enforcer.is_allowed("bash", ""));
assert!(enforcer.is_allowed("write_file", ""));
assert!(enforcer.is_allowed("edit_file", ""));
assert_eq!(
enforcer.check_file_write("/outside/path", "/workspace"),
EnforcementResult::Allowed
);
assert_eq!(enforcer.check_bash("rm -rf /"), EnforcementResult::Allowed);
}
#[test]
fn read_only_denies_writes() {
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
.with_tool_requirement("read_file", PermissionMode::ReadOnly)
.with_tool_requirement("grep_search", PermissionMode::ReadOnly)
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
let enforcer = PermissionEnforcer::new(policy);
assert!(enforcer.is_allowed("read_file", ""));
assert!(enforcer.is_allowed("grep_search", ""));
// write_file requires WorkspaceWrite but we're in ReadOnly
let result = enforcer.check("write_file", "");
assert!(matches!(result, EnforcementResult::Denied { .. }));
let result = enforcer.check_file_write("/workspace/file.rs", "/workspace");
assert!(matches!(result, EnforcementResult::Denied { .. }));
}
#[test]
fn read_only_allows_read_commands() {
let enforcer = make_enforcer(PermissionMode::ReadOnly);
assert_eq!(
enforcer.check_bash("cat src/main.rs"),
EnforcementResult::Allowed
);
assert_eq!(
enforcer.check_bash("grep -r 'pattern' ."),
EnforcementResult::Allowed
);
assert_eq!(enforcer.check_bash("ls -la"), EnforcementResult::Allowed);
}
#[test]
fn read_only_denies_write_commands() {
let enforcer = make_enforcer(PermissionMode::ReadOnly);
let result = enforcer.check_bash("rm file.txt");
assert!(matches!(result, EnforcementResult::Denied { .. }));
}
#[test]
fn workspace_write_allows_within_workspace() {
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace");
assert_eq!(result, EnforcementResult::Allowed);
}
#[test]
fn workspace_write_denies_outside_workspace() {
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
let result = enforcer.check_file_write("/etc/passwd", "/workspace");
assert!(matches!(result, EnforcementResult::Denied { .. }));
}
#[test]
fn prompt_mode_denies_without_prompter() {
let enforcer = make_enforcer(PermissionMode::Prompt);
let result = enforcer.check_bash("echo test");
assert!(matches!(result, EnforcementResult::Denied { .. }));
let result = enforcer.check_file_write("/workspace/file.rs", "/workspace");
assert!(matches!(result, EnforcementResult::Denied { .. }));
}
#[test]
fn workspace_boundary_check() {
assert!(is_within_workspace("/workspace/src/main.rs", "/workspace"));
assert!(is_within_workspace("/workspace", "/workspace"));
assert!(!is_within_workspace("/etc/passwd", "/workspace"));
assert!(!is_within_workspace("/workspacex/hack", "/workspace"));
}
#[test]
fn read_only_command_heuristic() {
assert!(is_read_only_command("cat file.txt"));
assert!(is_read_only_command("grep pattern file"));
assert!(is_read_only_command("git log --oneline"));
assert!(!is_read_only_command("rm file.txt"));
assert!(!is_read_only_command("echo test > file.txt"));
assert!(!is_read_only_command("sed -i 's/a/b/' file"));
}
#[test]
fn active_mode_returns_policy_mode() {
// given
let modes = [
PermissionMode::ReadOnly,
PermissionMode::WorkspaceWrite,
PermissionMode::DangerFullAccess,
PermissionMode::Prompt,
PermissionMode::Allow,
];
// when
let active_modes: Vec<_> = modes
.into_iter()
.map(|mode| make_enforcer(mode).active_mode())
.collect();
// then
assert_eq!(active_modes, modes);
}
#[test]
fn danger_full_access_permits_file_writes_and_bash() {
// given
let enforcer = make_enforcer(PermissionMode::DangerFullAccess);
// when
let file_result = enforcer.check_file_write("/outside/workspace/file.txt", "/workspace");
let bash_result = enforcer.check_bash("rm -rf /tmp/scratch");
// then
assert_eq!(file_result, EnforcementResult::Allowed);
assert_eq!(bash_result, EnforcementResult::Allowed);
}
#[test]
fn check_denied_payload_contains_tool_and_modes() {
// given
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
let enforcer = PermissionEnforcer::new(policy);
// when
let result = enforcer.check("write_file", "{}");
// then
match result {
EnforcementResult::Denied {
tool,
active_mode,
required_mode,
reason,
} => {
assert_eq!(tool, "write_file");
assert_eq!(active_mode, "read-only");
assert_eq!(required_mode, "workspace-write");
assert!(reason.contains("requires workspace-write permission"));
}
other => panic!("expected denied result, got {other:?}"),
}
}
#[test]
fn workspace_write_relative_path_resolved() {
// given
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
// when
let result = enforcer.check_file_write("src/main.rs", "/workspace");
// then
assert_eq!(result, EnforcementResult::Allowed);
}
#[test]
fn workspace_root_with_trailing_slash() {
// given
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
// when
let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace/");
// then
assert_eq!(result, EnforcementResult::Allowed);
}
#[test]
fn workspace_root_equality() {
// given
let root = "/workspace/";
// when
let equal_to_root = is_within_workspace("/workspace", root);
// then
assert!(equal_to_root);
}
#[test]
fn bash_heuristic_full_path_prefix() {
// given
let full_path_command = "/usr/bin/cat Cargo.toml";
let git_path_command = "/usr/local/bin/git status";
// when
let cat_result = is_read_only_command(full_path_command);
let git_result = is_read_only_command(git_path_command);
// then
assert!(cat_result);
assert!(git_result);
}
#[test]
fn bash_heuristic_redirects_block_read_only_commands() {
// given
let overwrite = "cat Cargo.toml > out.txt";
let append = "echo test >> out.txt";
// when
let overwrite_result = is_read_only_command(overwrite);
let append_result = is_read_only_command(append);
// then
assert!(!overwrite_result);
assert!(!append_result);
}
#[test]
fn bash_heuristic_in_place_flag_blocks() {
// given
let interactive_python = "python -i script.py";
let in_place_sed = "sed --in-place 's/a/b/' file.txt";
// when
let interactive_result = is_read_only_command(interactive_python);
let in_place_result = is_read_only_command(in_place_sed);
// then
assert!(!interactive_result);
assert!(!in_place_result);
}
#[test]
fn bash_heuristic_empty_command() {
// given
let empty = "";
let whitespace = " ";
// when
let empty_result = is_read_only_command(empty);
let whitespace_result = is_read_only_command(whitespace);
// then
assert!(!empty_result);
assert!(!whitespace_result);
}
#[test]
fn prompt_mode_check_bash_denied_payload_fields() {
// given
let enforcer = make_enforcer(PermissionMode::Prompt);
// when
let result = enforcer.check_bash("git status");
// then
match result {
EnforcementResult::Denied {
tool,
active_mode,
required_mode,
reason,
} => {
assert_eq!(tool, "bash");
assert_eq!(active_mode, "prompt");
assert_eq!(required_mode, "danger-full-access");
assert_eq!(reason, "bash requires confirmation in prompt mode");
}
other => panic!("expected denied result, got {other:?}"),
}
}
#[test]
fn read_only_check_file_write_denied_payload() {
// given
let enforcer = make_enforcer(PermissionMode::ReadOnly);
// when
let result = enforcer.check_file_write("/workspace/file.txt", "/workspace");
// then
match result {
EnforcementResult::Denied {
tool,
active_mode,
required_mode,
reason,
} => {
assert_eq!(tool, "write_file");
assert_eq!(active_mode, "read-only");
assert_eq!(required_mode, "workspace-write");
assert!(reason.contains("file writes are not allowed"));
}
other => panic!("expected denied result, got {other:?}"),
}
}
}

View File

@ -1,5 +1,10 @@
use std::collections::BTreeMap;
use serde_json::Value;
use crate::config::RuntimePermissionRuleConfig;
/// Permission level assigned to a tool invocation or runtime session.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum PermissionMode {
ReadOnly,
@ -22,34 +27,81 @@ impl PermissionMode {
}
}
/// Hook-provided override applied before standard permission evaluation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PermissionOverride {
Allow,
Deny,
Ask,
}
/// Additional permission context supplied by hooks or higher-level orchestration.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct PermissionContext {
override_decision: Option<PermissionOverride>,
override_reason: Option<String>,
}
impl PermissionContext {
#[must_use]
pub fn new(
override_decision: Option<PermissionOverride>,
override_reason: Option<String>,
) -> Self {
Self {
override_decision,
override_reason,
}
}
#[must_use]
pub fn override_decision(&self) -> Option<PermissionOverride> {
self.override_decision
}
#[must_use]
pub fn override_reason(&self) -> Option<&str> {
self.override_reason.as_deref()
}
}
/// Full authorization request presented to a permission prompt.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PermissionRequest {
pub tool_name: String,
pub input: String,
pub current_mode: PermissionMode,
pub required_mode: PermissionMode,
pub reason: Option<String>,
}
/// User-facing decision returned by a [`PermissionPrompter`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PermissionPromptDecision {
Allow,
Deny { reason: String },
}
/// Prompting interface used when policy requires interactive approval.
pub trait PermissionPrompter {
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision;
}
/// Final authorization result after evaluating static rules and prompts.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PermissionOutcome {
Allow,
Deny { reason: String },
}
/// Evaluates permission mode requirements plus allow/deny/ask rules.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PermissionPolicy {
active_mode: PermissionMode,
tool_requirements: BTreeMap<String, PermissionMode>,
allow_rules: Vec<PermissionRule>,
deny_rules: Vec<PermissionRule>,
ask_rules: Vec<PermissionRule>,
}
impl PermissionPolicy {
@ -58,6 +110,9 @@ impl PermissionPolicy {
Self {
active_mode,
tool_requirements: BTreeMap::new(),
allow_rules: Vec::new(),
deny_rules: Vec::new(),
ask_rules: Vec::new(),
}
}
@ -72,6 +127,26 @@ impl PermissionPolicy {
self
}
#[must_use]
pub fn with_permission_rules(mut self, config: &RuntimePermissionRuleConfig) -> Self {
self.allow_rules = config
.allow()
.iter()
.map(|rule| PermissionRule::parse(rule))
.collect();
self.deny_rules = config
.deny()
.iter()
.map(|rule| PermissionRule::parse(rule))
.collect();
self.ask_rules = config
.ask()
.iter()
.map(|rule| PermissionRule::parse(rule))
.collect();
self
}
#[must_use]
pub fn active_mode(&self) -> PermissionMode {
self.active_mode
@ -90,38 +165,121 @@ impl PermissionPolicy {
&self,
tool_name: &str,
input: &str,
mut prompter: Option<&mut dyn PermissionPrompter>,
prompter: Option<&mut dyn PermissionPrompter>,
) -> PermissionOutcome {
let current_mode = self.active_mode();
let required_mode = self.required_mode_for(tool_name);
if current_mode == PermissionMode::Allow || current_mode >= required_mode {
return PermissionOutcome::Allow;
self.authorize_with_context(tool_name, input, &PermissionContext::default(), prompter)
}
let request = PermissionRequest {
tool_name: tool_name.to_string(),
input: input.to_string(),
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn authorize_with_context(
&self,
tool_name: &str,
input: &str,
context: &PermissionContext,
prompter: Option<&mut dyn PermissionPrompter>,
) -> PermissionOutcome {
if let Some(rule) = Self::find_matching_rule(&self.deny_rules, tool_name, input) {
return PermissionOutcome::Deny {
reason: format!(
"Permission to use {tool_name} has been denied by rule '{}'",
rule.raw
),
};
}
let current_mode = self.active_mode();
let required_mode = self.required_mode_for(tool_name);
let ask_rule = Self::find_matching_rule(&self.ask_rules, tool_name, input);
let allow_rule = Self::find_matching_rule(&self.allow_rules, tool_name, input);
match context.override_decision() {
Some(PermissionOverride::Deny) => {
return PermissionOutcome::Deny {
reason: context.override_reason().map_or_else(
|| format!("tool '{tool_name}' denied by hook"),
ToOwned::to_owned,
),
};
}
Some(PermissionOverride::Ask) => {
let reason = context.override_reason().map_or_else(
|| format!("tool '{tool_name}' requires approval due to hook guidance"),
ToOwned::to_owned,
);
return Self::prompt_or_deny(
tool_name,
input,
current_mode,
required_mode,
};
Some(reason),
prompter,
);
}
Some(PermissionOverride::Allow) => {
if let Some(rule) = ask_rule {
let reason = format!(
"tool '{tool_name}' requires approval due to ask rule '{}'",
rule.raw
);
return Self::prompt_or_deny(
tool_name,
input,
current_mode,
required_mode,
Some(reason),
prompter,
);
}
if allow_rule.is_some()
|| current_mode == PermissionMode::Allow
|| current_mode >= required_mode
{
return PermissionOutcome::Allow;
}
}
None => {}
}
if let Some(rule) = ask_rule {
let reason = format!(
"tool '{tool_name}' requires approval due to ask rule '{}'",
rule.raw
);
return Self::prompt_or_deny(
tool_name,
input,
current_mode,
required_mode,
Some(reason),
prompter,
);
}
if allow_rule.is_some()
|| current_mode == PermissionMode::Allow
|| current_mode >= required_mode
{
return PermissionOutcome::Allow;
}
if current_mode == PermissionMode::Prompt
|| (current_mode == PermissionMode::WorkspaceWrite
&& required_mode == PermissionMode::DangerFullAccess)
{
return match prompter.as_mut() {
Some(prompter) => match prompter.decide(&request) {
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
},
None => PermissionOutcome::Deny {
reason: format!(
let reason = Some(format!(
"tool '{tool_name}' requires approval to escalate from {} to {}",
current_mode.as_str(),
required_mode.as_str()
),
},
};
));
return Self::prompt_or_deny(
tool_name,
input,
current_mode,
required_mode,
reason,
prompter,
);
}
PermissionOutcome::Deny {
@ -132,14 +290,191 @@ impl PermissionPolicy {
),
}
}
fn prompt_or_deny(
tool_name: &str,
input: &str,
current_mode: PermissionMode,
required_mode: PermissionMode,
reason: Option<String>,
mut prompter: Option<&mut dyn PermissionPrompter>,
) -> PermissionOutcome {
let request = PermissionRequest {
tool_name: tool_name.to_string(),
input: input.to_string(),
current_mode,
required_mode,
reason: reason.clone(),
};
match prompter.as_mut() {
Some(prompter) => match prompter.decide(&request) {
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
},
None => PermissionOutcome::Deny {
reason: reason.unwrap_or_else(|| {
format!(
"tool '{tool_name}' requires approval to run while mode is {}",
current_mode.as_str()
)
}),
},
}
}
fn find_matching_rule<'a>(
rules: &'a [PermissionRule],
tool_name: &str,
input: &str,
) -> Option<&'a PermissionRule> {
rules.iter().find(|rule| rule.matches(tool_name, input))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct PermissionRule {
raw: String,
tool_name: String,
matcher: PermissionRuleMatcher,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum PermissionRuleMatcher {
Any,
Exact(String),
Prefix(String),
}
impl PermissionRule {
fn parse(raw: &str) -> Self {
let trimmed = raw.trim();
let open = find_first_unescaped(trimmed, '(');
let close = find_last_unescaped(trimmed, ')');
if let (Some(open), Some(close)) = (open, close) {
if close == trimmed.len() - 1 && open < close {
let tool_name = trimmed[..open].trim();
let content = &trimmed[open + 1..close];
if !tool_name.is_empty() {
let matcher = parse_rule_matcher(content);
return Self {
raw: trimmed.to_string(),
tool_name: tool_name.to_string(),
matcher,
};
}
}
}
Self {
raw: trimmed.to_string(),
tool_name: trimmed.to_string(),
matcher: PermissionRuleMatcher::Any,
}
}
fn matches(&self, tool_name: &str, input: &str) -> bool {
if self.tool_name != tool_name {
return false;
}
match &self.matcher {
PermissionRuleMatcher::Any => true,
PermissionRuleMatcher::Exact(expected) => {
extract_permission_subject(input).is_some_and(|candidate| candidate == *expected)
}
PermissionRuleMatcher::Prefix(prefix) => extract_permission_subject(input)
.is_some_and(|candidate| candidate.starts_with(prefix)),
}
}
}
fn parse_rule_matcher(content: &str) -> PermissionRuleMatcher {
let unescaped = unescape_rule_content(content.trim());
if unescaped.is_empty() || unescaped == "*" {
PermissionRuleMatcher::Any
} else if let Some(prefix) = unescaped.strip_suffix(":*") {
PermissionRuleMatcher::Prefix(prefix.to_string())
} else {
PermissionRuleMatcher::Exact(unescaped)
}
}
fn unescape_rule_content(content: &str) -> String {
content
.replace(r"\(", "(")
.replace(r"\)", ")")
.replace(r"\\", r"\")
}
fn find_first_unescaped(value: &str, needle: char) -> Option<usize> {
let mut escaped = false;
for (idx, ch) in value.char_indices() {
if ch == '\\' {
escaped = !escaped;
continue;
}
if ch == needle && !escaped {
return Some(idx);
}
escaped = false;
}
None
}
fn find_last_unescaped(value: &str, needle: char) -> Option<usize> {
let chars = value.char_indices().collect::<Vec<_>>();
for (pos, (idx, ch)) in chars.iter().enumerate().rev() {
if *ch != needle {
continue;
}
let mut backslashes = 0;
for (_, prev) in chars[..pos].iter().rev() {
if *prev == '\\' {
backslashes += 1;
} else {
break;
}
}
if backslashes % 2 == 0 {
return Some(*idx);
}
}
None
}
fn extract_permission_subject(input: &str) -> Option<String> {
let parsed = serde_json::from_str::<Value>(input).ok();
if let Some(Value::Object(object)) = parsed {
for key in [
"command",
"path",
"file_path",
"filePath",
"notebook_path",
"notebookPath",
"url",
"pattern",
"code",
"message",
] {
if let Some(value) = object.get(key).and_then(Value::as_str) {
return Some(value.to_string());
}
}
}
(!input.trim().is_empty()).then(|| input.to_string())
}
#[cfg(test)]
mod tests {
use super::{
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
PermissionPrompter, PermissionRequest,
PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
PermissionPromptDecision, PermissionPrompter, PermissionRequest,
};
use crate::config::RuntimePermissionRuleConfig;
struct RecordingPrompter {
seen: Vec<PermissionRequest>,
@ -229,4 +564,120 @@ mod tests {
PermissionOutcome::Deny { reason } if reason == "not now"
));
}
#[test]
fn applies_rule_based_denials_and_allows() {
let rules = RuntimePermissionRuleConfig::new(
vec!["bash(git:*)".to_string()],
vec!["bash(rm -rf:*)".to_string()],
Vec::new(),
);
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
.with_permission_rules(&rules);
assert_eq!(
policy.authorize("bash", r#"{"command":"git status"}"#, None),
PermissionOutcome::Allow
);
assert!(matches!(
policy.authorize("bash", r#"{"command":"rm -rf /tmp/x"}"#, None),
PermissionOutcome::Deny { reason } if reason.contains("denied by rule")
));
}
#[test]
fn ask_rules_force_prompt_even_when_mode_allows() {
let rules = RuntimePermissionRuleConfig::new(
Vec::new(),
Vec::new(),
vec!["bash(git:*)".to_string()],
);
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
.with_permission_rules(&rules);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: true,
};
let outcome = policy.authorize("bash", r#"{"command":"git status"}"#, Some(&mut prompter));
assert_eq!(outcome, PermissionOutcome::Allow);
assert_eq!(prompter.seen.len(), 1);
assert!(prompter.seen[0]
.reason
.as_deref()
.is_some_and(|reason| reason.contains("ask rule")));
}
#[test]
fn hook_allow_still_respects_ask_rules() {
let rules = RuntimePermissionRuleConfig::new(
Vec::new(),
Vec::new(),
vec!["bash(git:*)".to_string()],
);
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
.with_permission_rules(&rules);
let context = PermissionContext::new(
Some(PermissionOverride::Allow),
Some("hook approved".to_string()),
);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: true,
};
let outcome = policy.authorize_with_context(
"bash",
r#"{"command":"git status"}"#,
&context,
Some(&mut prompter),
);
assert_eq!(outcome, PermissionOutcome::Allow);
assert_eq!(prompter.seen.len(), 1);
}
#[test]
fn hook_deny_short_circuits_permission_flow() {
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
let context = PermissionContext::new(
Some(PermissionOverride::Deny),
Some("blocked by hook".to_string()),
);
assert_eq!(
policy.authorize_with_context("bash", "{}", &context, None),
PermissionOutcome::Deny {
reason: "blocked by hook".to_string(),
}
);
}
#[test]
fn hook_ask_forces_prompt() {
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
let context = PermissionContext::new(
Some(PermissionOverride::Ask),
Some("hook requested confirmation".to_string()),
);
let mut prompter = RecordingPrompter {
seen: Vec::new(),
allow: true,
};
let outcome = policy.authorize_with_context("bash", "{}", &context, Some(&mut prompter));
assert_eq!(outcome, PermissionOutcome::Allow);
assert_eq!(prompter.seen.len(), 1);
assert_eq!(
prompter.seen[0].reason.as_deref(),
Some("hook requested confirmation")
);
}
}

View File

@ -0,0 +1,533 @@
#![allow(clippy::redundant_closure_for_method_calls)]
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::config::RuntimePluginConfig;
use crate::mcp_tool_bridge::{McpResourceInfo, McpToolInfo};
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub type ToolInfo = McpToolInfo;
pub type ResourceInfo = McpResourceInfo;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ServerStatus {
Healthy,
Degraded,
Failed,
}
impl std::fmt::Display for ServerStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Healthy => write!(f, "healthy"),
Self::Degraded => write!(f, "degraded"),
Self::Failed => write!(f, "failed"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ServerHealth {
pub server_name: String,
pub status: ServerStatus,
pub capabilities: Vec<String>,
pub last_error: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "state")]
pub enum PluginState {
Unconfigured,
Validated,
Starting,
Healthy,
Degraded {
healthy_servers: Vec<String>,
failed_servers: Vec<ServerHealth>,
},
Failed {
reason: String,
},
ShuttingDown,
Stopped,
}
impl PluginState {
#[must_use]
pub fn from_servers(servers: &[ServerHealth]) -> Self {
if servers.is_empty() {
return Self::Failed {
reason: "no servers available".to_string(),
};
}
let healthy_servers = servers
.iter()
.filter(|server| server.status != ServerStatus::Failed)
.map(|server| server.server_name.clone())
.collect::<Vec<_>>();
let failed_servers = servers
.iter()
.filter(|server| server.status == ServerStatus::Failed)
.cloned()
.collect::<Vec<_>>();
let has_degraded_server = servers
.iter()
.any(|server| server.status == ServerStatus::Degraded);
if failed_servers.is_empty() && !has_degraded_server {
Self::Healthy
} else if healthy_servers.is_empty() {
Self::Failed {
reason: format!("all {} servers failed", failed_servers.len()),
}
} else {
Self::Degraded {
healthy_servers,
failed_servers,
}
}
}
}
impl std::fmt::Display for PluginState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unconfigured => write!(f, "unconfigured"),
Self::Validated => write!(f, "validated"),
Self::Starting => write!(f, "starting"),
Self::Healthy => write!(f, "healthy"),
Self::Degraded { .. } => write!(f, "degraded"),
Self::Failed { .. } => write!(f, "failed"),
Self::ShuttingDown => write!(f, "shutting_down"),
Self::Stopped => write!(f, "stopped"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PluginHealthcheck {
pub plugin_name: String,
pub state: PluginState,
pub servers: Vec<ServerHealth>,
pub last_check: u64,
}
impl PluginHealthcheck {
#[must_use]
pub fn new(plugin_name: impl Into<String>, servers: Vec<ServerHealth>) -> Self {
let state = PluginState::from_servers(&servers);
Self {
plugin_name: plugin_name.into(),
state,
servers,
last_check: now_secs(),
}
}
#[must_use]
pub fn degraded_mode(&self, discovery: &DiscoveryResult) -> Option<DegradedMode> {
match &self.state {
PluginState::Degraded {
healthy_servers,
failed_servers,
} => Some(DegradedMode {
available_tools: discovery
.tools
.iter()
.map(|tool| tool.name.clone())
.collect(),
unavailable_tools: failed_servers
.iter()
.flat_map(|server| server.capabilities.iter().cloned())
.collect(),
reason: format!(
"{} servers healthy, {} servers failed",
healthy_servers.len(),
failed_servers.len()
),
}),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveryResult {
pub tools: Vec<ToolInfo>,
pub resources: Vec<ResourceInfo>,
pub partial: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DegradedMode {
pub available_tools: Vec<String>,
pub unavailable_tools: Vec<String>,
pub reason: String,
}
impl DegradedMode {
#[must_use]
pub fn new(
available_tools: Vec<String>,
unavailable_tools: Vec<String>,
reason: impl Into<String>,
) -> Self {
Self {
available_tools,
unavailable_tools,
reason: reason.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PluginLifecycleEvent {
ConfigValidated,
StartupHealthy,
StartupDegraded,
StartupFailed,
Shutdown,
}
impl std::fmt::Display for PluginLifecycleEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ConfigValidated => write!(f, "config_validated"),
Self::StartupHealthy => write!(f, "startup_healthy"),
Self::StartupDegraded => write!(f, "startup_degraded"),
Self::StartupFailed => write!(f, "startup_failed"),
Self::Shutdown => write!(f, "shutdown"),
}
}
}
pub trait PluginLifecycle {
fn validate_config(&self, config: &RuntimePluginConfig) -> Result<(), String>;
fn healthcheck(&self) -> PluginHealthcheck;
fn discover(&self) -> DiscoveryResult;
fn shutdown(&mut self) -> Result<(), String>;
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone)]
struct MockPluginLifecycle {
plugin_name: String,
valid_config: bool,
healthcheck: PluginHealthcheck,
discovery: DiscoveryResult,
shutdown_error: Option<String>,
shutdown_called: bool,
}
impl MockPluginLifecycle {
fn new(
plugin_name: &str,
valid_config: bool,
servers: Vec<ServerHealth>,
discovery: DiscoveryResult,
shutdown_error: Option<String>,
) -> Self {
Self {
plugin_name: plugin_name.to_string(),
valid_config,
healthcheck: PluginHealthcheck::new(plugin_name, servers),
discovery,
shutdown_error,
shutdown_called: false,
}
}
}
impl PluginLifecycle for MockPluginLifecycle {
fn validate_config(&self, _config: &RuntimePluginConfig) -> Result<(), String> {
if self.valid_config {
Ok(())
} else {
Err(format!(
"plugin `{}` failed configuration validation",
self.plugin_name
))
}
}
fn healthcheck(&self) -> PluginHealthcheck {
if self.shutdown_called {
PluginHealthcheck {
plugin_name: self.plugin_name.clone(),
state: PluginState::Stopped,
servers: self.healthcheck.servers.clone(),
last_check: now_secs(),
}
} else {
self.healthcheck.clone()
}
}
fn discover(&self) -> DiscoveryResult {
self.discovery.clone()
}
fn shutdown(&mut self) -> Result<(), String> {
if let Some(error) = &self.shutdown_error {
return Err(error.clone());
}
self.shutdown_called = true;
Ok(())
}
}
fn healthy_server(name: &str, capabilities: &[&str]) -> ServerHealth {
ServerHealth {
server_name: name.to_string(),
status: ServerStatus::Healthy,
capabilities: capabilities
.iter()
.map(|capability| capability.to_string())
.collect(),
last_error: None,
}
}
fn failed_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth {
ServerHealth {
server_name: name.to_string(),
status: ServerStatus::Failed,
capabilities: capabilities
.iter()
.map(|capability| capability.to_string())
.collect(),
last_error: Some(error.to_string()),
}
}
fn degraded_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth {
ServerHealth {
server_name: name.to_string(),
status: ServerStatus::Degraded,
capabilities: capabilities
.iter()
.map(|capability| capability.to_string())
.collect(),
last_error: Some(error.to_string()),
}
}
fn tool(name: &str) -> ToolInfo {
ToolInfo {
name: name.to_string(),
description: Some(format!("{name} tool")),
input_schema: None,
}
}
fn resource(name: &str, uri: &str) -> ResourceInfo {
ResourceInfo {
uri: uri.to_string(),
name: name.to_string(),
description: Some(format!("{name} resource")),
mime_type: Some("application/json".to_string()),
}
}
#[test]
fn full_lifecycle_happy_path() {
// given
let mut lifecycle = MockPluginLifecycle::new(
"healthy-plugin",
true,
vec![
healthy_server("alpha", &["search", "read"]),
healthy_server("beta", &["write"]),
],
DiscoveryResult {
tools: vec![tool("search"), tool("read"), tool("write")],
resources: vec![resource("docs", "file:///docs")],
partial: false,
},
None,
);
let config = RuntimePluginConfig::default();
// when
let validation = lifecycle.validate_config(&config);
let healthcheck = lifecycle.healthcheck();
let discovery = lifecycle.discover();
let shutdown = lifecycle.shutdown();
let post_shutdown = lifecycle.healthcheck();
// then
assert_eq!(validation, Ok(()));
assert_eq!(healthcheck.state, PluginState::Healthy);
assert_eq!(healthcheck.plugin_name, "healthy-plugin");
assert_eq!(discovery.tools.len(), 3);
assert_eq!(discovery.resources.len(), 1);
assert!(!discovery.partial);
assert_eq!(shutdown, Ok(()));
assert_eq!(post_shutdown.state, PluginState::Stopped);
}
#[test]
fn degraded_startup_when_one_of_three_servers_fails() {
// given
let lifecycle = MockPluginLifecycle::new(
"degraded-plugin",
true,
vec![
healthy_server("alpha", &["search"]),
failed_server("beta", &["write"], "connection refused"),
healthy_server("gamma", &["read"]),
],
DiscoveryResult {
tools: vec![tool("search"), tool("read")],
resources: vec![resource("alpha-docs", "file:///alpha")],
partial: true,
},
None,
);
// when
let healthcheck = lifecycle.healthcheck();
let discovery = lifecycle.discover();
let degraded_mode = healthcheck
.degraded_mode(&discovery)
.expect("degraded startup should expose degraded mode");
// then
match healthcheck.state {
PluginState::Degraded {
healthy_servers,
failed_servers,
} => {
assert_eq!(
healthy_servers,
vec!["alpha".to_string(), "gamma".to_string()]
);
assert_eq!(failed_servers.len(), 1);
assert_eq!(failed_servers[0].server_name, "beta");
assert_eq!(
failed_servers[0].last_error.as_deref(),
Some("connection refused")
);
}
other => panic!("expected degraded state, got {other:?}"),
}
assert!(discovery.partial);
assert_eq!(
degraded_mode.available_tools,
vec!["search".to_string(), "read".to_string()]
);
assert_eq!(degraded_mode.unavailable_tools, vec!["write".to_string()]);
assert_eq!(degraded_mode.reason, "2 servers healthy, 1 servers failed");
}
#[test]
fn degraded_server_status_keeps_server_usable() {
// given
let lifecycle = MockPluginLifecycle::new(
"soft-degraded-plugin",
true,
vec![
healthy_server("alpha", &["search"]),
degraded_server("beta", &["write"], "high latency"),
],
DiscoveryResult {
tools: vec![tool("search"), tool("write")],
resources: Vec::new(),
partial: true,
},
None,
);
// when
let healthcheck = lifecycle.healthcheck();
// then
match healthcheck.state {
PluginState::Degraded {
healthy_servers,
failed_servers,
} => {
assert_eq!(
healthy_servers,
vec!["alpha".to_string(), "beta".to_string()]
);
assert!(failed_servers.is_empty());
}
other => panic!("expected degraded state, got {other:?}"),
}
}
#[test]
fn complete_failure_when_all_servers_fail() {
// given
let lifecycle = MockPluginLifecycle::new(
"failed-plugin",
true,
vec![
failed_server("alpha", &["search"], "timeout"),
failed_server("beta", &["read"], "handshake failed"),
],
DiscoveryResult {
tools: Vec::new(),
resources: Vec::new(),
partial: false,
},
None,
);
// when
let healthcheck = lifecycle.healthcheck();
let discovery = lifecycle.discover();
// then
match &healthcheck.state {
PluginState::Failed { reason } => {
assert_eq!(reason, "all 2 servers failed");
}
other => panic!("expected failed state, got {other:?}"),
}
assert!(!discovery.partial);
assert!(discovery.tools.is_empty());
assert!(discovery.resources.is_empty());
assert!(healthcheck.degraded_mode(&discovery).is_none());
}
#[test]
fn graceful_shutdown() {
// given
let mut lifecycle = MockPluginLifecycle::new(
"shutdown-plugin",
true,
vec![healthy_server("alpha", &["search"])],
DiscoveryResult {
tools: vec![tool("search")],
resources: Vec::new(),
partial: false,
},
None,
);
// when
let shutdown = lifecycle.shutdown();
let post_shutdown = lifecycle.healthcheck();
// then
assert_eq!(shutdown, Ok(()));
assert_eq!(PluginLifecycleEvent::Shutdown.to_string(), "shutdown");
assert_eq!(post_shutdown.state, PluginState::Stopped);
}
}

View File

@ -0,0 +1,581 @@
use std::time::Duration;
pub type GreenLevel = u8;
const STALE_BRANCH_THRESHOLD: Duration = Duration::from_secs(60 * 60);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PolicyRule {
pub name: String,
pub condition: PolicyCondition,
pub action: PolicyAction,
pub priority: u32,
}
impl PolicyRule {
#[must_use]
pub fn new(
name: impl Into<String>,
condition: PolicyCondition,
action: PolicyAction,
priority: u32,
) -> Self {
Self {
name: name.into(),
condition,
action,
priority,
}
}
#[must_use]
pub fn matches(&self, context: &LaneContext) -> bool {
self.condition.matches(context)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyCondition {
And(Vec<PolicyCondition>),
Or(Vec<PolicyCondition>),
GreenAt { level: GreenLevel },
StaleBranch,
StartupBlocked,
LaneCompleted,
LaneReconciled,
ReviewPassed,
ScopedDiff,
TimedOut { duration: Duration },
}
impl PolicyCondition {
#[must_use]
pub fn matches(&self, context: &LaneContext) -> bool {
match self {
Self::And(conditions) => conditions
.iter()
.all(|condition| condition.matches(context)),
Self::Or(conditions) => conditions
.iter()
.any(|condition| condition.matches(context)),
Self::GreenAt { level } => context.green_level >= *level,
Self::StaleBranch => context.branch_freshness >= STALE_BRANCH_THRESHOLD,
Self::StartupBlocked => context.blocker == LaneBlocker::Startup,
Self::LaneCompleted => context.completed,
Self::LaneReconciled => context.reconciled,
Self::ReviewPassed => context.review_status == ReviewStatus::Approved,
Self::ScopedDiff => context.diff_scope == DiffScope::Scoped,
Self::TimedOut { duration } => context.branch_freshness >= *duration,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyAction {
MergeToDev,
MergeForward,
RecoverOnce,
Escalate { reason: String },
CloseoutLane,
CleanupSession,
Reconcile { reason: ReconcileReason },
Notify { channel: String },
Block { reason: String },
Chain(Vec<PolicyAction>),
}
/// Why a lane was reconciled without further action.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReconcileReason {
/// Branch already merged into main — no PR needed.
AlreadyMerged,
/// Work superseded by another lane or direct commit.
Superseded,
/// PR would be empty — all changes already landed.
EmptyDiff,
/// Lane manually closed by operator.
ManualClose,
}
impl PolicyAction {
fn flatten_into(&self, actions: &mut Vec<PolicyAction>) {
match self {
Self::Chain(chained) => {
for action in chained {
action.flatten_into(actions);
}
}
_ => actions.push(self.clone()),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LaneBlocker {
None,
Startup,
External,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReviewStatus {
Pending,
Approved,
Rejected,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiffScope {
Full,
Scoped,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LaneContext {
pub lane_id: String,
pub green_level: GreenLevel,
pub branch_freshness: Duration,
pub blocker: LaneBlocker,
pub review_status: ReviewStatus,
pub diff_scope: DiffScope,
pub completed: bool,
pub reconciled: bool,
}
impl LaneContext {
#[must_use]
pub fn new(
lane_id: impl Into<String>,
green_level: GreenLevel,
branch_freshness: Duration,
blocker: LaneBlocker,
review_status: ReviewStatus,
diff_scope: DiffScope,
completed: bool,
) -> Self {
Self {
lane_id: lane_id.into(),
green_level,
branch_freshness,
blocker,
review_status,
diff_scope,
completed,
reconciled: false,
}
}
/// Create a lane context that is already reconciled (no further action needed).
#[must_use]
pub fn reconciled(lane_id: impl Into<String>) -> Self {
Self {
lane_id: lane_id.into(),
green_level: 0,
branch_freshness: Duration::from_secs(0),
blocker: LaneBlocker::None,
review_status: ReviewStatus::Pending,
diff_scope: DiffScope::Full,
completed: true,
reconciled: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PolicyEngine {
rules: Vec<PolicyRule>,
}
impl PolicyEngine {
#[must_use]
pub fn new(mut rules: Vec<PolicyRule>) -> Self {
rules.sort_by_key(|rule| rule.priority);
Self { rules }
}
#[must_use]
pub fn rules(&self) -> &[PolicyRule] {
&self.rules
}
#[must_use]
pub fn evaluate(&self, context: &LaneContext) -> Vec<PolicyAction> {
evaluate(self, context)
}
}
#[must_use]
pub fn evaluate(engine: &PolicyEngine, context: &LaneContext) -> Vec<PolicyAction> {
let mut actions = Vec::new();
for rule in &engine.rules {
if rule.matches(context) {
rule.action.flatten_into(&mut actions);
}
}
actions
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::{
evaluate, DiffScope, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, PolicyEngine,
PolicyRule, ReconcileReason, ReviewStatus, STALE_BRANCH_THRESHOLD,
};
fn default_context() -> LaneContext {
LaneContext::new(
"lane-7",
0,
Duration::from_secs(0),
LaneBlocker::None,
ReviewStatus::Pending,
DiffScope::Full,
false,
)
}
#[test]
fn merge_to_dev_rule_fires_for_green_scoped_reviewed_lane() {
// given
let engine = PolicyEngine::new(vec![PolicyRule::new(
"merge-to-dev",
PolicyCondition::And(vec![
PolicyCondition::GreenAt { level: 2 },
PolicyCondition::ScopedDiff,
PolicyCondition::ReviewPassed,
]),
PolicyAction::MergeToDev,
20,
)]);
let context = LaneContext::new(
"lane-7",
3,
Duration::from_secs(5),
LaneBlocker::None,
ReviewStatus::Approved,
DiffScope::Scoped,
false,
);
// when
let actions = engine.evaluate(&context);
// then
assert_eq!(actions, vec![PolicyAction::MergeToDev]);
}
#[test]
fn stale_branch_rule_fires_at_threshold() {
// given
let engine = PolicyEngine::new(vec![PolicyRule::new(
"merge-forward",
PolicyCondition::StaleBranch,
PolicyAction::MergeForward,
10,
)]);
let context = LaneContext::new(
"lane-7",
1,
STALE_BRANCH_THRESHOLD,
LaneBlocker::None,
ReviewStatus::Pending,
DiffScope::Full,
false,
);
// when
let actions = engine.evaluate(&context);
// then
assert_eq!(actions, vec![PolicyAction::MergeForward]);
}
#[test]
fn startup_blocked_rule_recovers_then_escalates() {
// given
let engine = PolicyEngine::new(vec![PolicyRule::new(
"startup-recovery",
PolicyCondition::StartupBlocked,
PolicyAction::Chain(vec![
PolicyAction::RecoverOnce,
PolicyAction::Escalate {
reason: "startup remained blocked".to_string(),
},
]),
15,
)]);
let context = LaneContext::new(
"lane-7",
0,
Duration::from_secs(0),
LaneBlocker::Startup,
ReviewStatus::Pending,
DiffScope::Full,
false,
);
// when
let actions = engine.evaluate(&context);
// then
assert_eq!(
actions,
vec![
PolicyAction::RecoverOnce,
PolicyAction::Escalate {
reason: "startup remained blocked".to_string(),
},
]
);
}
#[test]
fn completed_lane_rule_closes_out_and_cleans_up() {
// given
let engine = PolicyEngine::new(vec![PolicyRule::new(
"lane-closeout",
PolicyCondition::LaneCompleted,
PolicyAction::Chain(vec![
PolicyAction::CloseoutLane,
PolicyAction::CleanupSession,
]),
30,
)]);
let context = LaneContext::new(
"lane-7",
0,
Duration::from_secs(0),
LaneBlocker::None,
ReviewStatus::Pending,
DiffScope::Full,
true,
);
// when
let actions = engine.evaluate(&context);
// then
assert_eq!(
actions,
vec![PolicyAction::CloseoutLane, PolicyAction::CleanupSession]
);
}
#[test]
fn matching_rules_are_returned_in_priority_order_with_stable_ties() {
// given
let engine = PolicyEngine::new(vec![
PolicyRule::new(
"late-cleanup",
PolicyCondition::And(vec![]),
PolicyAction::CleanupSession,
30,
),
PolicyRule::new(
"first-notify",
PolicyCondition::And(vec![]),
PolicyAction::Notify {
channel: "ops".to_string(),
},
10,
),
PolicyRule::new(
"second-notify",
PolicyCondition::And(vec![]),
PolicyAction::Notify {
channel: "review".to_string(),
},
10,
),
PolicyRule::new(
"merge",
PolicyCondition::And(vec![]),
PolicyAction::MergeToDev,
20,
),
]);
let context = default_context();
// when
let actions = evaluate(&engine, &context);
// then
assert_eq!(
actions,
vec![
PolicyAction::Notify {
channel: "ops".to_string(),
},
PolicyAction::Notify {
channel: "review".to_string(),
},
PolicyAction::MergeToDev,
PolicyAction::CleanupSession,
]
);
}
#[test]
fn combinators_handle_empty_cases_and_nested_chains() {
// given
let engine = PolicyEngine::new(vec![
PolicyRule::new(
"empty-and",
PolicyCondition::And(vec![]),
PolicyAction::Notify {
channel: "orchestrator".to_string(),
},
5,
),
PolicyRule::new(
"empty-or",
PolicyCondition::Or(vec![]),
PolicyAction::Block {
reason: "should not fire".to_string(),
},
10,
),
PolicyRule::new(
"nested",
PolicyCondition::Or(vec![
PolicyCondition::StartupBlocked,
PolicyCondition::And(vec![
PolicyCondition::GreenAt { level: 2 },
PolicyCondition::TimedOut {
duration: Duration::from_secs(5),
},
]),
]),
PolicyAction::Chain(vec![
PolicyAction::Notify {
channel: "alerts".to_string(),
},
PolicyAction::Chain(vec![
PolicyAction::MergeForward,
PolicyAction::CleanupSession,
]),
]),
15,
),
]);
let context = LaneContext::new(
"lane-7",
2,
Duration::from_secs(10),
LaneBlocker::External,
ReviewStatus::Pending,
DiffScope::Full,
false,
);
// when
let actions = engine.evaluate(&context);
// then
assert_eq!(
actions,
vec![
PolicyAction::Notify {
channel: "orchestrator".to_string(),
},
PolicyAction::Notify {
channel: "alerts".to_string(),
},
PolicyAction::MergeForward,
PolicyAction::CleanupSession,
]
);
}
#[test]
fn reconciled_lane_emits_reconcile_and_cleanup() {
// given — a lane where branch is already merged, no PR needed, session stale
let engine = PolicyEngine::new(vec![
PolicyRule::new(
"reconcile-closeout",
PolicyCondition::LaneReconciled,
PolicyAction::Chain(vec![
PolicyAction::Reconcile {
reason: ReconcileReason::AlreadyMerged,
},
PolicyAction::CloseoutLane,
PolicyAction::CleanupSession,
]),
5,
),
// This rule should NOT fire — reconciled lanes are completed but we want
// the more specific reconcile rule to handle them
PolicyRule::new(
"generic-closeout",
PolicyCondition::And(vec![
PolicyCondition::LaneCompleted,
// Only fire if NOT reconciled
PolicyCondition::And(vec![]),
]),
PolicyAction::CloseoutLane,
30,
),
]);
let context = LaneContext::reconciled("lane-9411");
// when
let actions = engine.evaluate(&context);
// then — reconcile rule fires first (priority 5), then generic closeout also fires
// because reconciled context has completed=true
assert_eq!(
actions,
vec![
PolicyAction::Reconcile {
reason: ReconcileReason::AlreadyMerged,
},
PolicyAction::CloseoutLane,
PolicyAction::CleanupSession,
PolicyAction::CloseoutLane,
]
);
}
#[test]
fn reconciled_context_has_correct_defaults() {
let ctx = LaneContext::reconciled("test-lane");
assert_eq!(ctx.lane_id, "test-lane");
assert!(ctx.completed);
assert!(ctx.reconciled);
assert_eq!(ctx.blocker, LaneBlocker::None);
assert_eq!(ctx.green_level, 0);
}
#[test]
fn non_reconciled_lane_does_not_trigger_reconcile_rule() {
let engine = PolicyEngine::new(vec![PolicyRule::new(
"reconcile-closeout",
PolicyCondition::LaneReconciled,
PolicyAction::Reconcile {
reason: ReconcileReason::EmptyDiff,
},
5,
)]);
// Normal completed lane — not reconciled
let context = LaneContext::new(
"lane-7",
0,
Duration::from_secs(0),
LaneBlocker::None,
ReviewStatus::Pending,
DiffScope::Full,
true,
);
let actions = engine.evaluate(&context);
assert!(actions.is_empty());
}
#[test]
fn reconcile_reason_variants_are_distinct() {
assert_ne!(ReconcileReason::AlreadyMerged, ReconcileReason::Superseded);
assert_ne!(ReconcileReason::EmptyDiff, ReconcileReason::ManualClose);
}
}

View File

@ -4,8 +4,9 @@ use std::path::{Path, PathBuf};
use std::process::Command;
use crate::config::{ConfigError, ConfigLoader, RuntimeConfig};
use lsp::LspContextEnrichment;
use crate::git_context::GitContext;
/// Errors raised while assembling the final system prompt.
#[derive(Debug)]
pub enum PromptBuildError {
Io(std::io::Error),
@ -35,23 +36,28 @@ impl From<ConfigError> for PromptBuildError {
}
}
/// Marker separating static prompt scaffolding from dynamic runtime context.
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
pub const FRONTIER_MODEL_NAME: &str = "Opus 4.6";
/// Human-readable default frontier model name embedded into generated prompts.
pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6";
const MAX_INSTRUCTION_FILE_CHARS: usize = 4_000;
const MAX_TOTAL_INSTRUCTION_CHARS: usize = 12_000;
/// Contents of an instruction file included in prompt construction.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContextFile {
pub path: PathBuf,
pub content: String,
}
/// Project-local context injected into the rendered system prompt.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProjectContext {
pub cwd: PathBuf,
pub current_date: String,
pub git_status: Option<String>,
pub git_diff: Option<String>,
pub git_context: Option<GitContext>,
pub instruction_files: Vec<ContextFile>,
}
@ -67,6 +73,7 @@ impl ProjectContext {
current_date: current_date.into(),
git_status: None,
git_diff: None,
git_context: None,
instruction_files,
})
}
@ -78,10 +85,12 @@ impl ProjectContext {
let mut context = Self::discover(cwd, current_date)?;
context.git_status = read_git_status(&context.cwd);
context.git_diff = read_git_diff(&context.cwd);
context.git_context = GitContext::detect(&context.cwd);
Ok(context)
}
}
/// Builder for the runtime system prompt and dynamic environment sections.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct SystemPromptBuilder {
output_style_name: Option<String>,
@ -131,15 +140,6 @@ impl SystemPromptBuilder {
self
}
#[must_use]
pub fn with_lsp_context(mut self, enrichment: &LspContextEnrichment) -> Self {
if !enrichment.is_empty() {
self.append_sections
.push(enrichment.render_prompt_section());
}
self
}
#[must_use]
pub fn build(&self) -> Vec<String> {
let mut sections = Vec::new();
@ -194,6 +194,7 @@ impl SystemPromptBuilder {
}
}
/// Formats each item as an indented bullet for prompt sections.
#[must_use]
pub fn prepend_bullets(items: Vec<String>) -> Vec<String> {
items.into_iter().map(|item| format!(" - {item}")).collect()
@ -211,9 +212,9 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
let mut files = Vec::new();
for dir in directories {
for candidate in [
dir.join("CLAW.md"),
dir.join("CLAW.local.md"),
dir.join(".claw").join("CLAW.md"),
dir.join("CLAUDE.md"),
dir.join("CLAUDE.local.md"),
dir.join(".claw").join("CLAUDE.md"),
dir.join(".claw").join("instructions.md"),
] {
push_context_file(&mut files, candidate)?;
@ -292,7 +293,7 @@ fn render_project_context(project_context: &ProjectContext) -> String {
];
if !project_context.instruction_files.is_empty() {
bullets.push(format!(
"Claw instruction files discovered: {}.",
"Claude instruction files discovered: {}.",
project_context.instruction_files.len()
));
}
@ -302,16 +303,32 @@ fn render_project_context(project_context: &ProjectContext) -> String {
lines.push("Git status snapshot:".to_string());
lines.push(status.clone());
}
if let Some(ref gc) = project_context.git_context {
if !gc.recent_commits.is_empty() {
lines.push(String::new());
lines.push("Recent commits (last 5):".to_string());
for c in &gc.recent_commits {
lines.push(format!(" {} {}", c.hash, c.subject));
}
}
}
if let Some(diff) = &project_context.git_diff {
lines.push(String::new());
lines.push("Git diff snapshot:".to_string());
lines.push(diff.clone());
}
if let Some(git_context) = &project_context.git_context {
let rendered = git_context.render();
if !rendered.is_empty() {
lines.push(String::new());
lines.push(rendered);
}
}
lines.join("\n")
}
fn render_instruction_files(files: &[ContextFile]) -> String {
let mut sections = vec!["# Claw instructions".to_string()];
let mut sections = vec!["# Claude instructions".to_string()];
let mut remaining_chars = MAX_TOTAL_INSTRUCTION_CHARS;
for file in files {
if remaining_chars == 0 {
@ -411,6 +428,7 @@ fn collapse_blank_lines(content: &str) -> String {
result
}
/// Loads config and project context, then renders the system prompt text.
pub fn load_system_prompt(
cwd: impl Into<PathBuf>,
current_date: impl Into<String>,
@ -523,24 +541,31 @@ mod tests {
crate::test_env_lock()
}
fn ensure_valid_cwd() {
if std::env::current_dir().is_err() {
std::env::set_current_dir(env!("CARGO_MANIFEST_DIR"))
.expect("test cwd should be recoverable");
}
}
#[test]
fn discovers_instruction_files_from_ancestor_chain() {
let root = temp_dir();
let nested = root.join("apps").join("api");
fs::create_dir_all(nested.join(".claw")).expect("nested claw dir");
fs::write(root.join("CLAW.md"), "root instructions").expect("write root instructions");
fs::write(root.join("CLAW.local.md"), "local instructions")
fs::write(root.join("CLAUDE.md"), "root instructions").expect("write root instructions");
fs::write(root.join("CLAUDE.local.md"), "local instructions")
.expect("write local instructions");
fs::create_dir_all(root.join("apps")).expect("apps dir");
fs::create_dir_all(root.join("apps").join(".claw")).expect("apps claw dir");
fs::write(root.join("apps").join("CLAW.md"), "apps instructions")
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
.expect("write apps instructions");
fs::write(
root.join("apps").join(".claw").join("instructions.md"),
"apps dot claw instructions",
"apps dot claude instructions",
)
.expect("write apps dot claw instructions");
fs::write(nested.join(".claw").join("CLAW.md"), "nested rules")
.expect("write apps dot claude instructions");
fs::write(nested.join(".claw").join("CLAUDE.md"), "nested rules")
.expect("write nested rules");
fs::write(
nested.join(".claw").join("instructions.md"),
@ -561,7 +586,7 @@ mod tests {
"root instructions",
"local instructions",
"apps instructions",
"apps dot claw instructions",
"apps dot claude instructions",
"nested rules",
"nested instructions"
]
@ -574,8 +599,8 @@ mod tests {
let root = temp_dir();
let nested = root.join("apps").join("api");
fs::create_dir_all(&nested).expect("nested dir");
fs::write(root.join("CLAW.md"), "same rules\n\n").expect("write root");
fs::write(nested.join("CLAW.md"), "same rules\n").expect("write nested");
fs::write(root.join("CLAUDE.md"), "same rules\n\n").expect("write root");
fs::write(nested.join("CLAUDE.md"), "same rules\n").expect("write nested");
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
assert_eq!(context.instruction_files.len(), 1);
@ -603,14 +628,15 @@ mod tests {
#[test]
fn displays_context_paths_compactly() {
assert_eq!(
display_context_path(Path::new("/tmp/project/.claw/CLAW.md")),
"CLAW.md"
display_context_path(Path::new("/tmp/project/.claw/CLAUDE.md")),
"CLAUDE.md"
);
}
#[test]
fn discover_with_git_includes_status_snapshot() {
let _guard = env_lock();
ensure_valid_cwd();
let root = temp_dir();
fs::create_dir_all(&root).expect("root dir");
std::process::Command::new("git")
@ -618,7 +644,7 @@ mod tests {
.current_dir(&root)
.status()
.expect("git init should run");
fs::write(root.join("CLAW.md"), "rules").expect("write instructions");
fs::write(root.join("CLAUDE.md"), "rules").expect("write instructions");
fs::write(root.join("tracked.txt"), "hello").expect("write tracked file");
let context =
@ -626,16 +652,99 @@ mod tests {
let status = context.git_status.expect("git status should be present");
assert!(status.contains("## No commits yet on") || status.contains("## "));
assert!(status.contains("?? CLAW.md"));
assert!(status.contains("?? CLAUDE.md"));
assert!(status.contains("?? tracked.txt"));
assert!(context.git_diff.is_none());
fs::remove_dir_all(root).expect("cleanup temp dir");
}
#[test]
fn discover_with_git_includes_recent_commits_and_renders_them() {
// given: a git repo with three commits and a current branch
let _guard = env_lock();
ensure_valid_cwd();
let root = temp_dir();
fs::create_dir_all(&root).expect("root dir");
std::process::Command::new("git")
.args(["init", "--quiet", "-b", "main"])
.current_dir(&root)
.status()
.expect("git init should run");
std::process::Command::new("git")
.args(["config", "user.email", "tests@example.com"])
.current_dir(&root)
.status()
.expect("git config email should run");
std::process::Command::new("git")
.args(["config", "user.name", "Runtime Prompt Tests"])
.current_dir(&root)
.status()
.expect("git config name should run");
for (file, message) in [
("a.txt", "first commit"),
("b.txt", "second commit"),
("c.txt", "third commit"),
] {
fs::write(root.join(file), "x\n").expect("write commit file");
std::process::Command::new("git")
.args(["add", file])
.current_dir(&root)
.status()
.expect("git add should run");
std::process::Command::new("git")
.args(["commit", "-m", message, "--quiet"])
.current_dir(&root)
.status()
.expect("git commit should run");
}
fs::write(root.join("d.txt"), "staged\n").expect("write staged file");
std::process::Command::new("git")
.args(["add", "d.txt"])
.current_dir(&root)
.status()
.expect("git add staged should run");
// when: discovering project context with git auto-include
let context =
ProjectContext::discover_with_git(&root, "2026-03-31").expect("context should load");
let rendered = SystemPromptBuilder::new()
.with_os("linux", "6.8")
.with_project_context(context.clone())
.render();
// then: branch, recent commits and staged files are present in context
let gc = context
.git_context
.as_ref()
.expect("git context should be present");
let commits: String = gc
.recent_commits
.iter()
.map(|c| c.subject.clone())
.collect::<Vec<_>>()
.join("\n");
assert!(commits.contains("first commit"));
assert!(commits.contains("second commit"));
assert!(commits.contains("third commit"));
assert_eq!(gc.recent_commits.len(), 3);
let status = context.git_status.as_deref().expect("status snapshot");
assert!(status.contains("## main"));
assert!(status.contains("A d.txt"));
assert!(rendered.contains("Recent commits (last 5):"));
assert!(rendered.contains("first commit"));
assert!(rendered.contains("Git status snapshot:"));
assert!(rendered.contains("## main"));
fs::remove_dir_all(root).expect("cleanup temp dir");
}
#[test]
fn discover_with_git_includes_diff_snapshot_for_tracked_changes() {
let _guard = env_lock();
ensure_valid_cwd();
let root = temp_dir();
fs::create_dir_all(&root).expect("root dir");
std::process::Command::new("git")
@ -677,10 +786,10 @@ mod tests {
}
#[test]
fn load_system_prompt_reads_claw_files_and_config() {
fn load_system_prompt_reads_claude_files_and_config() {
let root = temp_dir();
fs::create_dir_all(root.join(".claw")).expect("claw dir");
fs::write(root.join("CLAW.md"), "Project rules").expect("write instructions");
fs::write(root.join("CLAUDE.md"), "Project rules").expect("write instructions");
fs::write(
root.join(".claw").join("settings.json"),
r#"{"permissionMode":"acceptEdits"}"#,
@ -688,6 +797,7 @@ mod tests {
.expect("write settings");
let _guard = env_lock();
ensure_valid_cwd();
let previous = std::env::current_dir().expect("cwd");
let original_home = std::env::var("HOME").ok();
let original_claw_home = std::env::var("CLAW_CONFIG_HOME").ok();
@ -719,10 +829,10 @@ mod tests {
}
#[test]
fn renders_claw_code_style_sections_with_project_context() {
fn renders_claude_code_style_sections_with_project_context() {
let root = temp_dir();
fs::create_dir_all(root.join(".claw")).expect("claw dir");
fs::write(root.join("CLAW.md"), "Project rules").expect("write CLAW.md");
fs::write(root.join("CLAUDE.md"), "Project rules").expect("write CLAUDE.md");
fs::write(
root.join(".claw").join("settings.json"),
r#"{"permissionMode":"acceptEdits"}"#,
@ -743,7 +853,7 @@ mod tests {
assert!(prompt.contains("# System"));
assert!(prompt.contains("# Project context"));
assert!(prompt.contains("# Claw instructions"));
assert!(prompt.contains("# Claude instructions"));
assert!(prompt.contains("Project rules"));
assert!(prompt.contains("permissionMode"));
assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY));
@ -760,7 +870,7 @@ mod tests {
}
#[test]
fn discovers_dot_claw_instructions_markdown() {
fn discovers_dot_claude_instructions_markdown() {
let root = temp_dir();
let nested = root.join("apps").join("api");
fs::create_dir_all(nested.join(".claw")).expect("nested claw dir");
@ -785,10 +895,10 @@ mod tests {
#[test]
fn renders_instruction_file_metadata() {
let rendered = render_instruction_files(&[ContextFile {
path: PathBuf::from("/tmp/project/CLAW.md"),
path: PathBuf::from("/tmp/project/CLAUDE.md"),
content: "Project rules".to_string(),
}]);
assert!(rendered.contains("# Claw instructions"));
assert!(rendered.contains("# Claude instructions"));
assert!(rendered.contains("scope: /tmp/project"));
assert!(rendered.contains("Project rules"));
}

View File

@ -0,0 +1,631 @@
#![allow(clippy::cast_possible_truncation, clippy::uninlined_format_args)]
//! Recovery recipes for common failure scenarios.
//!
//! Encodes known automatic recoveries for the six failure scenarios
//! listed in ROADMAP item 8, and enforces one automatic recovery
//! attempt before escalation. Each attempt is emitted as a structured
//! recovery event.
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::worker_boot::WorkerFailureKind;
/// The six failure scenarios that have known recovery recipes.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FailureScenario {
TrustPromptUnresolved,
PromptMisdelivery,
StaleBranch,
CompileRedCrossCrate,
McpHandshakeFailure,
PartialPluginStartup,
ProviderFailure,
}
impl FailureScenario {
/// Returns all known failure scenarios.
#[must_use]
pub fn all() -> &'static [FailureScenario] {
&[
Self::TrustPromptUnresolved,
Self::PromptMisdelivery,
Self::StaleBranch,
Self::CompileRedCrossCrate,
Self::McpHandshakeFailure,
Self::PartialPluginStartup,
Self::ProviderFailure,
]
}
/// Map a `WorkerFailureKind` to the corresponding `FailureScenario`.
/// This is the bridge that lets recovery policy consume worker boot events.
#[must_use]
pub fn from_worker_failure_kind(kind: WorkerFailureKind) -> Self {
match kind {
WorkerFailureKind::TrustGate => Self::TrustPromptUnresolved,
WorkerFailureKind::PromptDelivery => Self::PromptMisdelivery,
WorkerFailureKind::Protocol => Self::McpHandshakeFailure,
WorkerFailureKind::Provider => Self::ProviderFailure,
}
}
}
impl std::fmt::Display for FailureScenario {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::TrustPromptUnresolved => write!(f, "trust_prompt_unresolved"),
Self::PromptMisdelivery => write!(f, "prompt_misdelivery"),
Self::StaleBranch => write!(f, "stale_branch"),
Self::CompileRedCrossCrate => write!(f, "compile_red_cross_crate"),
Self::McpHandshakeFailure => write!(f, "mcp_handshake_failure"),
Self::PartialPluginStartup => write!(f, "partial_plugin_startup"),
Self::ProviderFailure => write!(f, "provider_failure"),
}
}
}
/// Individual step that can be executed as part of a recovery recipe.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RecoveryStep {
AcceptTrustPrompt,
RedirectPromptToAgent,
RebaseBranch,
CleanBuild,
RetryMcpHandshake { timeout: u64 },
RestartPlugin { name: String },
RestartWorker,
EscalateToHuman { reason: String },
}
/// Policy governing what happens when automatic recovery is exhausted.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EscalationPolicy {
AlertHuman,
LogAndContinue,
Abort,
}
/// A recovery recipe encodes the sequence of steps to attempt for a
/// given failure scenario, along with the maximum number of automatic
/// attempts and the escalation policy.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RecoveryRecipe {
pub scenario: FailureScenario,
pub steps: Vec<RecoveryStep>,
pub max_attempts: u32,
pub escalation_policy: EscalationPolicy,
}
/// Outcome of a recovery attempt.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RecoveryResult {
Recovered {
steps_taken: u32,
},
PartialRecovery {
recovered: Vec<RecoveryStep>,
remaining: Vec<RecoveryStep>,
},
EscalationRequired {
reason: String,
},
}
/// Structured event emitted during recovery.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RecoveryEvent {
RecoveryAttempted {
scenario: FailureScenario,
recipe: RecoveryRecipe,
result: RecoveryResult,
},
RecoverySucceeded,
RecoveryFailed,
Escalated,
}
/// Minimal context for tracking recovery state and emitting events.
///
/// Holds per-scenario attempt counts, a structured event log, and an
/// optional simulation knob for controlling step outcomes during tests.
#[derive(Debug, Clone, Default)]
pub struct RecoveryContext {
attempts: HashMap<FailureScenario, u32>,
events: Vec<RecoveryEvent>,
/// Optional step index at which simulated execution fails.
/// `None` means all steps succeed.
fail_at_step: Option<usize>,
}
impl RecoveryContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Configure a step index at which simulated execution will fail.
#[must_use]
pub fn with_fail_at_step(mut self, index: usize) -> Self {
self.fail_at_step = Some(index);
self
}
/// Returns the structured event log populated during recovery.
#[must_use]
pub fn events(&self) -> &[RecoveryEvent] {
&self.events
}
/// Returns the number of recovery attempts made for a scenario.
#[must_use]
pub fn attempt_count(&self, scenario: &FailureScenario) -> u32 {
self.attempts.get(scenario).copied().unwrap_or(0)
}
}
/// Returns the known recovery recipe for the given failure scenario.
#[must_use]
pub fn recipe_for(scenario: &FailureScenario) -> RecoveryRecipe {
match scenario {
FailureScenario::TrustPromptUnresolved => RecoveryRecipe {
scenario: *scenario,
steps: vec![RecoveryStep::AcceptTrustPrompt],
max_attempts: 1,
escalation_policy: EscalationPolicy::AlertHuman,
},
FailureScenario::PromptMisdelivery => RecoveryRecipe {
scenario: *scenario,
steps: vec![RecoveryStep::RedirectPromptToAgent],
max_attempts: 1,
escalation_policy: EscalationPolicy::AlertHuman,
},
FailureScenario::StaleBranch => RecoveryRecipe {
scenario: *scenario,
steps: vec![RecoveryStep::RebaseBranch, RecoveryStep::CleanBuild],
max_attempts: 1,
escalation_policy: EscalationPolicy::AlertHuman,
},
FailureScenario::CompileRedCrossCrate => RecoveryRecipe {
scenario: *scenario,
steps: vec![RecoveryStep::CleanBuild],
max_attempts: 1,
escalation_policy: EscalationPolicy::AlertHuman,
},
FailureScenario::McpHandshakeFailure => RecoveryRecipe {
scenario: *scenario,
steps: vec![RecoveryStep::RetryMcpHandshake { timeout: 5000 }],
max_attempts: 1,
escalation_policy: EscalationPolicy::Abort,
},
FailureScenario::PartialPluginStartup => RecoveryRecipe {
scenario: *scenario,
steps: vec![
RecoveryStep::RestartPlugin {
name: "stalled".to_string(),
},
RecoveryStep::RetryMcpHandshake { timeout: 3000 },
],
max_attempts: 1,
escalation_policy: EscalationPolicy::LogAndContinue,
},
FailureScenario::ProviderFailure => RecoveryRecipe {
scenario: *scenario,
steps: vec![RecoveryStep::RestartWorker],
max_attempts: 1,
escalation_policy: EscalationPolicy::AlertHuman,
},
}
}
/// Attempts automatic recovery for the given failure scenario.
///
/// Looks up the recipe, enforces the one-attempt-before-escalation
/// policy, simulates step execution (controlled by the context), and
/// emits structured [`RecoveryEvent`]s for every attempt.
pub fn attempt_recovery(scenario: &FailureScenario, ctx: &mut RecoveryContext) -> RecoveryResult {
let recipe = recipe_for(scenario);
let attempt_count = ctx.attempts.entry(*scenario).or_insert(0);
// Enforce one automatic recovery attempt before escalation.
if *attempt_count >= recipe.max_attempts {
let result = RecoveryResult::EscalationRequired {
reason: format!(
"max recovery attempts ({}) exceeded for {}",
recipe.max_attempts, scenario
),
};
ctx.events.push(RecoveryEvent::RecoveryAttempted {
scenario: *scenario,
recipe,
result: result.clone(),
});
ctx.events.push(RecoveryEvent::Escalated);
return result;
}
*attempt_count += 1;
// Execute steps, honoring the optional fail_at_step simulation.
let fail_index = ctx.fail_at_step;
let mut executed = Vec::new();
let mut failed = false;
for (i, step) in recipe.steps.iter().enumerate() {
if fail_index == Some(i) {
failed = true;
break;
}
executed.push(step.clone());
}
let result = if failed {
let remaining: Vec<RecoveryStep> = recipe.steps[executed.len()..].to_vec();
if executed.is_empty() {
RecoveryResult::EscalationRequired {
reason: format!("recovery failed at first step for {}", scenario),
}
} else {
RecoveryResult::PartialRecovery {
recovered: executed,
remaining,
}
}
} else {
RecoveryResult::Recovered {
steps_taken: recipe.steps.len() as u32,
}
};
// Emit the attempt as structured event data.
ctx.events.push(RecoveryEvent::RecoveryAttempted {
scenario: *scenario,
recipe,
result: result.clone(),
});
match &result {
RecoveryResult::Recovered { .. } => {
ctx.events.push(RecoveryEvent::RecoverySucceeded);
}
RecoveryResult::PartialRecovery { .. } => {
ctx.events.push(RecoveryEvent::RecoveryFailed);
}
RecoveryResult::EscalationRequired { .. } => {
ctx.events.push(RecoveryEvent::Escalated);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn each_scenario_has_a_matching_recipe() {
// given
let scenarios = FailureScenario::all();
// when / then
for scenario in scenarios {
let recipe = recipe_for(scenario);
assert_eq!(
recipe.scenario, *scenario,
"recipe scenario should match requested scenario"
);
assert!(
!recipe.steps.is_empty(),
"recipe for {} should have at least one step",
scenario
);
assert!(
recipe.max_attempts >= 1,
"recipe for {} should allow at least one attempt",
scenario
);
}
}
#[test]
fn successful_recovery_returns_recovered_and_emits_events() {
// given
let mut ctx = RecoveryContext::new();
let scenario = FailureScenario::TrustPromptUnresolved;
// when
let result = attempt_recovery(&scenario, &mut ctx);
// then
assert_eq!(result, RecoveryResult::Recovered { steps_taken: 1 });
assert_eq!(ctx.events().len(), 2);
assert!(matches!(
&ctx.events()[0],
RecoveryEvent::RecoveryAttempted {
scenario: s,
result: r,
..
} if *s == FailureScenario::TrustPromptUnresolved
&& matches!(r, RecoveryResult::Recovered { steps_taken: 1 })
));
assert_eq!(ctx.events()[1], RecoveryEvent::RecoverySucceeded);
}
#[test]
fn escalation_after_max_attempts_exceeded() {
// given
let mut ctx = RecoveryContext::new();
let scenario = FailureScenario::PromptMisdelivery;
// when — first attempt succeeds
let first = attempt_recovery(&scenario, &mut ctx);
assert!(matches!(first, RecoveryResult::Recovered { .. }));
// when — second attempt should escalate
let second = attempt_recovery(&scenario, &mut ctx);
// then
assert!(
matches!(
&second,
RecoveryResult::EscalationRequired { reason }
if reason.contains("max recovery attempts")
),
"second attempt should require escalation, got: {second:?}"
);
assert_eq!(ctx.attempt_count(&scenario), 1);
assert!(ctx
.events()
.iter()
.any(|e| matches!(e, RecoveryEvent::Escalated)));
}
#[test]
fn partial_recovery_when_step_fails_midway() {
// given — PartialPluginStartup has two steps; fail at step index 1
let mut ctx = RecoveryContext::new().with_fail_at_step(1);
let scenario = FailureScenario::PartialPluginStartup;
// when
let result = attempt_recovery(&scenario, &mut ctx);
// then
match &result {
RecoveryResult::PartialRecovery {
recovered,
remaining,
} => {
assert_eq!(recovered.len(), 1, "one step should have succeeded");
assert_eq!(remaining.len(), 1, "one step should remain");
assert!(matches!(recovered[0], RecoveryStep::RestartPlugin { .. }));
assert!(matches!(
remaining[0],
RecoveryStep::RetryMcpHandshake { .. }
));
}
other => panic!("expected PartialRecovery, got {other:?}"),
}
assert!(ctx
.events()
.iter()
.any(|e| matches!(e, RecoveryEvent::RecoveryFailed)));
}
#[test]
fn first_step_failure_escalates_immediately() {
// given — fail at step index 0
let mut ctx = RecoveryContext::new().with_fail_at_step(0);
let scenario = FailureScenario::CompileRedCrossCrate;
// when
let result = attempt_recovery(&scenario, &mut ctx);
// then
assert!(
matches!(
&result,
RecoveryResult::EscalationRequired { reason }
if reason.contains("failed at first step")
),
"zero-step failure should escalate, got: {result:?}"
);
assert!(ctx
.events()
.iter()
.any(|e| matches!(e, RecoveryEvent::Escalated)));
}
#[test]
fn emitted_events_include_structured_attempt_data() {
// given
let mut ctx = RecoveryContext::new();
let scenario = FailureScenario::McpHandshakeFailure;
// when
let _ = attempt_recovery(&scenario, &mut ctx);
// then — verify the RecoveryAttempted event carries full context
let attempted = ctx
.events()
.iter()
.find(|e| matches!(e, RecoveryEvent::RecoveryAttempted { .. }))
.expect("should have emitted RecoveryAttempted event");
match attempted {
RecoveryEvent::RecoveryAttempted {
scenario: s,
recipe,
result,
} => {
assert_eq!(*s, scenario);
assert_eq!(recipe.scenario, scenario);
assert!(!recipe.steps.is_empty());
assert!(matches!(result, RecoveryResult::Recovered { .. }));
}
_ => unreachable!(),
}
// Verify the event is serializable as structured JSON
let json = serde_json::to_string(&ctx.events()[0])
.expect("recovery event should be serializable to JSON");
assert!(
json.contains("mcp_handshake_failure"),
"serialized event should contain scenario name"
);
}
#[test]
fn recovery_context_tracks_attempts_per_scenario() {
// given
let mut ctx = RecoveryContext::new();
// when
assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 0);
attempt_recovery(&FailureScenario::StaleBranch, &mut ctx);
// then
assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 1);
assert_eq!(ctx.attempt_count(&FailureScenario::PromptMisdelivery), 0);
}
#[test]
fn stale_branch_recipe_has_rebase_then_clean_build() {
// given
let recipe = recipe_for(&FailureScenario::StaleBranch);
// then
assert_eq!(recipe.steps.len(), 2);
assert_eq!(recipe.steps[0], RecoveryStep::RebaseBranch);
assert_eq!(recipe.steps[1], RecoveryStep::CleanBuild);
}
#[test]
fn partial_plugin_startup_recipe_has_restart_then_handshake() {
// given
let recipe = recipe_for(&FailureScenario::PartialPluginStartup);
// then
assert_eq!(recipe.steps.len(), 2);
assert!(matches!(
recipe.steps[0],
RecoveryStep::RestartPlugin { .. }
));
assert!(matches!(
recipe.steps[1],
RecoveryStep::RetryMcpHandshake { timeout: 3000 }
));
assert_eq!(recipe.escalation_policy, EscalationPolicy::LogAndContinue);
}
#[test]
fn failure_scenario_display_all_variants() {
// given
let cases = [
(
FailureScenario::TrustPromptUnresolved,
"trust_prompt_unresolved",
),
(FailureScenario::PromptMisdelivery, "prompt_misdelivery"),
(FailureScenario::StaleBranch, "stale_branch"),
(
FailureScenario::CompileRedCrossCrate,
"compile_red_cross_crate",
),
(
FailureScenario::McpHandshakeFailure,
"mcp_handshake_failure",
),
(
FailureScenario::PartialPluginStartup,
"partial_plugin_startup",
),
];
// when / then
for (scenario, expected) in &cases {
assert_eq!(scenario.to_string(), *expected);
}
}
#[test]
fn multi_step_success_reports_correct_steps_taken() {
// given — StaleBranch has 2 steps, no simulated failure
let mut ctx = RecoveryContext::new();
let scenario = FailureScenario::StaleBranch;
// when
let result = attempt_recovery(&scenario, &mut ctx);
// then
assert_eq!(result, RecoveryResult::Recovered { steps_taken: 2 });
}
#[test]
fn mcp_handshake_recipe_uses_abort_escalation_policy() {
// given
let recipe = recipe_for(&FailureScenario::McpHandshakeFailure);
// then
assert_eq!(recipe.escalation_policy, EscalationPolicy::Abort);
assert_eq!(recipe.max_attempts, 1);
}
#[test]
fn worker_failure_kind_maps_to_failure_scenario() {
// given / when / then — verify the bridge is correct
assert_eq!(
FailureScenario::from_worker_failure_kind(WorkerFailureKind::TrustGate),
FailureScenario::TrustPromptUnresolved,
);
assert_eq!(
FailureScenario::from_worker_failure_kind(WorkerFailureKind::PromptDelivery),
FailureScenario::PromptMisdelivery,
);
assert_eq!(
FailureScenario::from_worker_failure_kind(WorkerFailureKind::Protocol),
FailureScenario::McpHandshakeFailure,
);
assert_eq!(
FailureScenario::from_worker_failure_kind(WorkerFailureKind::Provider),
FailureScenario::ProviderFailure,
);
}
#[test]
fn provider_failure_recipe_uses_restart_worker_step() {
// given
let recipe = recipe_for(&FailureScenario::ProviderFailure);
// then
assert_eq!(recipe.scenario, FailureScenario::ProviderFailure);
assert!(recipe.steps.contains(&RecoveryStep::RestartWorker));
assert_eq!(recipe.escalation_policy, EscalationPolicy::AlertHuman);
assert_eq!(recipe.max_attempts, 1);
}
#[test]
fn provider_failure_recovery_attempt_succeeds_then_escalates() {
// given
let mut ctx = RecoveryContext::new();
let scenario = FailureScenario::ProviderFailure;
// when — first attempt
let first = attempt_recovery(&scenario, &mut ctx);
assert!(matches!(first, RecoveryResult::Recovered { .. }));
// when — second attempt should escalate (max_attempts=1)
let second = attempt_recovery(&scenario, &mut ctx);
assert!(matches!(second, RecoveryResult::EscalationRequired { .. }));
assert!(ctx
.events()
.iter()
.any(|e| matches!(e, RecoveryEvent::Escalated)));
}
}

View File

@ -72,9 +72,9 @@ impl RemoteSessionContext {
#[must_use]
pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self {
Self {
enabled: env_truthy(env_map.get("CLAW_CODE_REMOTE")),
enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")),
session_id: env_map
.get("CLAW_CODE_REMOTE_SESSION_ID")
.get("CLAUDE_CODE_REMOTE_SESSION_ID")
.filter(|value| !value.is_empty())
.cloned(),
base_url: env_map
@ -272,9 +272,9 @@ mod tests {
#[test]
fn remote_context_reads_env_state() {
let env = BTreeMap::from([
("CLAW_CODE_REMOTE".to_string(), "true".to_string()),
("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()),
(
"CLAW_CODE_REMOTE_SESSION_ID".to_string(),
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
"session-123".to_string(),
),
(
@ -291,7 +291,7 @@ mod tests {
#[test]
fn bootstrap_fails_open_when_token_or_session_is_missing() {
let env = BTreeMap::from([
("CLAW_CODE_REMOTE".to_string(), "1".to_string()),
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
]);
let bootstrap = UpstreamProxyBootstrap::from_env_map(&env);
@ -307,10 +307,10 @@ mod tests {
fs::write(&token_path, "secret-token\n").expect("write token");
let env = BTreeMap::from([
("CLAW_CODE_REMOTE".to_string(), "1".to_string()),
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
(
"CLAW_CODE_REMOTE_SESSION_ID".to_string(),
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
"session-123".to_string(),
),
(

View File

@ -107,23 +107,11 @@ impl SandboxConfig {
#[must_use]
pub fn detect_container_environment() -> ContainerEnvironment {
let proc_1_cgroup = if cfg!(target_os = "linux") {
fs::read_to_string("/proc/1/cgroup").ok()
} else {
None
};
let proc_1_cgroup = fs::read_to_string("/proc/1/cgroup").ok();
detect_container_environment_from(SandboxDetectionInputs {
env_pairs: env::vars().collect(),
dockerenv_exists: if cfg!(target_os = "linux") {
Path::new("/.dockerenv").exists()
} else {
false
},
containerenv_exists: if cfg!(target_os = "linux") {
Path::new("/run/.containerenv").exists()
} else {
false
},
dockerenv_exists: Path::new("/.dockerenv").exists(),
containerenv_exists: Path::new("/run/.containerenv").exists(),
proc_1_cgroup: proc_1_cgroup.as_deref(),
})
}
@ -173,7 +161,7 @@ pub fn resolve_sandbox_status(config: &SandboxConfig, cwd: &Path) -> SandboxStat
#[must_use]
pub fn resolve_sandbox_status_for_request(request: &SandboxRequest, cwd: &Path) -> SandboxStatus {
let container = detect_container_environment();
let namespace_supported = cfg!(target_os = "linux") && command_exists("unshare");
let namespace_supported = cfg!(target_os = "linux") && unshare_user_namespace_works();
let network_supported = namespace_supported;
let filesystem_active =
request.enabled && request.filesystem_mode != FilesystemIsolationMode::Off;
@ -294,6 +282,27 @@ fn command_exists(command: &str) -> bool {
.is_some_and(|paths| env::split_paths(&paths).any(|path| path.join(command).exists()))
}
/// Check whether `unshare --user` actually works on this system.
/// On some CI environments (e.g. GitHub Actions), the binary exists but
/// user namespaces are restricted, causing silent failures.
fn unshare_user_namespace_works() -> bool {
use std::sync::OnceLock;
static RESULT: OnceLock<bool> = OnceLock::new();
*RESULT.get_or_init(|| {
if !command_exists("unshare") {
return false;
}
std::process::Command::new("unshare")
.args(["--user", "--map-root-user", "true"])
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false)
})
}
#[cfg(test)]
mod tests {
use super::{

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,873 @@
#![allow(dead_code)]
use std::env;
use std::fmt::{Display, Formatter};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::UNIX_EPOCH;
use crate::session::{Session, SessionError};
/// Per-worktree session store that namespaces on-disk session files by
/// workspace fingerprint so that parallel `opencode serve` instances never
/// collide.
///
/// Create via [`SessionStore::from_cwd`] (derives the store path from the
/// server's working directory) or [`SessionStore::from_data_dir`] (honours an
/// explicit `--data-dir` flag). Both constructors produce a directory layout
/// of `<data_dir>/sessions/<workspace_hash>/` where `<workspace_hash>` is a
/// stable hex digest of the canonical workspace root.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionStore {
/// Resolved root of the session namespace, e.g.
/// `/home/user/project/.claw/sessions/a1b2c3d4e5f60718/`.
sessions_root: PathBuf,
/// The canonical workspace path that was fingerprinted.
workspace_root: PathBuf,
}
impl SessionStore {
/// Build a store from the server's current working directory.
///
/// The on-disk layout becomes `<cwd>/.claw/sessions/<workspace_hash>/`.
pub fn from_cwd(cwd: impl AsRef<Path>) -> Result<Self, SessionControlError> {
let cwd = cwd.as_ref();
let sessions_root = cwd
.join(".claw")
.join("sessions")
.join(workspace_fingerprint(cwd));
fs::create_dir_all(&sessions_root)?;
Ok(Self {
sessions_root,
workspace_root: cwd.to_path_buf(),
})
}
/// Build a store from an explicit `--data-dir` flag.
///
/// The on-disk layout becomes `<data_dir>/sessions/<workspace_hash>/`
/// where `<workspace_hash>` is derived from `workspace_root`.
pub fn from_data_dir(
data_dir: impl AsRef<Path>,
workspace_root: impl AsRef<Path>,
) -> Result<Self, SessionControlError> {
let workspace_root = workspace_root.as_ref();
let sessions_root = data_dir
.as_ref()
.join("sessions")
.join(workspace_fingerprint(workspace_root));
fs::create_dir_all(&sessions_root)?;
Ok(Self {
sessions_root,
workspace_root: workspace_root.to_path_buf(),
})
}
/// The fully resolved sessions directory for this namespace.
#[must_use]
pub fn sessions_dir(&self) -> &Path {
&self.sessions_root
}
/// The workspace root this store is bound to.
#[must_use]
pub fn workspace_root(&self) -> &Path {
&self.workspace_root
}
pub fn create_handle(&self, session_id: &str) -> SessionHandle {
let id = session_id.to_string();
let path = self
.sessions_root
.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}"));
SessionHandle { id, path }
}
pub fn resolve_reference(&self, reference: &str) -> Result<SessionHandle, SessionControlError> {
if is_session_reference_alias(reference) {
let latest = self.latest_session()?;
return Ok(SessionHandle {
id: latest.id,
path: latest.path,
});
}
let direct = PathBuf::from(reference);
let candidate = if direct.is_absolute() {
direct.clone()
} else {
self.workspace_root.join(&direct)
};
let looks_like_path = direct.extension().is_some() || direct.components().count() > 1;
let path = if candidate.exists() {
candidate
} else if looks_like_path {
return Err(SessionControlError::Format(
format_missing_session_reference(reference),
));
} else {
self.resolve_managed_path(reference)?
};
Ok(SessionHandle {
id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()),
path,
})
}
pub fn resolve_managed_path(&self, session_id: &str) -> Result<PathBuf, SessionControlError> {
for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] {
let path = self.sessions_root.join(format!("{session_id}.{extension}"));
if path.exists() {
return Ok(path);
}
}
Err(SessionControlError::Format(
format_missing_session_reference(session_id),
))
}
pub fn list_sessions(&self) -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
let mut sessions = Vec::new();
let read_result = fs::read_dir(&self.sessions_root);
let entries = match read_result {
Ok(entries) => entries,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(sessions),
Err(err) => return Err(err.into()),
};
for entry in entries {
let entry = entry?;
let path = entry.path();
if !is_managed_session_file(&path) {
continue;
}
let metadata = entry.metadata()?;
let modified_epoch_millis = metadata
.modified()
.ok()
.and_then(|time| time.duration_since(UNIX_EPOCH).ok())
.map(|duration| duration.as_millis())
.unwrap_or_default();
let (id, message_count, parent_session_id, branch_name) =
match Session::load_from_path(&path) {
Ok(session) => {
let parent_session_id = session
.fork
.as_ref()
.map(|fork| fork.parent_session_id.clone());
let branch_name = session
.fork
.as_ref()
.and_then(|fork| fork.branch_name.clone());
(
session.session_id,
session.messages.len(),
parent_session_id,
branch_name,
)
}
Err(_) => (
path.file_stem()
.and_then(|value| value.to_str())
.unwrap_or("unknown")
.to_string(),
0,
None,
None,
),
};
sessions.push(ManagedSessionSummary {
id,
path,
modified_epoch_millis,
message_count,
parent_session_id,
branch_name,
});
}
sessions.sort_by(|left, right| {
right
.modified_epoch_millis
.cmp(&left.modified_epoch_millis)
.then_with(|| right.id.cmp(&left.id))
});
Ok(sessions)
}
pub fn latest_session(&self) -> Result<ManagedSessionSummary, SessionControlError> {
self.list_sessions()?
.into_iter()
.next()
.ok_or_else(|| SessionControlError::Format(format_no_managed_sessions()))
}
pub fn load_session(
&self,
reference: &str,
) -> Result<LoadedManagedSession, SessionControlError> {
let handle = self.resolve_reference(reference)?;
let session = Session::load_from_path(&handle.path)?;
Ok(LoadedManagedSession {
handle: SessionHandle {
id: session.session_id.clone(),
path: handle.path,
},
session,
})
}
pub fn fork_session(
&self,
session: &Session,
branch_name: Option<String>,
) -> Result<ForkedManagedSession, SessionControlError> {
let parent_session_id = session.session_id.clone();
let forked = session.fork(branch_name);
let handle = self.create_handle(&forked.session_id);
let branch_name = forked
.fork
.as_ref()
.and_then(|fork| fork.branch_name.clone());
let forked = forked.with_persistence_path(handle.path.clone());
forked.save_to_path(&handle.path)?;
Ok(ForkedManagedSession {
parent_session_id,
handle,
session: forked,
branch_name,
})
}
}
/// Stable hex fingerprint of a workspace path.
///
/// Uses FNV-1a (64-bit) to produce a 16-char hex string that partitions the
/// on-disk session directory per workspace root.
#[must_use]
pub fn workspace_fingerprint(workspace_root: &Path) -> String {
let input = workspace_root.to_string_lossy();
let mut hash = 0xcbf2_9ce4_8422_2325_u64;
for byte in input.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
format!("{hash:016x}")
}
pub const PRIMARY_SESSION_EXTENSION: &str = "jsonl";
pub const LEGACY_SESSION_EXTENSION: &str = "json";
pub const LATEST_SESSION_REFERENCE: &str = "latest";
const SESSION_REFERENCE_ALIASES: &[&str] = &[LATEST_SESSION_REFERENCE, "last", "recent"];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionHandle {
pub id: String,
pub path: PathBuf,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ManagedSessionSummary {
pub id: String,
pub path: PathBuf,
pub modified_epoch_millis: u128,
pub message_count: usize,
pub parent_session_id: Option<String>,
pub branch_name: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LoadedManagedSession {
pub handle: SessionHandle,
pub session: Session,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ForkedManagedSession {
pub parent_session_id: String,
pub handle: SessionHandle,
pub session: Session,
pub branch_name: Option<String>,
}
#[derive(Debug)]
pub enum SessionControlError {
Io(std::io::Error),
Session(SessionError),
Format(String),
}
impl Display for SessionControlError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(error) => write!(f, "{error}"),
Self::Session(error) => write!(f, "{error}"),
Self::Format(error) => write!(f, "{error}"),
}
}
}
impl std::error::Error for SessionControlError {}
impl From<std::io::Error> for SessionControlError {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
impl From<SessionError> for SessionControlError {
fn from(value: SessionError) -> Self {
Self::Session(value)
}
}
pub fn sessions_dir() -> Result<PathBuf, SessionControlError> {
managed_sessions_dir_for(env::current_dir()?)
}
pub fn managed_sessions_dir_for(
base_dir: impl AsRef<Path>,
) -> Result<PathBuf, SessionControlError> {
let path = base_dir.as_ref().join(".claw").join("sessions");
fs::create_dir_all(&path)?;
Ok(path)
}
pub fn create_managed_session_handle(
session_id: &str,
) -> Result<SessionHandle, SessionControlError> {
create_managed_session_handle_for(env::current_dir()?, session_id)
}
pub fn create_managed_session_handle_for(
base_dir: impl AsRef<Path>,
session_id: &str,
) -> Result<SessionHandle, SessionControlError> {
let id = session_id.to_string();
let path =
managed_sessions_dir_for(base_dir)?.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}"));
Ok(SessionHandle { id, path })
}
pub fn resolve_session_reference(reference: &str) -> Result<SessionHandle, SessionControlError> {
resolve_session_reference_for(env::current_dir()?, reference)
}
pub fn resolve_session_reference_for(
base_dir: impl AsRef<Path>,
reference: &str,
) -> Result<SessionHandle, SessionControlError> {
let base_dir = base_dir.as_ref();
if is_session_reference_alias(reference) {
let latest = latest_managed_session_for(base_dir)?;
return Ok(SessionHandle {
id: latest.id,
path: latest.path,
});
}
let direct = PathBuf::from(reference);
let candidate = if direct.is_absolute() {
direct.clone()
} else {
base_dir.join(&direct)
};
let looks_like_path = direct.extension().is_some() || direct.components().count() > 1;
let path = if candidate.exists() {
candidate
} else if looks_like_path {
return Err(SessionControlError::Format(
format_missing_session_reference(reference),
));
} else {
resolve_managed_session_path_for(base_dir, reference)?
};
Ok(SessionHandle {
id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()),
path,
})
}
pub fn resolve_managed_session_path(session_id: &str) -> Result<PathBuf, SessionControlError> {
resolve_managed_session_path_for(env::current_dir()?, session_id)
}
pub fn resolve_managed_session_path_for(
base_dir: impl AsRef<Path>,
session_id: &str,
) -> Result<PathBuf, SessionControlError> {
let directory = managed_sessions_dir_for(base_dir)?;
for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] {
let path = directory.join(format!("{session_id}.{extension}"));
if path.exists() {
return Ok(path);
}
}
Err(SessionControlError::Format(
format_missing_session_reference(session_id),
))
}
#[must_use]
pub fn is_managed_session_file(path: &Path) -> bool {
path.extension()
.and_then(|ext| ext.to_str())
.is_some_and(|extension| {
extension == PRIMARY_SESSION_EXTENSION || extension == LEGACY_SESSION_EXTENSION
})
}
pub fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
list_managed_sessions_for(env::current_dir()?)
}
pub fn list_managed_sessions_for(
base_dir: impl AsRef<Path>,
) -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
let mut sessions = Vec::new();
for entry in fs::read_dir(managed_sessions_dir_for(base_dir)?)? {
let entry = entry?;
let path = entry.path();
if !is_managed_session_file(&path) {
continue;
}
let metadata = entry.metadata()?;
let modified_epoch_millis = metadata
.modified()
.ok()
.and_then(|time| time.duration_since(UNIX_EPOCH).ok())
.map(|duration| duration.as_millis())
.unwrap_or_default();
let (id, message_count, parent_session_id, branch_name) =
match Session::load_from_path(&path) {
Ok(session) => {
let parent_session_id = session
.fork
.as_ref()
.map(|fork| fork.parent_session_id.clone());
let branch_name = session
.fork
.as_ref()
.and_then(|fork| fork.branch_name.clone());
(
session.session_id,
session.messages.len(),
parent_session_id,
branch_name,
)
}
Err(_) => (
path.file_stem()
.and_then(|value| value.to_str())
.unwrap_or("unknown")
.to_string(),
0,
None,
None,
),
};
sessions.push(ManagedSessionSummary {
id,
path,
modified_epoch_millis,
message_count,
parent_session_id,
branch_name,
});
}
sessions.sort_by(|left, right| {
right
.modified_epoch_millis
.cmp(&left.modified_epoch_millis)
.then_with(|| right.id.cmp(&left.id))
});
Ok(sessions)
}
pub fn latest_managed_session() -> Result<ManagedSessionSummary, SessionControlError> {
latest_managed_session_for(env::current_dir()?)
}
pub fn latest_managed_session_for(
base_dir: impl AsRef<Path>,
) -> Result<ManagedSessionSummary, SessionControlError> {
list_managed_sessions_for(base_dir)?
.into_iter()
.next()
.ok_or_else(|| SessionControlError::Format(format_no_managed_sessions()))
}
pub fn load_managed_session(reference: &str) -> Result<LoadedManagedSession, SessionControlError> {
load_managed_session_for(env::current_dir()?, reference)
}
pub fn load_managed_session_for(
base_dir: impl AsRef<Path>,
reference: &str,
) -> Result<LoadedManagedSession, SessionControlError> {
let handle = resolve_session_reference_for(base_dir, reference)?;
let session = Session::load_from_path(&handle.path)?;
Ok(LoadedManagedSession {
handle: SessionHandle {
id: session.session_id.clone(),
path: handle.path,
},
session,
})
}
pub fn fork_managed_session(
session: &Session,
branch_name: Option<String>,
) -> Result<ForkedManagedSession, SessionControlError> {
fork_managed_session_for(env::current_dir()?, session, branch_name)
}
pub fn fork_managed_session_for(
base_dir: impl AsRef<Path>,
session: &Session,
branch_name: Option<String>,
) -> Result<ForkedManagedSession, SessionControlError> {
let parent_session_id = session.session_id.clone();
let forked = session.fork(branch_name);
let handle = create_managed_session_handle_for(base_dir, &forked.session_id)?;
let branch_name = forked
.fork
.as_ref()
.and_then(|fork| fork.branch_name.clone());
let forked = forked.with_persistence_path(handle.path.clone());
forked.save_to_path(&handle.path)?;
Ok(ForkedManagedSession {
parent_session_id,
handle,
session: forked,
branch_name,
})
}
#[must_use]
pub fn is_session_reference_alias(reference: &str) -> bool {
SESSION_REFERENCE_ALIASES
.iter()
.any(|alias| reference.eq_ignore_ascii_case(alias))
}
fn session_id_from_path(path: &Path) -> Option<String> {
path.file_name()
.and_then(|value| value.to_str())
.and_then(|name| {
name.strip_suffix(&format!(".{PRIMARY_SESSION_EXTENSION}"))
.or_else(|| name.strip_suffix(&format!(".{LEGACY_SESSION_EXTENSION}")))
})
.map(ToOwned::to_owned)
}
fn format_missing_session_reference(reference: &str) -> String {
format!(
"session not found: {reference}\nHint: managed sessions live in .claw/sessions/. Try `{LATEST_SESSION_REFERENCE}` for the most recent session or `/session list` in the REPL."
)
}
fn format_no_managed_sessions() -> String {
format!(
"no managed sessions found in .claw/sessions/\nStart `claw` to create a session, then rerun with `--resume {LATEST_SESSION_REFERENCE}`."
)
}
#[cfg(test)]
mod tests {
use super::{
create_managed_session_handle_for, fork_managed_session_for, is_session_reference_alias,
list_managed_sessions_for, load_managed_session_for, resolve_session_reference_for,
workspace_fingerprint, ManagedSessionSummary, SessionStore, LATEST_SESSION_REFERENCE,
};
use crate::session::Session;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_dir() -> PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("runtime-session-control-{nanos}"))
}
fn persist_session(root: &Path, text: &str) -> Session {
let mut session = Session::new();
session
.push_user_text(text)
.expect("session message should save");
let handle = create_managed_session_handle_for(root, &session.session_id)
.expect("managed session handle should build");
let session = session.with_persistence_path(handle.path.clone());
session
.save_to_path(&handle.path)
.expect("session should persist");
session
}
fn wait_for_next_millisecond() {
let start = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after epoch")
.as_millis();
while SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after epoch")
.as_millis()
<= start
{}
}
fn summary_by_id<'a>(
summaries: &'a [ManagedSessionSummary],
id: &str,
) -> &'a ManagedSessionSummary {
summaries
.iter()
.find(|summary| summary.id == id)
.expect("session summary should exist")
}
#[test]
fn creates_and_lists_managed_sessions() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("root dir should exist");
let older = persist_session(&root, "older session");
wait_for_next_millisecond();
let newer = persist_session(&root, "newer session");
// when
let sessions = list_managed_sessions_for(&root).expect("managed sessions should list");
// then
assert_eq!(sessions.len(), 2);
assert_eq!(sessions[0].id, newer.session_id);
assert_eq!(summary_by_id(&sessions, &older.session_id).message_count, 1);
assert_eq!(summary_by_id(&sessions, &newer.session_id).message_count, 1);
fs::remove_dir_all(root).expect("temp dir should clean up");
}
#[test]
fn resolves_latest_alias_and_loads_session_from_workspace_root() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("root dir should exist");
let older = persist_session(&root, "older session");
wait_for_next_millisecond();
let newer = persist_session(&root, "newer session");
// when
let handle = resolve_session_reference_for(&root, LATEST_SESSION_REFERENCE)
.expect("latest alias should resolve");
let loaded = load_managed_session_for(&root, "recent")
.expect("recent alias should load the latest session");
// then
assert_eq!(handle.id, newer.session_id);
assert_eq!(loaded.handle.id, newer.session_id);
assert_eq!(loaded.session.messages.len(), 1);
assert_ne!(loaded.handle.id, older.session_id);
assert!(is_session_reference_alias("last"));
fs::remove_dir_all(root).expect("temp dir should clean up");
}
#[test]
fn forks_session_into_managed_storage_with_lineage() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("root dir should exist");
let source = persist_session(&root, "parent session");
// when
let forked = fork_managed_session_for(&root, &source, Some("incident-review".to_string()))
.expect("session should fork");
let sessions = list_managed_sessions_for(&root).expect("managed sessions should list");
let summary = summary_by_id(&sessions, &forked.handle.id);
// then
assert_eq!(forked.parent_session_id, source.session_id);
assert_eq!(forked.branch_name.as_deref(), Some("incident-review"));
assert_eq!(
summary.parent_session_id.as_deref(),
Some(source.session_id.as_str())
);
assert_eq!(summary.branch_name.as_deref(), Some("incident-review"));
assert_eq!(
forked.session.persistence_path(),
Some(forked.handle.path.as_path())
);
fs::remove_dir_all(root).expect("temp dir should clean up");
}
// ------------------------------------------------------------------
// Per-worktree session isolation (SessionStore) tests
// ------------------------------------------------------------------
fn persist_session_via_store(store: &SessionStore, text: &str) -> Session {
let mut session = Session::new();
session
.push_user_text(text)
.expect("session message should save");
let handle = store.create_handle(&session.session_id);
let session = session.with_persistence_path(handle.path.clone());
session
.save_to_path(&handle.path)
.expect("session should persist");
session
}
#[test]
fn workspace_fingerprint_is_deterministic_and_differs_per_path() {
// given
let path_a = Path::new("/tmp/worktree-alpha");
let path_b = Path::new("/tmp/worktree-beta");
// when
let fp_a1 = workspace_fingerprint(path_a);
let fp_a2 = workspace_fingerprint(path_a);
let fp_b = workspace_fingerprint(path_b);
// then
assert_eq!(fp_a1, fp_a2, "same path must produce the same fingerprint");
assert_ne!(
fp_a1, fp_b,
"different paths must produce different fingerprints"
);
assert_eq!(fp_a1.len(), 16, "fingerprint must be a 16-char hex string");
}
#[test]
fn session_store_from_cwd_isolates_sessions_by_workspace() {
// given
let base = temp_dir();
let workspace_a = base.join("repo-alpha");
let workspace_b = base.join("repo-beta");
fs::create_dir_all(&workspace_a).expect("workspace a should exist");
fs::create_dir_all(&workspace_b).expect("workspace b should exist");
let store_a = SessionStore::from_cwd(&workspace_a).expect("store a should build");
let store_b = SessionStore::from_cwd(&workspace_b).expect("store b should build");
// when
let session_a = persist_session_via_store(&store_a, "alpha work");
let _session_b = persist_session_via_store(&store_b, "beta work");
// then — each store only sees its own sessions
let list_a = store_a.list_sessions().expect("list a");
let list_b = store_b.list_sessions().expect("list b");
assert_eq!(list_a.len(), 1, "store a should see exactly one session");
assert_eq!(list_b.len(), 1, "store b should see exactly one session");
assert_eq!(list_a[0].id, session_a.session_id);
assert_ne!(
store_a.sessions_dir(),
store_b.sessions_dir(),
"session directories must differ across workspaces"
);
fs::remove_dir_all(base).expect("temp dir should clean up");
}
#[test]
fn session_store_from_data_dir_namespaces_by_workspace() {
// given
let base = temp_dir();
let data_dir = base.join("global-data");
let workspace_a = PathBuf::from("/tmp/project-one");
let workspace_b = PathBuf::from("/tmp/project-two");
fs::create_dir_all(&data_dir).expect("data dir should exist");
let store_a =
SessionStore::from_data_dir(&data_dir, &workspace_a).expect("store a should build");
let store_b =
SessionStore::from_data_dir(&data_dir, &workspace_b).expect("store b should build");
// when
persist_session_via_store(&store_a, "work in project-one");
persist_session_via_store(&store_b, "work in project-two");
// then
assert_ne!(
store_a.sessions_dir(),
store_b.sessions_dir(),
"data-dir stores must namespace by workspace"
);
assert_eq!(store_a.list_sessions().expect("list a").len(), 1);
assert_eq!(store_b.list_sessions().expect("list b").len(), 1);
assert_eq!(store_a.workspace_root(), workspace_a.as_path());
assert_eq!(store_b.workspace_root(), workspace_b.as_path());
fs::remove_dir_all(base).expect("temp dir should clean up");
}
#[test]
fn session_store_create_and_load_round_trip() {
// given
let base = temp_dir();
fs::create_dir_all(&base).expect("base dir should exist");
let store = SessionStore::from_cwd(&base).expect("store should build");
let session = persist_session_via_store(&store, "round-trip message");
// when
let loaded = store
.load_session(&session.session_id)
.expect("session should load via store");
// then
assert_eq!(loaded.handle.id, session.session_id);
assert_eq!(loaded.session.messages.len(), 1);
fs::remove_dir_all(base).expect("temp dir should clean up");
}
#[test]
fn session_store_latest_and_resolve_reference() {
// given
let base = temp_dir();
fs::create_dir_all(&base).expect("base dir should exist");
let store = SessionStore::from_cwd(&base).expect("store should build");
let _older = persist_session_via_store(&store, "older");
wait_for_next_millisecond();
let newer = persist_session_via_store(&store, "newer");
// when
let latest = store.latest_session().expect("latest should resolve");
let handle = store
.resolve_reference("latest")
.expect("latest alias should resolve");
// then
assert_eq!(latest.id, newer.session_id);
assert_eq!(handle.id, newer.session_id);
fs::remove_dir_all(base).expect("temp dir should clean up");
}
#[test]
fn session_store_fork_stays_in_same_namespace() {
// given
let base = temp_dir();
fs::create_dir_all(&base).expect("base dir should exist");
let store = SessionStore::from_cwd(&base).expect("store should build");
let source = persist_session_via_store(&store, "parent work");
// when
let forked = store
.fork_session(&source, Some("bugfix".to_string()))
.expect("fork should succeed");
let sessions = store.list_sessions().expect("list sessions");
// then
assert_eq!(
sessions.len(),
2,
"forked session must land in the same namespace"
);
assert_eq!(forked.parent_session_id, source.session_id);
assert_eq!(forked.branch_name.as_deref(), Some("bugfix"));
assert!(
forked.handle.path.starts_with(store.sessions_dir()),
"forked session path must be inside the store namespace"
);
fs::remove_dir_all(base).expect("temp dir should clean up");
}
}

View File

@ -80,7 +80,11 @@ impl IncrementalSseParser {
}
fn take_event(&mut self) -> Option<SseEvent> {
if self.data_lines.is_empty() && self.event_name.is_none() && self.id.is_none() && self.retry.is_none() {
if self.data_lines.is_empty()
&& self.event_name.is_none()
&& self.id.is_none()
&& self.retry.is_none()
{
return None;
}
@ -102,8 +106,13 @@ mod tests {
#[test]
fn parses_streaming_events() {
// given
let mut parser = IncrementalSseParser::new();
// when
let first = parser.push_chunk("event: message\ndata: hel");
// then
assert!(first.is_empty());
let second = parser.push_chunk("lo\n\nid: 1\ndata: world\n\n");
@ -125,4 +134,25 @@ mod tests {
]
);
}
#[test]
fn finish_flushes_a_trailing_event_without_separator() {
// given
let mut parser = IncrementalSseParser::new();
parser.push_chunk("event: message\ndata: trailing");
// when
let events = parser.finish();
// then
assert_eq!(
events,
vec![SseEvent {
event: Some("message".to_string()),
data: "trailing".to_string(),
id: None,
retry: None,
}]
);
}
}

View File

@ -0,0 +1,429 @@
#![allow(clippy::must_use_candidate)]
use std::path::Path;
use std::process::Command;
/// Outcome of comparing the worktree HEAD against the expected base commit.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BaseCommitState {
/// HEAD matches the expected base commit.
Matches,
/// HEAD has diverged from the expected base.
Diverged { expected: String, actual: String },
/// No expected base was supplied (neither flag nor file).
NoExpectedBase,
/// The working directory is not inside a git repository.
NotAGitRepo,
}
/// Where the expected base commit originated from.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BaseCommitSource {
Flag(String),
File(String),
}
/// Read the `.claw-base` file from the given directory and return the trimmed
/// commit hash, or `None` when the file is absent or empty.
pub fn read_claw_base_file(cwd: &Path) -> Option<String> {
let path = cwd.join(".claw-base");
let content = std::fs::read_to_string(path).ok()?;
let trimmed = content.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
/// Resolve the expected base commit: prefer the `--base-commit` flag value,
/// fall back to reading `.claw-base` from `cwd`.
pub fn resolve_expected_base(flag_value: Option<&str>, cwd: &Path) -> Option<BaseCommitSource> {
if let Some(value) = flag_value {
let trimmed = value.trim();
if !trimmed.is_empty() {
return Some(BaseCommitSource::Flag(trimmed.to_string()));
}
}
read_claw_base_file(cwd).map(BaseCommitSource::File)
}
/// Verify that the worktree HEAD matches `expected_base`.
///
/// Returns [`BaseCommitState::NoExpectedBase`] when no expected commit is
/// provided (the check is effectively a no-op in that case).
pub fn check_base_commit(cwd: &Path, expected_base: Option<&BaseCommitSource>) -> BaseCommitState {
let Some(source) = expected_base else {
return BaseCommitState::NoExpectedBase;
};
let expected_raw = match source {
BaseCommitSource::Flag(value) | BaseCommitSource::File(value) => value.as_str(),
};
let Some(head_sha) = resolve_head_sha(cwd) else {
return BaseCommitState::NotAGitRepo;
};
let Some(expected_sha) = resolve_rev(cwd, expected_raw) else {
// If the expected ref cannot be resolved, compare raw strings as a
// best-effort fallback (e.g. partial SHA provided by the caller).
return if head_sha.starts_with(expected_raw) || expected_raw.starts_with(&head_sha) {
BaseCommitState::Matches
} else {
BaseCommitState::Diverged {
expected: expected_raw.to_string(),
actual: head_sha,
}
};
};
if head_sha == expected_sha {
BaseCommitState::Matches
} else {
BaseCommitState::Diverged {
expected: expected_sha,
actual: head_sha,
}
}
}
/// Format a human-readable warning when the base commit has diverged.
///
/// Returns `None` for non-warning states (`Matches`, `NoExpectedBase`).
pub fn format_stale_base_warning(state: &BaseCommitState) -> Option<String> {
match state {
BaseCommitState::Diverged { expected, actual } => Some(format!(
"warning: worktree HEAD ({actual}) does not match expected base commit ({expected}). \
Session may run against a stale codebase."
)),
BaseCommitState::NotAGitRepo => {
Some("warning: stale-base check skipped — not inside a git repository.".to_string())
}
BaseCommitState::Matches | BaseCommitState::NoExpectedBase => None,
}
}
fn resolve_head_sha(cwd: &Path) -> Option<String> {
resolve_rev(cwd, "HEAD")
}
fn resolve_rev(cwd: &Path, rev: &str) -> Option<String> {
let output = Command::new("git")
.args(["rev-parse", rev])
.current_dir(cwd)
.output()
.ok()?;
if !output.status.success() {
return None;
}
let sha = String::from_utf8(output.stdout).ok()?;
let trimmed = sha.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::process::Command;
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_dir() -> std::path::PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("runtime-stale-base-{nanos}"))
}
fn init_repo(path: &std::path::Path) {
fs::create_dir_all(path).expect("create repo dir");
run(path, &["init", "--quiet", "-b", "main"]);
run(path, &["config", "user.email", "tests@example.com"]);
run(path, &["config", "user.name", "Stale Base Tests"]);
fs::write(path.join("init.txt"), "initial\n").expect("write init file");
run(path, &["add", "."]);
run(path, &["commit", "-m", "initial commit", "--quiet"]);
}
fn run(cwd: &std::path::Path, args: &[&str]) {
let status = Command::new("git")
.args(args)
.current_dir(cwd)
.status()
.unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" ")));
assert!(
status.success(),
"git {} exited with {status}",
args.join(" ")
);
}
fn commit_file(repo: &std::path::Path, name: &str, msg: &str) {
fs::write(repo.join(name), format!("{msg}\n")).expect("write file");
run(repo, &["add", name]);
run(repo, &["commit", "-m", msg, "--quiet"]);
}
fn head_sha(repo: &std::path::Path) -> String {
let output = Command::new("git")
.args(["rev-parse", "HEAD"])
.current_dir(repo)
.output()
.expect("git rev-parse HEAD");
String::from_utf8(output.stdout)
.expect("valid utf8")
.trim()
.to_string()
}
#[test]
fn matches_when_head_equals_expected_base() {
// given
let root = temp_dir();
init_repo(&root);
let sha = head_sha(&root);
let source = BaseCommitSource::Flag(sha);
// when
let state = check_base_commit(&root, Some(&source));
// then
assert_eq!(state, BaseCommitState::Matches);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn diverged_when_head_moved_past_expected_base() {
// given
let root = temp_dir();
init_repo(&root);
let old_sha = head_sha(&root);
commit_file(&root, "extra.txt", "move head forward");
let new_sha = head_sha(&root);
let source = BaseCommitSource::Flag(old_sha.clone());
// when
let state = check_base_commit(&root, Some(&source));
// then
assert_eq!(
state,
BaseCommitState::Diverged {
expected: old_sha,
actual: new_sha,
}
);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn no_expected_base_when_source_is_none() {
// given
let root = temp_dir();
init_repo(&root);
// when
let state = check_base_commit(&root, None);
// then
assert_eq!(state, BaseCommitState::NoExpectedBase);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn not_a_git_repo_when_outside_repo() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("create dir");
let source = BaseCommitSource::Flag("abc1234".to_string());
// when
let state = check_base_commit(&root, Some(&source));
// then
assert_eq!(state, BaseCommitState::NotAGitRepo);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn reads_claw_base_file() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("create dir");
fs::write(root.join(".claw-base"), "abc1234def5678\n").expect("write .claw-base");
// when
let value = read_claw_base_file(&root);
// then
assert_eq!(value, Some("abc1234def5678".to_string()));
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn returns_none_for_missing_claw_base_file() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("create dir");
// when
let value = read_claw_base_file(&root);
// then
assert!(value.is_none());
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn returns_none_for_empty_claw_base_file() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("create dir");
fs::write(root.join(".claw-base"), " \n").expect("write empty .claw-base");
// when
let value = read_claw_base_file(&root);
// then
assert!(value.is_none());
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn resolve_expected_base_prefers_flag_over_file() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("create dir");
fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base");
// when
let source = resolve_expected_base(Some("from_flag"), &root);
// then
assert_eq!(
source,
Some(BaseCommitSource::Flag("from_flag".to_string()))
);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn resolve_expected_base_falls_back_to_file() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("create dir");
fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base");
// when
let source = resolve_expected_base(None, &root);
// then
assert_eq!(
source,
Some(BaseCommitSource::File("from_file".to_string()))
);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn resolve_expected_base_returns_none_when_nothing_available() {
// given
let root = temp_dir();
fs::create_dir_all(&root).expect("create dir");
// when
let source = resolve_expected_base(None, &root);
// then
assert!(source.is_none());
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn format_warning_returns_message_for_diverged() {
// given
let state = BaseCommitState::Diverged {
expected: "abc1234".to_string(),
actual: "def5678".to_string(),
};
// when
let warning = format_stale_base_warning(&state);
// then
let message = warning.expect("should produce warning");
assert!(message.contains("abc1234"));
assert!(message.contains("def5678"));
assert!(message.contains("stale codebase"));
}
#[test]
fn format_warning_returns_none_for_matches() {
// given
let state = BaseCommitState::Matches;
// when
let warning = format_stale_base_warning(&state);
// then
assert!(warning.is_none());
}
#[test]
fn format_warning_returns_none_for_no_expected_base() {
// given
let state = BaseCommitState::NoExpectedBase;
// when
let warning = format_stale_base_warning(&state);
// then
assert!(warning.is_none());
}
#[test]
fn matches_with_claw_base_file_in_real_repo() {
// given
let root = temp_dir();
init_repo(&root);
let sha = head_sha(&root);
fs::write(root.join(".claw-base"), format!("{sha}\n")).expect("write .claw-base");
let source = resolve_expected_base(None, &root);
// when
let state = check_base_commit(&root, source.as_ref());
// then
assert_eq!(state, BaseCommitState::Matches);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn diverged_with_claw_base_file_after_new_commit() {
// given
let root = temp_dir();
init_repo(&root);
let old_sha = head_sha(&root);
fs::write(root.join(".claw-base"), format!("{old_sha}\n")).expect("write .claw-base");
commit_file(&root, "new.txt", "advance head");
let new_sha = head_sha(&root);
let source = resolve_expected_base(None, &root);
// when
let state = check_base_commit(&root, source.as_ref());
// then
assert_eq!(
state,
BaseCommitState::Diverged {
expected: old_sha,
actual: new_sha,
}
);
fs::remove_dir_all(&root).expect("cleanup");
}
}

View File

@ -0,0 +1,417 @@
#![allow(clippy::must_use_candidate)]
use std::path::Path;
use std::process::Command;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BranchFreshness {
Fresh,
Stale {
commits_behind: usize,
missing_fixes: Vec<String>,
},
Diverged {
ahead: usize,
behind: usize,
missing_fixes: Vec<String>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StaleBranchPolicy {
AutoRebase,
AutoMergeForward,
WarnOnly,
Block,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StaleBranchEvent {
BranchStaleAgainstMain {
branch: String,
commits_behind: usize,
missing_fixes: Vec<String>,
},
RebaseAttempted {
branch: String,
result: String,
},
MergeForwardAttempted {
branch: String,
result: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StaleBranchAction {
Noop,
Warn { message: String },
Block { message: String },
Rebase,
MergeForward,
}
pub fn check_freshness(branch: &str, main_ref: &str) -> BranchFreshness {
check_freshness_in(branch, main_ref, Path::new("."))
}
pub fn apply_policy(freshness: &BranchFreshness, policy: StaleBranchPolicy) -> StaleBranchAction {
match freshness {
BranchFreshness::Fresh => StaleBranchAction::Noop,
BranchFreshness::Stale {
commits_behind,
missing_fixes,
} => match policy {
StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn {
message: format!(
"Branch is {commits_behind} commit(s) behind main. Missing fixes: {}",
if missing_fixes.is_empty() {
"(none)".to_string()
} else {
missing_fixes.join("; ")
}
),
},
StaleBranchPolicy::Block => StaleBranchAction::Block {
message: format!(
"Branch is {commits_behind} commit(s) behind main and must be updated before proceeding."
),
},
StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase,
StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward,
},
BranchFreshness::Diverged {
ahead,
behind,
missing_fixes,
} => match policy {
StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn {
message: format!(
"Branch has diverged: {ahead} commit(s) ahead, {behind} commit(s) behind main. Missing fixes: {}",
format_missing_fixes(missing_fixes)
),
},
StaleBranchPolicy::Block => StaleBranchAction::Block {
message: format!(
"Branch has diverged ({ahead} ahead, {behind} behind) and must be reconciled before proceeding. Missing fixes: {}",
format_missing_fixes(missing_fixes)
),
},
StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase,
StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward,
},
}
}
pub(crate) fn check_freshness_in(
branch: &str,
main_ref: &str,
repo_path: &Path,
) -> BranchFreshness {
let behind = rev_list_count(main_ref, branch, repo_path);
let ahead = rev_list_count(branch, main_ref, repo_path);
if behind == 0 {
return BranchFreshness::Fresh;
}
if ahead > 0 {
return BranchFreshness::Diverged {
ahead,
behind,
missing_fixes: missing_fix_subjects(main_ref, branch, repo_path),
};
}
let missing_fixes = missing_fix_subjects(main_ref, branch, repo_path);
BranchFreshness::Stale {
commits_behind: behind,
missing_fixes,
}
}
fn format_missing_fixes(missing_fixes: &[String]) -> String {
if missing_fixes.is_empty() {
"(none)".to_string()
} else {
missing_fixes.join("; ")
}
}
fn rev_list_count(a: &str, b: &str, repo_path: &Path) -> usize {
let output = Command::new("git")
.args(["rev-list", "--count", &format!("{b}..{a}")])
.current_dir(repo_path)
.output();
match output {
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout)
.trim()
.parse::<usize>()
.unwrap_or(0),
_ => 0,
}
}
fn missing_fix_subjects(a: &str, b: &str, repo_path: &Path) -> Vec<String> {
let output = Command::new("git")
.args(["log", "--format=%s", &format!("{b}..{a}")])
.current_dir(repo_path)
.output();
match output {
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout)
.lines()
.filter(|l| !l.is_empty())
.map(String::from)
.collect(),
_ => Vec::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_dir() -> std::path::PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("runtime-stale-branch-{nanos}"))
}
fn init_repo(path: &Path) {
fs::create_dir_all(path).expect("create repo dir");
run(path, &["init", "--quiet", "-b", "main"]);
run(path, &["config", "user.email", "tests@example.com"]);
run(path, &["config", "user.name", "Stale Branch Tests"]);
fs::write(path.join("init.txt"), "initial\n").expect("write init file");
run(path, &["add", "."]);
run(path, &["commit", "-m", "initial commit", "--quiet"]);
}
fn run(cwd: &Path, args: &[&str]) {
let status = Command::new("git")
.args(args)
.current_dir(cwd)
.status()
.unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" ")));
assert!(
status.success(),
"git {} exited with {status}",
args.join(" ")
);
}
fn commit_file(repo: &Path, name: &str, msg: &str) {
fs::write(repo.join(name), format!("{msg}\n")).expect("write file");
run(repo, &["add", name]);
run(repo, &["commit", "-m", msg, "--quiet"]);
}
#[test]
fn fresh_branch_passes() {
let root = temp_dir();
init_repo(&root);
// given
run(&root, &["checkout", "-b", "topic"]);
// when
let freshness = check_freshness_in("topic", "main", &root);
// then
assert_eq!(freshness, BranchFreshness::Fresh);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn fresh_branch_ahead_of_main_still_fresh() {
let root = temp_dir();
init_repo(&root);
// given
run(&root, &["checkout", "-b", "topic"]);
commit_file(&root, "feature.txt", "add feature");
// when
let freshness = check_freshness_in("topic", "main", &root);
// then
assert_eq!(freshness, BranchFreshness::Fresh);
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn stale_branch_detected_with_correct_behind_count_and_missing_fixes() {
let root = temp_dir();
init_repo(&root);
// given
run(&root, &["checkout", "-b", "topic"]);
run(&root, &["checkout", "main"]);
commit_file(&root, "fix1.txt", "fix: resolve timeout");
commit_file(&root, "fix2.txt", "fix: handle null pointer");
// when
let freshness = check_freshness_in("topic", "main", &root);
// then
match freshness {
BranchFreshness::Stale {
commits_behind,
missing_fixes,
} => {
assert_eq!(commits_behind, 2);
assert_eq!(missing_fixes.len(), 2);
assert_eq!(missing_fixes[0], "fix: handle null pointer");
assert_eq!(missing_fixes[1], "fix: resolve timeout");
}
other => panic!("expected Stale, got {other:?}"),
}
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn diverged_branch_detection() {
let root = temp_dir();
init_repo(&root);
// given
run(&root, &["checkout", "-b", "topic"]);
commit_file(&root, "topic_work.txt", "topic work");
run(&root, &["checkout", "main"]);
commit_file(&root, "main_fix.txt", "main fix");
// when
let freshness = check_freshness_in("topic", "main", &root);
// then
match freshness {
BranchFreshness::Diverged {
ahead,
behind,
missing_fixes,
} => {
assert_eq!(ahead, 1);
assert_eq!(behind, 1);
assert_eq!(missing_fixes, vec!["main fix".to_string()]);
}
other => panic!("expected Diverged, got {other:?}"),
}
fs::remove_dir_all(&root).expect("cleanup");
}
#[test]
fn policy_noop_for_fresh_branch() {
// given
let freshness = BranchFreshness::Fresh;
// when
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
// then
assert_eq!(action, StaleBranchAction::Noop);
}
#[test]
fn policy_warn_for_stale_branch() {
// given
let freshness = BranchFreshness::Stale {
commits_behind: 3,
missing_fixes: vec!["fix: timeout".into(), "fix: null ptr".into()],
};
// when
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
// then
match action {
StaleBranchAction::Warn { message } => {
assert!(message.contains("3 commit(s) behind"));
assert!(message.contains("fix: timeout"));
assert!(message.contains("fix: null ptr"));
}
other => panic!("expected Warn, got {other:?}"),
}
}
#[test]
fn policy_block_for_stale_branch() {
// given
let freshness = BranchFreshness::Stale {
commits_behind: 1,
missing_fixes: vec!["hotfix".into()],
};
// when
let action = apply_policy(&freshness, StaleBranchPolicy::Block);
// then
match action {
StaleBranchAction::Block { message } => {
assert!(message.contains("1 commit(s) behind"));
}
other => panic!("expected Block, got {other:?}"),
}
}
#[test]
fn policy_auto_rebase_for_stale_branch() {
// given
let freshness = BranchFreshness::Stale {
commits_behind: 2,
missing_fixes: vec![],
};
// when
let action = apply_policy(&freshness, StaleBranchPolicy::AutoRebase);
// then
assert_eq!(action, StaleBranchAction::Rebase);
}
#[test]
fn policy_auto_merge_forward_for_diverged_branch() {
// given
let freshness = BranchFreshness::Diverged {
ahead: 5,
behind: 2,
missing_fixes: vec!["fix: merge main".into()],
};
// when
let action = apply_policy(&freshness, StaleBranchPolicy::AutoMergeForward);
// then
assert_eq!(action, StaleBranchAction::MergeForward);
}
#[test]
fn policy_warn_for_diverged_branch() {
// given
let freshness = BranchFreshness::Diverged {
ahead: 3,
behind: 1,
missing_fixes: vec!["main hotfix".into()],
};
// when
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
// then
match action {
StaleBranchAction::Warn { message } => {
assert!(message.contains("diverged"));
assert!(message.contains("3 commit(s) ahead"));
assert!(message.contains("1 commit(s) behind"));
assert!(message.contains("main hotfix"));
}
other => panic!("expected Warn, got {other:?}"),
}
}
}

View File

@ -0,0 +1,300 @@
use std::collections::BTreeSet;
const DEFAULT_MAX_CHARS: usize = 1_200;
const DEFAULT_MAX_LINES: usize = 24;
const DEFAULT_MAX_LINE_CHARS: usize = 160;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SummaryCompressionBudget {
pub max_chars: usize,
pub max_lines: usize,
pub max_line_chars: usize,
}
impl Default for SummaryCompressionBudget {
fn default() -> Self {
Self {
max_chars: DEFAULT_MAX_CHARS,
max_lines: DEFAULT_MAX_LINES,
max_line_chars: DEFAULT_MAX_LINE_CHARS,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SummaryCompressionResult {
pub summary: String,
pub original_chars: usize,
pub compressed_chars: usize,
pub original_lines: usize,
pub compressed_lines: usize,
pub removed_duplicate_lines: usize,
pub omitted_lines: usize,
pub truncated: bool,
}
#[must_use]
pub fn compress_summary(
summary: &str,
budget: SummaryCompressionBudget,
) -> SummaryCompressionResult {
let original_chars = summary.chars().count();
let original_lines = summary.lines().count();
let normalized = normalize_lines(summary, budget.max_line_chars);
if normalized.lines.is_empty() || budget.max_chars == 0 || budget.max_lines == 0 {
return SummaryCompressionResult {
summary: String::new(),
original_chars,
compressed_chars: 0,
original_lines,
compressed_lines: 0,
removed_duplicate_lines: normalized.removed_duplicate_lines,
omitted_lines: normalized.lines.len(),
truncated: original_chars > 0,
};
}
let selected = select_line_indexes(&normalized.lines, budget);
let mut compressed_lines = selected
.iter()
.map(|index| normalized.lines[*index].clone())
.collect::<Vec<_>>();
if compressed_lines.is_empty() {
compressed_lines.push(truncate_line(&normalized.lines[0], budget.max_chars));
}
let omitted_lines = normalized
.lines
.len()
.saturating_sub(compressed_lines.len());
if omitted_lines > 0 {
let omission_notice = omission_notice(omitted_lines);
push_line_with_budget(&mut compressed_lines, omission_notice, budget);
}
let compressed_summary = compressed_lines.join("\n");
SummaryCompressionResult {
compressed_chars: compressed_summary.chars().count(),
compressed_lines: compressed_lines.len(),
removed_duplicate_lines: normalized.removed_duplicate_lines,
omitted_lines,
truncated: compressed_summary != summary.trim(),
summary: compressed_summary,
original_chars,
original_lines,
}
}
#[must_use]
pub fn compress_summary_text(summary: &str) -> String {
compress_summary(summary, SummaryCompressionBudget::default()).summary
}
#[derive(Debug, Default)]
struct NormalizedSummary {
lines: Vec<String>,
removed_duplicate_lines: usize,
}
fn normalize_lines(summary: &str, max_line_chars: usize) -> NormalizedSummary {
let mut seen = BTreeSet::new();
let mut lines = Vec::new();
let mut removed_duplicate_lines = 0;
for raw_line in summary.lines() {
let normalized = collapse_inline_whitespace(raw_line);
if normalized.is_empty() {
continue;
}
let truncated = truncate_line(&normalized, max_line_chars);
let dedupe_key = dedupe_key(&truncated);
if !seen.insert(dedupe_key) {
removed_duplicate_lines += 1;
continue;
}
lines.push(truncated);
}
NormalizedSummary {
lines,
removed_duplicate_lines,
}
}
fn select_line_indexes(lines: &[String], budget: SummaryCompressionBudget) -> Vec<usize> {
let mut selected = BTreeSet::<usize>::new();
for priority in 0..=3 {
for (index, line) in lines.iter().enumerate() {
if selected.contains(&index) || line_priority(line) != priority {
continue;
}
let candidate = selected
.iter()
.map(|selected_index| lines[*selected_index].as_str())
.chain(std::iter::once(line.as_str()))
.collect::<Vec<_>>();
if candidate.len() > budget.max_lines {
continue;
}
if joined_char_count(&candidate) > budget.max_chars {
continue;
}
selected.insert(index);
}
}
selected.into_iter().collect()
}
fn push_line_with_budget(lines: &mut Vec<String>, line: String, budget: SummaryCompressionBudget) {
let candidate = lines
.iter()
.map(String::as_str)
.chain(std::iter::once(line.as_str()))
.collect::<Vec<_>>();
if candidate.len() <= budget.max_lines && joined_char_count(&candidate) <= budget.max_chars {
lines.push(line);
}
}
fn joined_char_count(lines: &[&str]) -> usize {
lines.iter().map(|line| line.chars().count()).sum::<usize>() + lines.len().saturating_sub(1)
}
fn line_priority(line: &str) -> usize {
if line == "Summary:" || line == "Conversation summary:" || is_core_detail(line) {
0
} else if is_section_header(line) {
1
} else if line.starts_with("- ") || line.starts_with(" - ") {
2
} else {
3
}
}
fn is_core_detail(line: &str) -> bool {
[
"- Scope:",
"- Current work:",
"- Pending work:",
"- Key files referenced:",
"- Tools mentioned:",
"- Recent user requests:",
"- Previously compacted context:",
"- Newly compacted context:",
]
.iter()
.any(|prefix| line.starts_with(prefix))
}
fn is_section_header(line: &str) -> bool {
line.ends_with(':')
}
fn omission_notice(omitted_lines: usize) -> String {
format!("- … {omitted_lines} additional line(s) omitted.")
}
fn collapse_inline_whitespace(line: &str) -> String {
line.split_whitespace().collect::<Vec<_>>().join(" ")
}
fn truncate_line(line: &str, max_chars: usize) -> String {
if max_chars == 0 || line.chars().count() <= max_chars {
return line.to_string();
}
if max_chars == 1 {
return "".to_string();
}
let mut truncated = line
.chars()
.take(max_chars.saturating_sub(1))
.collect::<String>();
truncated.push('…');
truncated
}
fn dedupe_key(line: &str) -> String {
line.to_ascii_lowercase()
}
#[cfg(test)]
mod tests {
use super::{compress_summary, compress_summary_text, SummaryCompressionBudget};
#[test]
fn collapses_whitespace_and_duplicate_lines() {
// given
let summary = "Conversation summary:\n\n- Scope: compact earlier messages.\n- Scope: compact earlier messages.\n- Current work: update runtime module.\n";
// when
let result = compress_summary(summary, SummaryCompressionBudget::default());
// then
assert_eq!(result.removed_duplicate_lines, 1);
assert!(result
.summary
.contains("- Scope: compact earlier messages."));
assert!(!result.summary.contains(" compact earlier"));
}
#[test]
fn keeps_core_lines_when_budget_is_tight() {
// given
let summary = [
"Conversation summary:",
"- Scope: 18 earlier messages compacted.",
"- Current work: finish summary compression.",
"- Key timeline:",
" - user: asked for a working implementation.",
" - assistant: inspected runtime compaction flow.",
" - tool: cargo check succeeded.",
]
.join("\n");
// when
let result = compress_summary(
&summary,
SummaryCompressionBudget {
max_chars: 120,
max_lines: 3,
max_line_chars: 80,
},
);
// then
assert!(result.summary.contains("Conversation summary:"));
assert!(result
.summary
.contains("- Scope: 18 earlier messages compacted."));
assert!(result
.summary
.contains("- Current work: finish summary compression."));
assert!(result.omitted_lines > 0);
}
#[test]
fn provides_a_default_text_only_helper() {
// given
let summary = "Summary:\n\nA short line.";
// when
let compressed = compress_summary_text(summary);
// then
assert_eq!(compressed, "Summary:\nA short line.");
}
}

View File

@ -0,0 +1,158 @@
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TaskPacket {
pub objective: String,
pub scope: String,
pub repo: String,
pub branch_policy: String,
pub acceptance_tests: Vec<String>,
pub commit_policy: String,
pub reporting_contract: String,
pub escalation_policy: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TaskPacketValidationError {
errors: Vec<String>,
}
impl TaskPacketValidationError {
#[must_use]
pub fn new(errors: Vec<String>) -> Self {
Self { errors }
}
#[must_use]
pub fn errors(&self) -> &[String] {
&self.errors
}
}
impl Display for TaskPacketValidationError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.errors.join("; "))
}
}
impl std::error::Error for TaskPacketValidationError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ValidatedPacket(TaskPacket);
impl ValidatedPacket {
#[must_use]
pub fn packet(&self) -> &TaskPacket {
&self.0
}
#[must_use]
pub fn into_inner(self) -> TaskPacket {
self.0
}
}
pub fn validate_packet(packet: TaskPacket) -> Result<ValidatedPacket, TaskPacketValidationError> {
let mut errors = Vec::new();
validate_required("objective", &packet.objective, &mut errors);
validate_required("scope", &packet.scope, &mut errors);
validate_required("repo", &packet.repo, &mut errors);
validate_required("branch_policy", &packet.branch_policy, &mut errors);
validate_required("commit_policy", &packet.commit_policy, &mut errors);
validate_required(
"reporting_contract",
&packet.reporting_contract,
&mut errors,
);
validate_required("escalation_policy", &packet.escalation_policy, &mut errors);
for (index, test) in packet.acceptance_tests.iter().enumerate() {
if test.trim().is_empty() {
errors.push(format!(
"acceptance_tests contains an empty value at index {index}"
));
}
}
if errors.is_empty() {
Ok(ValidatedPacket(packet))
} else {
Err(TaskPacketValidationError::new(errors))
}
}
fn validate_required(field: &str, value: &str, errors: &mut Vec<String>) {
if value.trim().is_empty() {
errors.push(format!("{field} must not be empty"));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_packet() -> TaskPacket {
TaskPacket {
objective: "Implement typed task packet format".to_string(),
scope: "runtime/task system".to_string(),
repo: "claw-code-parity".to_string(),
branch_policy: "origin/main only".to_string(),
acceptance_tests: vec![
"cargo build --workspace".to_string(),
"cargo test --workspace".to_string(),
],
commit_policy: "single verified commit".to_string(),
reporting_contract: "print build result, test result, commit sha".to_string(),
escalation_policy: "stop only on destructive ambiguity".to_string(),
}
}
#[test]
fn valid_packet_passes_validation() {
let packet = sample_packet();
let validated = validate_packet(packet.clone()).expect("packet should validate");
assert_eq!(validated.packet(), &packet);
assert_eq!(validated.into_inner(), packet);
}
#[test]
fn invalid_packet_accumulates_errors() {
let packet = TaskPacket {
objective: " ".to_string(),
scope: String::new(),
repo: String::new(),
branch_policy: "\t".to_string(),
acceptance_tests: vec!["ok".to_string(), " ".to_string()],
commit_policy: String::new(),
reporting_contract: String::new(),
escalation_policy: String::new(),
};
let error = validate_packet(packet).expect_err("packet should be rejected");
assert!(error.errors().len() >= 7);
assert!(error
.errors()
.contains(&"objective must not be empty".to_string()));
assert!(error
.errors()
.contains(&"scope must not be empty".to_string()));
assert!(error
.errors()
.contains(&"repo must not be empty".to_string()));
assert!(error
.errors()
.contains(&"acceptance_tests contains an empty value at index 1".to_string()));
}
#[test]
fn serialization_roundtrip_preserves_packet() {
let packet = sample_packet();
let serialized = serde_json::to_string(&packet).expect("packet should serialize");
let deserialized: TaskPacket =
serde_json::from_str(&serialized).expect("packet should deserialize");
assert_eq!(deserialized, packet);
}
}

View File

@ -0,0 +1,503 @@
#![allow(clippy::must_use_candidate, clippy::unnecessary_map_or)]
//! In-memory task registry for sub-agent task lifecycle management.
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::{validate_packet, TaskPacket, TaskPacketValidationError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskStatus {
Created,
Running,
Completed,
Failed,
Stopped,
}
impl std::fmt::Display for TaskStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Created => write!(f, "created"),
Self::Running => write!(f, "running"),
Self::Completed => write!(f, "completed"),
Self::Failed => write!(f, "failed"),
Self::Stopped => write!(f, "stopped"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub task_id: String,
pub prompt: String,
pub description: Option<String>,
pub task_packet: Option<TaskPacket>,
pub status: TaskStatus,
pub created_at: u64,
pub updated_at: u64,
pub messages: Vec<TaskMessage>,
pub output: String,
pub team_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskMessage {
pub role: String,
pub content: String,
pub timestamp: u64,
}
#[derive(Debug, Clone, Default)]
pub struct TaskRegistry {
inner: Arc<Mutex<RegistryInner>>,
}
#[derive(Debug, Default)]
struct RegistryInner {
tasks: HashMap<String, Task>,
counter: u64,
}
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
impl TaskRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn create(&self, prompt: &str, description: Option<&str>) -> Task {
self.create_task(prompt.to_owned(), description.map(str::to_owned), None)
}
pub fn create_from_packet(
&self,
packet: TaskPacket,
) -> Result<Task, TaskPacketValidationError> {
let packet = validate_packet(packet)?.into_inner();
Ok(self.create_task(
packet.objective.clone(),
Some(packet.scope.clone()),
Some(packet),
))
}
fn create_task(
&self,
prompt: String,
description: Option<String>,
task_packet: Option<TaskPacket>,
) -> Task {
let mut inner = self.inner.lock().expect("registry lock poisoned");
inner.counter += 1;
let ts = now_secs();
let task_id = format!("task_{:08x}_{}", ts, inner.counter);
let task = Task {
task_id: task_id.clone(),
prompt,
description,
task_packet,
status: TaskStatus::Created,
created_at: ts,
updated_at: ts,
messages: Vec::new(),
output: String::new(),
team_id: None,
};
inner.tasks.insert(task_id, task.clone());
task
}
pub fn get(&self, task_id: &str) -> Option<Task> {
let inner = self.inner.lock().expect("registry lock poisoned");
inner.tasks.get(task_id).cloned()
}
pub fn list(&self, status_filter: Option<TaskStatus>) -> Vec<Task> {
let inner = self.inner.lock().expect("registry lock poisoned");
inner
.tasks
.values()
.filter(|t| status_filter.map_or(true, |s| t.status == s))
.cloned()
.collect()
}
pub fn stop(&self, task_id: &str) -> Result<Task, String> {
let mut inner = self.inner.lock().expect("registry lock poisoned");
let task = inner
.tasks
.get_mut(task_id)
.ok_or_else(|| format!("task not found: {task_id}"))?;
match task.status {
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => {
return Err(format!(
"task {task_id} is already in terminal state: {}",
task.status
));
}
_ => {}
}
task.status = TaskStatus::Stopped;
task.updated_at = now_secs();
Ok(task.clone())
}
pub fn update(&self, task_id: &str, message: &str) -> Result<Task, String> {
let mut inner = self.inner.lock().expect("registry lock poisoned");
let task = inner
.tasks
.get_mut(task_id)
.ok_or_else(|| format!("task not found: {task_id}"))?;
task.messages.push(TaskMessage {
role: String::from("user"),
content: message.to_owned(),
timestamp: now_secs(),
});
task.updated_at = now_secs();
Ok(task.clone())
}
pub fn output(&self, task_id: &str) -> Result<String, String> {
let inner = self.inner.lock().expect("registry lock poisoned");
let task = inner
.tasks
.get(task_id)
.ok_or_else(|| format!("task not found: {task_id}"))?;
Ok(task.output.clone())
}
pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> {
let mut inner = self.inner.lock().expect("registry lock poisoned");
let task = inner
.tasks
.get_mut(task_id)
.ok_or_else(|| format!("task not found: {task_id}"))?;
task.output.push_str(output);
task.updated_at = now_secs();
Ok(())
}
pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> {
let mut inner = self.inner.lock().expect("registry lock poisoned");
let task = inner
.tasks
.get_mut(task_id)
.ok_or_else(|| format!("task not found: {task_id}"))?;
task.status = status;
task.updated_at = now_secs();
Ok(())
}
pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> {
let mut inner = self.inner.lock().expect("registry lock poisoned");
let task = inner
.tasks
.get_mut(task_id)
.ok_or_else(|| format!("task not found: {task_id}"))?;
task.team_id = Some(team_id.to_owned());
task.updated_at = now_secs();
Ok(())
}
pub fn remove(&self, task_id: &str) -> Option<Task> {
let mut inner = self.inner.lock().expect("registry lock poisoned");
inner.tasks.remove(task_id)
}
#[must_use]
pub fn len(&self) -> usize {
let inner = self.inner.lock().expect("registry lock poisoned");
inner.tasks.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_and_retrieves_tasks() {
let registry = TaskRegistry::new();
let task = registry.create("Do something", Some("A test task"));
assert_eq!(task.status, TaskStatus::Created);
assert_eq!(task.prompt, "Do something");
assert_eq!(task.description.as_deref(), Some("A test task"));
assert_eq!(task.task_packet, None);
let fetched = registry.get(&task.task_id).expect("task should exist");
assert_eq!(fetched.task_id, task.task_id);
}
#[test]
fn creates_task_from_packet() {
let registry = TaskRegistry::new();
let packet = TaskPacket {
objective: "Ship task packet support".to_string(),
scope: "runtime/task system".to_string(),
repo: "claw-code-parity".to_string(),
branch_policy: "origin/main only".to_string(),
acceptance_tests: vec!["cargo test --workspace".to_string()],
commit_policy: "single commit".to_string(),
reporting_contract: "print commit sha".to_string(),
escalation_policy: "manual escalation".to_string(),
};
let task = registry
.create_from_packet(packet.clone())
.expect("packet-backed task should be created");
assert_eq!(task.prompt, packet.objective);
assert_eq!(task.description.as_deref(), Some("runtime/task system"));
assert_eq!(task.task_packet, Some(packet.clone()));
let fetched = registry.get(&task.task_id).expect("task should exist");
assert_eq!(fetched.task_packet, Some(packet));
}
#[test]
fn lists_tasks_with_optional_filter() {
let registry = TaskRegistry::new();
registry.create("Task A", None);
let task_b = registry.create("Task B", None);
registry
.set_status(&task_b.task_id, TaskStatus::Running)
.expect("set status should succeed");
let all = registry.list(None);
assert_eq!(all.len(), 2);
let running = registry.list(Some(TaskStatus::Running));
assert_eq!(running.len(), 1);
assert_eq!(running[0].task_id, task_b.task_id);
let created = registry.list(Some(TaskStatus::Created));
assert_eq!(created.len(), 1);
}
#[test]
fn stops_running_task() {
let registry = TaskRegistry::new();
let task = registry.create("Stoppable", None);
registry
.set_status(&task.task_id, TaskStatus::Running)
.unwrap();
let stopped = registry.stop(&task.task_id).expect("stop should succeed");
assert_eq!(stopped.status, TaskStatus::Stopped);
// Stopping again should fail
let result = registry.stop(&task.task_id);
assert!(result.is_err());
}
#[test]
fn updates_task_with_messages() {
let registry = TaskRegistry::new();
let task = registry.create("Messageable", None);
let updated = registry
.update(&task.task_id, "Here's more context")
.expect("update should succeed");
assert_eq!(updated.messages.len(), 1);
assert_eq!(updated.messages[0].content, "Here's more context");
assert_eq!(updated.messages[0].role, "user");
}
#[test]
fn appends_and_retrieves_output() {
let registry = TaskRegistry::new();
let task = registry.create("Output task", None);
registry
.append_output(&task.task_id, "line 1\n")
.expect("append should succeed");
registry
.append_output(&task.task_id, "line 2\n")
.expect("append should succeed");
let output = registry.output(&task.task_id).expect("output should exist");
assert_eq!(output, "line 1\nline 2\n");
}
#[test]
fn assigns_team_and_removes_task() {
let registry = TaskRegistry::new();
let task = registry.create("Team task", None);
registry
.assign_team(&task.task_id, "team_abc")
.expect("assign should succeed");
let fetched = registry.get(&task.task_id).unwrap();
assert_eq!(fetched.team_id.as_deref(), Some("team_abc"));
let removed = registry.remove(&task.task_id);
assert!(removed.is_some());
assert!(registry.get(&task.task_id).is_none());
assert!(registry.is_empty());
}
#[test]
fn rejects_operations_on_missing_task() {
let registry = TaskRegistry::new();
assert!(registry.stop("nonexistent").is_err());
assert!(registry.update("nonexistent", "msg").is_err());
assert!(registry.output("nonexistent").is_err());
assert!(registry.append_output("nonexistent", "data").is_err());
assert!(registry
.set_status("nonexistent", TaskStatus::Running)
.is_err());
}
#[test]
fn task_status_display_all_variants() {
// given
let cases = [
(TaskStatus::Created, "created"),
(TaskStatus::Running, "running"),
(TaskStatus::Completed, "completed"),
(TaskStatus::Failed, "failed"),
(TaskStatus::Stopped, "stopped"),
];
// when
let rendered: Vec<_> = cases
.into_iter()
.map(|(status, expected)| (status.to_string(), expected))
.collect();
// then
assert_eq!(
rendered,
vec![
("created".to_string(), "created"),
("running".to_string(), "running"),
("completed".to_string(), "completed"),
("failed".to_string(), "failed"),
("stopped".to_string(), "stopped"),
]
);
}
#[test]
fn stop_rejects_completed_task() {
// given
let registry = TaskRegistry::new();
let task = registry.create("done", None);
registry
.set_status(&task.task_id, TaskStatus::Completed)
.expect("set status should succeed");
// when
let result = registry.stop(&task.task_id);
// then
let error = result.expect_err("completed task should be rejected");
assert!(error.contains("already in terminal state"));
assert!(error.contains("completed"));
}
#[test]
fn stop_rejects_failed_task() {
// given
let registry = TaskRegistry::new();
let task = registry.create("failed", None);
registry
.set_status(&task.task_id, TaskStatus::Failed)
.expect("set status should succeed");
// when
let result = registry.stop(&task.task_id);
// then
let error = result.expect_err("failed task should be rejected");
assert!(error.contains("already in terminal state"));
assert!(error.contains("failed"));
}
#[test]
fn stop_succeeds_from_created_state() {
// given
let registry = TaskRegistry::new();
let task = registry.create("created task", None);
// when
let stopped = registry.stop(&task.task_id).expect("stop should succeed");
// then
assert_eq!(stopped.status, TaskStatus::Stopped);
assert!(stopped.updated_at >= task.updated_at);
}
#[test]
fn new_registry_is_empty() {
// given
let registry = TaskRegistry::new();
// when
let all_tasks = registry.list(None);
// then
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert!(all_tasks.is_empty());
}
#[test]
fn create_without_description() {
// given
let registry = TaskRegistry::new();
// when
let task = registry.create("Do the thing", None);
// then
assert!(task.task_id.starts_with("task_"));
assert_eq!(task.description, None);
assert_eq!(task.task_packet, None);
assert!(task.messages.is_empty());
assert!(task.output.is_empty());
assert_eq!(task.team_id, None);
}
#[test]
fn remove_nonexistent_returns_none() {
// given
let registry = TaskRegistry::new();
// when
let removed = registry.remove("missing");
// then
assert!(removed.is_none());
}
#[test]
fn assign_team_rejects_missing_task() {
// given
let registry = TaskRegistry::new();
// when
let result = registry.assign_team("missing", "team_123");
// then
let error = result.expect_err("missing task should be rejected");
assert_eq!(error, "task not found: missing");
}
}

View File

@ -0,0 +1,509 @@
#![allow(clippy::must_use_candidate)]
//! In-memory registries for Team and Cron lifecycle management.
//!
//! Provides TeamCreate/Delete and CronCreate/Delete/List runtime backing
//! to replace the stub implementations in the tools crate.
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Team {
pub team_id: String,
pub name: String,
pub task_ids: Vec<String>,
pub status: TeamStatus,
pub created_at: u64,
pub updated_at: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TeamStatus {
Created,
Running,
Completed,
Deleted,
}
impl std::fmt::Display for TeamStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Created => write!(f, "created"),
Self::Running => write!(f, "running"),
Self::Completed => write!(f, "completed"),
Self::Deleted => write!(f, "deleted"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TeamRegistry {
inner: Arc<Mutex<TeamInner>>,
}
#[derive(Debug, Default)]
struct TeamInner {
teams: HashMap<String, Team>,
counter: u64,
}
impl TeamRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn create(&self, name: &str, task_ids: Vec<String>) -> Team {
let mut inner = self.inner.lock().expect("team registry lock poisoned");
inner.counter += 1;
let ts = now_secs();
let team_id = format!("team_{:08x}_{}", ts, inner.counter);
let team = Team {
team_id: team_id.clone(),
name: name.to_owned(),
task_ids,
status: TeamStatus::Created,
created_at: ts,
updated_at: ts,
};
inner.teams.insert(team_id, team.clone());
team
}
pub fn get(&self, team_id: &str) -> Option<Team> {
let inner = self.inner.lock().expect("team registry lock poisoned");
inner.teams.get(team_id).cloned()
}
pub fn list(&self) -> Vec<Team> {
let inner = self.inner.lock().expect("team registry lock poisoned");
inner.teams.values().cloned().collect()
}
pub fn delete(&self, team_id: &str) -> Result<Team, String> {
let mut inner = self.inner.lock().expect("team registry lock poisoned");
let team = inner
.teams
.get_mut(team_id)
.ok_or_else(|| format!("team not found: {team_id}"))?;
team.status = TeamStatus::Deleted;
team.updated_at = now_secs();
Ok(team.clone())
}
pub fn remove(&self, team_id: &str) -> Option<Team> {
let mut inner = self.inner.lock().expect("team registry lock poisoned");
inner.teams.remove(team_id)
}
#[must_use]
pub fn len(&self) -> usize {
let inner = self.inner.lock().expect("team registry lock poisoned");
inner.teams.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CronEntry {
pub cron_id: String,
pub schedule: String,
pub prompt: String,
pub description: Option<String>,
pub enabled: bool,
pub created_at: u64,
pub updated_at: u64,
pub last_run_at: Option<u64>,
pub run_count: u64,
}
#[derive(Debug, Clone, Default)]
pub struct CronRegistry {
inner: Arc<Mutex<CronInner>>,
}
#[derive(Debug, Default)]
struct CronInner {
entries: HashMap<String, CronEntry>,
counter: u64,
}
impl CronRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn create(&self, schedule: &str, prompt: &str, description: Option<&str>) -> CronEntry {
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
inner.counter += 1;
let ts = now_secs();
let cron_id = format!("cron_{:08x}_{}", ts, inner.counter);
let entry = CronEntry {
cron_id: cron_id.clone(),
schedule: schedule.to_owned(),
prompt: prompt.to_owned(),
description: description.map(str::to_owned),
enabled: true,
created_at: ts,
updated_at: ts,
last_run_at: None,
run_count: 0,
};
inner.entries.insert(cron_id, entry.clone());
entry
}
pub fn get(&self, cron_id: &str) -> Option<CronEntry> {
let inner = self.inner.lock().expect("cron registry lock poisoned");
inner.entries.get(cron_id).cloned()
}
pub fn list(&self, enabled_only: bool) -> Vec<CronEntry> {
let inner = self.inner.lock().expect("cron registry lock poisoned");
inner
.entries
.values()
.filter(|e| !enabled_only || e.enabled)
.cloned()
.collect()
}
pub fn delete(&self, cron_id: &str) -> Result<CronEntry, String> {
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
inner
.entries
.remove(cron_id)
.ok_or_else(|| format!("cron not found: {cron_id}"))
}
/// Disable a cron entry without removing it.
pub fn disable(&self, cron_id: &str) -> Result<(), String> {
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
let entry = inner
.entries
.get_mut(cron_id)
.ok_or_else(|| format!("cron not found: {cron_id}"))?;
entry.enabled = false;
entry.updated_at = now_secs();
Ok(())
}
/// Record a cron run.
pub fn record_run(&self, cron_id: &str) -> Result<(), String> {
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
let entry = inner
.entries
.get_mut(cron_id)
.ok_or_else(|| format!("cron not found: {cron_id}"))?;
entry.last_run_at = Some(now_secs());
entry.run_count += 1;
entry.updated_at = now_secs();
Ok(())
}
#[must_use]
pub fn len(&self) -> usize {
let inner = self.inner.lock().expect("cron registry lock poisoned");
inner.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
// ── Team tests ──────────────────────────────────────
#[test]
fn creates_and_retrieves_team() {
let registry = TeamRegistry::new();
let team = registry.create("Alpha Squad", vec!["task_001".into(), "task_002".into()]);
assert_eq!(team.name, "Alpha Squad");
assert_eq!(team.task_ids.len(), 2);
assert_eq!(team.status, TeamStatus::Created);
let fetched = registry.get(&team.team_id).expect("team should exist");
assert_eq!(fetched.team_id, team.team_id);
}
#[test]
fn lists_and_deletes_teams() {
let registry = TeamRegistry::new();
let t1 = registry.create("Team A", vec![]);
let t2 = registry.create("Team B", vec![]);
let all = registry.list();
assert_eq!(all.len(), 2);
let deleted = registry.delete(&t1.team_id).expect("delete should succeed");
assert_eq!(deleted.status, TeamStatus::Deleted);
// Team is still listable (soft delete)
let still_there = registry.get(&t1.team_id).unwrap();
assert_eq!(still_there.status, TeamStatus::Deleted);
// Hard remove
registry.remove(&t2.team_id);
assert_eq!(registry.len(), 1);
}
#[test]
fn rejects_missing_team_operations() {
let registry = TeamRegistry::new();
assert!(registry.delete("nonexistent").is_err());
assert!(registry.get("nonexistent").is_none());
}
// ── Cron tests ──────────────────────────────────────
#[test]
fn creates_and_retrieves_cron() {
let registry = CronRegistry::new();
let entry = registry.create("0 * * * *", "Check status", Some("hourly check"));
assert_eq!(entry.schedule, "0 * * * *");
assert_eq!(entry.prompt, "Check status");
assert!(entry.enabled);
assert_eq!(entry.run_count, 0);
assert!(entry.last_run_at.is_none());
let fetched = registry.get(&entry.cron_id).expect("cron should exist");
assert_eq!(fetched.cron_id, entry.cron_id);
}
#[test]
fn lists_with_enabled_filter() {
let registry = CronRegistry::new();
let c1 = registry.create("* * * * *", "Task 1", None);
let c2 = registry.create("0 * * * *", "Task 2", None);
registry
.disable(&c1.cron_id)
.expect("disable should succeed");
let all = registry.list(false);
assert_eq!(all.len(), 2);
let enabled_only = registry.list(true);
assert_eq!(enabled_only.len(), 1);
assert_eq!(enabled_only[0].cron_id, c2.cron_id);
}
#[test]
fn deletes_cron_entry() {
let registry = CronRegistry::new();
let entry = registry.create("* * * * *", "To delete", None);
let deleted = registry
.delete(&entry.cron_id)
.expect("delete should succeed");
assert_eq!(deleted.cron_id, entry.cron_id);
assert!(registry.get(&entry.cron_id).is_none());
assert!(registry.is_empty());
}
#[test]
fn records_cron_runs() {
let registry = CronRegistry::new();
let entry = registry.create("*/5 * * * *", "Recurring", None);
registry.record_run(&entry.cron_id).unwrap();
registry.record_run(&entry.cron_id).unwrap();
let fetched = registry.get(&entry.cron_id).unwrap();
assert_eq!(fetched.run_count, 2);
assert!(fetched.last_run_at.is_some());
}
#[test]
fn rejects_missing_cron_operations() {
let registry = CronRegistry::new();
assert!(registry.delete("nonexistent").is_err());
assert!(registry.disable("nonexistent").is_err());
assert!(registry.record_run("nonexistent").is_err());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn team_status_display_all_variants() {
// given
let cases = [
(TeamStatus::Created, "created"),
(TeamStatus::Running, "running"),
(TeamStatus::Completed, "completed"),
(TeamStatus::Deleted, "deleted"),
];
// when
let rendered: Vec<_> = cases
.into_iter()
.map(|(status, expected)| (status.to_string(), expected))
.collect();
// then
assert_eq!(
rendered,
vec![
("created".to_string(), "created"),
("running".to_string(), "running"),
("completed".to_string(), "completed"),
("deleted".to_string(), "deleted"),
]
);
}
#[test]
fn new_team_registry_is_empty() {
// given
let registry = TeamRegistry::new();
// when
let teams = registry.list();
// then
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert!(teams.is_empty());
}
#[test]
fn team_remove_nonexistent_returns_none() {
// given
let registry = TeamRegistry::new();
// when
let removed = registry.remove("missing");
// then
assert!(removed.is_none());
}
#[test]
fn team_len_transitions() {
// given
let registry = TeamRegistry::new();
// when
let alpha = registry.create("Alpha", vec![]);
let beta = registry.create("Beta", vec![]);
let after_create = registry.len();
registry.remove(&alpha.team_id);
let after_first_remove = registry.len();
registry.remove(&beta.team_id);
// then
assert_eq!(after_create, 2);
assert_eq!(after_first_remove, 1);
assert_eq!(registry.len(), 0);
assert!(registry.is_empty());
}
#[test]
fn cron_list_all_disabled_returns_empty_for_enabled_only() {
// given
let registry = CronRegistry::new();
let first = registry.create("* * * * *", "Task 1", None);
let second = registry.create("0 * * * *", "Task 2", None);
registry
.disable(&first.cron_id)
.expect("disable should succeed");
registry
.disable(&second.cron_id)
.expect("disable should succeed");
// when
let enabled_only = registry.list(true);
let all_entries = registry.list(false);
// then
assert!(enabled_only.is_empty());
assert_eq!(all_entries.len(), 2);
}
#[test]
fn cron_create_without_description() {
// given
let registry = CronRegistry::new();
// when
let entry = registry.create("*/15 * * * *", "Check health", None);
// then
assert!(entry.cron_id.starts_with("cron_"));
assert_eq!(entry.description, None);
assert!(entry.enabled);
assert_eq!(entry.run_count, 0);
assert_eq!(entry.last_run_at, None);
}
#[test]
fn new_cron_registry_is_empty() {
// given
let registry = CronRegistry::new();
// when
let enabled_only = registry.list(true);
let all_entries = registry.list(false);
// then
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert!(enabled_only.is_empty());
assert!(all_entries.is_empty());
}
#[test]
fn cron_record_run_updates_timestamp_and_counter() {
// given
let registry = CronRegistry::new();
let entry = registry.create("*/5 * * * *", "Recurring", None);
// when
registry
.record_run(&entry.cron_id)
.expect("first run should succeed");
registry
.record_run(&entry.cron_id)
.expect("second run should succeed");
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
// then
assert_eq!(fetched.run_count, 2);
assert!(fetched.last_run_at.is_some());
assert!(fetched.updated_at >= entry.updated_at);
}
#[test]
fn cron_disable_updates_timestamp() {
// given
let registry = CronRegistry::new();
let entry = registry.create("0 0 * * *", "Nightly", None);
// when
registry
.disable(&entry.cron_id)
.expect("disable should succeed");
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
// then
assert!(!fetched.enabled);
assert!(fetched.updated_at >= entry.updated_at);
}
}

View File

@ -0,0 +1,299 @@
use std::path::{Path, PathBuf};
const TRUST_PROMPT_CUES: &[&str] = &[
"do you trust the files in this folder",
"trust the files in this folder",
"trust this folder",
"allow and continue",
"yes, proceed",
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrustPolicy {
AutoTrust,
RequireApproval,
Deny,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrustEvent {
TrustRequired { cwd: String },
TrustResolved { cwd: String, policy: TrustPolicy },
TrustDenied { cwd: String, reason: String },
}
#[derive(Debug, Clone, Default)]
pub struct TrustConfig {
allowlisted: Vec<PathBuf>,
denied: Vec<PathBuf>,
}
impl TrustConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_allowlisted(mut self, path: impl Into<PathBuf>) -> Self {
self.allowlisted.push(path.into());
self
}
#[must_use]
pub fn with_denied(mut self, path: impl Into<PathBuf>) -> Self {
self.denied.push(path.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrustDecision {
NotRequired,
Required {
policy: TrustPolicy,
events: Vec<TrustEvent>,
},
}
impl TrustDecision {
#[must_use]
pub fn policy(&self) -> Option<TrustPolicy> {
match self {
Self::NotRequired => None,
Self::Required { policy, .. } => Some(*policy),
}
}
#[must_use]
pub fn events(&self) -> &[TrustEvent] {
match self {
Self::NotRequired => &[],
Self::Required { events, .. } => events,
}
}
}
#[derive(Debug, Clone)]
pub struct TrustResolver {
config: TrustConfig,
}
impl TrustResolver {
#[must_use]
pub fn new(config: TrustConfig) -> Self {
Self { config }
}
#[must_use]
pub fn resolve(&self, cwd: &str, screen_text: &str) -> TrustDecision {
if !detect_trust_prompt(screen_text) {
return TrustDecision::NotRequired;
}
let mut events = vec![TrustEvent::TrustRequired {
cwd: cwd.to_owned(),
}];
if let Some(matched_root) = self
.config
.denied
.iter()
.find(|root| path_matches(cwd, root))
{
let reason = format!("cwd matches denied trust root: {}", matched_root.display());
events.push(TrustEvent::TrustDenied {
cwd: cwd.to_owned(),
reason,
});
return TrustDecision::Required {
policy: TrustPolicy::Deny,
events,
};
}
if self
.config
.allowlisted
.iter()
.any(|root| path_matches(cwd, root))
{
events.push(TrustEvent::TrustResolved {
cwd: cwd.to_owned(),
policy: TrustPolicy::AutoTrust,
});
return TrustDecision::Required {
policy: TrustPolicy::AutoTrust,
events,
};
}
TrustDecision::Required {
policy: TrustPolicy::RequireApproval,
events,
}
}
#[must_use]
pub fn trusts(&self, cwd: &str) -> bool {
!self
.config
.denied
.iter()
.any(|root| path_matches(cwd, root))
&& self
.config
.allowlisted
.iter()
.any(|root| path_matches(cwd, root))
}
}
#[must_use]
pub fn detect_trust_prompt(screen_text: &str) -> bool {
let lowered = screen_text.to_ascii_lowercase();
TRUST_PROMPT_CUES
.iter()
.any(|needle| lowered.contains(needle))
}
#[must_use]
pub fn path_matches_trusted_root(cwd: &str, trusted_root: &str) -> bool {
path_matches(cwd, &normalize_path(Path::new(trusted_root)))
}
fn path_matches(candidate: &str, root: &Path) -> bool {
let candidate = normalize_path(Path::new(candidate));
let root = normalize_path(root);
candidate == root || candidate.starts_with(&root)
}
fn normalize_path(path: &Path) -> PathBuf {
std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf())
}
#[cfg(test)]
mod tests {
use super::{
detect_trust_prompt, path_matches_trusted_root, TrustConfig, TrustDecision, TrustEvent,
TrustPolicy, TrustResolver,
};
#[test]
fn detects_known_trust_prompt_copy() {
// given
let screen_text = "Do you trust the files in this folder?\n1. Yes, proceed\n2. No";
// when
let detected = detect_trust_prompt(screen_text);
// then
assert!(detected);
}
#[test]
fn does_not_emit_events_when_prompt_is_absent() {
// given
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
// when
let decision = resolver.resolve("/tmp/worktrees/repo-a", "Ready for your input\n>");
// then
assert_eq!(decision, TrustDecision::NotRequired);
assert_eq!(decision.events(), &[]);
assert_eq!(decision.policy(), None);
}
#[test]
fn auto_trusts_allowlisted_cwd_after_prompt_detection() {
// given
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
// when
let decision = resolver.resolve(
"/tmp/worktrees/repo-a",
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
);
// then
assert_eq!(decision.policy(), Some(TrustPolicy::AutoTrust));
assert_eq!(
decision.events(),
&[
TrustEvent::TrustRequired {
cwd: "/tmp/worktrees/repo-a".to_string(),
},
TrustEvent::TrustResolved {
cwd: "/tmp/worktrees/repo-a".to_string(),
policy: TrustPolicy::AutoTrust,
},
]
);
}
#[test]
fn requires_approval_for_unknown_cwd_after_prompt_detection() {
// given
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
// when
let decision = resolver.resolve(
"/tmp/other/repo-b",
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
);
// then
assert_eq!(decision.policy(), Some(TrustPolicy::RequireApproval));
assert_eq!(
decision.events(),
&[TrustEvent::TrustRequired {
cwd: "/tmp/other/repo-b".to_string(),
}]
);
}
#[test]
fn denied_root_takes_precedence_over_allowlist() {
// given
let resolver = TrustResolver::new(
TrustConfig::new()
.with_allowlisted("/tmp/worktrees")
.with_denied("/tmp/worktrees/repo-c"),
);
// when
let decision = resolver.resolve(
"/tmp/worktrees/repo-c",
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
);
// then
assert_eq!(decision.policy(), Some(TrustPolicy::Deny));
assert_eq!(
decision.events(),
&[
TrustEvent::TrustRequired {
cwd: "/tmp/worktrees/repo-c".to_string(),
},
TrustEvent::TrustDenied {
cwd: "/tmp/worktrees/repo-c".to_string(),
reason: "cwd matches denied trust root: /tmp/worktrees/repo-c".to_string(),
},
]
);
}
#[test]
fn sibling_prefix_does_not_match_trusted_root() {
// given
let trusted_root = "/tmp/worktrees";
let sibling_path = "/tmp/worktrees-other/repo-d";
// when
let matched = path_matches_trusted_root(sibling_path, trusted_root);
// then
assert!(!matched);
}
}

View File

@ -1,11 +1,11 @@
use crate::session::Session;
use serde::{Deserialize, Serialize};
const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
/// Per-million-token pricing used for cost estimation.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ModelPricing {
pub input_cost_per_million: f64,
@ -26,7 +26,8 @@ impl ModelPricing {
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
/// Token counters accumulated for a conversation turn or session.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
@ -34,6 +35,7 @@ pub struct TokenUsage {
pub cache_read_input_tokens: u32,
}
/// Estimated dollar cost derived from a [`TokenUsage`] sample.
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct UsageCostEstimate {
pub input_cost_usd: f64,
@ -52,6 +54,7 @@ impl UsageCostEstimate {
}
}
/// Returns pricing metadata for a known model alias or family.
#[must_use]
pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
let normalized = model.to_ascii_lowercase();
@ -156,10 +159,12 @@ fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
}
#[must_use]
/// Formats a dollar-denominated value for CLI display.
pub fn format_usd(amount: f64) -> String {
format!("${amount:.4}")
}
/// Aggregates token usage across a running session.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct UsageTracker {
latest_turn: TokenUsage,
@ -250,9 +255,9 @@ mod tests {
let cost = usage.estimate_cost_usd();
assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6"));
let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514"));
assert!(lines[0].contains("estimated_cost=$54.6750"));
assert!(lines[0].contains("model=claude-sonnet-4-6"));
assert!(lines[0].contains("model=claude-sonnet-4-20250514"));
assert!(lines[1].contains("cache_read=$0.3000"));
}
@ -265,7 +270,7 @@ mod tests {
cache_read_input_tokens: 0,
};
let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing");
let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing");
let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
@ -287,9 +292,8 @@ mod tests {
#[test]
fn reconstructs_usage_from_session_messages() {
let session = Session {
version: 1,
messages: vec![ConversationMessage {
let mut session = Session::new();
session.messages = vec![ConversationMessage {
role: MessageRole::Assistant,
blocks: vec![ContentBlock::Text {
text: "done".to_string(),
@ -300,8 +304,7 @@ mod tests {
cache_creation_input_tokens: 1,
cache_read_input_tokens: 0,
}),
}],
};
}];
let tracker = UsageTracker::from_session(&session);
assert_eq!(tracker.turns(), 1);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,386 @@
#![allow(clippy::doc_markdown, clippy::uninlined_format_args, unused_imports)]
//! Integration tests for cross-module wiring.
//!
//! These tests verify that adjacent modules in the runtime crate actually
//! connect correctly — catching wiring gaps that unit tests miss.
use std::time::Duration;
use runtime::green_contract::{GreenContract, GreenContractOutcome, GreenLevel};
use runtime::{
apply_policy, BranchFreshness, DiffScope, LaneBlocker, LaneContext, PolicyAction,
PolicyCondition, PolicyEngine, PolicyRule, ReconcileReason, ReviewStatus, StaleBranchAction,
StaleBranchPolicy,
};
/// stale_branch + policy_engine integration:
/// When a branch is detected stale, does it correctly flow through
/// PolicyCondition::StaleBranch to generate the expected action?
#[test]
fn stale_branch_detection_flows_into_policy_engine() {
// given — a stale branch context (2 hours behind main, threshold is 1 hour)
let stale_context = LaneContext::new(
"stale-lane",
0,
Duration::from_secs(2 * 60 * 60), // 2 hours stale
LaneBlocker::None,
ReviewStatus::Pending,
DiffScope::Full,
false,
);
let engine = PolicyEngine::new(vec![PolicyRule::new(
"stale-merge-forward",
PolicyCondition::StaleBranch,
PolicyAction::MergeForward,
10,
)]);
// when
let actions = engine.evaluate(&stale_context);
// then
assert_eq!(actions, vec![PolicyAction::MergeForward]);
}
/// stale_branch + policy_engine: Fresh branch does NOT trigger stale rules
#[test]
fn fresh_branch_does_not_trigger_stale_policy() {
let fresh_context = LaneContext::new(
"fresh-lane",
0,
Duration::from_secs(30 * 60), // 30 min stale — under 1 hour threshold
LaneBlocker::None,
ReviewStatus::Pending,
DiffScope::Full,
false,
);
let engine = PolicyEngine::new(vec![PolicyRule::new(
"stale-merge-forward",
PolicyCondition::StaleBranch,
PolicyAction::MergeForward,
10,
)]);
let actions = engine.evaluate(&fresh_context);
assert!(actions.is_empty());
}
/// green_contract + policy_engine integration:
/// A lane that meets its green contract should be mergeable
#[test]
fn green_contract_satisfied_allows_merge() {
let contract = GreenContract::new(GreenLevel::Workspace);
let satisfied = contract.is_satisfied_by(GreenLevel::Workspace);
assert!(satisfied);
let exceeded = contract.is_satisfied_by(GreenLevel::MergeReady);
assert!(exceeded);
let insufficient = contract.is_satisfied_by(GreenLevel::Package);
assert!(!insufficient);
}
/// green_contract + policy_engine:
/// Lane with green level below contract requirement gets blocked
#[test]
fn green_contract_unsatisfied_blocks_merge() {
let context = LaneContext::new(
"partial-green-lane",
1, // GreenLevel::Package as u8
Duration::from_secs(0),
LaneBlocker::None,
ReviewStatus::Pending,
DiffScope::Full,
false,
);
// This is a conceptual test — we need a way to express "requires workspace green"
// Currently LaneContext has raw green_level: u8, not a contract
// For now we just verify the policy condition works
let engine = PolicyEngine::new(vec![PolicyRule::new(
"workspace-green-required",
PolicyCondition::GreenAt { level: 3 }, // GreenLevel::Workspace
PolicyAction::MergeToDev,
10,
)]);
let actions = engine.evaluate(&context);
assert!(actions.is_empty()); // level 1 < 3, so no merge
}
/// reconciliation + policy_engine integration:
/// A reconciled lane should be handled by reconcile rules, not generic closeout
#[test]
fn reconciled_lane_matches_reconcile_condition() {
let context = LaneContext::reconciled("reconciled-lane");
let engine = PolicyEngine::new(vec![
PolicyRule::new(
"reconcile-first",
PolicyCondition::LaneReconciled,
PolicyAction::Reconcile {
reason: ReconcileReason::AlreadyMerged,
},
5,
),
PolicyRule::new(
"generic-closeout",
PolicyCondition::LaneCompleted,
PolicyAction::CloseoutLane,
30,
),
]);
let actions = engine.evaluate(&context);
// Both rules fire — reconcile (priority 5) first, then closeout (priority 30)
assert_eq!(
actions,
vec![
PolicyAction::Reconcile {
reason: ReconcileReason::AlreadyMerged,
},
PolicyAction::CloseoutLane,
]
);
}
/// stale_branch module: apply_policy generates correct actions
#[test]
fn stale_branch_apply_policy_produces_rebase_action() {
let stale = BranchFreshness::Stale {
commits_behind: 5,
missing_fixes: vec!["fix-123".to_string()],
};
let action = apply_policy(&stale, StaleBranchPolicy::AutoRebase);
assert_eq!(action, StaleBranchAction::Rebase);
}
#[test]
fn stale_branch_apply_policy_produces_merge_forward_action() {
let stale = BranchFreshness::Stale {
commits_behind: 3,
missing_fixes: vec![],
};
let action = apply_policy(&stale, StaleBranchPolicy::AutoMergeForward);
assert_eq!(action, StaleBranchAction::MergeForward);
}
#[test]
fn stale_branch_apply_policy_warn_only() {
let stale = BranchFreshness::Stale {
commits_behind: 2,
missing_fixes: vec!["fix-456".to_string()],
};
let action = apply_policy(&stale, StaleBranchPolicy::WarnOnly);
match action {
StaleBranchAction::Warn { message } => {
assert!(message.contains("2 commit(s) behind main"));
assert!(message.contains("fix-456"));
}
_ => panic!("expected Warn action, got {:?}", action),
}
}
#[test]
fn stale_branch_fresh_produces_noop() {
let fresh = BranchFreshness::Fresh;
let action = apply_policy(&fresh, StaleBranchPolicy::AutoRebase);
assert_eq!(action, StaleBranchAction::Noop);
}
/// Combined flow: stale detection + policy + action
#[test]
fn end_to_end_stale_lane_gets_merge_forward_action() {
// Simulating what a harness would do:
// 1. Detect branch freshness
// 2. Build lane context from freshness + other signals
// 3. Run policy engine
// 4. Return actions
// given: detected stale state
let _freshness = BranchFreshness::Stale {
commits_behind: 5,
missing_fixes: vec!["fix-123".to_string()],
};
// when: build context and evaluate policy
let context = LaneContext::new(
"lane-9411",
3, // Workspace green
Duration::from_secs(5 * 60 * 60), // 5 hours stale, definitely over threshold
LaneBlocker::None,
ReviewStatus::Approved,
DiffScope::Scoped,
false,
);
let engine = PolicyEngine::new(vec![
// Priority 5: Check if stale first
PolicyRule::new(
"auto-merge-forward-if-stale-and-approved",
PolicyCondition::And(vec![
PolicyCondition::StaleBranch,
PolicyCondition::ReviewPassed,
]),
PolicyAction::MergeForward,
5,
),
// Priority 10: Normal stale handling
PolicyRule::new(
"stale-warning",
PolicyCondition::StaleBranch,
PolicyAction::Notify {
channel: "#build-status".to_string(),
},
10,
),
]);
let actions = engine.evaluate(&context);
// then: both rules should fire (stale + approved matches both)
assert_eq!(
actions,
vec![
PolicyAction::MergeForward,
PolicyAction::Notify {
channel: "#build-status".to_string(),
},
]
);
}
/// Fresh branch with approved review should merge (not stale-blocked)
#[test]
fn fresh_approved_lane_gets_merge_action() {
let context = LaneContext::new(
"fresh-approved-lane",
3, // Workspace green
Duration::from_secs(30 * 60), // 30 min — under 1 hour threshold = fresh
LaneBlocker::None,
ReviewStatus::Approved,
DiffScope::Scoped,
false,
);
let engine = PolicyEngine::new(vec![PolicyRule::new(
"merge-if-green-approved-not-stale",
PolicyCondition::And(vec![
PolicyCondition::GreenAt { level: 3 },
PolicyCondition::ReviewPassed,
// NOT PolicyCondition::StaleBranch — fresh lanes bypass this
]),
PolicyAction::MergeToDev,
5,
)]);
let actions = engine.evaluate(&context);
assert_eq!(actions, vec![PolicyAction::MergeToDev]);
}
/// worker_boot + recovery_recipes + policy_engine integration:
/// When a session completes with a provider failure, does the worker
/// status transition trigger the correct recovery recipe, and does
/// the resulting recovery state feed into policy decisions?
#[test]
fn worker_provider_failure_flows_through_recovery_to_policy() {
use runtime::recovery_recipes::{
attempt_recovery, FailureScenario, RecoveryContext, RecoveryResult, RecoveryStep,
};
use runtime::worker_boot::{WorkerFailureKind, WorkerRegistry, WorkerStatus};
// given — a worker that encounters a provider failure during session completion
let registry = WorkerRegistry::new();
let worker = registry.create("/tmp/repo-recovery-test", &[], true);
// Worker reaches ready state
registry
.observe(&worker.worker_id, "Ready for your input\n>")
.expect("ready observe should succeed");
registry
.send_prompt(&worker.worker_id, Some("Run analysis"))
.expect("prompt send should succeed");
// Session completes with provider failure (finish="unknown", tokens=0)
let failed_worker = registry
.observe_completion(&worker.worker_id, "unknown", 0)
.expect("completion observe should succeed");
assert_eq!(failed_worker.status, WorkerStatus::Failed);
let failure = failed_worker
.last_error
.expect("worker should have recorded error");
assert_eq!(failure.kind, WorkerFailureKind::Provider);
// Bridge: WorkerFailureKind -> FailureScenario
let scenario = FailureScenario::from_worker_failure_kind(failure.kind);
assert_eq!(scenario, FailureScenario::ProviderFailure);
// Recovery recipe lookup and execution
let mut ctx = RecoveryContext::new();
let result = attempt_recovery(&scenario, &mut ctx);
// then — recovery should recommend RestartWorker step
assert!(
matches!(result, RecoveryResult::Recovered { steps_taken: 1 }),
"provider failure should recover via single RestartWorker step, got: {result:?}"
);
assert!(
ctx.events().iter().any(|e| {
matches!(
e,
runtime::recovery_recipes::RecoveryEvent::RecoveryAttempted {
result: RecoveryResult::Recovered { steps_taken: 1 },
..
}
)
}),
"recovery should emit structured attempt event"
);
// Policy integration: recovery success + green status = merge-ready
// (Simulating the policy check that would happen after successful recovery)
let recovery_success = matches!(result, RecoveryResult::Recovered { .. });
let green_level = 3; // Workspace green
let not_stale = Duration::from_secs(30 * 60); // 30 min — fresh
let post_recovery_context = LaneContext::new(
"recovered-lane",
green_level,
not_stale,
LaneBlocker::None,
ReviewStatus::Approved,
DiffScope::Scoped,
false,
);
let policy_engine = PolicyEngine::new(vec![
// Rule: if recovered from failure + green + approved -> merge
PolicyRule::new(
"merge-after-successful-recovery",
PolicyCondition::And(vec![
PolicyCondition::GreenAt { level: 3 },
PolicyCondition::ReviewPassed,
]),
PolicyAction::MergeToDev,
10,
),
]);
// Recovery success is a pre-condition; policy evaluates post-recovery context
assert!(
recovery_success,
"recovery must succeed for lane to proceed"
);
let actions = policy_engine.evaluate(&post_recovery_context);
assert_eq!(
actions,
vec![PolicyAction::MergeToDev],
"post-recovery green+approved lane should be merge-ready"
);
}

View File

@ -15,12 +15,20 @@ commands = { path = "../commands" }
compat-harness = { path = "../compat-harness" }
crossterm = "0.28"
pulldown-cmark = "0.13"
plugins = { path = "../plugins" }
rustyline = "15"
runtime = { path = "../runtime" }
serde_json = "1"
plugins = { path = "../plugins" }
serde = { version = "1", features = ["derive"] }
serde_json.workspace = true
syntect = "5"
tokio = { version = "1", features = ["rt-multi-thread", "time"] }
tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] }
tools = { path = "../tools" }
[lints]
workspace = true
[dev-dependencies]
mock-anthropic-service = { path = "../mock-anthropic-service" }
serde_json.workspace = true
tokio = { version = "1", features = ["rt-multi-thread"] }

View File

@ -0,0 +1,38 @@
use std::env;
use std::process::Command;
fn main() {
// Get git SHA (short hash)
let git_sha = Command::new("git")
.args(["rev-parse", "--short", "HEAD"])
.output()
.ok()
.and_then(|output| {
if output.status.success() {
String::from_utf8(output.stdout).ok()
} else {
None
}
})
.map_or_else(|| "unknown".to_string(), |s| s.trim().to_string());
println!("cargo:rustc-env=GIT_SHA={git_sha}");
// TARGET is always set by Cargo during build
let target = env::var("TARGET").unwrap_or_else(|_| "unknown".to_string());
println!("cargo:rustc-env=TARGET={target}");
// Build date from SOURCE_DATE_EPOCH (reproducible builds) or current date.
let build_date = std::env::var("SOURCE_DATE_EPOCH")
.ok()
.and_then(|epoch| epoch.parse::<i64>().ok())
.map_or_else(
|| "unknown".to_string(),
|_| std::env::var("BUILD_DATE").unwrap_or_else(|_| "unknown".to_string()),
);
println!("cargo:rustc-env=BUILD_DATE={build_date}");
// Rerun if git state changes
println!("cargo:rerun-if-changed=.git/HEAD");
println!("cargo:rerun-if-changed=.git/refs");
}

View File

@ -1,398 +0,0 @@
use std::io::{self, Write};
use std::path::PathBuf;
use crate::args::{OutputFormat, PermissionMode};
use crate::input::{LineEditor, ReadOutcome};
use crate::render::{Spinner, TerminalRenderer};
use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionConfig {
pub model: String,
pub permission_mode: PermissionMode,
pub config: Option<PathBuf>,
pub output_format: OutputFormat,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionState {
pub turns: usize,
pub compacted_messages: usize,
pub last_model: String,
pub last_usage: UsageSummary,
}
impl SessionState {
#[must_use]
pub fn new(model: impl Into<String>) -> Self {
Self {
turns: 0,
compacted_messages: 0,
last_model: model.into(),
last_usage: UsageSummary::default(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommandResult {
Continue,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SlashCommand {
Help,
Status,
Compact,
Unknown(String),
}
impl SlashCommand {
#[must_use]
pub fn parse(input: &str) -> Option<Self> {
let trimmed = input.trim();
if !trimmed.starts_with('/') {
return None;
}
let command = trimmed
.trim_start_matches('/')
.split_whitespace()
.next()
.unwrap_or_default();
Some(match command {
"help" => Self::Help,
"status" => Self::Status,
"compact" => Self::Compact,
other => Self::Unknown(other.to_string()),
})
}
}
struct SlashCommandHandler {
command: SlashCommand,
summary: &'static str,
}
const SLASH_COMMAND_HANDLERS: &[SlashCommandHandler] = &[
SlashCommandHandler {
command: SlashCommand::Help,
summary: "Show command help",
},
SlashCommandHandler {
command: SlashCommand::Status,
summary: "Show current session status",
},
SlashCommandHandler {
command: SlashCommand::Compact,
summary: "Compact local session history",
},
];
pub struct CliApp {
config: SessionConfig,
renderer: TerminalRenderer,
state: SessionState,
conversation_client: ConversationClient,
conversation_history: Vec<ConversationMessage>,
}
impl CliApp {
pub fn new(config: SessionConfig) -> Result<Self, RuntimeError> {
let state = SessionState::new(config.model.clone());
let conversation_client = ConversationClient::from_env(config.model.clone())?;
Ok(Self {
config,
renderer: TerminalRenderer::new(),
state,
conversation_client,
conversation_history: Vec::new(),
})
}
pub fn run_repl(&mut self) -> io::Result<()> {
let mut editor = LineEditor::new(" ", Vec::new());
println!("Rusty Claude CLI interactive mode");
println!("Type /help for commands. Shift+Enter or Ctrl+J inserts a newline.");
loop {
match editor.read_line()? {
ReadOutcome::Submit(input) => {
if input.trim().is_empty() {
continue;
}
self.handle_submission(&input, &mut io::stdout())?;
}
ReadOutcome::Cancel => continue,
ReadOutcome::Exit => break,
}
}
Ok(())
}
pub fn run_prompt(&mut self, prompt: &str, out: &mut impl Write) -> io::Result<()> {
self.render_response(prompt, out)
}
pub fn handle_submission(
&mut self,
input: &str,
out: &mut impl Write,
) -> io::Result<CommandResult> {
if let Some(command) = SlashCommand::parse(input) {
return self.dispatch_slash_command(command, out);
}
self.state.turns += 1;
self.render_response(input, out)?;
Ok(CommandResult::Continue)
}
fn dispatch_slash_command(
&mut self,
command: SlashCommand,
out: &mut impl Write,
) -> io::Result<CommandResult> {
match command {
SlashCommand::Help => Self::handle_help(out),
SlashCommand::Status => self.handle_status(out),
SlashCommand::Compact => self.handle_compact(out),
SlashCommand::Unknown(name) => {
writeln!(out, "Unknown slash command: /{name}")?;
Ok(CommandResult::Continue)
}
}
}
fn handle_help(out: &mut impl Write) -> io::Result<CommandResult> {
writeln!(out, "Available commands:")?;
for handler in SLASH_COMMAND_HANDLERS {
let name = match handler.command {
SlashCommand::Help => "/help",
SlashCommand::Status => "/status",
SlashCommand::Compact => "/compact",
SlashCommand::Unknown(_) => continue,
};
writeln!(out, " {name:<9} {}", handler.summary)?;
}
Ok(CommandResult::Continue)
}
fn handle_status(&mut self, out: &mut impl Write) -> io::Result<CommandResult> {
writeln!(
out,
"status: turns={} model={} permission-mode={:?} output-format={:?} last-usage={} in/{} out config={}",
self.state.turns,
self.state.last_model,
self.config.permission_mode,
self.config.output_format,
self.state.last_usage.input_tokens,
self.state.last_usage.output_tokens,
self.config
.config
.as_ref()
.map_or_else(|| String::from("<none>"), |path| path.display().to_string())
)?;
Ok(CommandResult::Continue)
}
fn handle_compact(&mut self, out: &mut impl Write) -> io::Result<CommandResult> {
self.state.compacted_messages += self.state.turns;
self.state.turns = 0;
self.conversation_history.clear();
writeln!(
out,
"Compacted session history into a local summary ({} messages total compacted).",
self.state.compacted_messages
)?;
Ok(CommandResult::Continue)
}
fn handle_stream_event(
renderer: &TerminalRenderer,
event: StreamEvent,
stream_spinner: &mut Spinner,
tool_spinner: &mut Spinner,
saw_text: &mut bool,
turn_usage: &mut UsageSummary,
out: &mut impl Write,
) {
match event {
StreamEvent::TextDelta(delta) => {
if !*saw_text {
let _ =
stream_spinner.finish("Streaming response", renderer.color_theme(), out);
*saw_text = true;
}
let _ = write!(out, "{delta}");
let _ = out.flush();
}
StreamEvent::ToolCallStart { name, input } => {
if *saw_text {
let _ = writeln!(out);
}
let _ = tool_spinner.tick(
&format!("Running tool `{name}` with {input}"),
renderer.color_theme(),
out,
);
}
StreamEvent::ToolCallResult {
name,
output,
is_error,
} => {
let label = if is_error {
format!("Tool `{name}` failed")
} else {
format!("Tool `{name}` completed")
};
let _ = tool_spinner.finish(&label, renderer.color_theme(), out);
let rendered_output = format!("### Tool `{name}`\n\n```text\n{output}\n```\n");
let _ = renderer.stream_markdown(&rendered_output, out);
}
StreamEvent::Usage(usage) => {
*turn_usage = usage;
}
}
}
fn write_turn_output(
&self,
summary: &runtime::TurnSummary,
out: &mut impl Write,
) -> io::Result<()> {
match self.config.output_format {
OutputFormat::Text => {
writeln!(
out,
"\nToken usage: {} input / {} output",
self.state.last_usage.input_tokens, self.state.last_usage.output_tokens
)?;
}
OutputFormat::Json => {
writeln!(
out,
"{}",
serde_json::json!({
"message": summary.assistant_text,
"usage": {
"input_tokens": self.state.last_usage.input_tokens,
"output_tokens": self.state.last_usage.output_tokens,
}
})
)?;
}
OutputFormat::Ndjson => {
writeln!(
out,
"{}",
serde_json::json!({
"type": "message",
"text": summary.assistant_text,
"usage": {
"input_tokens": self.state.last_usage.input_tokens,
"output_tokens": self.state.last_usage.output_tokens,
}
})
)?;
}
}
Ok(())
}
fn render_response(&mut self, input: &str, out: &mut impl Write) -> io::Result<()> {
let mut stream_spinner = Spinner::new();
stream_spinner.tick(
"Opening conversation stream",
self.renderer.color_theme(),
out,
)?;
let mut turn_usage = UsageSummary::default();
let mut tool_spinner = Spinner::new();
let mut saw_text = false;
let renderer = &self.renderer;
let result =
self.conversation_client
.run_turn(&mut self.conversation_history, input, |event| {
Self::handle_stream_event(
renderer,
event,
&mut stream_spinner,
&mut tool_spinner,
&mut saw_text,
&mut turn_usage,
out,
);
});
let summary = match result {
Ok(summary) => summary,
Err(error) => {
stream_spinner.fail(
"Streaming response failed",
self.renderer.color_theme(),
out,
)?;
return Err(io::Error::other(error));
}
};
self.state.last_usage = summary.usage.clone();
if saw_text {
writeln!(out)?;
} else {
stream_spinner.finish("Streaming response", self.renderer.color_theme(), out)?;
}
self.write_turn_output(&summary, out)?;
let _ = turn_usage;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use crate::args::{OutputFormat, PermissionMode};
use super::{CommandResult, SessionConfig, SlashCommand};
#[test]
fn parses_required_slash_commands() {
assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help));
assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status));
assert_eq!(
SlashCommand::parse("/compact now"),
Some(SlashCommand::Compact)
);
}
#[test]
fn help_output_lists_commands() {
let mut out = Vec::new();
let result = super::CliApp::handle_help(&mut out).expect("help succeeds");
assert_eq!(result, CommandResult::Continue);
let output = String::from_utf8_lossy(&out);
assert!(output.contains("/help"));
assert!(output.contains("/status"));
assert!(output.contains("/compact"));
}
#[test]
fn session_state_tracks_config_values() {
let config = SessionConfig {
model: "claude".into(),
permission_mode: PermissionMode::WorkspaceWrite,
config: Some(PathBuf::from("settings.toml")),
output_format: OutputFormat::Text,
};
assert_eq!(config.model, "claude");
assert_eq!(config.permission_mode, PermissionMode::WorkspaceWrite);
assert_eq!(config.config, Some(PathBuf::from("settings.toml")));
}
}

View File

@ -1,102 +0,0 @@
use std::path::PathBuf;
use clap::{Parser, Subcommand, ValueEnum};
#[derive(Debug, Clone, Parser, PartialEq, Eq)]
#[command(
name = "rusty-claude-cli",
version,
about = "Rust Claude CLI prototype"
)]
pub struct Cli {
#[arg(long, default_value = "claude-opus-4-6")]
pub model: String,
#[arg(long, value_enum, default_value_t = PermissionMode::WorkspaceWrite)]
pub permission_mode: PermissionMode,
#[arg(long)]
pub config: Option<PathBuf>,
#[arg(long, value_enum, default_value_t = OutputFormat::Text)]
pub output_format: OutputFormat,
#[command(subcommand)]
pub command: Option<Command>,
}
#[derive(Debug, Clone, Subcommand, PartialEq, Eq)]
pub enum Command {
/// Read upstream TS sources and print extracted counts
DumpManifests,
/// Print the current bootstrap phase skeleton
BootstrapPlan,
/// Start the OAuth login flow
Login,
/// Clear saved OAuth credentials
Logout,
/// Run a non-interactive prompt and exit
Prompt { prompt: Vec<String> },
}
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
pub enum PermissionMode {
ReadOnly,
WorkspaceWrite,
DangerFullAccess,
}
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
pub enum OutputFormat {
Text,
Json,
Ndjson,
}
#[cfg(test)]
mod tests {
use clap::Parser;
use super::{Cli, Command, OutputFormat, PermissionMode};
#[test]
fn parses_requested_flags() {
let cli = Cli::parse_from([
"rusty-claude-cli",
"--model",
"claude-3-5-haiku",
"--permission-mode",
"read-only",
"--config",
"/tmp/config.toml",
"--output-format",
"ndjson",
"prompt",
"hello",
"world",
]);
assert_eq!(cli.model, "claude-3-5-haiku");
assert_eq!(cli.permission_mode, PermissionMode::ReadOnly);
assert_eq!(
cli.config.as_deref(),
Some(std::path::Path::new("/tmp/config.toml"))
);
assert_eq!(cli.output_format, OutputFormat::Ndjson);
assert_eq!(
cli.command,
Some(Command::Prompt {
prompt: vec!["hello".into(), "world".into()]
})
);
}
#[test]
fn parses_login_and_logout_commands() {
let login = Cli::parse_from(["rusty-claude-cli", "login"]);
assert_eq!(login.command, Some(Command::Login));
let logout = Cli::parse_from(["rusty-claude-cli", "logout"]);
assert_eq!(logout.command, Some(Command::Logout));
}
}

View File

@ -1,15 +1,15 @@
use std::fs;
use std::path::{Path, PathBuf};
const STARTER_CLAUDE_JSON: &str = concat!(
const STARTER_CLAW_JSON: &str = concat!(
"{\n",
" \"permissions\": {\n",
" \"defaultMode\": \"acceptEdits\"\n",
" \"defaultMode\": \"dontAsk\"\n",
" }\n",
"}\n",
);
const GITIGNORE_COMMENT: &str = "# Claude Code local artifacts";
const GITIGNORE_ENTRIES: [&str; 2] = [".claude/settings.local.json", ".claude/sessions/"];
const GITIGNORE_COMMENT: &str = "# Claw Code local artifacts";
const GITIGNORE_ENTRIES: [&str; 2] = [".claw/settings.local.json", ".claw/sessions/"];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum InitStatus {
@ -80,16 +80,16 @@ struct RepoDetection {
pub(crate) fn initialize_repo(cwd: &Path) -> Result<InitReport, Box<dyn std::error::Error>> {
let mut artifacts = Vec::new();
let claude_dir = cwd.join(".claude");
let claw_dir = cwd.join(".claw");
artifacts.push(InitArtifact {
name: ".claude/",
status: ensure_dir(&claude_dir)?,
name: ".claw/",
status: ensure_dir(&claw_dir)?,
});
let claude_json = cwd.join(".claude.json");
let claw_json = cwd.join(".claw.json");
artifacts.push(InitArtifact {
name: ".claude.json",
status: write_file_if_missing(&claude_json, STARTER_CLAUDE_JSON)?,
name: ".claw.json",
status: write_file_if_missing(&claw_json, STARTER_CLAW_JSON)?,
});
let gitignore = cwd.join(".gitignore");
@ -164,7 +164,7 @@ pub(crate) fn render_init_claude_md(cwd: &Path) -> String {
let mut lines = vec![
"# CLAUDE.md".to_string(),
String::new(),
"This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.".to_string(),
"This file provides guidance to Claw Code (clawcode.dev) when working with code in this repository.".to_string(),
String::new(),
];
@ -209,7 +209,7 @@ pub(crate) fn render_init_claude_md(cwd: &Path) -> String {
lines.push("## Working agreement".to_string());
lines.push("- Prefer small, reviewable changes and keep generated bootstrap files aligned with actual repo workflows.".to_string());
lines.push("- Keep shared defaults in `.claude.json`; reserve `.claude/settings.local.json` for machine-local overrides.".to_string());
lines.push("- Keep shared defaults in `.claw.json`; reserve `.claw/settings.local.json` for machine-local overrides.".to_string());
lines.push("- Do not overwrite existing `CLAUDE.md` content automatically; update it intentionally when repo workflows change.".to_string());
lines.push(String::new());
@ -354,26 +354,27 @@ mod tests {
let report = initialize_repo(&root).expect("init should succeed");
let rendered = report.render();
assert!(rendered.contains(".claude/ created"));
assert!(rendered.contains(".claude.json created"));
assert!(rendered.contains(".claw/"));
assert!(rendered.contains(".claw.json"));
assert!(rendered.contains("created"));
assert!(rendered.contains(".gitignore created"));
assert!(rendered.contains("CLAUDE.md created"));
assert!(root.join(".claude").is_dir());
assert!(root.join(".claude.json").is_file());
assert!(root.join(".claw").is_dir());
assert!(root.join(".claw.json").is_file());
assert!(root.join("CLAUDE.md").is_file());
assert_eq!(
fs::read_to_string(root.join(".claude.json")).expect("read claude json"),
fs::read_to_string(root.join(".claw.json")).expect("read claw json"),
concat!(
"{\n",
" \"permissions\": {\n",
" \"defaultMode\": \"acceptEdits\"\n",
" \"defaultMode\": \"dontAsk\"\n",
" }\n",
"}\n",
)
);
let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore");
assert!(gitignore.contains(".claude/settings.local.json"));
assert!(gitignore.contains(".claude/sessions/"));
assert!(gitignore.contains(".claw/settings.local.json"));
assert!(gitignore.contains(".claw/sessions/"));
let claude_md = fs::read_to_string(root.join("CLAUDE.md")).expect("read claude md");
assert!(claude_md.contains("Languages: Rust."));
assert!(claude_md.contains("cargo clippy --workspace --all-targets -- -D warnings"));
@ -386,8 +387,7 @@ mod tests {
let root = temp_dir();
fs::create_dir_all(&root).expect("create root");
fs::write(root.join("CLAUDE.md"), "custom guidance\n").expect("write existing claude md");
fs::write(root.join(".gitignore"), ".claude/settings.local.json\n")
.expect("write gitignore");
fs::write(root.join(".gitignore"), ".claw/settings.local.json\n").expect("write gitignore");
let first = initialize_repo(&root).expect("first init should succeed");
assert!(first
@ -395,8 +395,9 @@ mod tests {
.contains("CLAUDE.md skipped (already exists)"));
let second = initialize_repo(&root).expect("second init should succeed");
let second_rendered = second.render();
assert!(second_rendered.contains(".claude/ skipped (already exists)"));
assert!(second_rendered.contains(".claude.json skipped (already exists)"));
assert!(second_rendered.contains(".claw/"));
assert!(second_rendered.contains(".claw.json"));
assert!(second_rendered.contains("skipped (already exists)"));
assert!(second_rendered.contains(".gitignore skipped (already exists)"));
assert!(second_rendered.contains("CLAUDE.md skipped (already exists)"));
assert_eq!(
@ -404,8 +405,8 @@ mod tests {
"custom guidance\n"
);
let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore");
assert_eq!(gitignore.matches(".claude/settings.local.json").count(), 1);
assert_eq!(gitignore.matches(".claude/sessions/").count(), 1);
assert_eq!(gitignore.matches(".claw/settings.local.json").count(), 1);
assert_eq!(gitignore.matches(".claw/sessions/").count(), 1);
fs::remove_dir_all(root).expect("cleanup temp dir");
}

View File

@ -1,166 +1,17 @@
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::BTreeSet;
use std::io::{self, IsTerminal, Write};
use crossterm::cursor::{MoveDown, MoveToColumn, MoveUp};
use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
use crossterm::queue;
use crossterm::terminal::{disable_raw_mode, enable_raw_mode, Clear, ClearType};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InputBuffer {
buffer: String,
cursor: usize,
}
impl InputBuffer {
#[must_use]
pub fn new() -> Self {
Self {
buffer: String::new(),
cursor: 0,
}
}
pub fn insert(&mut self, ch: char) {
self.buffer.insert(self.cursor, ch);
self.cursor += ch.len_utf8();
}
pub fn insert_newline(&mut self) {
self.insert('\n');
}
pub fn backspace(&mut self) {
if self.cursor == 0 {
return;
}
let previous = self.buffer[..self.cursor]
.char_indices()
.last()
.map_or(0, |(idx, _)| idx);
self.buffer.drain(previous..self.cursor);
self.cursor = previous;
}
pub fn move_left(&mut self) {
if self.cursor == 0 {
return;
}
self.cursor = self.buffer[..self.cursor]
.char_indices()
.last()
.map_or(0, |(idx, _)| idx);
}
pub fn move_right(&mut self) {
if self.cursor >= self.buffer.len() {
return;
}
if let Some(next) = self.buffer[self.cursor..].chars().next() {
self.cursor += next.len_utf8();
}
}
pub fn move_home(&mut self) {
self.cursor = 0;
}
pub fn move_end(&mut self) {
self.cursor = self.buffer.len();
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.buffer
}
#[cfg(test)]
#[must_use]
pub fn cursor(&self) -> usize {
self.cursor
}
pub fn clear(&mut self) {
self.buffer.clear();
self.cursor = 0;
}
pub fn replace(&mut self, value: impl Into<String>) {
self.buffer = value.into();
self.cursor = self.buffer.len();
}
#[must_use]
fn current_command_prefix(&self) -> Option<&str> {
if self.cursor != self.buffer.len() {
return None;
}
let prefix = &self.buffer[..self.cursor];
if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') {
return None;
}
Some(prefix)
}
pub fn complete_slash_command(&mut self, candidates: &[String]) -> bool {
let Some(prefix) = self.current_command_prefix() else {
return false;
};
let matches = candidates
.iter()
.filter(|candidate| candidate.starts_with(prefix))
.map(String::as_str)
.collect::<Vec<_>>();
if matches.is_empty() {
return false;
}
let replacement = longest_common_prefix(&matches);
if replacement == prefix {
return false;
}
self.replace(replacement);
true
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RenderedBuffer {
lines: Vec<String>,
cursor_row: u16,
cursor_col: u16,
}
impl RenderedBuffer {
#[must_use]
pub fn line_count(&self) -> usize {
self.lines.len()
}
fn write(&self, out: &mut impl Write) -> io::Result<()> {
for (index, line) in self.lines.iter().enumerate() {
if index > 0 {
writeln!(out)?;
}
write!(out, "{line}")?;
}
Ok(())
}
#[cfg(test)]
#[must_use]
pub fn lines(&self) -> &[String] {
&self.lines
}
#[cfg(test)]
#[must_use]
pub fn cursor_position(&self) -> (u16, u16) {
(self.cursor_row, self.cursor_col)
}
}
use rustyline::completion::{Completer, Pair};
use rustyline::error::ReadlineError;
use rustyline::highlight::{CmdKind, Highlighter};
use rustyline::hint::Hinter;
use rustyline::history::DefaultHistory;
use rustyline::validate::Validator;
use rustyline::{
Cmd, CompletionType, Config, Context, EditMode, Editor, Helper, KeyCode, KeyEvent, Modifiers,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReadOutcome {
@ -169,25 +20,105 @@ pub enum ReadOutcome {
Exit,
}
struct SlashCommandHelper {
completions: Vec<String>,
current_line: RefCell<String>,
}
impl SlashCommandHelper {
fn new(completions: Vec<String>) -> Self {
Self {
completions: normalize_completions(completions),
current_line: RefCell::new(String::new()),
}
}
fn reset_current_line(&self) {
self.current_line.borrow_mut().clear();
}
fn current_line(&self) -> String {
self.current_line.borrow().clone()
}
fn set_current_line(&self, line: &str) {
let mut current = self.current_line.borrow_mut();
current.clear();
current.push_str(line);
}
fn set_completions(&mut self, completions: Vec<String>) {
self.completions = normalize_completions(completions);
}
}
impl Completer for SlashCommandHelper {
type Candidate = Pair;
fn complete(
&self,
line: &str,
pos: usize,
_ctx: &Context<'_>,
) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
let Some(prefix) = slash_command_prefix(line, pos) else {
return Ok((0, Vec::new()));
};
let matches = self
.completions
.iter()
.filter(|candidate| candidate.starts_with(prefix))
.map(|candidate| Pair {
display: candidate.clone(),
replacement: candidate.clone(),
})
.collect();
Ok((0, matches))
}
}
impl Hinter for SlashCommandHelper {
type Hint = String;
}
impl Highlighter for SlashCommandHelper {
fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
self.set_current_line(line);
Cow::Borrowed(line)
}
fn highlight_char(&self, line: &str, _pos: usize, _kind: CmdKind) -> bool {
self.set_current_line(line);
false
}
}
impl Validator for SlashCommandHelper {}
impl Helper for SlashCommandHelper {}
pub struct LineEditor {
prompt: String,
continuation_prompt: String,
history: Vec<String>,
history_index: Option<usize>,
draft: Option<String>,
completions: Vec<String>,
editor: Editor<SlashCommandHelper, DefaultHistory>,
}
impl LineEditor {
#[must_use]
pub fn new(prompt: impl Into<String>, completions: Vec<String>) -> Self {
let config = Config::builder()
.completion_type(CompletionType::List)
.edit_mode(EditMode::Emacs)
.build();
let mut editor = Editor::<SlashCommandHelper, DefaultHistory>::with_config(config)
.expect("rustyline editor should initialize");
editor.set_helper(Some(SlashCommandHelper::new(completions)));
editor.bind_sequence(KeyEvent(KeyCode::Char('J'), Modifiers::CTRL), Cmd::Newline);
editor.bind_sequence(KeyEvent(KeyCode::Enter, Modifiers::SHIFT), Cmd::Newline);
Self {
prompt: prompt.into(),
continuation_prompt: String::from("> "),
history: Vec::new(),
history_index: None,
draft: None,
completions,
editor,
}
}
@ -196,9 +127,14 @@ impl LineEditor {
if entry.trim().is_empty() {
return;
}
self.history.push(entry);
self.history_index = None;
self.draft = None;
let _ = self.editor.add_history_entry(entry);
}
pub fn set_completions(&mut self, completions: Vec<String>) {
if let Some(helper) = self.editor.helper_mut() {
helper.set_completions(completions);
}
}
pub fn read_line(&mut self) -> io::Result<ReadOutcome> {
@ -206,43 +142,41 @@ impl LineEditor {
return self.read_line_fallback();
}
enable_raw_mode()?;
let mut stdout = io::stdout();
let mut input = InputBuffer::new();
let mut rendered_lines = 1usize;
self.redraw(&mut stdout, &input, rendered_lines)?;
if let Some(helper) = self.editor.helper_mut() {
helper.reset_current_line();
}
loop {
let event = event::read()?;
if let Event::Key(key) = event {
match self.handle_key(key, &mut input) {
EditorAction::Continue => {
rendered_lines = self.redraw(&mut stdout, &input, rendered_lines)?;
}
EditorAction::Submit => {
disable_raw_mode()?;
writeln!(stdout)?;
self.history_index = None;
self.draft = None;
return Ok(ReadOutcome::Submit(input.as_str().to_owned()));
}
EditorAction::Cancel => {
disable_raw_mode()?;
writeln!(stdout)?;
self.history_index = None;
self.draft = None;
return Ok(ReadOutcome::Cancel);
}
EditorAction::Exit => {
disable_raw_mode()?;
writeln!(stdout)?;
self.history_index = None;
self.draft = None;
return Ok(ReadOutcome::Exit);
match self.editor.readline(&self.prompt) {
Ok(line) => Ok(ReadOutcome::Submit(line)),
Err(ReadlineError::Interrupted) => {
let has_input = !self.current_line().is_empty();
self.finish_interrupted_read()?;
if has_input {
Ok(ReadOutcome::Cancel)
} else {
Ok(ReadOutcome::Exit)
}
}
Err(ReadlineError::Eof) => {
self.finish_interrupted_read()?;
Ok(ReadOutcome::Exit)
}
Err(error) => Err(io::Error::other(error)),
}
}
fn current_line(&self) -> String {
self.editor
.helper()
.map_or_else(String::new, SlashCommandHelper::current_line)
}
fn finish_interrupted_read(&mut self) -> io::Result<()> {
if let Some(helper) = self.editor.helper_mut() {
helper.reset_current_line();
}
let mut stdout = io::stdout();
writeln!(stdout)
}
fn read_line_fallback(&self) -> io::Result<ReadOutcome> {
@ -261,388 +195,136 @@ impl LineEditor {
}
Ok(ReadOutcome::Submit(buffer))
}
#[allow(clippy::too_many_lines)]
fn handle_key(&mut self, key: KeyEvent, input: &mut InputBuffer) -> EditorAction {
match key {
KeyEvent {
code: KeyCode::Char('c'),
modifiers,
..
} if modifiers.contains(KeyModifiers::CONTROL) => {
if input.as_str().is_empty() {
EditorAction::Exit
} else {
input.clear();
self.history_index = None;
self.draft = None;
EditorAction::Cancel
}
}
KeyEvent {
code: KeyCode::Char('j'),
modifiers,
..
} if modifiers.contains(KeyModifiers::CONTROL) => {
input.insert_newline();
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Enter,
modifiers,
..
} if modifiers.contains(KeyModifiers::SHIFT) => {
input.insert_newline();
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Enter,
..
} => EditorAction::Submit,
KeyEvent {
code: KeyCode::Backspace,
..
} => {
input.backspace();
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Left,
..
} => {
input.move_left();
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Right,
..
} => {
input.move_right();
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Up, ..
} => {
self.navigate_history_up(input);
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Down,
..
} => {
self.navigate_history_down(input);
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Tab, ..
} => {
input.complete_slash_command(&self.completions);
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Home,
..
} => {
input.move_home();
EditorAction::Continue
}
KeyEvent {
code: KeyCode::End, ..
} => {
input.move_end();
EditorAction::Continue
}
KeyEvent {
code: KeyCode::Esc, ..
} => {
input.clear();
self.history_index = None;
self.draft = None;
EditorAction::Cancel
}
KeyEvent {
code: KeyCode::Char(ch),
modifiers,
..
} if modifiers.is_empty() || modifiers == KeyModifiers::SHIFT => {
input.insert(ch);
self.history_index = None;
self.draft = None;
EditorAction::Continue
}
_ => EditorAction::Continue,
}
}
fn navigate_history_up(&mut self, input: &mut InputBuffer) {
if self.history.is_empty() {
return;
}
match self.history_index {
Some(0) => {}
Some(index) => {
let next_index = index - 1;
input.replace(self.history[next_index].clone());
self.history_index = Some(next_index);
}
None => {
self.draft = Some(input.as_str().to_owned());
let next_index = self.history.len() - 1;
input.replace(self.history[next_index].clone());
self.history_index = Some(next_index);
}
}
}
fn navigate_history_down(&mut self, input: &mut InputBuffer) {
let Some(index) = self.history_index else {
return;
};
if index + 1 < self.history.len() {
let next_index = index + 1;
input.replace(self.history[next_index].clone());
self.history_index = Some(next_index);
return;
}
input.replace(self.draft.take().unwrap_or_default());
self.history_index = None;
}
fn redraw(
&self,
out: &mut impl Write,
input: &InputBuffer,
previous_line_count: usize,
) -> io::Result<usize> {
let rendered = render_buffer(&self.prompt, &self.continuation_prompt, input);
if previous_line_count > 1 {
queue!(out, MoveUp(saturating_u16(previous_line_count - 1)))?;
}
queue!(out, MoveToColumn(0), Clear(ClearType::FromCursorDown),)?;
rendered.write(out)?;
queue!(
out,
MoveUp(saturating_u16(rendered.line_count().saturating_sub(1))),
MoveToColumn(0),
)?;
if rendered.cursor_row > 0 {
queue!(out, MoveDown(rendered.cursor_row))?;
}
queue!(out, MoveToColumn(rendered.cursor_col))?;
out.flush()?;
Ok(rendered.line_count())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum EditorAction {
Continue,
Submit,
Cancel,
Exit,
fn slash_command_prefix(line: &str, pos: usize) -> Option<&str> {
if pos != line.len() {
return None;
}
let prefix = &line[..pos];
if !prefix.starts_with('/') {
return None;
}
Some(prefix)
}
#[must_use]
pub fn render_buffer(
prompt: &str,
continuation_prompt: &str,
input: &InputBuffer,
) -> RenderedBuffer {
let before_cursor = &input.as_str()[..input.cursor];
let cursor_row = saturating_u16(before_cursor.chars().filter(|ch| *ch == '\n').count());
let cursor_line = before_cursor.rsplit('\n').next().unwrap_or_default();
let cursor_prompt = if cursor_row == 0 {
prompt
} else {
continuation_prompt
};
let cursor_col = saturating_u16(cursor_prompt.chars().count() + cursor_line.chars().count());
let mut lines = Vec::new();
for (index, line) in input.as_str().split('\n').enumerate() {
let prefix = if index == 0 {
prompt
} else {
continuation_prompt
};
lines.push(format!("{prefix}{line}"));
}
if lines.is_empty() {
lines.push(prompt.to_string());
}
RenderedBuffer {
lines,
cursor_row,
cursor_col,
}
}
#[must_use]
fn longest_common_prefix(values: &[&str]) -> String {
let Some(first) = values.first() else {
return String::new();
};
let mut prefix = (*first).to_string();
for value in values.iter().skip(1) {
while !value.starts_with(&prefix) {
prefix.pop();
if prefix.is_empty() {
break;
}
}
}
prefix
}
#[must_use]
fn saturating_u16(value: usize) -> u16 {
u16::try_from(value).unwrap_or(u16::MAX)
fn normalize_completions(completions: Vec<String>) -> Vec<String> {
let mut seen = BTreeSet::new();
completions
.into_iter()
.filter(|candidate| candidate.starts_with('/'))
.filter(|candidate| seen.insert(candidate.clone()))
.collect()
}
#[cfg(test)]
mod tests {
use super::{render_buffer, InputBuffer, LineEditor};
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use super::{slash_command_prefix, LineEditor, SlashCommandHelper};
use rustyline::completion::Completer;
use rustyline::highlight::Highlighter;
use rustyline::history::{DefaultHistory, History};
use rustyline::Context;
fn key(code: KeyCode) -> KeyEvent {
KeyEvent::new(code, KeyModifiers::NONE)
#[test]
fn extracts_terminal_slash_command_prefixes_with_arguments() {
assert_eq!(slash_command_prefix("/he", 3), Some("/he"));
assert_eq!(slash_command_prefix("/help me", 8), Some("/help me"));
assert_eq!(
slash_command_prefix("/session switch ses", 19),
Some("/session switch ses")
);
assert_eq!(slash_command_prefix("hello", 5), None);
assert_eq!(slash_command_prefix("/help", 2), None);
}
#[test]
fn supports_basic_line_editing() {
let mut input = InputBuffer::new();
input.insert('h');
input.insert('i');
input.move_end();
input.insert_newline();
input.insert('x');
assert_eq!(input.as_str(), "hi\nx");
assert_eq!(input.cursor(), 4);
input.move_left();
input.backspace();
assert_eq!(input.as_str(), "hix");
assert_eq!(input.cursor(), 2);
}
#[test]
fn completes_unique_slash_command() {
let mut input = InputBuffer::new();
for ch in "/he".chars() {
input.insert(ch);
}
assert!(input.complete_slash_command(&[
fn completes_matching_slash_commands() {
let helper = SlashCommandHelper::new(vec![
"/help".to_string(),
"/hello".to_string(),
"/status".to_string(),
]));
assert_eq!(input.as_str(), "/hel");
]);
let history = DefaultHistory::new();
let ctx = Context::new(&history);
let (start, matches) = helper
.complete("/he", 3, &ctx)
.expect("completion should work");
assert!(input.complete_slash_command(&["/help".to_string(), "/status".to_string()]));
assert_eq!(input.as_str(), "/help");
}
#[test]
fn ignores_completion_when_prefix_is_not_a_slash_command() {
let mut input = InputBuffer::new();
for ch in "hello".chars() {
input.insert(ch);
}
assert!(!input.complete_slash_command(&["/help".to_string()]));
assert_eq!(input.as_str(), "hello");
}
#[test]
fn history_navigation_restores_current_draft() {
let mut editor = LineEditor::new(" ", vec![]);
editor.push_history("/help");
editor.push_history("status report");
let mut input = InputBuffer::new();
for ch in "draft".chars() {
input.insert(ch);
}
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
assert_eq!(input.as_str(), "status report");
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
assert_eq!(input.as_str(), "/help");
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
assert_eq!(input.as_str(), "status report");
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
assert_eq!(input.as_str(), "draft");
}
#[test]
fn tab_key_completes_from_editor_candidates() {
let mut editor = LineEditor::new(
" ",
vec![
"/help".to_string(),
"/status".to_string(),
"/session".to_string(),
],
);
let mut input = InputBuffer::new();
for ch in "/st".chars() {
input.insert(ch);
}
let _ = editor.handle_key(key(KeyCode::Tab), &mut input);
assert_eq!(input.as_str(), "/status");
}
#[test]
fn renders_multiline_buffers_with_continuation_prompt() {
let mut input = InputBuffer::new();
for ch in "hello\nworld".chars() {
if ch == '\n' {
input.insert_newline();
} else {
input.insert(ch);
}
}
let rendered = render_buffer(" ", "> ", &input);
assert_eq!(start, 0);
assert_eq!(
rendered.lines(),
&[" hello".to_string(), "> world".to_string()]
matches
.into_iter()
.map(|candidate| candidate.replacement)
.collect::<Vec<_>>(),
vec!["/help".to_string(), "/hello".to_string()]
);
assert_eq!(rendered.cursor_position(), (1, 7));
}
#[test]
fn ctrl_c_exits_only_when_buffer_is_empty() {
let mut editor = LineEditor::new(" ", vec![]);
let mut empty = InputBuffer::new();
assert!(matches!(
editor.handle_key(
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
&mut empty,
),
super::EditorAction::Exit
));
fn completes_matching_slash_command_arguments() {
let helper = SlashCommandHelper::new(vec![
"/model".to_string(),
"/model opus".to_string(),
"/model sonnet".to_string(),
"/session switch alpha".to_string(),
]);
let history = DefaultHistory::new();
let ctx = Context::new(&history);
let (start, matches) = helper
.complete("/model o", 8, &ctx)
.expect("completion should work");
let mut filled = InputBuffer::new();
filled.insert('x');
assert!(matches!(
editor.handle_key(
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
&mut filled,
),
super::EditorAction::Cancel
));
assert!(filled.as_str().is_empty());
assert_eq!(start, 0);
assert_eq!(
matches
.into_iter()
.map(|candidate| candidate.replacement)
.collect::<Vec<_>>(),
vec!["/model opus".to_string()]
);
}
#[test]
fn ignores_non_slash_command_completion_requests() {
let helper = SlashCommandHelper::new(vec!["/help".to_string()]);
let history = DefaultHistory::new();
let ctx = Context::new(&history);
let (_, matches) = helper
.complete("hello", 5, &ctx)
.expect("completion should work");
assert!(matches.is_empty());
}
#[test]
fn tracks_current_buffer_through_highlighter() {
let helper = SlashCommandHelper::new(Vec::new());
let _ = helper.highlight("draft", 5);
assert_eq!(helper.current_line(), "draft");
}
#[test]
fn push_history_ignores_blank_entries() {
let mut editor = LineEditor::new("> ", vec!["/help".to_string()]);
editor.push_history(" ");
editor.push_history("/help");
assert_eq!(editor.editor.history().len(), 1);
}
#[test]
fn set_completions_replaces_and_normalizes_candidates() {
let mut editor = LineEditor::new("> ", vec!["/help".to_string()]);
editor.set_completions(vec![
"/model opus".to_string(),
"/model opus".to_string(),
"status".to_string(),
]);
let helper = editor.editor.helper().expect("helper should exist");
assert_eq!(helper.completions, vec!["/model opus".to_string()]);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,5 @@
use std::fmt::Write as FmtWrite;
use std::io::{self, Write};
use std::thread;
use std::time::Duration;
use crossterm::cursor::{MoveToColumn, RestorePosition, SavePosition};
use crossterm::style::{Color, Print, ResetColor, SetForegroundColor, Stylize};
@ -22,6 +20,7 @@ pub struct ColorTheme {
link: Color,
quote: Color,
table_border: Color,
code_block_border: Color,
spinner_active: Color,
spinner_done: Color,
spinner_failed: Color,
@ -37,6 +36,7 @@ impl Default for ColorTheme {
link: Color::Blue,
quote: Color::DarkGrey,
table_border: Color::DarkCyan,
code_block_border: Color::DarkGrey,
spinner_active: Color::Blue,
spinner_done: Color::Green,
spinner_failed: Color::Red,
@ -154,32 +154,63 @@ impl TableState {
struct RenderState {
emphasis: usize,
strong: usize,
heading_level: Option<u8>,
quote: usize,
list_stack: Vec<ListKind>,
link_stack: Vec<LinkState>,
table: Option<TableState>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct LinkState {
destination: String,
text: String,
}
impl RenderState {
fn style_text(&self, text: &str, theme: &ColorTheme) -> String {
let mut styled = text.to_string();
if self.strong > 0 {
styled = format!("{}", styled.bold().with(theme.strong));
let mut style = text.stylize();
if matches!(self.heading_level, Some(1 | 2)) || self.strong > 0 {
style = style.bold();
}
if self.emphasis > 0 {
styled = format!("{}", styled.italic().with(theme.emphasis));
}
if self.quote > 0 {
styled = format!("{}", styled.with(theme.quote));
}
styled
style = style.italic();
}
fn capture_target_mut<'a>(&'a mut self, output: &'a mut String) -> &'a mut String {
if let Some(table) = self.table.as_mut() {
&mut table.current_cell
} else {
output
if let Some(level) = self.heading_level {
style = match level {
1 => style.with(theme.heading),
2 => style.white(),
3 => style.with(Color::Blue),
_ => style.with(Color::Grey),
};
} else if self.strong > 0 {
style = style.with(theme.strong);
} else if self.emphasis > 0 {
style = style.with(theme.emphasis);
}
if self.quote > 0 {
style = style.with(theme.quote);
}
format!("{style}")
}
fn append_raw(&mut self, output: &mut String, text: &str) {
if let Some(link) = self.link_stack.last_mut() {
link.text.push_str(text);
} else if let Some(table) = self.table.as_mut() {
table.current_cell.push_str(text);
} else {
output.push_str(text);
}
}
fn append_styled(&mut self, output: &mut String, text: &str, theme: &ColorTheme) {
let styled = self.style_text(text, theme);
self.append_raw(output, &styled);
}
}
@ -218,13 +249,14 @@ impl TerminalRenderer {
#[must_use]
pub fn render_markdown(&self, markdown: &str) -> String {
let normalized = normalize_nested_fences(markdown);
let mut output = String::new();
let mut state = RenderState::default();
let mut code_language = String::new();
let mut code_buffer = String::new();
let mut in_code_block = false;
for event in Parser::new_ext(markdown, Options::all()) {
for event in Parser::new_ext(&normalized, Options::all()) {
self.render_event(
event,
&mut state,
@ -238,6 +270,11 @@ impl TerminalRenderer {
output.trim_end().to_string()
}
#[must_use]
pub fn markdown_to_ansi(&self, markdown: &str) -> String {
self.render_markdown(markdown)
}
#[allow(clippy::too_many_lines)]
fn render_event(
&self,
@ -249,15 +286,21 @@ impl TerminalRenderer {
in_code_block: &mut bool,
) {
match event {
Event::Start(Tag::Heading { level, .. }) => self.start_heading(level as u8, output),
Event::End(TagEnd::Heading(..) | TagEnd::Paragraph) => output.push_str("\n\n"),
Event::Start(Tag::Heading { level, .. }) => {
Self::start_heading(state, level as u8, output);
}
Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
Event::End(TagEnd::BlockQuote(..)) => {
state.quote = state.quote.saturating_sub(1);
output.push('\n');
}
Event::End(TagEnd::Heading(..)) => {
state.heading_level = None;
output.push_str("\n\n");
}
Event::End(TagEnd::Item) | Event::SoftBreak | Event::HardBreak => {
state.capture_target_mut(output).push('\n');
state.append_raw(output, "\n");
}
Event::Start(Tag::List(first_item)) => {
let kind = match first_item {
@ -293,41 +336,52 @@ impl TerminalRenderer {
Event::Code(code) => {
let rendered =
format!("{}", format!("`{code}`").with(self.color_theme.inline_code));
state.capture_target_mut(output).push_str(&rendered);
state.append_raw(output, &rendered);
}
Event::Rule => output.push_str("---\n"),
Event::Text(text) => {
self.push_text(text.as_ref(), state, output, code_buffer, *in_code_block);
}
Event::Html(html) | Event::InlineHtml(html) => {
state.capture_target_mut(output).push_str(&html);
state.append_raw(output, &html);
}
Event::FootnoteReference(reference) => {
let _ = write!(state.capture_target_mut(output), "[{reference}]");
state.append_raw(output, &format!("[{reference}]"));
}
Event::TaskListMarker(done) => {
state
.capture_target_mut(output)
.push_str(if done { "[x] " } else { "[ ] " });
state.append_raw(output, if done { "[x] " } else { "[ ] " });
}
Event::InlineMath(math) | Event::DisplayMath(math) => {
state.capture_target_mut(output).push_str(&math);
state.append_raw(output, &math);
}
Event::Start(Tag::Link { dest_url, .. }) => {
state.link_stack.push(LinkState {
destination: dest_url.to_string(),
text: String::new(),
});
}
Event::End(TagEnd::Link) => {
if let Some(link) = state.link_stack.pop() {
let label = if link.text.is_empty() {
link.destination.clone()
} else {
link.text
};
let rendered = format!(
"{}",
format!("[{dest_url}]")
format!("[{label}]({})", link.destination)
.underlined()
.with(self.color_theme.link)
);
state.capture_target_mut(output).push_str(&rendered);
state.append_raw(output, &rendered);
}
}
Event::Start(Tag::Image { dest_url, .. }) => {
let rendered = format!(
"{}",
format!("[image:{dest_url}]").with(self.color_theme.link)
);
state.capture_target_mut(output).push_str(&rendered);
state.append_raw(output, &rendered);
}
Event::Start(Tag::Table(..)) => state.table = Some(TableState::default()),
Event::End(TagEnd::Table) => {
@ -369,19 +423,15 @@ impl TerminalRenderer {
}
}
Event::Start(Tag::Paragraph | Tag::MetadataBlock(..) | _)
| Event::End(TagEnd::Link | TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {}
| Event::End(TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {}
}
}
fn start_heading(&self, level: u8, output: &mut String) {
fn start_heading(state: &mut RenderState, level: u8, output: &mut String) {
state.heading_level = Some(level);
if !output.is_empty() {
output.push('\n');
let prefix = match level {
1 => "# ",
2 => "## ",
3 => "### ",
_ => "#### ",
};
let _ = write!(output, "{}", prefix.bold().with(self.color_theme.heading));
}
}
fn start_quote(&self, state: &mut RenderState, output: &mut String) {
@ -405,20 +455,27 @@ impl TerminalRenderer {
}
fn start_code_block(&self, code_language: &str, output: &mut String) {
if !code_language.is_empty() {
let label = if code_language.is_empty() {
"code".to_string()
} else {
code_language.to_string()
};
let _ = writeln!(
output,
"{}",
format!("╭─ {code_language}").with(self.color_theme.heading)
format!("╭─ {label}")
.bold()
.with(self.color_theme.code_block_border)
);
}
}
fn finish_code_block(&self, code_buffer: &str, code_language: &str, output: &mut String) {
output.push_str(&self.highlight_code(code_buffer, code_language));
if !code_language.is_empty() {
let _ = write!(output, "{}", "╰─".with(self.color_theme.heading));
}
let _ = write!(
output,
"{}",
"╰─".bold().with(self.color_theme.code_block_border)
);
output.push_str("\n\n");
}
@ -433,8 +490,7 @@ impl TerminalRenderer {
if in_code_block {
code_buffer.push_str(text);
} else {
let rendered = state.style_text(text, &self.color_theme);
state.capture_target_mut(output).push_str(&rendered);
state.append_styled(output, text, &self.color_theme);
}
}
@ -521,9 +577,10 @@ impl TerminalRenderer {
for line in LinesWithEndings::from(code) {
match syntax_highlighter.highlight_line(line, &self.syntax_set) {
Ok(ranges) => {
colored_output.push_str(&as_24_bit_terminal_escaped(&ranges[..], false));
let escaped = as_24_bit_terminal_escaped(&ranges[..], false);
colored_output.push_str(&apply_code_block_background(&escaped));
}
Err(_) => colored_output.push_str(line),
Err(_) => colored_output.push_str(&apply_code_block_background(line)),
}
}
@ -531,16 +588,296 @@ impl TerminalRenderer {
}
pub fn stream_markdown(&self, markdown: &str, out: &mut impl Write) -> io::Result<()> {
let rendered_markdown = self.render_markdown(markdown);
for chunk in rendered_markdown.split_inclusive(char::is_whitespace) {
write!(out, "{chunk}")?;
out.flush()?;
thread::sleep(Duration::from_millis(8));
let rendered_markdown = self.markdown_to_ansi(markdown);
write!(out, "{rendered_markdown}")?;
if !rendered_markdown.ends_with('\n') {
writeln!(out)?;
}
writeln!(out)
out.flush()
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct MarkdownStreamState {
pending: String,
}
impl MarkdownStreamState {
#[must_use]
pub fn push(&mut self, renderer: &TerminalRenderer, delta: &str) -> Option<String> {
self.pending.push_str(delta);
let split = find_stream_safe_boundary(&self.pending)?;
let ready = self.pending[..split].to_string();
self.pending.drain(..split);
Some(renderer.markdown_to_ansi(&ready))
}
#[must_use]
pub fn flush(&mut self, renderer: &TerminalRenderer) -> Option<String> {
if self.pending.trim().is_empty() {
self.pending.clear();
None
} else {
let pending = std::mem::take(&mut self.pending);
Some(renderer.markdown_to_ansi(&pending))
}
}
}
fn apply_code_block_background(line: &str) -> String {
let trimmed = line.trim_end_matches('\n');
let trailing_newline = if trimmed.len() == line.len() {
""
} else {
"\n"
};
let with_background = trimmed.replace("\u{1b}[0m", "\u{1b}[0;48;5;236m");
format!("\u{1b}[48;5;236m{with_background}\u{1b}[0m{trailing_newline}")
}
/// Pre-process raw markdown so that fenced code blocks whose body contains
/// fence markers of equal or greater length are wrapped with a longer fence.
///
/// LLMs frequently emit triple-backtick code blocks that contain triple-backtick
/// examples. CommonMark (and pulldown-cmark) treats the inner marker as the
/// closing fence, breaking the render. This function detects the situation and
/// upgrades the outer fence to use enough backticks (or tildes) that the inner
/// markers become ordinary content.
fn normalize_nested_fences(markdown: &str) -> String {
// A fence line is either "labeled" (has an info string ⇒ always an opener)
// or "bare" (no info string ⇒ could be opener or closer).
#[derive(Debug, Clone)]
struct FenceLine {
char: char,
len: usize,
has_info: bool,
indent: usize,
}
fn parse_fence_line(line: &str) -> Option<FenceLine> {
let trimmed = line.trim_end_matches('\n').trim_end_matches('\r');
let indent = trimmed.chars().take_while(|c| *c == ' ').count();
if indent > 3 {
return None;
}
let rest = &trimmed[indent..];
let ch = rest.chars().next()?;
if ch != '`' && ch != '~' {
return None;
}
let len = rest.chars().take_while(|c| *c == ch).count();
if len < 3 {
return None;
}
let after = &rest[len..];
if ch == '`' && after.contains('`') {
return None;
}
let has_info = !after.trim().is_empty();
Some(FenceLine {
char: ch,
len,
has_info,
indent,
})
}
let lines: Vec<&str> = markdown.split_inclusive('\n').collect();
// Handle final line that may lack trailing newline.
// split_inclusive already keeps the original chunks, including a
// final chunk without '\n' if the input doesn't end with one.
// First pass: classify every line.
let fence_info: Vec<Option<FenceLine>> = lines.iter().map(|l| parse_fence_line(l)).collect();
// Second pass: pair openers with closers using a stack, recording
// (opener_idx, closer_idx) pairs plus the max fence length found between
// them.
struct StackEntry {
line_idx: usize,
fence: FenceLine,
}
let mut stack: Vec<StackEntry> = Vec::new();
// Paired blocks: (opener_line, closer_line, max_inner_fence_len)
let mut pairs: Vec<(usize, usize, usize)> = Vec::new();
for (i, fi) in fence_info.iter().enumerate() {
let Some(fl) = fi else { continue };
if fl.has_info {
// Labeled fence ⇒ always an opener.
stack.push(StackEntry {
line_idx: i,
fence: fl.clone(),
});
} else {
// Bare fence ⇒ try to close the top of the stack if compatible.
let closes_top = stack
.last()
.is_some_and(|top| top.fence.char == fl.char && fl.len >= top.fence.len);
if closes_top {
let opener = stack.pop().unwrap();
// Find max fence length of any fence line strictly between
// opener and closer (these are the nested fences).
let inner_max = fence_info[opener.line_idx + 1..i]
.iter()
.filter_map(|fi| fi.as_ref().map(|f| f.len))
.max()
.unwrap_or(0);
pairs.push((opener.line_idx, i, inner_max));
} else {
// Treat as opener.
stack.push(StackEntry {
line_idx: i,
fence: fl.clone(),
});
}
}
}
// Determine which lines need rewriting. A pair needs rewriting when
// its opener length <= max inner fence length.
struct Rewrite {
char: char,
new_len: usize,
indent: usize,
}
let mut rewrites: std::collections::HashMap<usize, Rewrite> = std::collections::HashMap::new();
for (opener_idx, closer_idx, inner_max) in &pairs {
let opener_fl = fence_info[*opener_idx].as_ref().unwrap();
if opener_fl.len <= *inner_max {
let new_len = inner_max + 1;
let info_part = {
let trimmed = lines[*opener_idx]
.trim_end_matches('\n')
.trim_end_matches('\r');
let rest = &trimmed[opener_fl.indent..];
rest[opener_fl.len..].to_string()
};
rewrites.insert(
*opener_idx,
Rewrite {
char: opener_fl.char,
new_len,
indent: opener_fl.indent,
},
);
let closer_fl = fence_info[*closer_idx].as_ref().unwrap();
rewrites.insert(
*closer_idx,
Rewrite {
char: closer_fl.char,
new_len,
indent: closer_fl.indent,
},
);
// Store info string only in the opener; closer keeps the trailing
// portion which is already handled through the original line.
// Actually, we rebuild both lines from scratch below, including
// the info string for the opener.
let _ = info_part; // consumed in rebuild
}
}
if rewrites.is_empty() {
return markdown.to_string();
}
// Rebuild.
let mut out = String::with_capacity(markdown.len() + rewrites.len() * 4);
for (i, line) in lines.iter().enumerate() {
if let Some(rw) = rewrites.get(&i) {
let fence_str: String = std::iter::repeat(rw.char).take(rw.new_len).collect();
let indent_str: String = std::iter::repeat(' ').take(rw.indent).collect();
// Recover the original info string (if any) and trailing newline.
let trimmed = line.trim_end_matches('\n').trim_end_matches('\r');
let fi = fence_info[i].as_ref().unwrap();
let info = &trimmed[fi.indent + fi.len..];
let trailing = &line[trimmed.len()..];
out.push_str(&indent_str);
out.push_str(&fence_str);
out.push_str(info);
out.push_str(trailing);
} else {
out.push_str(line);
}
}
out
}
fn find_stream_safe_boundary(markdown: &str) -> Option<usize> {
let mut open_fence: Option<FenceMarker> = None;
let mut last_boundary = None;
for (offset, line) in markdown.split_inclusive('\n').scan(0usize, |cursor, line| {
let start = *cursor;
*cursor += line.len();
Some((start, line))
}) {
let line_without_newline = line.trim_end_matches('\n');
if let Some(opener) = open_fence {
if line_closes_fence(line_without_newline, opener) {
open_fence = None;
last_boundary = Some(offset + line.len());
}
continue;
}
if let Some(opener) = parse_fence_opener(line_without_newline) {
open_fence = Some(opener);
continue;
}
if line_without_newline.trim().is_empty() {
last_boundary = Some(offset + line.len());
}
}
last_boundary
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct FenceMarker {
character: char,
length: usize,
}
fn parse_fence_opener(line: &str) -> Option<FenceMarker> {
let indent = line.chars().take_while(|c| *c == ' ').count();
if indent > 3 {
return None;
}
let rest = &line[indent..];
let character = rest.chars().next()?;
if character != '`' && character != '~' {
return None;
}
let length = rest.chars().take_while(|c| *c == character).count();
if length < 3 {
return None;
}
let info_string = &rest[length..];
if character == '`' && info_string.contains('`') {
return None;
}
Some(FenceMarker { character, length })
}
fn line_closes_fence(line: &str, opener: FenceMarker) -> bool {
let indent = line.chars().take_while(|c| *c == ' ').count();
if indent > 3 {
return false;
}
let rest = &line[indent..];
let length = rest.chars().take_while(|c| *c == opener.character).count();
if length < opener.length {
return false;
}
rest[length..].chars().all(|c| c == ' ' || c == '\t')
}
fn visible_width(input: &str) -> usize {
strip_ansi(input).chars().count()
}
@ -569,7 +906,7 @@ fn strip_ansi(input: &str) -> String {
#[cfg(test)]
mod tests {
use super::{strip_ansi, Spinner, TerminalRenderer};
use super::{strip_ansi, MarkdownStreamState, Spinner, TerminalRenderer};
#[test]
fn renders_markdown_with_styling_and_lists() {
@ -583,16 +920,28 @@ mod tests {
assert!(markdown_output.contains('\u{1b}'));
}
#[test]
fn renders_links_as_colored_markdown_labels() {
let terminal_renderer = TerminalRenderer::new();
let markdown_output =
terminal_renderer.render_markdown("See [Claw](https://example.com/docs) now.");
let plain_text = strip_ansi(&markdown_output);
assert!(plain_text.contains("[Claw](https://example.com/docs)"));
assert!(markdown_output.contains('\u{1b}'));
}
#[test]
fn highlights_fenced_code_blocks() {
let terminal_renderer = TerminalRenderer::new();
let markdown_output =
terminal_renderer.render_markdown("```rust\nfn hi() { println!(\"hi\"); }\n```");
terminal_renderer.markdown_to_ansi("```rust\nfn hi() { println!(\"hi\"); }\n```");
let plain_text = strip_ansi(&markdown_output);
assert!(plain_text.contains("╭─ rust"));
assert!(plain_text.contains("fn hi"));
assert!(markdown_output.contains('\u{1b}'));
assert!(markdown_output.contains("[48;5;236m"));
}
#[test]
@ -623,6 +972,80 @@ mod tests {
assert!(markdown_output.contains('\u{1b}'));
}
#[test]
fn streaming_state_waits_for_complete_blocks() {
let renderer = TerminalRenderer::new();
let mut state = MarkdownStreamState::default();
assert_eq!(state.push(&renderer, "# Heading"), None);
let flushed = state
.push(&renderer, "\n\nParagraph\n\n")
.expect("completed block");
let plain_text = strip_ansi(&flushed);
assert!(plain_text.contains("Heading"));
assert!(plain_text.contains("Paragraph"));
assert_eq!(state.push(&renderer, "```rust\nfn main() {}\n"), None);
let code = state
.push(&renderer, "```\n")
.expect("closed code fence flushes");
assert!(strip_ansi(&code).contains("fn main()"));
}
#[test]
fn streaming_state_holds_outer_fence_with_nested_inner_fence() {
let renderer = TerminalRenderer::new();
let mut state = MarkdownStreamState::default();
assert_eq!(
state.push(&renderer, "````markdown\n```rust\nfn inner() {}\n"),
None,
"inner triple backticks must not close the outer four-backtick fence"
);
assert_eq!(
state.push(&renderer, "```\n"),
None,
"closing the inner fence must not flush the outer fence"
);
let flushed = state
.push(&renderer, "````\n")
.expect("closing the outer four-backtick fence flushes the buffered block");
let plain_text = strip_ansi(&flushed);
assert!(plain_text.contains("fn inner()"));
assert!(plain_text.contains("```rust"));
}
#[test]
fn streaming_state_distinguishes_backtick_and_tilde_fences() {
let renderer = TerminalRenderer::new();
let mut state = MarkdownStreamState::default();
assert_eq!(state.push(&renderer, "~~~text\n"), None);
assert_eq!(
state.push(&renderer, "```\nstill inside tilde fence\n"),
None,
"a backtick fence cannot close a tilde-opened fence"
);
assert_eq!(state.push(&renderer, "```\n"), None);
let flushed = state
.push(&renderer, "~~~\n")
.expect("matching tilde marker closes the fence");
let plain_text = strip_ansi(&flushed);
assert!(plain_text.contains("still inside tilde fence"));
}
#[test]
fn renders_nested_fenced_code_block_preserves_inner_markers() {
let terminal_renderer = TerminalRenderer::new();
let markdown_output =
terminal_renderer.markdown_to_ansi("````markdown\n```rust\nfn nested() {}\n```\n````");
let plain_text = strip_ansi(&markdown_output);
assert!(plain_text.contains("╭─ markdown"));
assert!(plain_text.contains("```rust"));
assert!(plain_text.contains("fn nested()"));
}
#[test]
fn spinner_advances_frames() {
let terminal_renderer = TerminalRenderer::new();

View File

@ -0,0 +1,298 @@
use std::fs;
use std::path::{Path, PathBuf};
use std::process::{Command, Output};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use runtime::Session;
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
#[test]
fn status_command_applies_model_and_permission_mode_flags() {
// given
let temp_dir = unique_temp_dir("status-flags");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
// when
let output = Command::new(env!("CARGO_BIN_EXE_claw"))
.current_dir(&temp_dir)
.args([
"--model",
"sonnet",
"--permission-mode",
"read-only",
"status",
])
.output()
.expect("claw should launch");
// then
assert_success(&output);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Status"));
assert!(stdout.contains("Model claude-sonnet-4-6"));
assert!(stdout.contains("Permission mode read-only"));
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
}
#[test]
fn resume_flag_loads_a_saved_session_and_dispatches_status() {
// given
let temp_dir = unique_temp_dir("resume-status");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = write_session(&temp_dir, "resume-status");
// when
let output = Command::new(env!("CARGO_BIN_EXE_claw"))
.current_dir(&temp_dir)
.args([
"--resume",
session_path.to_str().expect("utf8 path"),
"/status",
])
.output()
.expect("claw should launch");
// then
assert_success(&output);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Status"));
assert!(stdout.contains("Messages 1"));
assert!(stdout.contains("Session "));
assert!(stdout.contains(session_path.to_str().expect("utf8 path")));
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
}
#[test]
fn slash_command_names_match_known_commands_and_suggest_nearby_unknown_ones() {
// given
let temp_dir = unique_temp_dir("slash-dispatch");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
// when
let help_output = Command::new(env!("CARGO_BIN_EXE_claw"))
.current_dir(&temp_dir)
.arg("/help")
.output()
.expect("claw should launch");
let unknown_output = Command::new(env!("CARGO_BIN_EXE_claw"))
.current_dir(&temp_dir)
.arg("/zstats")
.output()
.expect("claw should launch");
// then
assert_success(&help_output);
let help_stdout = String::from_utf8(help_output.stdout).expect("stdout should be utf8");
assert!(help_stdout.contains("Interactive slash commands:"));
assert!(help_stdout.contains("/status"));
assert!(
!unknown_output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&unknown_output.stdout),
String::from_utf8_lossy(&unknown_output.stderr)
);
let stderr = String::from_utf8(unknown_output.stderr).expect("stderr should be utf8");
assert!(stderr.contains("unknown slash command outside the REPL: /zstats"));
assert!(stderr.contains("Did you mean"));
assert!(stderr.contains("/status"));
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
}
#[test]
fn omc_namespaced_slash_commands_surface_a_targeted_compatibility_hint() {
let temp_dir = unique_temp_dir("slash-dispatch-omc");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let output = Command::new(env!("CARGO_BIN_EXE_claw"))
.current_dir(&temp_dir)
.arg("/oh-my-claudecode:hud")
.output()
.expect("claw should launch");
assert!(
!output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
let stderr = String::from_utf8(output.stderr).expect("stderr should be utf8");
assert!(stderr.contains("unknown slash command outside the REPL: /oh-my-claudecode:hud"));
assert!(stderr.contains("Claude Code/OMC plugin command"));
assert!(stderr.contains("does not yet load plugin slash commands"));
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
}
#[test]
fn config_command_loads_defaults_from_standard_config_locations() {
// given
let temp_dir = unique_temp_dir("config-defaults");
let config_home = temp_dir.join("home").join(".claw");
fs::create_dir_all(temp_dir.join(".claw")).expect("project config dir should exist");
fs::create_dir_all(&config_home).expect("home config dir should exist");
fs::write(config_home.join("settings.json"), r#"{"model":"haiku"}"#)
.expect("write user settings");
fs::write(temp_dir.join(".claw.json"), r#"{"model":"sonnet"}"#)
.expect("write project settings");
fs::write(
temp_dir.join(".claw").join("settings.local.json"),
r#"{"model":"opus"}"#,
)
.expect("write local settings");
let session_path = write_session(&temp_dir, "config-defaults");
// when
let output = command_in(&temp_dir)
.env("CLAW_CONFIG_HOME", &config_home)
.args([
"--resume",
session_path.to_str().expect("utf8 path"),
"/config",
"model",
])
.output()
.expect("claw should launch");
// then
assert_success(&output);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Config"));
assert!(stdout.contains("Loaded files 3"));
assert!(stdout.contains("Merged section: model"));
assert!(stdout.contains("opus"));
assert!(stdout.contains(
config_home
.join("settings.json")
.to_str()
.expect("utf8 path")
));
assert!(stdout.contains(temp_dir.join(".claw.json").to_str().expect("utf8 path")));
assert!(stdout.contains(
temp_dir
.join(".claw")
.join("settings.local.json")
.to_str()
.expect("utf8 path")
));
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
}
#[test]
fn doctor_command_runs_as_a_local_shell_entrypoint() {
// given
let temp_dir = unique_temp_dir("doctor-entrypoint");
let config_home = temp_dir.join("home").join(".claw");
fs::create_dir_all(&config_home).expect("config home should exist");
// when
let output = command_in(&temp_dir)
.env("CLAW_CONFIG_HOME", &config_home)
.env_remove("ANTHROPIC_API_KEY")
.env_remove("ANTHROPIC_AUTH_TOKEN")
.env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9")
.arg("doctor")
.output()
.expect("claw doctor should launch");
// then
assert_success(&output);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Doctor"));
assert!(stdout.contains("Auth"));
assert!(stdout.contains("Config"));
assert!(stdout.contains("Workspace"));
assert!(stdout.contains("Sandbox"));
assert!(!stdout.contains("Thinking"));
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
}
#[test]
fn local_subcommand_help_does_not_fall_through_to_runtime_or_provider_calls() {
let temp_dir = unique_temp_dir("subcommand-help");
let config_home = temp_dir.join("home").join(".claw");
fs::create_dir_all(&config_home).expect("config home should exist");
let doctor_help = command_in(&temp_dir)
.env("CLAW_CONFIG_HOME", &config_home)
.env_remove("ANTHROPIC_API_KEY")
.env_remove("ANTHROPIC_AUTH_TOKEN")
.env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9")
.args(["doctor", "--help"])
.output()
.expect("doctor help should launch");
let status_help = command_in(&temp_dir)
.env("CLAW_CONFIG_HOME", &config_home)
.env_remove("ANTHROPIC_API_KEY")
.env_remove("ANTHROPIC_AUTH_TOKEN")
.env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9")
.args(["status", "--help"])
.output()
.expect("status help should launch");
assert_success(&doctor_help);
let doctor_stdout = String::from_utf8(doctor_help.stdout).expect("stdout should be utf8");
assert!(doctor_stdout.contains("Usage claw doctor"));
assert!(doctor_stdout.contains("local-only health report"));
assert!(!doctor_stdout.contains("Thinking"));
assert_success(&status_help);
let status_stdout = String::from_utf8(status_help.stdout).expect("stdout should be utf8");
assert!(status_stdout.contains("Usage claw status"));
assert!(status_stdout.contains("local workspace snapshot"));
assert!(!status_stdout.contains("Thinking"));
let doctor_stderr = String::from_utf8(doctor_help.stderr).expect("stderr should be utf8");
let status_stderr = String::from_utf8(status_help.stderr).expect("stderr should be utf8");
assert!(!doctor_stderr.contains("auth_unavailable"));
assert!(!status_stderr.contains("auth_unavailable"));
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
}
fn command_in(cwd: &Path) -> Command {
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
command.current_dir(cwd);
command
}
fn write_session(root: &Path, label: &str) -> PathBuf {
let session_path = root.join(format!("{label}.jsonl"));
let mut session = Session::new();
session
.push_user_text(format!("session fixture for {label}"))
.expect("session write should succeed");
session
.save_to_path(&session_path)
.expect("session should persist");
session_path
}
fn assert_success(output: &Output) {
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
fn unique_temp_dir(label: &str) -> PathBuf {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock should be after epoch")
.as_millis();
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!(
"claw-{label}-{}-{millis}-{counter}",
std::process::id()
))
}

View File

@ -0,0 +1,159 @@
use std::fs;
use std::path::PathBuf;
use std::process::{Command, Output};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use mock_anthropic_service::{MockAnthropicService, SCENARIO_PREFIX};
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
#[test]
fn compact_flag_prints_only_final_assistant_text_without_tool_call_details() {
// given a workspace pointed at the mock Anthropic service and a fixture file
// that the read_file_roundtrip scenario will fetch through a tool call
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
let server = runtime
.block_on(MockAnthropicService::spawn())
.expect("mock service should start");
let base_url = server.base_url();
let workspace = unique_temp_dir("compact-read-file");
let config_home = workspace.join("config-home");
let home = workspace.join("home");
fs::create_dir_all(&workspace).expect("workspace should exist");
fs::create_dir_all(&config_home).expect("config home should exist");
fs::create_dir_all(&home).expect("home should exist");
fs::write(workspace.join("fixture.txt"), "alpha parity line\n").expect("fixture should write");
// when we run claw in compact text mode against a tool-using scenario
let prompt = format!("{SCENARIO_PREFIX}read_file_roundtrip");
let output = run_claw(
&workspace,
&config_home,
&home,
&base_url,
&[
"--model",
"sonnet",
"--permission-mode",
"read-only",
"--allowedTools",
"read_file",
"--compact",
&prompt,
],
);
// then the command exits successfully and stdout contains exactly the final
// assistant text with no tool call IDs, JSON envelopes, or spinner output
assert!(
output.status.success(),
"compact run should succeed\nstdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
let trimmed = stdout.trim_end_matches('\n');
assert_eq!(
trimmed, "read_file roundtrip complete: alpha parity line",
"compact stdout should contain only the final assistant text"
);
assert!(
!stdout.contains("toolu_"),
"compact stdout must not leak tool_use_id ({stdout:?})"
);
assert!(
!stdout.contains("\"tool_uses\""),
"compact stdout must not leak json envelopes ({stdout:?})"
);
assert!(
!stdout.contains("Thinking"),
"compact stdout must not include the spinner banner ({stdout:?})"
);
fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed");
}
#[test]
fn compact_flag_streaming_text_only_emits_final_message_text() {
// given a workspace pointed at the mock Anthropic service running the
// streaming_text scenario which only emits a single assistant text block
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
let server = runtime
.block_on(MockAnthropicService::spawn())
.expect("mock service should start");
let base_url = server.base_url();
let workspace = unique_temp_dir("compact-streaming-text");
let config_home = workspace.join("config-home");
let home = workspace.join("home");
fs::create_dir_all(&workspace).expect("workspace should exist");
fs::create_dir_all(&config_home).expect("config home should exist");
fs::create_dir_all(&home).expect("home should exist");
// when we invoke claw with --compact for the streaming text scenario
let prompt = format!("{SCENARIO_PREFIX}streaming_text");
let output = run_claw(
&workspace,
&config_home,
&home,
&base_url,
&[
"--model",
"sonnet",
"--permission-mode",
"read-only",
"--compact",
&prompt,
],
);
// then stdout should be exactly the assistant text followed by a newline
assert!(
output.status.success(),
"compact streaming run should succeed\nstdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert_eq!(
stdout, "Mock streaming says hello from the parity harness.\n",
"compact streaming stdout should contain only the final assistant text"
);
fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed");
}
fn run_claw(
cwd: &std::path::Path,
config_home: &std::path::Path,
home: &std::path::Path,
base_url: &str,
args: &[&str],
) -> Output {
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
command
.current_dir(cwd)
.env_clear()
.env("ANTHROPIC_API_KEY", "test-compact-key")
.env("ANTHROPIC_BASE_URL", base_url)
.env("CLAW_CONFIG_HOME", config_home)
.env("HOME", home)
.env("NO_COLOR", "1")
.env("PATH", "/usr/bin:/bin")
.args(args);
command.output().expect("claw should launch")
}
fn unique_temp_dir(label: &str) -> PathBuf {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock should be after epoch")
.as_millis();
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!(
"claw-compact-{label}-{}-{millis}-{counter}",
std::process::id()
))
}

View File

@ -0,0 +1,884 @@
#![cfg(unix)]
use std::collections::BTreeMap;
use std::fs;
use std::io::Write;
use std::os::unix::fs::PermissionsExt;
use std::path::{Path, PathBuf};
use std::process::{Command, Output, Stdio};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use mock_anthropic_service::{MockAnthropicService, SCENARIO_PREFIX};
use serde_json::{json, Value};
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
#[test]
#[allow(clippy::too_many_lines)]
fn clean_env_cli_reaches_mock_anthropic_service_across_scripted_parity_scenarios() {
let manifest_entries = load_scenario_manifest();
let manifest = manifest_entries
.iter()
.cloned()
.map(|entry| (entry.name.clone(), entry))
.collect::<BTreeMap<_, _>>();
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
let server = runtime
.block_on(MockAnthropicService::spawn())
.expect("mock service should start");
let base_url = server.base_url();
let cases = [
ScenarioCase {
name: "streaming_text",
permission_mode: "read-only",
allowed_tools: None,
stdin: None,
prepare: prepare_noop,
assert: assert_streaming_text,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "read_file_roundtrip",
permission_mode: "read-only",
allowed_tools: Some("read_file"),
stdin: None,
prepare: prepare_read_fixture,
assert: assert_read_file_roundtrip,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "grep_chunk_assembly",
permission_mode: "read-only",
allowed_tools: Some("grep_search"),
stdin: None,
prepare: prepare_grep_fixture,
assert: assert_grep_chunk_assembly,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "write_file_allowed",
permission_mode: "workspace-write",
allowed_tools: Some("write_file"),
stdin: None,
prepare: prepare_noop,
assert: assert_write_file_allowed,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "write_file_denied",
permission_mode: "read-only",
allowed_tools: Some("write_file"),
stdin: None,
prepare: prepare_noop,
assert: assert_write_file_denied,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "multi_tool_turn_roundtrip",
permission_mode: "read-only",
allowed_tools: Some("read_file,grep_search"),
stdin: None,
prepare: prepare_multi_tool_fixture,
assert: assert_multi_tool_turn_roundtrip,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "bash_stdout_roundtrip",
permission_mode: "danger-full-access",
allowed_tools: Some("bash"),
stdin: None,
prepare: prepare_noop,
assert: assert_bash_stdout_roundtrip,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "bash_permission_prompt_approved",
permission_mode: "workspace-write",
allowed_tools: Some("bash"),
stdin: Some("y\n"),
prepare: prepare_noop,
assert: assert_bash_permission_prompt_approved,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "bash_permission_prompt_denied",
permission_mode: "workspace-write",
allowed_tools: Some("bash"),
stdin: Some("n\n"),
prepare: prepare_noop,
assert: assert_bash_permission_prompt_denied,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "plugin_tool_roundtrip",
permission_mode: "workspace-write",
allowed_tools: None,
stdin: None,
prepare: prepare_plugin_fixture,
assert: assert_plugin_tool_roundtrip,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "auto_compact_triggered",
permission_mode: "read-only",
allowed_tools: None,
stdin: None,
prepare: prepare_noop,
assert: assert_auto_compact_triggered,
extra_env: None,
resume_session: None,
},
ScenarioCase {
name: "token_cost_reporting",
permission_mode: "read-only",
allowed_tools: None,
stdin: None,
prepare: prepare_noop,
assert: assert_token_cost_reporting,
extra_env: None,
resume_session: None,
},
];
let case_names = cases.iter().map(|case| case.name).collect::<Vec<_>>();
let manifest_names = manifest_entries
.iter()
.map(|entry| entry.name.as_str())
.collect::<Vec<_>>();
assert_eq!(
case_names, manifest_names,
"manifest and harness cases must stay aligned"
);
let mut scenario_reports = Vec::new();
for case in cases {
let workspace = HarnessWorkspace::new(unique_temp_dir(case.name));
workspace.create().expect("workspace should exist");
(case.prepare)(&workspace);
let run = run_case(case, &workspace, &base_url);
(case.assert)(&workspace, &run);
let manifest_entry = manifest
.get(case.name)
.unwrap_or_else(|| panic!("missing manifest entry for {}", case.name));
scenario_reports.push(build_scenario_report(
case.name,
manifest_entry,
&run.response,
));
fs::remove_dir_all(&workspace.root).expect("workspace cleanup should succeed");
}
let captured = runtime.block_on(server.captured_requests());
// After `be561bf` added count_tokens preflight, each turn sends an
// extra POST to `/v1/messages/count_tokens` before the messages POST.
// The original count (21) assumed messages-only requests. We now
// filter to `/v1/messages` and verify that subset matches the original
// scenario expectation.
let messages_only: Vec<_> = captured
.iter()
.filter(|r| r.path == "/v1/messages")
.collect();
assert_eq!(
messages_only.len(),
21,
"twelve scenarios should produce twenty-one /v1/messages requests (total captured: {}, includes count_tokens)",
captured.len()
);
assert!(messages_only.iter().all(|request| request.stream));
let scenarios = messages_only
.iter()
.map(|request| request.scenario.as_str())
.collect::<Vec<_>>();
assert_eq!(
scenarios,
vec![
"streaming_text",
"read_file_roundtrip",
"read_file_roundtrip",
"grep_chunk_assembly",
"grep_chunk_assembly",
"write_file_allowed",
"write_file_allowed",
"write_file_denied",
"write_file_denied",
"multi_tool_turn_roundtrip",
"multi_tool_turn_roundtrip",
"bash_stdout_roundtrip",
"bash_stdout_roundtrip",
"bash_permission_prompt_approved",
"bash_permission_prompt_approved",
"bash_permission_prompt_denied",
"bash_permission_prompt_denied",
"plugin_tool_roundtrip",
"plugin_tool_roundtrip",
"auto_compact_triggered",
"token_cost_reporting",
]
);
let mut request_counts = BTreeMap::new();
for request in &captured {
*request_counts
.entry(request.scenario.as_str())
.or_insert(0_usize) += 1;
}
for report in &mut scenario_reports {
report.request_count = *request_counts
.get(report.name.as_str())
.unwrap_or_else(|| panic!("missing request count for {}", report.name));
}
maybe_write_report(&scenario_reports);
}
#[derive(Clone, Copy)]
struct ScenarioCase {
name: &'static str,
permission_mode: &'static str,
allowed_tools: Option<&'static str>,
stdin: Option<&'static str>,
prepare: fn(&HarnessWorkspace),
assert: fn(&HarnessWorkspace, &ScenarioRun),
extra_env: Option<(&'static str, &'static str)>,
resume_session: Option<&'static str>,
}
struct HarnessWorkspace {
root: PathBuf,
config_home: PathBuf,
home: PathBuf,
}
impl HarnessWorkspace {
fn new(root: PathBuf) -> Self {
Self {
config_home: root.join("config-home"),
home: root.join("home"),
root,
}
}
fn create(&self) -> std::io::Result<()> {
fs::create_dir_all(&self.root)?;
fs::create_dir_all(&self.config_home)?;
fs::create_dir_all(&self.home)?;
Ok(())
}
}
struct ScenarioRun {
response: Value,
stdout: String,
}
#[derive(Debug, Clone)]
struct ScenarioManifestEntry {
name: String,
category: String,
description: String,
parity_refs: Vec<String>,
}
#[derive(Debug)]
struct ScenarioReport {
name: String,
category: String,
description: String,
parity_refs: Vec<String>,
iterations: u64,
request_count: usize,
tool_uses: Vec<String>,
tool_error_count: usize,
final_message: String,
}
fn run_case(case: ScenarioCase, workspace: &HarnessWorkspace, base_url: &str) -> ScenarioRun {
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
command
.current_dir(&workspace.root)
.env_clear()
.env("ANTHROPIC_API_KEY", "test-parity-key")
.env("ANTHROPIC_BASE_URL", base_url)
.env("CLAW_CONFIG_HOME", &workspace.config_home)
.env("HOME", &workspace.home)
.env("NO_COLOR", "1")
.env("PATH", "/usr/bin:/bin")
.args([
"--model",
"sonnet",
"--permission-mode",
case.permission_mode,
"--output-format=json",
]);
if let Some(allowed_tools) = case.allowed_tools {
command.args(["--allowedTools", allowed_tools]);
}
if let Some((key, value)) = case.extra_env {
command.env(key, value);
}
if let Some(session_id) = case.resume_session {
command.args(["--resume", session_id]);
}
let prompt = format!("{SCENARIO_PREFIX}{}", case.name);
command.arg(prompt);
let output = if let Some(stdin) = case.stdin {
let mut child = command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("claw should launch");
child
.stdin
.as_mut()
.expect("stdin should be piped")
.write_all(stdin.as_bytes())
.expect("stdin should write");
child.wait_with_output().expect("claw should finish")
} else {
command.output().expect("claw should launch")
};
assert_success(&output);
let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
ScenarioRun {
response: parse_json_output(&stdout),
stdout,
}
}
#[allow(dead_code)]
fn prepare_auto_compact_fixture(workspace: &HarnessWorkspace) {
let sessions_dir = workspace.root.join(".claw").join("sessions");
fs::create_dir_all(&sessions_dir).expect("sessions dir should exist");
// Write a pre-seeded session with 6 messages so auto-compact can remove them
let session_id = "parity-auto-compact-seed";
let session_jsonl = r#"{"type":"session_meta","version":3,"session_id":"parity-auto-compact-seed","created_at_ms":1743724800000,"updated_at_ms":1743724800000}
{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step one of the parity scenario"}]}}
{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step one"}]}}
{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step two of the parity scenario"}]}}
{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step two"}]}}
{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step three of the parity scenario"}]}}
{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step three"}]}}
"#;
fs::write(
sessions_dir.join(format!("{session_id}.jsonl")),
session_jsonl,
)
.expect("pre-seeded session should write");
}
fn prepare_noop(_: &HarnessWorkspace) {}
fn prepare_read_fixture(workspace: &HarnessWorkspace) {
fs::write(workspace.root.join("fixture.txt"), "alpha parity line\n")
.expect("fixture should write");
}
fn prepare_grep_fixture(workspace: &HarnessWorkspace) {
fs::write(
workspace.root.join("fixture.txt"),
"alpha parity line\nbeta line\ngamma parity line\n",
)
.expect("grep fixture should write");
}
fn prepare_multi_tool_fixture(workspace: &HarnessWorkspace) {
fs::write(
workspace.root.join("fixture.txt"),
"alpha parity line\nbeta line\ngamma parity line\n",
)
.expect("multi tool fixture should write");
}
fn prepare_plugin_fixture(workspace: &HarnessWorkspace) {
let plugin_root = workspace
.root
.join("external-plugins")
.join("parity-plugin");
let tool_dir = plugin_root.join("tools");
let manifest_dir = plugin_root.join(".claude-plugin");
fs::create_dir_all(&tool_dir).expect("plugin tools dir");
fs::create_dir_all(&manifest_dir).expect("plugin manifest dir");
let script_path = tool_dir.join("echo-json.sh");
fs::write(
&script_path,
"#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_PLUGIN_ID\" \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n",
)
.expect("plugin script should write");
let mut permissions = fs::metadata(&script_path)
.expect("plugin script metadata")
.permissions();
permissions.set_mode(0o755);
fs::set_permissions(&script_path, permissions).expect("plugin script should be executable");
fs::write(
manifest_dir.join("plugin.json"),
r#"{
"name": "parity-plugin",
"version": "1.0.0",
"description": "mock parity plugin",
"tools": [
{
"name": "plugin_echo",
"description": "Echo JSON input",
"inputSchema": {
"type": "object",
"properties": {
"message": { "type": "string" }
},
"required": ["message"],
"additionalProperties": false
},
"command": "./tools/echo-json.sh",
"requiredPermission": "workspace-write"
}
]
}"#,
)
.expect("plugin manifest should write");
fs::write(
workspace.config_home.join("settings.json"),
json!({
"enabledPlugins": {
"parity-plugin@external": true
},
"plugins": {
"externalDirectories": [plugin_root.parent().expect("plugin parent").display().to_string()]
}
})
.to_string(),
)
.expect("plugin settings should write");
}
fn assert_streaming_text(_: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(
run.response["message"],
Value::String("Mock streaming says hello from the parity harness.".to_string())
);
assert_eq!(run.response["iterations"], Value::from(1));
assert_eq!(run.response["tool_uses"], Value::Array(Vec::new()));
assert_eq!(run.response["tool_results"], Value::Array(Vec::new()));
}
fn assert_read_file_roundtrip(workspace: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(2));
assert_eq!(
run.response["tool_uses"][0]["name"],
Value::String("read_file".to_string())
);
assert_eq!(
run.response["tool_uses"][0]["input"],
Value::String(r#"{"path":"fixture.txt"}"#.to_string())
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("alpha parity line"));
let output = run.response["tool_results"][0]["output"]
.as_str()
.expect("tool output");
assert!(output.contains(&workspace.root.join("fixture.txt").display().to_string()));
assert!(output.contains("alpha parity line"));
}
fn assert_grep_chunk_assembly(_: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(2));
assert_eq!(
run.response["tool_uses"][0]["name"],
Value::String("grep_search".to_string())
);
assert_eq!(
run.response["tool_uses"][0]["input"],
Value::String(
r#"{"pattern":"parity","path":"fixture.txt","output_mode":"count"}"#.to_string()
)
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("2 occurrences"));
assert_eq!(
run.response["tool_results"][0]["is_error"],
Value::Bool(false)
);
}
fn assert_write_file_allowed(workspace: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(2));
assert_eq!(
run.response["tool_uses"][0]["name"],
Value::String("write_file".to_string())
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("generated/output.txt"));
let generated = workspace.root.join("generated").join("output.txt");
let contents = fs::read_to_string(&generated).expect("generated file should exist");
assert_eq!(contents, "created by mock service\n");
assert_eq!(
run.response["tool_results"][0]["is_error"],
Value::Bool(false)
);
}
fn assert_write_file_denied(workspace: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(2));
assert_eq!(
run.response["tool_uses"][0]["name"],
Value::String("write_file".to_string())
);
let tool_output = run.response["tool_results"][0]["output"]
.as_str()
.expect("tool output");
assert!(tool_output.contains("requires workspace-write permission"));
assert_eq!(
run.response["tool_results"][0]["is_error"],
Value::Bool(true)
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("denied as expected"));
assert!(!workspace.root.join("generated").join("denied.txt").exists());
}
fn assert_multi_tool_turn_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(2));
let tool_uses = run.response["tool_uses"]
.as_array()
.expect("tool uses array");
assert_eq!(
tool_uses.len(),
2,
"expected two tool uses in a single turn"
);
assert_eq!(tool_uses[0]["name"], Value::String("read_file".to_string()));
assert_eq!(
tool_uses[1]["name"],
Value::String("grep_search".to_string())
);
let tool_results = run.response["tool_results"]
.as_array()
.expect("tool results array");
assert_eq!(
tool_results.len(),
2,
"expected two tool results in a single turn"
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("alpha parity line"));
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("2 occurrences"));
}
fn assert_bash_stdout_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(2));
assert_eq!(
run.response["tool_uses"][0]["name"],
Value::String("bash".to_string())
);
let tool_output = run.response["tool_results"][0]["output"]
.as_str()
.expect("tool output");
let parsed: Value = serde_json::from_str(tool_output).expect("bash output json");
assert_eq!(
parsed["stdout"],
Value::String("alpha from bash".to_string())
);
assert_eq!(
run.response["tool_results"][0]["is_error"],
Value::Bool(false)
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("alpha from bash"));
}
fn assert_bash_permission_prompt_approved(_: &HarnessWorkspace, run: &ScenarioRun) {
assert!(run.stdout.contains("Permission approval required"));
assert!(run.stdout.contains("Approve this tool call? [y/N]:"));
assert_eq!(run.response["iterations"], Value::from(2));
assert_eq!(
run.response["tool_results"][0]["is_error"],
Value::Bool(false)
);
let tool_output = run.response["tool_results"][0]["output"]
.as_str()
.expect("tool output");
let parsed: Value = serde_json::from_str(tool_output).expect("bash output json");
assert_eq!(
parsed["stdout"],
Value::String("approved via prompt".to_string())
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("approved and executed"));
}
fn assert_bash_permission_prompt_denied(_: &HarnessWorkspace, run: &ScenarioRun) {
assert!(run.stdout.contains("Permission approval required"));
assert!(run.stdout.contains("Approve this tool call? [y/N]:"));
assert_eq!(run.response["iterations"], Value::from(2));
let tool_output = run.response["tool_results"][0]["output"]
.as_str()
.expect("tool output");
assert!(tool_output.contains("denied by user approval prompt"));
assert_eq!(
run.response["tool_results"][0]["is_error"],
Value::Bool(true)
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("denied as expected"));
}
fn assert_plugin_tool_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(2));
assert_eq!(
run.response["tool_uses"][0]["name"],
Value::String("plugin_echo".to_string())
);
let tool_output = run.response["tool_results"][0]["output"]
.as_str()
.expect("tool output");
let parsed: Value = serde_json::from_str(tool_output).expect("plugin output json");
assert_eq!(
parsed["plugin"],
Value::String("parity-plugin@external".to_string())
);
assert_eq!(parsed["tool"], Value::String("plugin_echo".to_string()));
assert_eq!(
parsed["input"]["message"],
Value::String("hello from plugin parity".to_string())
);
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("hello from plugin parity"));
}
fn assert_auto_compact_triggered(_: &HarnessWorkspace, run: &ScenarioRun) {
// Validates that the auto_compaction field is present in JSON output (format parity).
// Trigger behavior is covered by conversation::tests::auto_compacts_when_cumulative_input_threshold_is_crossed.
assert_eq!(run.response["iterations"], Value::from(1));
assert_eq!(run.response["tool_uses"], Value::Array(Vec::new()));
assert!(
run.response["message"]
.as_str()
.expect("message text")
.contains("auto compact parity complete."),
"expected auto compact message in response"
);
// auto_compaction key must be present in JSON (may be null for below-threshold sessions)
assert!(
run.response
.as_object()
.expect("response object")
.contains_key("auto_compaction"),
"auto_compaction key must be present in JSON output"
);
// Verify input_tokens field reflects the large mock token counts
let input_tokens = run.response["usage"]["input_tokens"]
.as_u64()
.expect("input_tokens should be present");
assert!(
input_tokens >= 50_000,
"input_tokens should reflect mock service value (got {input_tokens})"
);
}
fn assert_token_cost_reporting(_: &HarnessWorkspace, run: &ScenarioRun) {
assert_eq!(run.response["iterations"], Value::from(1));
assert!(run.response["message"]
.as_str()
.expect("message text")
.contains("token cost reporting parity complete."),);
let usage = &run.response["usage"];
assert!(
usage["input_tokens"].as_u64().unwrap_or(0) > 0,
"input_tokens should be non-zero"
);
assert!(
usage["output_tokens"].as_u64().unwrap_or(0) > 0,
"output_tokens should be non-zero"
);
assert!(
run.response["estimated_cost"]
.as_str()
.is_some_and(|cost| cost.starts_with('$')),
"estimated_cost should be a dollar-prefixed string"
);
}
fn parse_json_output(stdout: &str) -> Value {
if let Some(index) = stdout.rfind("{\"auto_compaction\"") {
return serde_json::from_str(&stdout[index..]).unwrap_or_else(|error| {
panic!("failed to parse JSON response from stdout: {error}\n{stdout}")
});
}
stdout
.lines()
.rev()
.find_map(|line| {
let trimmed = line.trim();
if trimmed.starts_with('{') && trimmed.ends_with('}') {
serde_json::from_str(trimmed).ok()
} else {
None
}
})
.unwrap_or_else(|| panic!("no JSON response line found in stdout:\n{stdout}"))
}
fn build_scenario_report(
name: &str,
manifest_entry: &ScenarioManifestEntry,
response: &Value,
) -> ScenarioReport {
ScenarioReport {
name: name.to_string(),
category: manifest_entry.category.clone(),
description: manifest_entry.description.clone(),
parity_refs: manifest_entry.parity_refs.clone(),
iterations: response["iterations"]
.as_u64()
.expect("iterations should exist"),
request_count: 0,
tool_uses: response["tool_uses"]
.as_array()
.expect("tool uses array")
.iter()
.filter_map(|value| value["name"].as_str().map(ToOwned::to_owned))
.collect(),
tool_error_count: response["tool_results"]
.as_array()
.expect("tool results array")
.iter()
.filter(|value| value["is_error"].as_bool().unwrap_or(false))
.count(),
final_message: response["message"]
.as_str()
.expect("message text")
.to_string(),
}
}
fn maybe_write_report(reports: &[ScenarioReport]) {
let Some(path) = std::env::var_os("MOCK_PARITY_REPORT_PATH") else {
return;
};
let payload = json!({
"scenario_count": reports.len(),
"request_count": reports.iter().map(|report| report.request_count).sum::<usize>(),
"scenarios": reports.iter().map(scenario_report_json).collect::<Vec<_>>(),
});
fs::write(
path,
serde_json::to_vec_pretty(&payload).expect("report json should serialize"),
)
.expect("report should write");
}
fn load_scenario_manifest() -> Vec<ScenarioManifestEntry> {
let manifest_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../mock_parity_scenarios.json");
let manifest = fs::read_to_string(&manifest_path).expect("scenario manifest should exist");
serde_json::from_str::<Vec<Value>>(&manifest)
.expect("scenario manifest should parse")
.into_iter()
.map(|entry| ScenarioManifestEntry {
name: entry["name"]
.as_str()
.expect("scenario name should be a string")
.to_string(),
category: entry["category"]
.as_str()
.expect("scenario category should be a string")
.to_string(),
description: entry["description"]
.as_str()
.expect("scenario description should be a string")
.to_string(),
parity_refs: entry["parity_refs"]
.as_array()
.expect("parity refs should be an array")
.iter()
.map(|value| {
value
.as_str()
.expect("parity ref should be a string")
.to_string()
})
.collect(),
})
.collect()
}
fn scenario_report_json(report: &ScenarioReport) -> Value {
json!({
"name": report.name,
"category": report.category,
"description": report.description,
"parity_refs": report.parity_refs,
"iterations": report.iterations,
"request_count": report.request_count,
"tool_uses": report.tool_uses,
"tool_error_count": report.tool_error_count,
"final_message": report.final_message,
})
}
fn assert_success(output: &Output) {
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
fn unique_temp_dir(label: &str) -> PathBuf {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock should be after epoch")
.as_millis();
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!(
"claw-mock-parity-{label}-{}-{millis}-{counter}",
std::process::id()
))
}

View File

@ -0,0 +1,429 @@
use std::fs;
use std::path::{Path, PathBuf};
use std::process::{Command, Output};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use serde_json::Value;
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
#[test]
fn help_emits_json_when_requested() {
let root = unique_temp_dir("help-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let parsed = assert_json_command(&root, &["--output-format", "json", "help"]);
assert_eq!(parsed["kind"], "help");
assert!(parsed["message"]
.as_str()
.expect("help text")
.contains("Usage:"));
}
#[test]
fn version_emits_json_when_requested() {
let root = unique_temp_dir("version-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let parsed = assert_json_command(&root, &["--output-format", "json", "version"]);
assert_eq!(parsed["kind"], "version");
assert_eq!(parsed["version"], env!("CARGO_PKG_VERSION"));
}
#[test]
fn status_and_sandbox_emit_json_when_requested() {
let root = unique_temp_dir("status-sandbox-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let status = assert_json_command(&root, &["--output-format", "json", "status"]);
assert_eq!(status["kind"], "status");
assert!(status["workspace"]["cwd"].as_str().is_some());
let sandbox = assert_json_command(&root, &["--output-format", "json", "sandbox"]);
assert_eq!(sandbox["kind"], "sandbox");
assert!(sandbox["filesystem_mode"].as_str().is_some());
}
#[test]
fn inventory_commands_emit_structured_json_when_requested() {
let root = unique_temp_dir("inventory-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let isolated_home = root.join("home");
let isolated_config = root.join("config-home");
let isolated_codex = root.join("codex-home");
fs::create_dir_all(&isolated_home).expect("isolated home should exist");
let agents = assert_json_command_with_env(
&root,
&["--output-format", "json", "agents"],
&[
("HOME", isolated_home.to_str().expect("utf8 home")),
(
"CLAW_CONFIG_HOME",
isolated_config.to_str().expect("utf8 config home"),
),
(
"CODEX_HOME",
isolated_codex.to_str().expect("utf8 codex home"),
),
],
);
assert_eq!(agents["kind"], "agents");
assert_eq!(agents["action"], "list");
assert_eq!(agents["count"], 0);
assert_eq!(agents["summary"]["active"], 0);
assert!(agents["agents"]
.as_array()
.expect("agents array")
.is_empty());
let mcp = assert_json_command(&root, &["--output-format", "json", "mcp"]);
assert_eq!(mcp["kind"], "mcp");
assert_eq!(mcp["action"], "list");
let skills = assert_json_command(&root, &["--output-format", "json", "skills"]);
assert_eq!(skills["kind"], "skills");
assert_eq!(skills["action"], "list");
}
#[test]
fn agents_command_emits_structured_agent_entries_when_requested() {
let root = unique_temp_dir("agents-json-populated");
let workspace = root.join("workspace");
let project_agents = workspace.join(".codex").join("agents");
let home = root.join("home");
let user_agents = home.join(".codex").join("agents");
let isolated_config = root.join("config-home");
let isolated_codex = root.join("codex-home");
fs::create_dir_all(&workspace).expect("workspace should exist");
write_agent(
&project_agents,
"planner",
"Project planner",
"gpt-5.4",
"medium",
);
write_agent(
&project_agents,
"verifier",
"Verification agent",
"gpt-5.4-mini",
"high",
);
write_agent(
&user_agents,
"planner",
"User planner",
"gpt-5.4-mini",
"high",
);
let parsed = assert_json_command_with_env(
&workspace,
&["--output-format", "json", "agents"],
&[
("HOME", home.to_str().expect("utf8 home")),
(
"CLAW_CONFIG_HOME",
isolated_config.to_str().expect("utf8 config home"),
),
(
"CODEX_HOME",
isolated_codex.to_str().expect("utf8 codex home"),
),
],
);
assert_eq!(parsed["kind"], "agents");
assert_eq!(parsed["action"], "list");
assert_eq!(parsed["count"], 3);
assert_eq!(parsed["summary"]["active"], 2);
assert_eq!(parsed["summary"]["shadowed"], 1);
assert_eq!(parsed["agents"][0]["name"], "planner");
assert_eq!(parsed["agents"][0]["source"]["id"], "project_claw");
assert_eq!(parsed["agents"][0]["active"], true);
assert_eq!(parsed["agents"][1]["name"], "verifier");
assert_eq!(parsed["agents"][2]["name"], "planner");
assert_eq!(parsed["agents"][2]["active"], false);
assert_eq!(parsed["agents"][2]["shadowed_by"]["id"], "project_claw");
}
#[test]
fn bootstrap_and_system_prompt_emit_json_when_requested() {
let root = unique_temp_dir("bootstrap-system-prompt-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let plan = assert_json_command(&root, &["--output-format", "json", "bootstrap-plan"]);
assert_eq!(plan["kind"], "bootstrap-plan");
assert!(plan["phases"].as_array().expect("phases").len() > 1);
let prompt = assert_json_command(&root, &["--output-format", "json", "system-prompt"]);
assert_eq!(prompt["kind"], "system-prompt");
assert!(prompt["message"]
.as_str()
.expect("prompt text")
.contains("interactive agent"));
}
#[test]
fn dump_manifests_and_init_emit_json_when_requested() {
let root = unique_temp_dir("manifest-init-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let upstream = write_upstream_fixture(&root);
let manifests = assert_json_command_with_env(
&root,
&["--output-format", "json", "dump-manifests"],
&[(
"CLAUDE_CODE_UPSTREAM",
upstream.to_str().expect("utf8 upstream"),
)],
);
assert_eq!(manifests["kind"], "dump-manifests");
assert_eq!(manifests["commands"], 1);
assert_eq!(manifests["tools"], 1);
let workspace = root.join("workspace");
fs::create_dir_all(&workspace).expect("workspace should exist");
let init = assert_json_command(&workspace, &["--output-format", "json", "init"]);
assert_eq!(init["kind"], "init");
assert!(workspace.join("CLAUDE.md").exists());
}
#[test]
fn doctor_and_resume_status_emit_json_when_requested() {
let root = unique_temp_dir("doctor-resume-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let doctor = assert_json_command(&root, &["--output-format", "json", "doctor"]);
assert_eq!(doctor["kind"], "doctor");
assert!(doctor["message"].is_string());
let summary = doctor["summary"].as_object().expect("doctor summary");
assert!(summary["ok"].as_u64().is_some());
assert!(summary["warnings"].as_u64().is_some());
assert!(summary["failures"].as_u64().is_some());
let checks = doctor["checks"].as_array().expect("doctor checks");
assert_eq!(checks.len(), 5);
let check_names = checks
.iter()
.map(|check| {
assert!(check["status"].as_str().is_some());
assert!(check["summary"].as_str().is_some());
assert!(check["details"].is_array());
check["name"].as_str().expect("doctor check name")
})
.collect::<Vec<_>>();
assert_eq!(
check_names,
vec!["auth", "config", "workspace", "sandbox", "system"]
);
let workspace = checks
.iter()
.find(|check| check["name"] == "workspace")
.expect("workspace check");
assert!(workspace["cwd"].as_str().is_some());
assert!(workspace["in_git_repo"].is_boolean());
let sandbox = checks
.iter()
.find(|check| check["name"] == "sandbox")
.expect("sandbox check");
assert!(sandbox["filesystem_mode"].as_str().is_some());
assert!(sandbox["enabled"].is_boolean());
assert!(sandbox["fallback_reason"].is_null() || sandbox["fallback_reason"].is_string());
let session_path = root.join("session.jsonl");
fs::write(
&session_path,
"{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n{\"type\":\"message\",\"message\":{\"role\":\"user\",\"blocks\":[{\"type\":\"text\",\"text\":\"hello\"}]}}\n",
)
.expect("session should write");
let resumed = assert_json_command(
&root,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 session path"),
"/status",
],
);
assert_eq!(resumed["kind"], "status");
// model is null in resume mode (not known without --model flag)
assert!(resumed["model"].is_null());
assert_eq!(resumed["usage"]["messages"], 1);
assert!(resumed["workspace"]["cwd"].as_str().is_some());
assert!(resumed["sandbox"]["filesystem_mode"].as_str().is_some());
}
#[test]
fn resumed_inventory_commands_emit_structured_json_when_requested() {
let root = unique_temp_dir("resume-inventory-json");
let config_home = root.join("config-home");
let home = root.join("home");
fs::create_dir_all(&config_home).expect("config home should exist");
fs::create_dir_all(&home).expect("home should exist");
let session_path = root.join("session.jsonl");
fs::write(
&session_path,
"{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-inventory-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n{\"type\":\"message\",\"message\":{\"role\":\"user\",\"blocks\":[{\"type\":\"text\",\"text\":\"inventory\"}]}}\n",
)
.expect("session should write");
let mcp = assert_json_command_with_env(
&root,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 session path"),
"/mcp",
],
&[
(
"CLAW_CONFIG_HOME",
config_home.to_str().expect("utf8 config home"),
),
("HOME", home.to_str().expect("utf8 home")),
],
);
assert_eq!(mcp["kind"], "mcp");
assert_eq!(mcp["action"], "list");
assert!(mcp["servers"].is_array());
let skills = assert_json_command_with_env(
&root,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 session path"),
"/skills",
],
&[
(
"CLAW_CONFIG_HOME",
config_home.to_str().expect("utf8 config home"),
),
("HOME", home.to_str().expect("utf8 home")),
],
);
assert_eq!(skills["kind"], "skills");
assert_eq!(skills["action"], "list");
assert!(skills["summary"]["total"].is_number());
assert!(skills["skills"].is_array());
}
#[test]
fn resumed_version_and_init_emit_structured_json_when_requested() {
let root = unique_temp_dir("resume-version-init-json");
fs::create_dir_all(&root).expect("temp dir should exist");
let session_path = root.join("session.jsonl");
fs::write(
&session_path,
"{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-version-init-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n",
)
.expect("session should write");
let version = assert_json_command(
&root,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 session path"),
"/version",
],
);
assert_eq!(version["kind"], "version");
assert_eq!(version["version"], env!("CARGO_PKG_VERSION"));
let init = assert_json_command(
&root,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 session path"),
"/init",
],
);
assert_eq!(init["kind"], "init");
assert!(root.join("CLAUDE.md").exists());
}
fn assert_json_command(current_dir: &Path, args: &[&str]) -> Value {
assert_json_command_with_env(current_dir, args, &[])
}
fn assert_json_command_with_env(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Value {
let output = run_claw(current_dir, args, envs);
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
serde_json::from_slice(&output.stdout).expect("stdout should be valid json")
}
fn run_claw(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output {
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
command.current_dir(current_dir).args(args);
for (key, value) in envs {
command.env(key, value);
}
command.output().expect("claw should launch")
}
fn write_upstream_fixture(root: &Path) -> PathBuf {
let upstream = root.join("claw-code");
let src = upstream.join("src");
let entrypoints = src.join("entrypoints");
fs::create_dir_all(&entrypoints).expect("upstream entrypoints dir should exist");
fs::write(
src.join("commands.ts"),
"import FooCommand from './commands/foo'\n",
)
.expect("commands fixture should write");
fs::write(
src.join("tools.ts"),
"import ReadTool from './tools/read'\n",
)
.expect("tools fixture should write");
fs::write(
entrypoints.join("cli.tsx"),
"if (args[0] === '--version') {}\nstartupProfiler()\n",
)
.expect("cli fixture should write");
upstream
}
fn write_agent(root: &Path, name: &str, description: &str, model: &str, reasoning: &str) {
fs::create_dir_all(root).expect("agent root should exist");
fs::write(
root.join(format!("{name}.toml")),
format!(
"name = \"{name}\"\ndescription = \"{description}\"\nmodel = \"{model}\"\nmodel_reasoning_effort = \"{reasoning}\"\n"
),
)
.expect("agent fixture should write");
}
fn unique_temp_dir(label: &str) -> PathBuf {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock should be after epoch")
.as_millis();
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!(
"claw-output-format-{label}-{}-{millis}-{counter}",
std::process::id()
))
}

View File

@ -0,0 +1,555 @@
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use std::process::{Command, Output};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use runtime::ContentBlock;
use runtime::Session;
use serde_json::Value;
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
#[test]
fn resumed_binary_accepts_slash_commands_with_arguments() {
// given
let temp_dir = unique_temp_dir("resume-slash-commands");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
let export_path = temp_dir.join("notes.txt");
let mut session = Session::new();
session
.push_user_text("ship the slash command harness")
.expect("session write should succeed");
session
.save_to_path(&session_path)
.expect("session should persist");
// when
let output = run_claw(
&temp_dir,
&[
"--resume",
session_path.to_str().expect("utf8 path"),
"/export",
export_path.to_str().expect("utf8 path"),
"/clear",
"--confirm",
],
);
// then
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Export"));
assert!(stdout.contains("wrote transcript"));
assert!(stdout.contains(export_path.to_str().expect("utf8 path")));
assert!(stdout.contains("Session cleared"));
assert!(stdout.contains("Mode resumed session reset"));
assert!(stdout.contains("Previous session"));
assert!(stdout.contains("Resume previous claw --resume"));
assert!(stdout.contains("Backup "));
assert!(stdout.contains("Session file "));
let export = fs::read_to_string(&export_path).expect("export file should exist");
assert!(export.contains("# Conversation Export"));
assert!(export.contains("ship the slash command harness"));
let restored = Session::load_from_path(&session_path).expect("cleared session should load");
assert!(restored.messages.is_empty());
let backup_path = stdout
.lines()
.find_map(|line| line.strip_prefix(" Backup "))
.map(PathBuf::from)
.expect("clear output should include backup path");
let backup = Session::load_from_path(&backup_path).expect("backup session should load");
assert_eq!(backup.messages.len(), 1);
assert!(matches!(
backup.messages[0].blocks.first(),
Some(ContentBlock::Text { text }) if text == "ship the slash command harness"
));
}
#[test]
fn status_command_applies_cli_flags_end_to_end() {
// given
let temp_dir = unique_temp_dir("status-command-flags");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
// when
let output = run_claw(
&temp_dir,
&[
"--model",
"sonnet",
"--permission-mode",
"read-only",
"status",
],
);
// then
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Status"));
assert!(stdout.contains("Model claude-sonnet-4-6"));
assert!(stdout.contains("Permission mode read-only"));
}
#[test]
fn resumed_config_command_loads_settings_files_end_to_end() {
// given
let temp_dir = unique_temp_dir("resume-config");
let project_dir = temp_dir.join("project");
let config_home = temp_dir.join("home").join(".claw");
fs::create_dir_all(project_dir.join(".claw")).expect("project config dir should exist");
fs::create_dir_all(&config_home).expect("config home should exist");
let session_path = project_dir.join("session.jsonl");
Session::new()
.with_persistence_path(&session_path)
.save_to_path(&session_path)
.expect("session should persist");
fs::write(config_home.join("settings.json"), r#"{"model":"haiku"}"#)
.expect("user config should write");
fs::write(
project_dir.join(".claw").join("settings.local.json"),
r#"{"model":"opus"}"#,
)
.expect("local config should write");
// when
let output = run_claw_with_env(
&project_dir,
&[
"--resume",
session_path.to_str().expect("utf8 path"),
"/config",
"model",
],
&[("CLAW_CONFIG_HOME", config_home.to_str().expect("utf8 path"))],
);
// then
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Config"));
assert!(stdout.contains("Loaded files 2"));
assert!(stdout.contains(
config_home
.join("settings.json")
.to_str()
.expect("utf8 path")
));
assert!(stdout.contains(
project_dir
.join(".claw")
.join("settings.local.json")
.to_str()
.expect("utf8 path")
));
assert!(stdout.contains("Merged section: model"));
assert!(stdout.contains("opus"));
}
#[test]
fn resume_latest_restores_the_most_recent_managed_session() {
// given
let temp_dir = unique_temp_dir("resume-latest");
let project_dir = temp_dir.join("project");
let sessions_dir = project_dir.join(".claw").join("sessions");
fs::create_dir_all(&sessions_dir).expect("sessions dir should exist");
let older_path = sessions_dir.join("session-older.jsonl");
let newer_path = sessions_dir.join("session-newer.jsonl");
let mut older = Session::new().with_persistence_path(&older_path);
older
.push_user_text("older session")
.expect("older session write should succeed");
older
.save_to_path(&older_path)
.expect("older session should persist");
let mut newer = Session::new().with_persistence_path(&newer_path);
newer
.push_user_text("newer session")
.expect("newer session write should succeed");
newer
.push_user_text("resume me")
.expect("newer session write should succeed");
newer
.save_to_path(&newer_path)
.expect("newer session should persist");
// when
let output = run_claw(&project_dir, &["--resume", "latest", "/status"]);
// then
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
assert!(stdout.contains("Status"));
assert!(stdout.contains("Messages 2"));
assert!(stdout.contains(newer_path.to_str().expect("utf8 path")));
}
#[test]
fn resumed_status_command_emits_structured_json_when_requested() {
// given
let temp_dir = unique_temp_dir("resume-status-json");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
let mut session = Session::new();
session
.push_user_text("resume status json fixture")
.expect("session write should succeed");
session
.save_to_path(&session_path)
.expect("session should persist");
// when
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
"/status",
],
);
// then
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
let parsed: Value =
serde_json::from_str(stdout.trim()).expect("resume status output should be json");
assert_eq!(parsed["kind"], "status");
// model is null in resume mode (not known without --model flag)
assert!(parsed["model"].is_null());
assert_eq!(parsed["permission_mode"], "danger-full-access");
assert_eq!(parsed["usage"]["messages"], 1);
assert!(parsed["usage"]["turns"].is_number());
assert!(parsed["workspace"]["cwd"].as_str().is_some());
assert_eq!(
parsed["workspace"]["session"],
session_path.to_str().expect("utf8 path")
);
assert!(parsed["workspace"]["changed_files"].is_number());
assert_eq!(parsed["workspace"]["loaded_config_files"].as_u64(), Some(0));
assert!(parsed["sandbox"]["filesystem_mode"].as_str().is_some());
}
#[test]
fn resumed_status_surfaces_persisted_model() {
// given — create a session with model already set
let temp_dir = unique_temp_dir("resume-status-model");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
let mut session = Session::new();
session.model = Some("claude-sonnet-4-6".to_string());
session
.push_user_text("model persistence fixture")
.expect("write ok");
session.save_to_path(&session_path).expect("persist ok");
// when
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
"/status",
],
);
// then
assert!(
output.status.success(),
"stderr:\n{}",
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("utf8");
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
assert_eq!(parsed["kind"], "status");
assert_eq!(
parsed["model"], "claude-sonnet-4-6",
"model should round-trip through session metadata"
);
}
#[test]
fn resumed_sandbox_command_emits_structured_json_when_requested() {
// given
let temp_dir = unique_temp_dir("resume-sandbox-json");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
Session::new()
.save_to_path(&session_path)
.expect("session should persist");
// when
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
"/sandbox",
],
);
// then
assert!(
output.status.success(),
"stdout:\n{}\n\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
let parsed: Value =
serde_json::from_str(stdout.trim()).expect("resume sandbox output should be json");
assert_eq!(parsed["kind"], "sandbox");
assert!(parsed["enabled"].is_boolean());
assert!(parsed["active"].is_boolean());
assert!(parsed["supported"].is_boolean());
assert!(parsed["filesystem_mode"].as_str().is_some());
assert!(parsed["allowed_mounts"].is_array());
assert!(parsed["markers"].is_array());
}
#[test]
fn resumed_version_command_emits_structured_json() {
let temp_dir = unique_temp_dir("resume-version-json");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
Session::new()
.save_to_path(&session_path)
.expect("session should persist");
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
"/version",
],
);
assert!(
output.status.success(),
"stderr:\n{}",
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("utf8");
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
assert_eq!(parsed["kind"], "version");
assert!(parsed["version"].as_str().is_some());
assert!(parsed["git_sha"].as_str().is_some());
assert!(parsed["target"].as_str().is_some());
}
#[test]
fn resumed_export_command_emits_structured_json() {
let temp_dir = unique_temp_dir("resume-export-json");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
let mut session = Session::new();
session
.push_user_text("export json fixture")
.expect("write ok");
session.save_to_path(&session_path).expect("persist ok");
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
"/export",
],
);
assert!(
output.status.success(),
"stderr:\n{}",
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("utf8");
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
assert_eq!(parsed["kind"], "export");
assert!(parsed["file"].as_str().is_some());
assert_eq!(parsed["message_count"], 1);
}
#[test]
fn resumed_help_command_emits_structured_json() {
let temp_dir = unique_temp_dir("resume-help-json");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
Session::new()
.save_to_path(&session_path)
.expect("persist ok");
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
"/help",
],
);
assert!(
output.status.success(),
"stderr:\n{}",
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("utf8");
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
assert_eq!(parsed["kind"], "help");
assert!(parsed["text"].as_str().is_some());
let text = parsed["text"].as_str().unwrap();
assert!(text.contains("/status"), "help text should list /status");
}
#[test]
fn resumed_no_command_emits_restored_json() {
let temp_dir = unique_temp_dir("resume-no-cmd-json");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
let mut session = Session::new();
session
.push_user_text("restored json fixture")
.expect("write ok");
session.save_to_path(&session_path).expect("persist ok");
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
],
);
assert!(
output.status.success(),
"stderr:\n{}",
String::from_utf8_lossy(&output.stderr)
);
let stdout = String::from_utf8(output.stdout).expect("utf8");
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
assert_eq!(parsed["kind"], "restored");
assert!(parsed["session_id"].as_str().is_some());
assert!(parsed["path"].as_str().is_some());
assert_eq!(parsed["message_count"], 1);
}
#[test]
fn resumed_stub_command_emits_not_implemented_json() {
let temp_dir = unique_temp_dir("resume-stub-json");
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
let session_path = temp_dir.join("session.jsonl");
Session::new()
.save_to_path(&session_path)
.expect("persist ok");
let output = run_claw(
&temp_dir,
&[
"--output-format",
"json",
"--resume",
session_path.to_str().expect("utf8 path"),
"/allowed-tools",
],
);
// Stub commands exit with code 2
assert!(!output.status.success());
let stderr = String::from_utf8(output.stderr).expect("utf8");
let parsed: Value = serde_json::from_str(stderr.trim()).expect("should be json");
assert_eq!(parsed["type"], "error");
assert!(
parsed["error"]
.as_str()
.unwrap()
.contains("not yet implemented"),
"error should say not yet implemented: {:?}",
parsed["error"]
);
}
fn run_claw(current_dir: &Path, args: &[&str]) -> Output {
run_claw_with_env(current_dir, args, &[])
}
fn run_claw_with_env(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output {
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
command.current_dir(current_dir).args(args);
for (key, value) in envs {
command.env(key, value);
}
command.output().expect("claw should launch")
}
fn unique_temp_dir(label: &str) -> PathBuf {
let millis = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("clock should be after epoch")
.as_millis();
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!(
"claw-{label}-{}-{millis}-{counter}",
std::process::id()
))
}

View File

@ -20,7 +20,7 @@ use runtime::{
ToolError, ToolExecutor,
};
use api::{
max_tokens_for_model, resolve_startup_auth_source, AuthSource, ClawApiClient,
max_tokens_for_model, resolve_startup_auth_source, AnthropicClient, AuthSource,
ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest,
MessageResponse, OutputContentBlock,
StreamEvent as ApiStreamEvent, ToolChoice, ToolResultContentBlock,
@ -104,7 +104,7 @@ impl SessionEvent {
// ── ServerApiClient实现 runtime::ApiClient trait ────────────────────
struct ServerApiClient {
client: ClawApiClient,
client: AnthropicClient,
model: String,
tool_registry: GlobalToolRegistry,
allowed_tools: Option<BTreeSet<String>>,
@ -127,6 +127,12 @@ impl ApiClient for ServerApiClient {
.then(|| self.tool_registry.definitions(self.allowed_tools.as_ref())),
tool_choice: self.enable_tools.then_some(ToolChoice::Auto),
stream: true,
temperature: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
stop: None,
reasoning_effort: None,
};
let rt = tokio::runtime::Runtime::new()
@ -192,7 +198,6 @@ impl ApiClient for ServerApiClient {
session_id: self.session_id.clone(),
thinking: thinking.clone(),
});
events.push(AssistantEvent::ThinkingDelta(thinking));
}
}
ContentBlockDelta::SignatureDelta { .. } => {}
@ -303,7 +308,6 @@ fn push_output_block(
session_id: session_id.to_string(),
thinking: thinking.clone(),
});
events.push(AssistantEvent::ThinkingDelta(thinking));
}
}
OutputContentBlock::RedactedThinking { .. } => {}
@ -412,17 +416,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
ContentBlock::Text { text } => {
InputContentBlock::Text { text: text.clone() }
}
ContentBlock::Thinking {
thinking,
signature,
} => InputContentBlock::Thinking {
thinking: thinking.clone(),
signature: signature.clone(),
},
ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking {
data: serde_json::from_str(&data.render())
.unwrap_or(serde_json::Value::Null),
},
ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse {
id: id.clone(),
name: name.clone(),
@ -454,6 +447,7 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy {
tool_registry
.permission_specs(None)
.unwrap_or_default()
.into_iter()
.fold(PermissionPolicy::new(mode), |policy, (name, req)| {
policy.with_tool_requirement(name, req)
@ -632,7 +626,7 @@ impl Session {
let (events_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
let auth = resolve_server_auth_source(cwd)?;
let client = ClawApiClient::from_auth(auth).with_base_url(api::read_base_url());
let client = AnthropicClient::from_auth(auth).with_base_url(api::read_base_url());
let api_client = ServerApiClient {
client,
@ -658,7 +652,7 @@ impl Session {
tool_executor,
policy,
system_prompt,
feature_config.clone(),
&feature_config,
);
Ok(Self {

View File

@ -0,0 +1,13 @@
[package]
name = "telemetry"
version.workspace = true
edition.workspace = true
license.workspace = true
publish.workspace = true
[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"
[lints]
workspace = true

526
crates/telemetry/src/lib.rs Normal file
View File

@ -0,0 +1,526 @@
use std::fmt::{Debug, Formatter};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
pub const DEFAULT_ANTHROPIC_VERSION: &str = "2023-06-01";
pub const DEFAULT_APP_NAME: &str = "claude-code";
pub const DEFAULT_RUNTIME: &str = "rust";
pub const DEFAULT_AGENTIC_BETA: &str = "claude-code-20250219";
pub const DEFAULT_PROMPT_CACHING_SCOPE_BETA: &str = "prompt-caching-scope-2026-01-05";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ClientIdentity {
pub app_name: String,
pub app_version: String,
pub runtime: String,
}
impl ClientIdentity {
#[must_use]
pub fn new(app_name: impl Into<String>, app_version: impl Into<String>) -> Self {
Self {
app_name: app_name.into(),
app_version: app_version.into(),
runtime: DEFAULT_RUNTIME.to_string(),
}
}
#[must_use]
pub fn with_runtime(mut self, runtime: impl Into<String>) -> Self {
self.runtime = runtime.into();
self
}
#[must_use]
pub fn user_agent(&self) -> String {
format!("{}/{}", self.app_name, self.app_version)
}
}
impl Default for ClientIdentity {
fn default() -> Self {
Self::new(DEFAULT_APP_NAME, env!("CARGO_PKG_VERSION"))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AnthropicRequestProfile {
pub anthropic_version: String,
pub client_identity: ClientIdentity,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub betas: Vec<String>,
#[serde(default, skip_serializing_if = "Map::is_empty")]
pub extra_body: Map<String, Value>,
}
impl AnthropicRequestProfile {
#[must_use]
pub fn new(client_identity: ClientIdentity) -> Self {
Self {
anthropic_version: DEFAULT_ANTHROPIC_VERSION.to_string(),
client_identity,
betas: vec![
DEFAULT_AGENTIC_BETA.to_string(),
DEFAULT_PROMPT_CACHING_SCOPE_BETA.to_string(),
],
extra_body: Map::new(),
}
}
#[must_use]
pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
let beta = beta.into();
if !self.betas.contains(&beta) {
self.betas.push(beta);
}
self
}
#[must_use]
pub fn with_extra_body(mut self, key: impl Into<String>, value: Value) -> Self {
self.extra_body.insert(key.into(), value);
self
}
#[must_use]
pub fn header_pairs(&self) -> Vec<(String, String)> {
let mut headers = vec![
(
"anthropic-version".to_string(),
self.anthropic_version.clone(),
),
("user-agent".to_string(), self.client_identity.user_agent()),
];
if !self.betas.is_empty() {
headers.push(("anthropic-beta".to_string(), self.betas.join(",")));
}
headers
}
pub fn render_json_body<T: Serialize>(&self, request: &T) -> Result<Value, serde_json::Error> {
let mut body = serde_json::to_value(request)?;
let object = body.as_object_mut().ok_or_else(|| {
serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"request body must serialize to a JSON object",
))
})?;
for (key, value) in &self.extra_body {
object.insert(key.clone(), value.clone());
}
if !self.betas.is_empty() {
object.insert(
"betas".to_string(),
Value::Array(self.betas.iter().cloned().map(Value::String).collect()),
);
}
Ok(body)
}
}
impl Default for AnthropicRequestProfile {
fn default() -> Self {
Self::new(ClientIdentity::default())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AnalyticsEvent {
pub namespace: String,
pub action: String,
#[serde(default, skip_serializing_if = "Map::is_empty")]
pub properties: Map<String, Value>,
}
impl AnalyticsEvent {
#[must_use]
pub fn new(namespace: impl Into<String>, action: impl Into<String>) -> Self {
Self {
namespace: namespace.into(),
action: action.into(),
properties: Map::new(),
}
}
#[must_use]
pub fn with_property(mut self, key: impl Into<String>, value: Value) -> Self {
self.properties.insert(key.into(), value);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SessionTraceRecord {
pub session_id: String,
pub sequence: u64,
pub name: String,
pub timestamp_ms: u64,
#[serde(default, skip_serializing_if = "Map::is_empty")]
pub attributes: Map<String, Value>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TelemetryEvent {
HttpRequestStarted {
session_id: String,
attempt: u32,
method: String,
path: String,
#[serde(default, skip_serializing_if = "Map::is_empty")]
attributes: Map<String, Value>,
},
HttpRequestSucceeded {
session_id: String,
attempt: u32,
method: String,
path: String,
status: u16,
#[serde(default, skip_serializing_if = "Option::is_none")]
request_id: Option<String>,
#[serde(default, skip_serializing_if = "Map::is_empty")]
attributes: Map<String, Value>,
},
HttpRequestFailed {
session_id: String,
attempt: u32,
method: String,
path: String,
error: String,
retryable: bool,
#[serde(default, skip_serializing_if = "Map::is_empty")]
attributes: Map<String, Value>,
},
Analytics(AnalyticsEvent),
SessionTrace(SessionTraceRecord),
}
pub trait TelemetrySink: Send + Sync {
fn record(&self, event: TelemetryEvent);
}
#[derive(Default)]
pub struct MemoryTelemetrySink {
events: Mutex<Vec<TelemetryEvent>>,
}
impl MemoryTelemetrySink {
#[must_use]
pub fn events(&self) -> Vec<TelemetryEvent> {
self.events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
}
}
impl TelemetrySink for MemoryTelemetrySink {
fn record(&self, event: TelemetryEvent) {
self.events
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(event);
}
}
pub struct JsonlTelemetrySink {
path: PathBuf,
file: Mutex<File>,
}
impl Debug for JsonlTelemetrySink {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JsonlTelemetrySink")
.field("path", &self.path)
.finish_non_exhaustive()
}
}
impl JsonlTelemetrySink {
pub fn new(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
let path = path.as_ref().to_path_buf();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let file = OpenOptions::new().create(true).append(true).open(&path)?;
Ok(Self {
path,
file: Mutex::new(file),
})
}
#[must_use]
pub fn path(&self) -> &Path {
&self.path
}
}
impl TelemetrySink for JsonlTelemetrySink {
fn record(&self, event: TelemetryEvent) {
let Ok(line) = serde_json::to_string(&event) else {
return;
};
let mut file = self
.file
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let _ = writeln!(file, "{line}");
let _ = file.flush();
}
}
#[derive(Clone)]
pub struct SessionTracer {
session_id: String,
sequence: Arc<AtomicU64>,
sink: Arc<dyn TelemetrySink>,
}
impl Debug for SessionTracer {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionTracer")
.field("session_id", &self.session_id)
.finish_non_exhaustive()
}
}
impl SessionTracer {
#[must_use]
pub fn new(session_id: impl Into<String>, sink: Arc<dyn TelemetrySink>) -> Self {
Self {
session_id: session_id.into(),
sequence: Arc::new(AtomicU64::new(0)),
sink,
}
}
#[must_use]
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn record(&self, name: impl Into<String>, attributes: Map<String, Value>) {
let record = SessionTraceRecord {
session_id: self.session_id.clone(),
sequence: self.sequence.fetch_add(1, Ordering::Relaxed),
name: name.into(),
timestamp_ms: current_timestamp_ms(),
attributes,
};
self.sink.record(TelemetryEvent::SessionTrace(record));
}
pub fn record_http_request_started(
&self,
attempt: u32,
method: impl Into<String>,
path: impl Into<String>,
attributes: Map<String, Value>,
) {
let method = method.into();
let path = path.into();
self.sink.record(TelemetryEvent::HttpRequestStarted {
session_id: self.session_id.clone(),
attempt,
method: method.clone(),
path: path.clone(),
attributes: attributes.clone(),
});
self.record(
"http_request_started",
merge_trace_fields(method, path, attempt, attributes),
);
}
pub fn record_http_request_succeeded(
&self,
attempt: u32,
method: impl Into<String>,
path: impl Into<String>,
status: u16,
request_id: Option<String>,
attributes: Map<String, Value>,
) {
let method = method.into();
let path = path.into();
self.sink.record(TelemetryEvent::HttpRequestSucceeded {
session_id: self.session_id.clone(),
attempt,
method: method.clone(),
path: path.clone(),
status,
request_id: request_id.clone(),
attributes: attributes.clone(),
});
let mut trace_attributes = merge_trace_fields(method, path, attempt, attributes);
trace_attributes.insert("status".to_string(), Value::from(status));
if let Some(request_id) = request_id {
trace_attributes.insert("request_id".to_string(), Value::String(request_id));
}
self.record("http_request_succeeded", trace_attributes);
}
pub fn record_http_request_failed(
&self,
attempt: u32,
method: impl Into<String>,
path: impl Into<String>,
error: impl Into<String>,
retryable: bool,
attributes: Map<String, Value>,
) {
let method = method.into();
let path = path.into();
let error = error.into();
self.sink.record(TelemetryEvent::HttpRequestFailed {
session_id: self.session_id.clone(),
attempt,
method: method.clone(),
path: path.clone(),
error: error.clone(),
retryable,
attributes: attributes.clone(),
});
let mut trace_attributes = merge_trace_fields(method, path, attempt, attributes);
trace_attributes.insert("error".to_string(), Value::String(error));
trace_attributes.insert("retryable".to_string(), Value::Bool(retryable));
self.record("http_request_failed", trace_attributes);
}
pub fn record_analytics(&self, event: AnalyticsEvent) {
let mut attributes = event.properties.clone();
attributes.insert(
"namespace".to_string(),
Value::String(event.namespace.clone()),
);
attributes.insert("action".to_string(), Value::String(event.action.clone()));
self.sink.record(TelemetryEvent::Analytics(event));
self.record("analytics", attributes);
}
}
fn merge_trace_fields(
method: String,
path: String,
attempt: u32,
mut attributes: Map<String, Value>,
) -> Map<String, Value> {
attributes.insert("method".to_string(), Value::String(method));
attributes.insert("path".to_string(), Value::String(path));
attributes.insert("attempt".to_string(), Value::from(attempt));
attributes
}
fn current_timestamp_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
.try_into()
.unwrap_or(u64::MAX)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_profile_emits_headers_and_merges_body() {
let profile = AnthropicRequestProfile::new(
ClientIdentity::new("claude-code", "1.2.3").with_runtime("rust-cli"),
)
.with_beta("tools-2026-04-01")
.with_extra_body("metadata", serde_json::json!({"source": "test"}));
assert_eq!(
profile.header_pairs(),
vec![
(
"anthropic-version".to_string(),
DEFAULT_ANTHROPIC_VERSION.to_string()
),
("user-agent".to_string(), "claude-code/1.2.3".to_string()),
(
"anthropic-beta".to_string(),
"claude-code-20250219,prompt-caching-scope-2026-01-05,tools-2026-04-01"
.to_string(),
),
]
);
let body = profile
.render_json_body(&serde_json::json!({"model": "claude-sonnet"}))
.expect("body should serialize");
assert_eq!(
body["metadata"]["source"],
Value::String("test".to_string())
);
assert_eq!(
body["betas"],
serde_json::json!([
"claude-code-20250219",
"prompt-caching-scope-2026-01-05",
"tools-2026-04-01"
])
);
}
#[test]
fn session_tracer_records_structured_events_and_trace_sequence() {
let sink = Arc::new(MemoryTelemetrySink::default());
let tracer = SessionTracer::new("session-123", sink.clone());
tracer.record_http_request_started(1, "POST", "/v1/messages", Map::new());
tracer.record_analytics(
AnalyticsEvent::new("cli", "prompt_sent")
.with_property("model", Value::String("claude-opus".to_string())),
);
let events = sink.events();
assert!(matches!(
&events[0],
TelemetryEvent::HttpRequestStarted {
session_id,
attempt: 1,
method,
path,
..
} if session_id == "session-123" && method == "POST" && path == "/v1/messages"
));
assert!(matches!(
&events[1],
TelemetryEvent::SessionTrace(SessionTraceRecord { sequence: 0, name, .. })
if name == "http_request_started"
));
assert!(matches!(&events[2], TelemetryEvent::Analytics(_)));
assert!(matches!(
&events[3],
TelemetryEvent::SessionTrace(SessionTraceRecord { sequence: 1, name, .. })
if name == "analytics"
));
}
#[test]
fn jsonl_sink_persists_events() {
let path =
std::env::temp_dir().join(format!("telemetry-jsonl-{}.log", current_timestamp_ms()));
let sink = JsonlTelemetrySink::new(&path).expect("sink should create file");
sink.record(TelemetryEvent::Analytics(
AnalyticsEvent::new("cli", "turn_completed").with_property("ok", Value::Bool(true)),
));
let contents = std::fs::read_to_string(&path).expect("telemetry log should be readable");
assert!(contents.contains("\"type\":\"analytics\""));
assert!(contents.contains("\"action\":\"turn_completed\""));
let _ = std::fs::remove_file(path);
}
}

View File

@ -7,6 +7,8 @@ publish.workspace = true
[dependencies]
api = { path = "../api" }
commands = { path = "../commands" }
flate2 = "1"
plugins = { path = "../plugins" }
runtime = { path = "../runtime" }
reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] }

View File

@ -0,0 +1,181 @@
//! Lane completion detector — automatically marks lanes as completed when
//! session finishes successfully with green tests and pushed code.
//!
//! This bridges the gap where `LaneContext::completed` was a passive bool
//! that nothing automatically set. Now completion is detected from:
//! - Agent output shows Finished status
//! - No errors/blockers present
//! - Tests passed (green status)
//! - Code pushed (has output file)
use runtime::{
evaluate, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, PolicyEngine, PolicyRule,
ReviewStatus,
};
use crate::AgentOutput;
/// Detects if a lane should be automatically marked as completed.
///
/// Returns `Some(LaneContext)` with `completed = true` if all conditions met,
/// `None` if lane should remain active.
#[allow(dead_code)]
pub(crate) fn detect_lane_completion(
output: &AgentOutput,
test_green: bool,
has_pushed: bool,
) -> Option<LaneContext> {
// Must be finished without errors
if output.error.is_some() {
return None;
}
// Must have finished status
if !output.status.eq_ignore_ascii_case("completed")
&& !output.status.eq_ignore_ascii_case("finished")
{
return None;
}
// Must have no current blocker
if output.current_blocker.is_some() {
return None;
}
// Must have green tests
if !test_green {
return None;
}
// Must have pushed code
if !has_pushed {
return None;
}
// All conditions met — create completed context
Some(LaneContext {
lane_id: output.agent_id.clone(),
green_level: 3, // Workspace green
branch_freshness: std::time::Duration::from_secs(0),
blocker: LaneBlocker::None,
review_status: ReviewStatus::Approved,
diff_scope: runtime::DiffScope::Scoped,
completed: true,
reconciled: false,
})
}
/// Evaluates policy actions for a completed lane.
#[allow(dead_code)]
pub(crate) fn evaluate_completed_lane(context: &LaneContext) -> Vec<PolicyAction> {
let engine = PolicyEngine::new(vec![
PolicyRule::new(
"closeout-completed-lane",
PolicyCondition::And(vec![
PolicyCondition::LaneCompleted,
PolicyCondition::GreenAt { level: 3 },
]),
PolicyAction::CloseoutLane,
10,
),
PolicyRule::new(
"cleanup-completed-session",
PolicyCondition::LaneCompleted,
PolicyAction::CleanupSession,
5,
),
]);
evaluate(&engine, context)
}
#[cfg(test)]
mod tests {
use super::*;
use runtime::{DiffScope, LaneBlocker};
fn test_output() -> AgentOutput {
AgentOutput {
agent_id: "test-lane-1".to_string(),
name: "Test Agent".to_string(),
description: "Test".to_string(),
subagent_type: None,
model: None,
status: "Finished".to_string(),
output_file: "/tmp/test.output".to_string(),
manifest_file: "/tmp/test.manifest".to_string(),
created_at: "2024-01-01T00:00:00Z".to_string(),
started_at: Some("2024-01-01T00:00:00Z".to_string()),
completed_at: Some("2024-01-01T00:00:00Z".to_string()),
lane_events: vec![],
derived_state: "working".to_string(),
current_blocker: None,
error: None,
}
}
#[test]
fn detects_completion_when_all_conditions_met() {
let output = test_output();
let result = detect_lane_completion(&output, true, true);
assert!(result.is_some());
let context = result.unwrap();
assert!(context.completed);
assert_eq!(context.green_level, 3);
assert_eq!(context.blocker, LaneBlocker::None);
}
#[test]
fn no_completion_when_error_present() {
let mut output = test_output();
output.error = Some("Build failed".to_string());
let result = detect_lane_completion(&output, true, true);
assert!(result.is_none());
}
#[test]
fn no_completion_when_not_finished() {
let mut output = test_output();
output.status = "Running".to_string();
let result = detect_lane_completion(&output, true, true);
assert!(result.is_none());
}
#[test]
fn no_completion_when_tests_not_green() {
let output = test_output();
let result = detect_lane_completion(&output, false, true);
assert!(result.is_none());
}
#[test]
fn no_completion_when_not_pushed() {
let output = test_output();
let result = detect_lane_completion(&output, true, false);
assert!(result.is_none());
}
#[test]
fn evaluate_triggers_closeout_for_completed_lane() {
let context = LaneContext {
lane_id: "completed-lane".to_string(),
green_level: 3,
branch_freshness: std::time::Duration::from_secs(0),
blocker: LaneBlocker::None,
review_status: ReviewStatus::Approved,
diff_scope: DiffScope::Scoped,
completed: true,
reconciled: false,
};
let actions = evaluate_completed_lane(&context);
assert!(actions.contains(&PolicyAction::CloseoutLane));
assert!(actions.contains(&PolicyAction::CleanupSession));
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,548 @@
//! Minimal PDF text extraction.
//!
//! Reads a PDF file, locates `/Contents` stream objects, decompresses with
//! flate2 when the stream uses `/FlateDecode`, and extracts text operators
//! found between `BT` / `ET` markers.
use std::io::Read as _;
use std::path::Path;
/// Extract all readable text from a PDF file.
///
/// Returns the concatenated text found inside BT/ET operators across all
/// content streams. Non-text pages or encrypted PDFs yield an empty string
/// rather than an error.
pub fn extract_text(path: &Path) -> Result<String, String> {
let data = std::fs::read(path).map_err(|e| format!("failed to read PDF: {e}"))?;
Ok(extract_text_from_bytes(&data))
}
/// Core extraction from raw PDF bytes — useful for testing without touching the
/// filesystem.
pub(crate) fn extract_text_from_bytes(data: &[u8]) -> String {
let mut all_text = String::new();
let mut offset = 0;
while offset < data.len() {
let Some(stream_start) = find_subsequence(&data[offset..], b"stream") else {
break;
};
let abs_start = offset + stream_start;
// Determine the byte offset right after "stream\r\n" or "stream\n".
let content_start = skip_stream_eol(data, abs_start + b"stream".len());
let Some(end_rel) = find_subsequence(&data[content_start..], b"endstream") else {
break;
};
let content_end = content_start + end_rel;
// Look backwards from "stream" for a FlateDecode hint in the object
// dictionary. We scan at most 512 bytes before the stream keyword.
let dict_window_start = abs_start.saturating_sub(512);
let dict_window = &data[dict_window_start..abs_start];
let is_flate = find_subsequence(dict_window, b"FlateDecode").is_some();
// Only process streams whose parent dictionary references /Contents or
// looks like a page content stream (contains /Length). We intentionally
// keep this loose to cover both inline and referenced content streams.
let raw = &data[content_start..content_end];
let decompressed;
let stream_bytes: &[u8] = if is_flate {
if let Ok(buf) = inflate(raw) {
decompressed = buf;
&decompressed
} else {
offset = content_end;
continue;
}
} else {
raw
};
let text = extract_bt_et_text(stream_bytes);
if !text.is_empty() {
if !all_text.is_empty() {
all_text.push('\n');
}
all_text.push_str(&text);
}
offset = content_end;
}
all_text
}
/// Inflate (zlib / deflate) compressed data via `flate2`.
fn inflate(data: &[u8]) -> Result<Vec<u8>, String> {
let mut decoder = flate2::read::ZlibDecoder::new(data);
let mut buf = Vec::new();
decoder
.read_to_end(&mut buf)
.map_err(|e| format!("flate2 inflate error: {e}"))?;
Ok(buf)
}
/// Extract text from PDF content-stream operators between BT and ET markers.
///
/// Handles the common text-showing operators:
/// - `Tj` — show a string
/// - `TJ` — show an array of strings/numbers
/// - `'` — move to next line and show string
/// - `"` — set spacing, move to next line and show string
fn extract_bt_et_text(stream: &[u8]) -> String {
let text = String::from_utf8_lossy(stream);
let mut result = String::new();
let mut in_bt = false;
for line in text.lines() {
let trimmed = line.trim();
if trimmed == "BT" {
in_bt = true;
continue;
}
if trimmed == "ET" {
in_bt = false;
continue;
}
if !in_bt {
continue;
}
// Tj operator: (text) Tj
if trimmed.ends_with("Tj") {
if let Some(s) = extract_parenthesized_string(trimmed) {
if !result.is_empty() && !result.ends_with('\n') {
result.push(' ');
}
result.push_str(&s);
}
}
// TJ operator: [ (text) 123 (text) ] TJ
else if trimmed.ends_with("TJ") {
let extracted = extract_tj_array(trimmed);
if !extracted.is_empty() {
if !result.is_empty() && !result.ends_with('\n') {
result.push(' ');
}
result.push_str(&extracted);
}
}
// ' operator: (text) ' and " operator: aw ac (text) "
else if is_newline_show_operator(trimmed) {
if let Some(s) = extract_parenthesized_string(trimmed) {
if !result.is_empty() {
result.push('\n');
}
result.push_str(&s);
}
}
}
result
}
/// Returns `true` when `trimmed` looks like a `'` or `"` text-show operator.
fn is_newline_show_operator(trimmed: &str) -> bool {
(trimmed.ends_with('\'') && trimmed.len() > 1)
|| (trimmed.ends_with('"') && trimmed.contains('('))
}
/// Pull the text from the first `(…)` group, handling escaped parens and
/// common PDF escape sequences.
fn extract_parenthesized_string(input: &str) -> Option<String> {
let open = input.find('(')?;
let bytes = input.as_bytes();
let mut depth = 0;
let mut result = String::new();
let mut i = open;
while i < bytes.len() {
match bytes[i] {
b'(' => {
if depth > 0 {
result.push('(');
}
depth += 1;
}
b')' => {
depth -= 1;
if depth == 0 {
return Some(result);
}
result.push(')');
}
b'\\' if i + 1 < bytes.len() => {
i += 1;
match bytes[i] {
b'n' => result.push('\n'),
b'r' => result.push('\r'),
b't' => result.push('\t'),
b'\\' => result.push('\\'),
b'(' => result.push('('),
b')' => result.push(')'),
// Octal sequences — up to 3 digits.
d @ b'0'..=b'7' => {
let mut octal = u32::from(d - b'0');
for _ in 0..2 {
if i + 1 < bytes.len()
&& bytes[i + 1].is_ascii_digit()
&& bytes[i + 1] <= b'7'
{
i += 1;
octal = octal * 8 + u32::from(bytes[i] - b'0');
} else {
break;
}
}
if let Some(ch) = char::from_u32(octal) {
result.push(ch);
}
}
other => result.push(char::from(other)),
}
}
ch => result.push(char::from(ch)),
}
i += 1;
}
None // unbalanced
}
/// Extract concatenated strings from a TJ array like `[ (Hello) -120 (World) ] TJ`.
fn extract_tj_array(input: &str) -> String {
let mut result = String::new();
let Some(bracket_start) = input.find('[') else {
return result;
};
let Some(bracket_end) = input.rfind(']') else {
return result;
};
let inner = &input[bracket_start + 1..bracket_end];
let mut i = 0;
let bytes = inner.as_bytes();
while i < bytes.len() {
if bytes[i] == b'(' {
// Reconstruct the parenthesized string and extract it.
if let Some(s) = extract_parenthesized_string(&inner[i..]) {
result.push_str(&s);
// Skip past the closing paren.
let mut depth = 0u32;
for &b in &bytes[i..] {
i += 1;
if b == b'(' {
depth += 1;
} else if b == b')' {
depth -= 1;
if depth == 0 {
break;
}
}
}
continue;
}
}
i += 1;
}
result
}
/// Skip past the end-of-line marker that immediately follows the `stream`
/// keyword. Per the PDF spec this is either `\r\n` or `\n`.
fn skip_stream_eol(data: &[u8], pos: usize) -> usize {
if pos < data.len() && data[pos] == b'\r' {
if pos + 1 < data.len() && data[pos + 1] == b'\n' {
return pos + 2;
}
return pos + 1;
}
if pos < data.len() && data[pos] == b'\n' {
return pos + 1;
}
pos
}
/// Simple byte-subsequence search.
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
/// Check if a user-supplied path looks like a PDF file reference.
#[must_use]
pub fn looks_like_pdf_path(text: &str) -> Option<&str> {
for token in text.split_whitespace() {
let cleaned = token.trim_matches(|c: char| c == '\'' || c == '"' || c == '`');
if let Some(dot_pos) = cleaned.rfind('.') {
if cleaned[dot_pos + 1..].eq_ignore_ascii_case("pdf") && dot_pos > 0 {
return Some(cleaned);
}
}
}
None
}
/// Auto-extract text from a PDF path mentioned in a user prompt.
///
/// Returns `Some((path, extracted_text))` when a `.pdf` path is detected and
/// the file exists, otherwise `None`.
#[must_use]
pub fn maybe_extract_pdf_from_prompt(prompt: &str) -> Option<(String, String)> {
let pdf_path = looks_like_pdf_path(prompt)?;
let path = Path::new(pdf_path);
if !path.exists() {
return None;
}
let text = extract_text(path).ok()?;
if text.is_empty() {
return None;
}
Some((pdf_path.to_string(), text))
}
#[cfg(test)]
mod tests {
use super::*;
/// Build a minimal valid PDF with a single page containing uncompressed
/// text. This is the smallest PDF structure that exercises the BT/ET
/// extraction path.
fn build_simple_pdf(text: &str) -> Vec<u8> {
let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET");
let stream_bytes = content_stream.as_bytes();
let mut pdf = Vec::new();
// Header
pdf.extend_from_slice(b"%PDF-1.4\n");
// Object 1 — Catalog
let obj1_offset = pdf.len();
pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n");
// Object 2 — Pages
let obj2_offset = pdf.len();
pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n");
// Object 3 — Page
let obj3_offset = pdf.len();
pdf.extend_from_slice(
b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n",
);
// Object 4 — Content stream (uncompressed)
let obj4_offset = pdf.len();
let length = stream_bytes.len();
let header = format!("4 0 obj\n<< /Length {length} >>\nstream\n");
pdf.extend_from_slice(header.as_bytes());
pdf.extend_from_slice(stream_bytes);
pdf.extend_from_slice(b"\nendstream\nendobj\n");
// Cross-reference table
let xref_offset = pdf.len();
pdf.extend_from_slice(b"xref\n0 5\n");
pdf.extend_from_slice(b"0000000000 65535 f \n");
pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes());
pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes());
pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes());
pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes());
// Trailer
pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n");
pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes());
pdf
}
/// Build a minimal PDF with flate-compressed content stream.
fn build_flate_pdf(text: &str) -> Vec<u8> {
use flate2::write::ZlibEncoder;
use flate2::Compression;
use std::io::Write as _;
let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET");
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
encoder
.write_all(content_stream.as_bytes())
.expect("compress");
let compressed = encoder.finish().expect("finish");
let mut pdf = Vec::new();
pdf.extend_from_slice(b"%PDF-1.4\n");
let obj1_offset = pdf.len();
pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n");
let obj2_offset = pdf.len();
pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n");
let obj3_offset = pdf.len();
pdf.extend_from_slice(
b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n",
);
let obj4_offset = pdf.len();
let length = compressed.len();
let header = format!("4 0 obj\n<< /Length {length} /Filter /FlateDecode >>\nstream\n");
pdf.extend_from_slice(header.as_bytes());
pdf.extend_from_slice(&compressed);
pdf.extend_from_slice(b"\nendstream\nendobj\n");
let xref_offset = pdf.len();
pdf.extend_from_slice(b"xref\n0 5\n");
pdf.extend_from_slice(b"0000000000 65535 f \n");
pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes());
pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes());
pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes());
pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes());
pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n");
pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes());
pdf
}
#[test]
fn extracts_uncompressed_text_from_minimal_pdf() {
// given
let pdf_bytes = build_simple_pdf("Hello World");
// when
let text = extract_text_from_bytes(&pdf_bytes);
// then
assert_eq!(text, "Hello World");
}
#[test]
fn extracts_text_from_flate_compressed_stream() {
// given
let pdf_bytes = build_flate_pdf("Compressed PDF Text");
// when
let text = extract_text_from_bytes(&pdf_bytes);
// then
assert_eq!(text, "Compressed PDF Text");
}
#[test]
fn handles_tj_array_operator() {
// given
let stream = b"BT\n/F1 12 Tf\n[ (Hello) -120 ( World) ] TJ\nET";
// Build a raw PDF with TJ array operator instead of simple Tj.
let content_stream = std::str::from_utf8(stream).unwrap();
let raw = format!(
"%PDF-1.4\n1 0 obj\n<< /Type /Catalog >>\nendobj\n\
2 0 obj\n<< /Length {} >>\nstream\n{}\nendstream\nendobj\n%%EOF\n",
content_stream.len(),
content_stream
);
let pdf_bytes = raw.into_bytes();
// when
let text = extract_text_from_bytes(&pdf_bytes);
// then
assert_eq!(text, "Hello World");
}
#[test]
fn handles_escaped_parentheses() {
// given
let content = b"BT\n(Hello \\(World\\)) Tj\nET";
let raw = format!(
"%PDF-1.4\n1 0 obj\n<< /Length {} >>\nstream\n",
content.len()
);
let mut pdf_bytes = raw.into_bytes();
pdf_bytes.extend_from_slice(content);
pdf_bytes.extend_from_slice(b"\nendstream\nendobj\n%%EOF\n");
// when
let text = extract_text_from_bytes(&pdf_bytes);
// then
assert_eq!(text, "Hello (World)");
}
#[test]
fn returns_empty_for_non_pdf_data() {
// given
let data = b"This is not a PDF file at all";
// when
let text = extract_text_from_bytes(data);
// then
assert!(text.is_empty());
}
#[test]
fn extracts_text_from_file_on_disk() {
// given
let pdf_bytes = build_simple_pdf("Disk Test");
let dir = std::env::temp_dir().join("clawd-pdf-extract-test");
std::fs::create_dir_all(&dir).unwrap();
let pdf_path = dir.join("test.pdf");
std::fs::write(&pdf_path, &pdf_bytes).unwrap();
// when
let text = extract_text(&pdf_path).unwrap();
// then
assert_eq!(text, "Disk Test");
// cleanup
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn looks_like_pdf_path_detects_pdf_references() {
// given / when / then
assert_eq!(
looks_like_pdf_path("Please read /tmp/report.pdf"),
Some("/tmp/report.pdf")
);
assert_eq!(looks_like_pdf_path("Check file.PDF now"), Some("file.PDF"));
assert_eq!(looks_like_pdf_path("no pdf here"), None);
}
#[test]
fn maybe_extract_pdf_from_prompt_returns_none_for_missing_file() {
// given
let prompt = "Read /tmp/nonexistent-abc123.pdf please";
// when
let result = maybe_extract_pdf_from_prompt(prompt);
// then
assert!(result.is_none());
}
#[test]
fn maybe_extract_pdf_from_prompt_extracts_existing_file() {
// given
let pdf_bytes = build_simple_pdf("Auto Extracted");
let dir = std::env::temp_dir().join("clawd-pdf-auto-extract-test");
std::fs::create_dir_all(&dir).unwrap();
let pdf_path = dir.join("auto.pdf");
std::fs::write(&pdf_path, &pdf_bytes).unwrap();
let prompt = format!("Summarize {}", pdf_path.display());
// when
let result = maybe_extract_pdf_from_prompt(&prompt);
// then
let (path, text) = result.expect("should extract");
assert_eq!(path, pdf_path.display().to_string());
assert_eq!(text, "Auto Extracted");
// cleanup
let _ = std::fs::remove_dir_all(&dir);
}
}

View File

@ -5,9 +5,9 @@ import { theme, Skeleton, Spin, Popover } from 'antd';
import { XMarkdown } from '@ant-design/x-markdown';
// 助手气泡 body 撑满可用宽度,避免 Mermaid 等内容宽度受文本行长度影响
const bubbleStyle = document.createElement('style');
bubbleStyle.textContent = '.ant-bubble-start > .ant-bubble-body { width: 80%; }';
document.head.appendChild(bubbleStyle);
// const bubbleStyle = document.createElement('style');
// bubbleStyle.textContent = '.ant-bubble-start > .ant-bubble-body { width: 100%; }';
// document.head.appendChild(bubbleStyle);
import type { ComponentProps, Token } from '@ant-design/x-markdown';
import Latex from '@ant-design/x-markdown/plugins/latex';
import '@ant-design/x-markdown/themes/light.css';
@ -18,8 +18,6 @@ import WelcomeScreen from './WelcomeScreen';
// ── XMarkdown 插件配置 ────────────────────────────────────────────────
// LaTeX 数学公式插件:解析 $...$ / $$...$$ / \(...\) / \[...\]
// 自定义脚注插件:解析 [^1] 语法 → <footnote> 标签
const footnoteExtension = {
name: 'footnote',
level: 'inline' as const,
@ -281,14 +279,12 @@ const ChatView: React.FC<ChatViewProps> = ({
variant: 'filled' as const,
shape: 'round' as const,
avatar: <UserOutlined />,
// styles: { content: { width: '80%' } },
},
assistant: {
placement: 'start' as const,
variant: 'borderless' as const,
avatar: <RobotOutlined />,
streaming: true,
styles: { content: { width: '80%' } },
header: (_content: unknown, { status }: { status?: string }) => {
if (status === 'loading' || status === 'updating') {
return (