feat: 合并上游 Rust 实现,扩展 API/运行时/工具链能力
将 claw-code/rust/crates 的完整实现合并到主 workspace,涵盖
9 个 crate 的更新与 2 个新 crate 的引入。
API 层:
- 用原生 Anthropic 客户端(anthropic.rs)替换 claw_provider,
新增 prompt cache 减少重复请求开销
- 新增 HTTP 客户端构建器统一代理配置,OpenAI 兼容端增加
DashScope/Qwen 支持与抖动重试
- MessageRequest 扩展 temperature/top_p 等模型调参字段
- SSE 解析器增加 provider 上下文感知的错误信息
运行时(~11,000 行新增):
- 新增 bash 命令安全校验、分支锁碰撞检测、配置文件校验
- 新增会话存储与控制面、MCP 生命周期状态机与服务端实现
- 新增权限执行引擎、策略引擎、插件生命周期管理
- 新增 worker 启动编排、任务/定时任务注册表、信任解析器
- 保留 Windows cmd /C fallback
命令/插件/工具:
- commands 大幅重写,扩展 sandbox、doctor、plan 等 slash 命令
- plugins 新增 PostToolUseFailure hook 与宽容加载机制
- tools 新增 PDF 提取与 lane 补全工具
新增 crate:mock-anthropic-service(测试)、telemetry(遥测)
适配 claw-cli/server:ClawApiClient→AnthropicClient 重命名,
SlashCommand::parse 返回 Result,移除 session 级 Thinking 变体,
TokenUsage/ConversationMessage 补充序列化支持
This commit is contained in:
parent
4a04faf926
commit
d8d77824f4
56
Cargo.lock
generated
56
Cargo.lock
generated
@ -34,6 +34,7 @@ dependencies = [
|
||||
"runtime",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"telemetry",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
@ -347,12 +348,24 @@ version = "0.15.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
|
||||
|
||||
[[package]]
|
||||
name = "endian-type"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d"
|
||||
|
||||
[[package]]
|
||||
name = "env_home"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe"
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.2"
|
||||
@ -945,6 +958,15 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mock-anthropic-service"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"api",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nibble_vec"
|
||||
version = "0.1.0"
|
||||
@ -1340,14 +1362,15 @@ name = "runtime"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"lsp",
|
||||
"plugins",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"telemetry",
|
||||
"tokio",
|
||||
"walkdir",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1431,9 +1454,12 @@ dependencies = [
|
||||
"commands",
|
||||
"compat-harness",
|
||||
"crossterm",
|
||||
"mock-anthropic-service",
|
||||
"plugins",
|
||||
"pulldown-cmark",
|
||||
"runtime",
|
||||
"rustyline",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"syntect",
|
||||
"tokio",
|
||||
@ -1721,6 +1747,14 @@ dependencies = [
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "telemetry"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "2.0.18"
|
||||
@ -1852,6 +1886,8 @@ name = "tools"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"api",
|
||||
"commands",
|
||||
"flate2",
|
||||
"plugins",
|
||||
"reqwest",
|
||||
"runtime",
|
||||
@ -2129,6 +2165,18 @@ dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "7.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762"
|
||||
dependencies = [
|
||||
"either",
|
||||
"env_home",
|
||||
"rustix 1.1.4",
|
||||
"winsafe",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
@ -2384,6 +2432,12 @@ version = "0.53.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
|
||||
|
||||
[[package]]
|
||||
name = "winsafe"
|
||||
version = "0.0.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904"
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.51.0"
|
||||
|
||||
32
Cargo.toml
32
Cargo.toml
@ -21,3 +21,35 @@ pedantic = { level = "warn", priority = -1 }
|
||||
module_name_repetitions = "allow"
|
||||
missing_panics_doc = "allow"
|
||||
missing_errors_doc = "allow"
|
||||
uninlined_format_args = "allow"
|
||||
map_unwrap_or = "allow"
|
||||
doc_markdown = "allow"
|
||||
redundant_pattern_matching = "allow"
|
||||
items_after_statements = "allow"
|
||||
must_use_candidate = "allow"
|
||||
io_other_error = "allow"
|
||||
implicit_clone = "allow"
|
||||
redundant_closure_for_method_calls = "allow"
|
||||
unnecessary_map_or = "allow"
|
||||
manual_let_else = "allow"
|
||||
format_push_string = "allow"
|
||||
match_same_arms = "allow"
|
||||
similar_names = "allow"
|
||||
needless_continue = "allow"
|
||||
assigning_clones = "allow"
|
||||
cast_possible_truncation = "allow"
|
||||
cast_possible_wrap = "allow"
|
||||
cast_sign_loss = "allow"
|
||||
cmp_owned = "allow"
|
||||
collapsible_if = "allow"
|
||||
too_many_lines = "allow"
|
||||
wildcard_in_or_patterns = "allow"
|
||||
explicit_counter_loop = "allow"
|
||||
manual_repeat_n = "allow"
|
||||
manual_str_repeat = "allow"
|
||||
needless_borrow = "allow"
|
||||
needless_pass_by_value = "allow"
|
||||
single_match_else = "allow"
|
||||
too_many_arguments = "allow"
|
||||
unnested_or_patterns = "allow"
|
||||
unused_self = "allow"
|
||||
|
||||
@ -10,6 +10,7 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "rus
|
||||
runtime = { path = "../runtime" }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json.workspace = true
|
||||
telemetry = { path = "../telemetry" }
|
||||
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
|
||||
|
||||
[lints]
|
||||
|
||||
@ -1,70 +1,91 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::providers::claw_provider::{self, AuthSource, ClawApiClient};
|
||||
use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats};
|
||||
use crate::providers::anthropic::{self, AnthropicClient, AuthSource};
|
||||
use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig};
|
||||
use crate::providers::{self, Provider, ProviderKind};
|
||||
use crate::providers::{self, ProviderKind};
|
||||
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
|
||||
|
||||
async fn send_via_provider<P: Provider>(
|
||||
provider: &P,
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageResponse, ApiError> {
|
||||
provider.send_message(request).await
|
||||
}
|
||||
|
||||
async fn stream_via_provider<P: Provider>(
|
||||
provider: &P,
|
||||
request: &MessageRequest,
|
||||
) -> Result<P::Stream, ApiError> {
|
||||
provider.stream_message(request).await
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProviderClient {
|
||||
ClawApi(ClawApiClient),
|
||||
Anthropic(AnthropicClient),
|
||||
Xai(OpenAiCompatClient),
|
||||
OpenAi(OpenAiCompatClient),
|
||||
}
|
||||
|
||||
impl ProviderClient {
|
||||
pub fn from_model(model: &str) -> Result<Self, ApiError> {
|
||||
Self::from_model_with_default_auth(model, None)
|
||||
Self::from_model_with_anthropic_auth(model, None)
|
||||
}
|
||||
|
||||
pub fn from_model_with_default_auth(
|
||||
pub fn from_model_with_anthropic_auth(
|
||||
model: &str,
|
||||
default_auth: Option<AuthSource>,
|
||||
anthropic_auth: Option<AuthSource>,
|
||||
) -> Result<Self, ApiError> {
|
||||
let resolved_model = providers::resolve_model_alias(model);
|
||||
match providers::detect_provider_kind(&resolved_model) {
|
||||
ProviderKind::ClawApi => Ok(Self::ClawApi(match default_auth {
|
||||
Some(auth) => ClawApiClient::from_auth(auth),
|
||||
None => ClawApiClient::from_env()?,
|
||||
ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth {
|
||||
Some(auth) => AnthropicClient::from_auth(auth),
|
||||
None => AnthropicClient::from_env()?,
|
||||
})),
|
||||
ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env(
|
||||
OpenAiCompatConfig::xai(),
|
||||
)?)),
|
||||
ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env(
|
||||
OpenAiCompatConfig::openai(),
|
||||
)?)),
|
||||
ProviderKind::OpenAi => {
|
||||
// DashScope models (qwen-*) also return ProviderKind::OpenAi because they
|
||||
// speak the OpenAI wire format, but they need the DashScope config which
|
||||
// reads DASHSCOPE_API_KEY and points at dashscope.aliyuncs.com.
|
||||
let config = match providers::metadata_for_model(&resolved_model) {
|
||||
Some(meta) if meta.auth_env == "DASHSCOPE_API_KEY" => {
|
||||
OpenAiCompatConfig::dashscope()
|
||||
}
|
||||
_ => OpenAiCompatConfig::openai(),
|
||||
};
|
||||
Ok(Self::OpenAi(OpenAiCompatClient::from_env(config)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn provider_kind(&self) -> ProviderKind {
|
||||
match self {
|
||||
Self::ClawApi(_) => ProviderKind::ClawApi,
|
||||
Self::Anthropic(_) => ProviderKind::Anthropic,
|
||||
Self::Xai(_) => ProviderKind::Xai,
|
||||
Self::OpenAi(_) => ProviderKind::OpenAi,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_prompt_cache(self, prompt_cache: PromptCache) -> Self {
|
||||
match self {
|
||||
Self::Anthropic(client) => Self::Anthropic(client.with_prompt_cache(prompt_cache)),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn prompt_cache_stats(&self) -> Option<PromptCacheStats> {
|
||||
match self {
|
||||
Self::Anthropic(client) => client.prompt_cache_stats(),
|
||||
Self::Xai(_) | Self::OpenAi(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn take_last_prompt_cache_record(&self) -> Option<PromptCacheRecord> {
|
||||
match self {
|
||||
Self::Anthropic(client) => client.take_last_prompt_cache_record(),
|
||||
Self::Xai(_) | Self::OpenAi(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageResponse, ApiError> {
|
||||
match self {
|
||||
Self::ClawApi(client) => send_via_provider(client, request).await,
|
||||
Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await,
|
||||
Self::Anthropic(client) => client.send_message(request).await,
|
||||
Self::Xai(client) | Self::OpenAi(client) => client.send_message(request).await,
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,10 +94,12 @@ impl ProviderClient {
|
||||
request: &MessageRequest,
|
||||
) -> Result<MessageStream, ApiError> {
|
||||
match self {
|
||||
Self::ClawApi(client) => stream_via_provider(client, request)
|
||||
Self::Anthropic(client) => client
|
||||
.stream_message(request)
|
||||
.await
|
||||
.map(MessageStream::ClawApi),
|
||||
Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request)
|
||||
.map(MessageStream::Anthropic),
|
||||
Self::Xai(client) | Self::OpenAi(client) => client
|
||||
.stream_message(request)
|
||||
.await
|
||||
.map(MessageStream::OpenAiCompat),
|
||||
}
|
||||
@ -85,7 +108,7 @@ impl ProviderClient {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum MessageStream {
|
||||
ClawApi(claw_provider::MessageStream),
|
||||
Anthropic(anthropic::MessageStream),
|
||||
OpenAiCompat(openai_compat::MessageStream),
|
||||
}
|
||||
|
||||
@ -93,25 +116,25 @@ impl MessageStream {
|
||||
#[must_use]
|
||||
pub fn request_id(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::ClawApi(stream) => stream.request_id(),
|
||||
Self::Anthropic(stream) => stream.request_id(),
|
||||
Self::OpenAiCompat(stream) => stream.request_id(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
||||
match self {
|
||||
Self::ClawApi(stream) => stream.next_event().await,
|
||||
Self::Anthropic(stream) => stream.next_event().await,
|
||||
Self::OpenAiCompat(stream) => stream.next_event().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub use claw_provider::{
|
||||
pub use anthropic::{
|
||||
oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet,
|
||||
};
|
||||
#[must_use]
|
||||
pub fn read_base_url() -> String {
|
||||
claw_provider::read_base_url()
|
||||
anthropic::read_base_url()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@ -121,8 +144,21 @@ pub fn read_xai_base_url() -> String {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use super::ProviderClient;
|
||||
use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind};
|
||||
|
||||
/// Serializes every test in this module that mutates process-wide
|
||||
/// environment variables so concurrent test threads cannot observe
|
||||
/// each other's partially-applied state.
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolves_existing_and_grok_aliases() {
|
||||
assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6");
|
||||
@ -135,7 +171,71 @@ mod tests {
|
||||
assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai);
|
||||
assert_eq!(
|
||||
detect_provider_kind("claude-sonnet-4-6"),
|
||||
ProviderKind::ClawApi
|
||||
ProviderKind::Anthropic
|
||||
);
|
||||
}
|
||||
|
||||
/// Snapshot-restore guard for a single environment variable. Mirrors
|
||||
/// the pattern used in `providers/mod.rs` tests: captures the original
|
||||
/// value on construction, applies the override, and restores on drop so
|
||||
/// tests leave the process env untouched even when they panic.
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
original: Option<std::ffi::OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &'static str, value: Option<&str>) -> Self {
|
||||
let original = std::env::var_os(key);
|
||||
match value {
|
||||
Some(value) => std::env::set_var(key, value),
|
||||
None => std::env::remove_var(key),
|
||||
}
|
||||
Self { key, original }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
match self.original.take() {
|
||||
Some(value) => std::env::set_var(self.key, value),
|
||||
None => std::env::remove_var(self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dashscope_model_uses_dashscope_config_not_openai() {
|
||||
// Regression: qwen-plus was being routed to OpenAiCompatConfig::openai()
|
||||
// which reads OPENAI_API_KEY and points at api.openai.com, when it should
|
||||
// use OpenAiCompatConfig::dashscope() which reads DASHSCOPE_API_KEY and
|
||||
// points at dashscope.aliyuncs.com.
|
||||
let _lock = env_lock();
|
||||
let _dashscope = EnvVarGuard::set("DASHSCOPE_API_KEY", Some("test-dashscope-key"));
|
||||
let _openai = EnvVarGuard::set("OPENAI_API_KEY", None);
|
||||
|
||||
let client = ProviderClient::from_model("qwen-plus");
|
||||
|
||||
// Must succeed (not fail with "missing OPENAI_API_KEY")
|
||||
assert!(
|
||||
client.is_ok(),
|
||||
"qwen-plus with DASHSCOPE_API_KEY set should build successfully, got: {:?}",
|
||||
client.err()
|
||||
);
|
||||
|
||||
// Verify it's the OpenAi variant pointed at the DashScope base URL.
|
||||
match client.unwrap() {
|
||||
ProviderClient::OpenAi(openai_client) => {
|
||||
assert!(
|
||||
openai_client.base_url().contains("dashscope.aliyuncs.com"),
|
||||
"qwen-plus should route to DashScope base URL (contains 'dashscope.aliyuncs.com'), got: {}",
|
||||
openai_client.base_url()
|
||||
);
|
||||
}
|
||||
other => panic!(
|
||||
"Expected ProviderClient::OpenAi for qwen-plus, got: {:?}",
|
||||
other
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,22 +2,55 @@ use std::env::VarError;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::time::Duration;
|
||||
|
||||
const GENERIC_FATAL_WRAPPER_MARKERS: &[&str] = &[
|
||||
"something went wrong while processing your request",
|
||||
"please try again, or use /new to start a fresh session",
|
||||
];
|
||||
|
||||
const CONTEXT_WINDOW_ERROR_MARKERS: &[&str] = &[
|
||||
"maximum context length",
|
||||
"context window",
|
||||
"context length",
|
||||
"too many tokens",
|
||||
"prompt is too long",
|
||||
"input is too long",
|
||||
"request is too large",
|
||||
];
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ApiError {
|
||||
MissingCredentials {
|
||||
provider: &'static str,
|
||||
env_vars: &'static [&'static str],
|
||||
/// Optional, runtime-computed hint appended to the error Display
|
||||
/// output. Populated when the provider resolver can infer what the
|
||||
/// user probably intended (e.g. an OpenAI key is set but Anthropic
|
||||
/// was selected because no Anthropic credentials exist).
|
||||
hint: Option<String>,
|
||||
},
|
||||
ContextWindowExceeded {
|
||||
model: String,
|
||||
estimated_input_tokens: u32,
|
||||
requested_output_tokens: u32,
|
||||
estimated_total_tokens: u32,
|
||||
context_window_tokens: u32,
|
||||
},
|
||||
ExpiredOAuthToken,
|
||||
Auth(String),
|
||||
InvalidApiKeyEnv(VarError),
|
||||
Http(reqwest::Error),
|
||||
Io(std::io::Error),
|
||||
Json(serde_json::Error),
|
||||
Json {
|
||||
provider: String,
|
||||
model: String,
|
||||
body_snippet: String,
|
||||
source: serde_json::Error,
|
||||
},
|
||||
Api {
|
||||
status: reqwest::StatusCode,
|
||||
error_type: Option<String>,
|
||||
message: Option<String>,
|
||||
request_id: Option<String>,
|
||||
body: String,
|
||||
retryable: bool,
|
||||
},
|
||||
@ -38,7 +71,48 @@ impl ApiError {
|
||||
provider: &'static str,
|
||||
env_vars: &'static [&'static str],
|
||||
) -> Self {
|
||||
Self::MissingCredentials { provider, env_vars }
|
||||
Self::MissingCredentials {
|
||||
provider,
|
||||
env_vars,
|
||||
hint: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a `MissingCredentials` error carrying an extra, runtime-computed
|
||||
/// hint string that the Display impl appends after the canonical "missing
|
||||
/// <provider> credentials" message. Used by the provider resolver to
|
||||
/// suggest the likely fix when the user has credentials for a different
|
||||
/// provider already in the environment.
|
||||
#[must_use]
|
||||
pub fn missing_credentials_with_hint(
|
||||
provider: &'static str,
|
||||
env_vars: &'static [&'static str],
|
||||
hint: impl Into<String>,
|
||||
) -> Self {
|
||||
Self::MissingCredentials {
|
||||
provider,
|
||||
env_vars,
|
||||
hint: Some(hint.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a `Self::Json` enriched with the provider name, the model that
|
||||
/// was requested, and the first 200 characters of the raw response body so
|
||||
/// that callers can diagnose deserialization failures without re-running
|
||||
/// the request.
|
||||
#[must_use]
|
||||
pub fn json_deserialize(
|
||||
provider: impl Into<String>,
|
||||
model: impl Into<String>,
|
||||
body: &str,
|
||||
source: serde_json::Error,
|
||||
) -> Self {
|
||||
Self::Json {
|
||||
provider: provider.into(),
|
||||
model: model.into(),
|
||||
body_snippet: truncate_body_snippet(body, 200),
|
||||
source,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@ -48,11 +122,106 @@ impl ApiError {
|
||||
Self::Api { retryable, .. } => *retryable,
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
|
||||
Self::MissingCredentials { .. }
|
||||
| Self::ContextWindowExceeded { .. }
|
||||
| Self::ExpiredOAuthToken
|
||||
| Self::Auth(_)
|
||||
| Self::InvalidApiKeyEnv(_)
|
||||
| Self::Io(_)
|
||||
| Self::Json(_)
|
||||
| Self::Json { .. }
|
||||
| Self::InvalidSseFrame(_)
|
||||
| Self::BackoffOverflow { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn request_id(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Api { request_id, .. } => request_id.as_deref(),
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.request_id(),
|
||||
Self::MissingCredentials { .. }
|
||||
| Self::ContextWindowExceeded { .. }
|
||||
| Self::ExpiredOAuthToken
|
||||
| Self::Auth(_)
|
||||
| Self::InvalidApiKeyEnv(_)
|
||||
| Self::Http(_)
|
||||
| Self::Io(_)
|
||||
| Self::Json { .. }
|
||||
| Self::InvalidSseFrame(_)
|
||||
| Self::BackoffOverflow { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn safe_failure_class(&self) -> &'static str {
|
||||
match self {
|
||||
Self::RetriesExhausted { .. } if self.is_context_window_failure() => "context_window",
|
||||
Self::RetriesExhausted { .. } if self.is_generic_fatal_wrapper() => {
|
||||
"provider_retry_exhausted"
|
||||
}
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.safe_failure_class(),
|
||||
Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) => {
|
||||
"provider_auth"
|
||||
}
|
||||
Self::Api { status, .. } if matches!(status.as_u16(), 401 | 403) => "provider_auth",
|
||||
Self::ContextWindowExceeded { .. } => "context_window",
|
||||
Self::Api { .. } if self.is_context_window_failure() => "context_window",
|
||||
Self::Api { status, .. } if status.as_u16() == 429 => "provider_rate_limit",
|
||||
Self::Api { .. } if self.is_generic_fatal_wrapper() => "provider_internal",
|
||||
Self::Api { .. } => "provider_error",
|
||||
Self::Http(_) | Self::InvalidSseFrame(_) | Self::BackoffOverflow { .. } => {
|
||||
"provider_transport"
|
||||
}
|
||||
Self::InvalidApiKeyEnv(_) | Self::Io(_) | Self::Json { .. } => "runtime_io",
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_generic_fatal_wrapper(&self) -> bool {
|
||||
match self {
|
||||
Self::Api { message, body, .. } => {
|
||||
message
|
||||
.as_deref()
|
||||
.is_some_and(looks_like_generic_fatal_wrapper)
|
||||
|| looks_like_generic_fatal_wrapper(body)
|
||||
}
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.is_generic_fatal_wrapper(),
|
||||
Self::MissingCredentials { .. }
|
||||
| Self::ContextWindowExceeded { .. }
|
||||
| Self::ExpiredOAuthToken
|
||||
| Self::Auth(_)
|
||||
| Self::InvalidApiKeyEnv(_)
|
||||
| Self::Http(_)
|
||||
| Self::Io(_)
|
||||
| Self::Json { .. }
|
||||
| Self::InvalidSseFrame(_)
|
||||
| Self::BackoffOverflow { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_context_window_failure(&self) -> bool {
|
||||
match self {
|
||||
Self::ContextWindowExceeded { .. } => true,
|
||||
Self::Api {
|
||||
status,
|
||||
message,
|
||||
body,
|
||||
..
|
||||
} => {
|
||||
matches!(status.as_u16(), 400 | 413 | 422)
|
||||
&& (message
|
||||
.as_deref()
|
||||
.is_some_and(looks_like_context_window_error)
|
||||
|| looks_like_context_window_error(body))
|
||||
}
|
||||
Self::RetriesExhausted { last_error, .. } => last_error.is_context_window_failure(),
|
||||
Self::MissingCredentials { .. }
|
||||
| Self::ExpiredOAuthToken
|
||||
| Self::Auth(_)
|
||||
| Self::InvalidApiKeyEnv(_)
|
||||
| Self::Http(_)
|
||||
| Self::Io(_)
|
||||
| Self::Json { .. }
|
||||
| Self::InvalidSseFrame(_)
|
||||
| Self::BackoffOverflow { .. } => false,
|
||||
}
|
||||
@ -62,10 +231,43 @@ impl ApiError {
|
||||
impl Display for ApiError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::MissingCredentials { provider, env_vars } => write!(
|
||||
Self::MissingCredentials {
|
||||
provider,
|
||||
env_vars,
|
||||
hint,
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"missing {provider} credentials; export {} before calling the {provider} API",
|
||||
env_vars.join(" or ")
|
||||
)?;
|
||||
if cfg!(target_os = "windows") {
|
||||
if let Some(primary) = env_vars.first() {
|
||||
write!(
|
||||
f,
|
||||
" (on Windows, environment variables set in PowerShell only persist for the current session; use `setx {primary} <value>` to make it permanent, then open a new terminal, or place a `.env` file containing `{primary}=<value>` in the current working directory)"
|
||||
)?;
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
" (on Windows, environment variables set in PowerShell only persist for the current session; use `setx` to make them permanent, then open a new terminal, or place a `.env` file in the current working directory)"
|
||||
)?;
|
||||
}
|
||||
}
|
||||
if let Some(hint) = hint {
|
||||
write!(f, " — hint: {hint}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::ContextWindowExceeded {
|
||||
model,
|
||||
estimated_input_tokens,
|
||||
requested_output_tokens,
|
||||
estimated_total_tokens,
|
||||
context_window_tokens,
|
||||
} => write!(
|
||||
f,
|
||||
"context_window_blocked for {model}: estimated input {estimated_input_tokens} + requested output {requested_output_tokens} = {estimated_total_tokens} tokens exceeds the {context_window_tokens}-token context window; compact the session or reduce request size before retrying"
|
||||
),
|
||||
Self::ExpiredOAuthToken => {
|
||||
write!(
|
||||
@ -79,19 +281,37 @@ impl Display for ApiError {
|
||||
}
|
||||
Self::Http(error) => write!(f, "http error: {error}"),
|
||||
Self::Io(error) => write!(f, "io error: {error}"),
|
||||
Self::Json(error) => write!(f, "json error: {error}"),
|
||||
Self::Json {
|
||||
provider,
|
||||
model,
|
||||
body_snippet,
|
||||
source,
|
||||
} => write!(
|
||||
f,
|
||||
"failed to parse {provider} response for model {model}: {source}; first 200 chars of body: {body_snippet}"
|
||||
),
|
||||
Self::Api {
|
||||
status,
|
||||
error_type,
|
||||
message,
|
||||
request_id,
|
||||
body,
|
||||
..
|
||||
} => match (error_type, message) {
|
||||
(Some(error_type), Some(message)) => {
|
||||
write!(f, "api returned {status} ({error_type}): {message}")
|
||||
} => {
|
||||
if let (Some(error_type), Some(message)) = (error_type, message) {
|
||||
write!(f, "api returned {status} ({error_type})")?;
|
||||
if let Some(request_id) = request_id {
|
||||
write!(f, " [trace {request_id}]")?;
|
||||
}
|
||||
write!(f, ": {message}")
|
||||
} else {
|
||||
write!(f, "api returned {status}")?;
|
||||
if let Some(request_id) = request_id {
|
||||
write!(f, " [trace {request_id}]")?;
|
||||
}
|
||||
write!(f, ": {body}")
|
||||
}
|
||||
}
|
||||
_ => write!(f, "api returned {status}: {body}"),
|
||||
},
|
||||
Self::RetriesExhausted {
|
||||
attempts,
|
||||
last_error,
|
||||
@ -124,7 +344,12 @@ impl From<std::io::Error> for ApiError {
|
||||
|
||||
impl From<serde_json::Error> for ApiError {
|
||||
fn from(value: serde_json::Error) -> Self {
|
||||
Self::Json(value)
|
||||
Self::Json {
|
||||
provider: "unknown".to_string(),
|
||||
model: "unknown".to_string(),
|
||||
body_snippet: String::new(),
|
||||
source: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,3 +358,215 @@ impl From<VarError> for ApiError {
|
||||
Self::InvalidApiKeyEnv(value)
|
||||
}
|
||||
}
|
||||
|
||||
fn looks_like_generic_fatal_wrapper(text: &str) -> bool {
|
||||
let lowered = text.to_ascii_lowercase();
|
||||
GENERIC_FATAL_WRAPPER_MARKERS
|
||||
.iter()
|
||||
.any(|marker| lowered.contains(marker))
|
||||
}
|
||||
|
||||
fn looks_like_context_window_error(text: &str) -> bool {
|
||||
let lowered = text.to_ascii_lowercase();
|
||||
CONTEXT_WINDOW_ERROR_MARKERS
|
||||
.iter()
|
||||
.any(|marker| lowered.contains(marker))
|
||||
}
|
||||
|
||||
/// Truncate `body` so the resulting snippet contains at most `max_chars`
|
||||
/// characters (counted by Unicode scalar values, not bytes), preserving the
|
||||
/// leading slice of the body that the caller most often needs to inspect.
|
||||
fn truncate_body_snippet(body: &str, max_chars: usize) -> String {
|
||||
let mut taken_chars = 0;
|
||||
let mut byte_end = 0;
|
||||
for (offset, character) in body.char_indices() {
|
||||
if taken_chars >= max_chars {
|
||||
break;
|
||||
}
|
||||
taken_chars += 1;
|
||||
byte_end = offset + character.len_utf8();
|
||||
}
|
||||
if taken_chars >= max_chars && byte_end < body.len() {
|
||||
format!("{}…", &body[..byte_end])
|
||||
} else {
|
||||
body[..byte_end].to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{truncate_body_snippet, ApiError};
|
||||
|
||||
#[test]
|
||||
fn json_deserialize_error_includes_provider_model_and_truncated_body_snippet() {
|
||||
let raw_body = format!("{}{}", "x".repeat(190), "_TAIL_PAST_200_CHARS_MARKER_");
|
||||
let source = serde_json::from_str::<serde_json::Value>("{not json")
|
||||
.expect_err("invalid json should fail to parse");
|
||||
|
||||
let error = ApiError::json_deserialize("Anthropic", "claude-opus-4-6", &raw_body, source);
|
||||
let rendered = error.to_string();
|
||||
|
||||
assert!(
|
||||
rendered.starts_with("failed to parse Anthropic response for model claude-opus-4-6: "),
|
||||
"rendered error should lead with provider and model: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
rendered.contains("first 200 chars of body: "),
|
||||
"rendered error should label the body snippet: {rendered}"
|
||||
);
|
||||
let snippet = rendered
|
||||
.split("first 200 chars of body: ")
|
||||
.nth(1)
|
||||
.expect("snippet section should be present");
|
||||
assert!(
|
||||
snippet.starts_with(&"x".repeat(190)),
|
||||
"snippet should preserve the leading characters of the body: {snippet}"
|
||||
);
|
||||
assert!(
|
||||
snippet.ends_with('…'),
|
||||
"snippet should signal truncation with an ellipsis: {snippet}"
|
||||
);
|
||||
assert!(
|
||||
!snippet.contains("_TAIL_PAST_200_CHARS_MARKER_"),
|
||||
"snippet should drop characters past the 200-char cap: {snippet}"
|
||||
);
|
||||
assert_eq!(error.safe_failure_class(), "runtime_io");
|
||||
assert_eq!(error.request_id(), None);
|
||||
assert!(!error.is_retryable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_body_snippet_keeps_short_bodies_intact() {
|
||||
assert_eq!(truncate_body_snippet("hello", 200), "hello");
|
||||
assert_eq!(truncate_body_snippet("", 200), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_body_snippet_caps_long_bodies_at_max_chars() {
|
||||
let body = "a".repeat(250);
|
||||
let snippet = truncate_body_snippet(&body, 200);
|
||||
assert_eq!(snippet.chars().count(), 201, "200 chars + ellipsis");
|
||||
assert!(snippet.ends_with('…'));
|
||||
assert!(snippet.starts_with(&"a".repeat(200)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_body_snippet_does_not_split_multibyte_characters() {
|
||||
let body = "한글한글한글한글한글한글";
|
||||
let snippet = truncate_body_snippet(body, 4);
|
||||
assert_eq!(snippet, "한글한글…");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_generic_fatal_wrapper_and_classifies_it_as_provider_internal() {
|
||||
let error = ApiError::Api {
|
||||
status: reqwest::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
error_type: Some("api_error".to_string()),
|
||||
message: Some(
|
||||
"Something went wrong while processing your request. Please try again, or use /new to start a fresh session."
|
||||
.to_string(),
|
||||
),
|
||||
request_id: Some("req_jobdori_123".to_string()),
|
||||
body: String::new(),
|
||||
retryable: true,
|
||||
};
|
||||
|
||||
assert!(error.is_generic_fatal_wrapper());
|
||||
assert_eq!(error.safe_failure_class(), "provider_internal");
|
||||
assert_eq!(error.request_id(), Some("req_jobdori_123"));
|
||||
assert!(error.to_string().contains("[trace req_jobdori_123]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retries_exhausted_preserves_nested_request_id_and_failure_class() {
|
||||
let error = ApiError::RetriesExhausted {
|
||||
attempts: 3,
|
||||
last_error: Box::new(ApiError::Api {
|
||||
status: reqwest::StatusCode::BAD_GATEWAY,
|
||||
error_type: Some("api_error".to_string()),
|
||||
message: Some(
|
||||
"Something went wrong while processing your request. Please try again, or use /new to start a fresh session."
|
||||
.to_string(),
|
||||
),
|
||||
request_id: Some("req_nested_456".to_string()),
|
||||
body: String::new(),
|
||||
retryable: true,
|
||||
}),
|
||||
};
|
||||
|
||||
assert!(error.is_generic_fatal_wrapper());
|
||||
assert_eq!(error.safe_failure_class(), "provider_retry_exhausted");
|
||||
assert_eq!(error.request_id(), Some("req_nested_456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classifies_provider_context_window_errors() {
|
||||
let error = ApiError::Api {
|
||||
status: reqwest::StatusCode::BAD_REQUEST,
|
||||
error_type: Some("invalid_request_error".to_string()),
|
||||
message: Some(
|
||||
"This model's maximum context length is 200000 tokens, but your request used 230000 tokens."
|
||||
.to_string(),
|
||||
),
|
||||
request_id: Some("req_ctx_123".to_string()),
|
||||
body: String::new(),
|
||||
retryable: false,
|
||||
};
|
||||
|
||||
assert!(error.is_context_window_failure());
|
||||
assert_eq!(error.safe_failure_class(), "context_window");
|
||||
assert_eq!(error.request_id(), Some("req_ctx_123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_credentials_without_hint_renders_the_canonical_message() {
|
||||
// given
|
||||
let error = ApiError::missing_credentials(
|
||||
"Anthropic",
|
||||
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||
);
|
||||
|
||||
// when
|
||||
let rendered = error.to_string();
|
||||
|
||||
// then
|
||||
assert!(
|
||||
rendered.starts_with(
|
||||
"missing Anthropic credentials; export ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY before calling the Anthropic API"
|
||||
),
|
||||
"rendered error should lead with the canonical missing-credential message: {rendered}"
|
||||
);
|
||||
assert!(
|
||||
!rendered.contains(" — hint: "),
|
||||
"no hint should be appended when none is supplied: {rendered}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_credentials_with_hint_appends_the_hint_after_base_message() {
|
||||
// given
|
||||
let error = ApiError::missing_credentials_with_hint(
|
||||
"Anthropic",
|
||||
&["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
|
||||
"I see OPENAI_API_KEY is set — if you meant to use the OpenAI-compat provider, prefix your model name with `openai/` so prefix routing selects it.",
|
||||
);
|
||||
|
||||
// when
|
||||
let rendered = error.to_string();
|
||||
|
||||
// then
|
||||
assert!(
|
||||
rendered.starts_with("missing Anthropic credentials;"),
|
||||
"hint should be appended, not replace the base message: {rendered}"
|
||||
);
|
||||
let hint_marker = " — hint: I see OPENAI_API_KEY is set — if you meant to use the OpenAI-compat provider, prefix your model name with `openai/` so prefix routing selects it.";
|
||||
assert!(
|
||||
rendered.ends_with(hint_marker),
|
||||
"rendered error should end with the hint: {rendered}"
|
||||
);
|
||||
// Classification semantics are unaffected by the presence of a hint.
|
||||
assert_eq!(error.safe_failure_class(), "provider_auth");
|
||||
assert!(!error.is_retryable());
|
||||
assert_eq!(error.request_id(), None);
|
||||
}
|
||||
}
|
||||
|
||||
344
crates/api/src/http_client.rs
Normal file
344
crates/api/src/http_client.rs
Normal file
@ -0,0 +1,344 @@
|
||||
use crate::error::ApiError;
|
||||
|
||||
const HTTP_PROXY_KEYS: [&str; 2] = ["HTTP_PROXY", "http_proxy"];
|
||||
const HTTPS_PROXY_KEYS: [&str; 2] = ["HTTPS_PROXY", "https_proxy"];
|
||||
const NO_PROXY_KEYS: [&str; 2] = ["NO_PROXY", "no_proxy"];
|
||||
|
||||
/// Snapshot of the proxy-related environment variables that influence the
|
||||
/// outbound HTTP client. Captured up front so callers can inspect, log, and
|
||||
/// test the resolved configuration without re-reading the process environment.
|
||||
///
|
||||
/// When `proxy_url` is set it acts as a single catch-all proxy for both
|
||||
/// HTTP and HTTPS traffic, taking precedence over the per-scheme fields.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProxyConfig {
|
||||
pub http_proxy: Option<String>,
|
||||
pub https_proxy: Option<String>,
|
||||
pub no_proxy: Option<String>,
|
||||
/// Optional unified proxy URL that applies to both HTTP and HTTPS.
|
||||
/// When set, this takes precedence over `http_proxy` and `https_proxy`.
|
||||
pub proxy_url: Option<String>,
|
||||
}
|
||||
|
||||
impl ProxyConfig {
|
||||
/// Read proxy settings from the live process environment, honouring both
|
||||
/// the upper- and lower-case spellings used by curl, git, and friends.
|
||||
#[must_use]
|
||||
pub fn from_env() -> Self {
|
||||
Self::from_lookup(|key| std::env::var(key).ok())
|
||||
}
|
||||
|
||||
/// Create a proxy configuration from a single URL that applies to both
|
||||
/// HTTP and HTTPS traffic. This is the config-file alternative to setting
|
||||
/// `HTTP_PROXY` and `HTTPS_PROXY` environment variables separately.
|
||||
#[must_use]
|
||||
pub fn from_proxy_url(url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
proxy_url: Some(url.into()),
|
||||
..Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn from_lookup<F>(mut lookup: F) -> Self
|
||||
where
|
||||
F: FnMut(&str) -> Option<String>,
|
||||
{
|
||||
Self {
|
||||
http_proxy: first_non_empty(&HTTP_PROXY_KEYS, &mut lookup),
|
||||
https_proxy: first_non_empty(&HTTPS_PROXY_KEYS, &mut lookup),
|
||||
no_proxy: first_non_empty(&NO_PROXY_KEYS, &mut lookup),
|
||||
proxy_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.proxy_url.is_none() && self.http_proxy.is_none() && self.https_proxy.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a `reqwest::Client` that honours the standard `HTTP_PROXY`,
|
||||
/// `HTTPS_PROXY`, and `NO_PROXY` environment variables. When no proxy is
|
||||
/// configured the client behaves identically to `reqwest::Client::new()`.
|
||||
pub fn build_http_client() -> Result<reqwest::Client, ApiError> {
|
||||
build_http_client_with(&ProxyConfig::from_env())
|
||||
}
|
||||
|
||||
/// Infallible counterpart to [`build_http_client`] for constructors that
|
||||
/// historically returned `Self` rather than `Result<Self, _>`. When the proxy
|
||||
/// configuration is malformed we fall back to a default client so that
|
||||
/// callers retain the previous behaviour and the failure surfaces on the
|
||||
/// first outbound request instead of at construction time.
|
||||
#[must_use]
|
||||
pub fn build_http_client_or_default() -> reqwest::Client {
|
||||
build_http_client().unwrap_or_else(|_| reqwest::Client::new())
|
||||
}
|
||||
|
||||
/// Build a `reqwest::Client` from an explicit [`ProxyConfig`]. Used by tests
|
||||
/// and by callers that want to override process-level environment lookups.
|
||||
///
|
||||
/// When `config.proxy_url` is set it overrides the per-scheme `http_proxy`
|
||||
/// and `https_proxy` fields and is registered as both an HTTP and HTTPS
|
||||
/// proxy so a single value can route every outbound request.
|
||||
pub fn build_http_client_with(config: &ProxyConfig) -> Result<reqwest::Client, ApiError> {
|
||||
let mut builder = reqwest::Client::builder().no_proxy();
|
||||
|
||||
let no_proxy = config
|
||||
.no_proxy
|
||||
.as_deref()
|
||||
.and_then(reqwest::NoProxy::from_string);
|
||||
|
||||
let (http_proxy_url, https_proxy_url) = match config.proxy_url.as_deref() {
|
||||
Some(unified) => (Some(unified), Some(unified)),
|
||||
None => (config.http_proxy.as_deref(), config.https_proxy.as_deref()),
|
||||
};
|
||||
|
||||
if let Some(url) = https_proxy_url {
|
||||
let mut proxy = reqwest::Proxy::https(url)?;
|
||||
if let Some(filter) = no_proxy.clone() {
|
||||
proxy = proxy.no_proxy(Some(filter));
|
||||
}
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
|
||||
if let Some(url) = http_proxy_url {
|
||||
let mut proxy = reqwest::Proxy::http(url)?;
|
||||
if let Some(filter) = no_proxy.clone() {
|
||||
proxy = proxy.no_proxy(Some(filter));
|
||||
}
|
||||
builder = builder.proxy(proxy);
|
||||
}
|
||||
|
||||
Ok(builder.build()?)
|
||||
}
|
||||
|
||||
fn first_non_empty<F>(keys: &[&str], lookup: &mut F) -> Option<String>
|
||||
where
|
||||
F: FnMut(&str) -> Option<String>,
|
||||
{
|
||||
keys.iter()
|
||||
.find_map(|key| lookup(key).filter(|value| !value.is_empty()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{build_http_client_with, ProxyConfig};
|
||||
|
||||
fn config_from_map(pairs: &[(&str, &str)]) -> ProxyConfig {
|
||||
let map: HashMap<String, String> = pairs
|
||||
.iter()
|
||||
.map(|(key, value)| ((*key).to_string(), (*value).to_string()))
|
||||
.collect();
|
||||
ProxyConfig::from_lookup(|key| map.get(key).cloned())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_is_empty_when_no_env_vars_are_set() {
|
||||
// given
|
||||
let config = config_from_map(&[]);
|
||||
|
||||
// when
|
||||
let empty = config.is_empty();
|
||||
|
||||
// then
|
||||
assert!(empty);
|
||||
assert_eq!(config, ProxyConfig::default());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_reads_uppercase_http_https_and_no_proxy() {
|
||||
// given
|
||||
let pairs = [
|
||||
("HTTP_PROXY", "http://proxy.internal:3128"),
|
||||
("HTTPS_PROXY", "http://secure.internal:3129"),
|
||||
("NO_PROXY", "localhost,127.0.0.1,.corp"),
|
||||
];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.http_proxy.as_deref(),
|
||||
Some("http://proxy.internal:3128")
|
||||
);
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://secure.internal:3129")
|
||||
);
|
||||
assert_eq!(
|
||||
config.no_proxy.as_deref(),
|
||||
Some("localhost,127.0.0.1,.corp")
|
||||
);
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_falls_back_to_lowercase_keys() {
|
||||
// given
|
||||
let pairs = [
|
||||
("http_proxy", "http://lower.internal:3128"),
|
||||
("https_proxy", "http://lower-secure.internal:3129"),
|
||||
("no_proxy", ".lower"),
|
||||
];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.http_proxy.as_deref(),
|
||||
Some("http://lower.internal:3128")
|
||||
);
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://lower-secure.internal:3129")
|
||||
);
|
||||
assert_eq!(config.no_proxy.as_deref(), Some(".lower"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_prefers_uppercase_over_lowercase_when_both_set() {
|
||||
// given
|
||||
let pairs = [
|
||||
("HTTP_PROXY", "http://upper.internal:3128"),
|
||||
("http_proxy", "http://lower.internal:3128"),
|
||||
];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.http_proxy.as_deref(),
|
||||
Some("http://upper.internal:3128")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_treats_empty_strings_as_unset() {
|
||||
// given
|
||||
let pairs = [("HTTP_PROXY", ""), ("http_proxy", "")];
|
||||
|
||||
// when
|
||||
let config = config_from_map(&pairs);
|
||||
|
||||
// then
|
||||
assert!(config.http_proxy.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_succeeds_when_no_proxy_is_configured() {
|
||||
// given
|
||||
let config = ProxyConfig::default();
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_succeeds_with_valid_http_and_https_proxies() {
|
||||
// given
|
||||
let config = ProxyConfig {
|
||||
http_proxy: Some("http://proxy.internal:3128".to_string()),
|
||||
https_proxy: Some("http://secure.internal:3129".to_string()),
|
||||
no_proxy: Some("localhost,127.0.0.1".to_string()),
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_returns_http_error_for_invalid_proxy_url() {
|
||||
// given
|
||||
let config = ProxyConfig {
|
||||
http_proxy: None,
|
||||
https_proxy: Some("not a url".to_string()),
|
||||
no_proxy: None,
|
||||
proxy_url: None,
|
||||
};
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("invalid proxy URL must be reported as a build failure");
|
||||
assert!(
|
||||
matches!(error, crate::error::ApiError::Http(_)),
|
||||
"expected ApiError::Http for invalid proxy URL, got: {error:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_proxy_url_sets_unified_field_and_leaves_per_scheme_empty() {
|
||||
// given / when
|
||||
let config = ProxyConfig::from_proxy_url("http://unified.internal:3128");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
config.proxy_url.as_deref(),
|
||||
Some("http://unified.internal:3128")
|
||||
);
|
||||
assert!(config.http_proxy.is_none());
|
||||
assert!(config.https_proxy.is_none());
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_succeeds_with_unified_proxy_url() {
|
||||
// given
|
||||
let config = ProxyConfig {
|
||||
proxy_url: Some("http://unified.internal:3128".to_string()),
|
||||
no_proxy: Some("localhost".to_string()),
|
||||
..ProxyConfig::default()
|
||||
};
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_url_takes_precedence_over_per_scheme_fields() {
|
||||
// given – both per-scheme and unified are set
|
||||
let config = ProxyConfig {
|
||||
http_proxy: Some("http://per-scheme.internal:1111".to_string()),
|
||||
https_proxy: Some("http://per-scheme.internal:2222".to_string()),
|
||||
no_proxy: None,
|
||||
proxy_url: Some("http://unified.internal:3128".to_string()),
|
||||
};
|
||||
|
||||
// when – building succeeds (the unified URL is valid)
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_http_client_returns_error_for_invalid_unified_proxy_url() {
|
||||
// given
|
||||
let config = ProxyConfig::from_proxy_url("not a url");
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
matches!(result, Err(crate::error::ApiError::Http(_))),
|
||||
"invalid unified proxy URL should fail: {result:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,7 @@
|
||||
mod client;
|
||||
mod error;
|
||||
mod http_client;
|
||||
mod prompt_cache;
|
||||
mod providers;
|
||||
mod sse;
|
||||
mod types;
|
||||
@ -9,10 +11,18 @@ pub use client::{
|
||||
resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient,
|
||||
};
|
||||
pub use error::ApiError;
|
||||
pub use providers::claw_provider::{AuthSource, ClawApiClient, ClawApiClient as ApiClient};
|
||||
pub use http_client::{
|
||||
build_http_client, build_http_client_or_default, build_http_client_with, ProxyConfig,
|
||||
};
|
||||
pub use prompt_cache::{
|
||||
CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord,
|
||||
PromptCacheStats,
|
||||
};
|
||||
pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource};
|
||||
pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig};
|
||||
pub use providers::{
|
||||
detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind,
|
||||
detect_provider_kind, max_tokens_for_model, max_tokens_for_model_with_override,
|
||||
resolve_model_alias, ProviderKind,
|
||||
};
|
||||
pub use sse::{parse_frame, SseParser};
|
||||
pub use types::{
|
||||
@ -21,3 +31,9 @@ pub use types::{
|
||||
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
|
||||
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
||||
};
|
||||
|
||||
pub use telemetry::{
|
||||
AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, JsonlTelemetrySink,
|
||||
MemoryTelemetrySink, SessionTraceRecord, SessionTracer, TelemetryEvent, TelemetrySink,
|
||||
DEFAULT_ANTHROPIC_VERSION,
|
||||
};
|
||||
|
||||
735
crates/api/src/prompt_cache.rs
Normal file
735
crates/api/src/prompt_cache.rs
Normal file
@ -0,0 +1,735 @@
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::types::{MessageRequest, MessageResponse, Usage};
|
||||
|
||||
const DEFAULT_COMPLETION_TTL_SECS: u64 = 30;
|
||||
const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60;
|
||||
const DEFAULT_BREAK_MIN_DROP: u32 = 2_000;
|
||||
const MAX_SANITIZED_LENGTH: usize = 80;
|
||||
const REQUEST_FINGERPRINT_VERSION: u32 = 1;
|
||||
const REQUEST_FINGERPRINT_PREFIX: &str = "v1";
|
||||
const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
|
||||
const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PromptCacheConfig {
|
||||
pub session_id: String,
|
||||
pub completion_ttl: Duration,
|
||||
pub prompt_ttl: Duration,
|
||||
pub cache_break_min_drop: u32,
|
||||
}
|
||||
|
||||
impl PromptCacheConfig {
|
||||
#[must_use]
|
||||
pub fn new(session_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
session_id: session_id.into(),
|
||||
completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS),
|
||||
prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS),
|
||||
cache_break_min_drop: DEFAULT_BREAK_MIN_DROP,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PromptCacheConfig {
|
||||
fn default() -> Self {
|
||||
Self::new("default")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct PromptCachePaths {
|
||||
pub root: PathBuf,
|
||||
pub session_dir: PathBuf,
|
||||
pub completion_dir: PathBuf,
|
||||
pub session_state_path: PathBuf,
|
||||
pub stats_path: PathBuf,
|
||||
}
|
||||
|
||||
impl PromptCachePaths {
|
||||
#[must_use]
|
||||
pub fn for_session(session_id: &str) -> Self {
|
||||
let root = base_cache_root();
|
||||
let session_dir = root.join(sanitize_path_segment(session_id));
|
||||
let completion_dir = session_dir.join("completions");
|
||||
Self {
|
||||
root,
|
||||
session_state_path: session_dir.join("session-state.json"),
|
||||
stats_path: session_dir.join("stats.json"),
|
||||
session_dir,
|
||||
completion_dir,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf {
|
||||
self.completion_dir.join(format!("{request_hash}.json"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct PromptCacheStats {
|
||||
pub tracked_requests: u64,
|
||||
pub completion_cache_hits: u64,
|
||||
pub completion_cache_misses: u64,
|
||||
pub completion_cache_writes: u64,
|
||||
pub expected_invalidations: u64,
|
||||
pub unexpected_cache_breaks: u64,
|
||||
pub total_cache_creation_input_tokens: u64,
|
||||
pub total_cache_read_input_tokens: u64,
|
||||
pub last_cache_creation_input_tokens: Option<u32>,
|
||||
pub last_cache_read_input_tokens: Option<u32>,
|
||||
pub last_request_hash: Option<String>,
|
||||
pub last_completion_cache_key: Option<String>,
|
||||
pub last_break_reason: Option<String>,
|
||||
pub last_cache_source: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct CacheBreakEvent {
|
||||
pub unexpected: bool,
|
||||
pub reason: String,
|
||||
pub previous_cache_read_input_tokens: u32,
|
||||
pub current_cache_read_input_tokens: u32,
|
||||
pub token_drop: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PromptCacheRecord {
|
||||
pub cache_break: Option<CacheBreakEvent>,
|
||||
pub stats: PromptCacheStats,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PromptCache {
|
||||
inner: Arc<Mutex<PromptCacheInner>>,
|
||||
}
|
||||
|
||||
impl PromptCache {
|
||||
#[must_use]
|
||||
pub fn new(session_id: impl Into<String>) -> Self {
|
||||
Self::with_config(PromptCacheConfig::new(session_id))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_config(config: PromptCacheConfig) -> Self {
|
||||
let paths = PromptCachePaths::for_session(&config.session_id);
|
||||
let stats = read_json::<PromptCacheStats>(&paths.stats_path).unwrap_or_default();
|
||||
let previous = read_json::<TrackedPromptState>(&paths.session_state_path);
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(PromptCacheInner {
|
||||
config,
|
||||
paths,
|
||||
stats,
|
||||
previous,
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn paths(&self) -> PromptCachePaths {
|
||||
self.lock().paths.clone()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn stats(&self) -> PromptCacheStats {
|
||||
self.lock().stats.clone()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn lookup_completion(&self, request: &MessageRequest) -> Option<MessageResponse> {
|
||||
let request_hash = request_hash_hex(request);
|
||||
let (paths, ttl) = {
|
||||
let inner = self.lock();
|
||||
(inner.paths.clone(), inner.config.completion_ttl)
|
||||
};
|
||||
let entry_path = paths.completion_entry_path(&request_hash);
|
||||
let entry = read_json::<CompletionCacheEntry>(&entry_path);
|
||||
let Some(entry) = entry else {
|
||||
let mut inner = self.lock();
|
||||
inner.stats.completion_cache_misses += 1;
|
||||
inner.stats.last_completion_cache_key = Some(request_hash);
|
||||
persist_state(&inner);
|
||||
return None;
|
||||
};
|
||||
|
||||
if entry.fingerprint_version != current_fingerprint_version() {
|
||||
let mut inner = self.lock();
|
||||
inner.stats.completion_cache_misses += 1;
|
||||
inner.stats.last_completion_cache_key = Some(request_hash.clone());
|
||||
let _ = fs::remove_file(entry_path);
|
||||
persist_state(&inner);
|
||||
return None;
|
||||
}
|
||||
|
||||
let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs();
|
||||
let mut inner = self.lock();
|
||||
inner.stats.last_completion_cache_key = Some(request_hash.clone());
|
||||
if expired {
|
||||
inner.stats.completion_cache_misses += 1;
|
||||
let _ = fs::remove_file(entry_path);
|
||||
persist_state(&inner);
|
||||
return None;
|
||||
}
|
||||
|
||||
inner.stats.completion_cache_hits += 1;
|
||||
apply_usage_to_stats(
|
||||
&mut inner.stats,
|
||||
&entry.response.usage,
|
||||
&request_hash,
|
||||
"completion-cache",
|
||||
);
|
||||
inner.previous = Some(TrackedPromptState::from_usage(
|
||||
request,
|
||||
&entry.response.usage,
|
||||
));
|
||||
persist_state(&inner);
|
||||
Some(entry.response)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn record_response(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
response: &MessageResponse,
|
||||
) -> PromptCacheRecord {
|
||||
self.record_usage_internal(request, &response.usage, Some(response))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord {
|
||||
self.record_usage_internal(request, usage, None)
|
||||
}
|
||||
|
||||
fn record_usage_internal(
|
||||
&self,
|
||||
request: &MessageRequest,
|
||||
usage: &Usage,
|
||||
response: Option<&MessageResponse>,
|
||||
) -> PromptCacheRecord {
|
||||
let request_hash = request_hash_hex(request);
|
||||
let mut inner = self.lock();
|
||||
let previous = inner.previous.clone();
|
||||
let current = TrackedPromptState::from_usage(request, usage);
|
||||
let cache_break = detect_cache_break(&inner.config, previous.as_ref(), ¤t);
|
||||
|
||||
inner.stats.tracked_requests += 1;
|
||||
apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response");
|
||||
if let Some(event) = &cache_break {
|
||||
if event.unexpected {
|
||||
inner.stats.unexpected_cache_breaks += 1;
|
||||
} else {
|
||||
inner.stats.expected_invalidations += 1;
|
||||
}
|
||||
inner.stats.last_break_reason = Some(event.reason.clone());
|
||||
}
|
||||
|
||||
inner.previous = Some(current);
|
||||
if let Some(response) = response {
|
||||
write_completion_entry(&inner.paths, &request_hash, response);
|
||||
inner.stats.completion_cache_writes += 1;
|
||||
}
|
||||
persist_state(&inner);
|
||||
|
||||
PromptCacheRecord {
|
||||
cache_break,
|
||||
stats: inner.stats.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> {
|
||||
self.inner
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PromptCacheInner {
|
||||
config: PromptCacheConfig,
|
||||
paths: PromptCachePaths,
|
||||
stats: PromptCacheStats,
|
||||
previous: Option<TrackedPromptState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct CompletionCacheEntry {
|
||||
cached_at_unix_secs: u64,
|
||||
#[serde(default = "current_fingerprint_version")]
|
||||
fingerprint_version: u32,
|
||||
response: MessageResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
struct TrackedPromptState {
|
||||
observed_at_unix_secs: u64,
|
||||
#[serde(default = "current_fingerprint_version")]
|
||||
fingerprint_version: u32,
|
||||
model_hash: u64,
|
||||
system_hash: u64,
|
||||
tools_hash: u64,
|
||||
messages_hash: u64,
|
||||
cache_read_input_tokens: u32,
|
||||
}
|
||||
|
||||
impl TrackedPromptState {
|
||||
fn from_usage(request: &MessageRequest, usage: &Usage) -> Self {
|
||||
let hashes = RequestFingerprints::from_request(request);
|
||||
Self {
|
||||
observed_at_unix_secs: now_unix_secs(),
|
||||
fingerprint_version: current_fingerprint_version(),
|
||||
model_hash: hashes.model,
|
||||
system_hash: hashes.system,
|
||||
tools_hash: hashes.tools,
|
||||
messages_hash: hashes.messages,
|
||||
cache_read_input_tokens: usage.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct RequestFingerprints {
|
||||
model: u64,
|
||||
system: u64,
|
||||
tools: u64,
|
||||
messages: u64,
|
||||
}
|
||||
|
||||
impl RequestFingerprints {
|
||||
fn from_request(request: &MessageRequest) -> Self {
|
||||
Self {
|
||||
model: hash_serializable(&request.model),
|
||||
system: hash_serializable(&request.system),
|
||||
tools: hash_serializable(&request.tools),
|
||||
messages: hash_serializable(&request.messages),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_cache_break(
|
||||
config: &PromptCacheConfig,
|
||||
previous: Option<&TrackedPromptState>,
|
||||
current: &TrackedPromptState,
|
||||
) -> Option<CacheBreakEvent> {
|
||||
let previous = previous?;
|
||||
if previous.fingerprint_version != current.fingerprint_version {
|
||||
return Some(CacheBreakEvent {
|
||||
unexpected: false,
|
||||
reason: format!(
|
||||
"fingerprint version changed (v{} -> v{})",
|
||||
previous.fingerprint_version, current.fingerprint_version
|
||||
),
|
||||
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
|
||||
current_cache_read_input_tokens: current.cache_read_input_tokens,
|
||||
token_drop: previous
|
||||
.cache_read_input_tokens
|
||||
.saturating_sub(current.cache_read_input_tokens),
|
||||
});
|
||||
}
|
||||
let token_drop = previous
|
||||
.cache_read_input_tokens
|
||||
.saturating_sub(current.cache_read_input_tokens);
|
||||
if token_drop < config.cache_break_min_drop {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut reasons = Vec::new();
|
||||
if previous.model_hash != current.model_hash {
|
||||
reasons.push("model changed");
|
||||
}
|
||||
if previous.system_hash != current.system_hash {
|
||||
reasons.push("system prompt changed");
|
||||
}
|
||||
if previous.tools_hash != current.tools_hash {
|
||||
reasons.push("tool definitions changed");
|
||||
}
|
||||
if previous.messages_hash != current.messages_hash {
|
||||
reasons.push("message payload changed");
|
||||
}
|
||||
|
||||
let elapsed = current
|
||||
.observed_at_unix_secs
|
||||
.saturating_sub(previous.observed_at_unix_secs);
|
||||
|
||||
let (unexpected, reason) = if reasons.is_empty() {
|
||||
if elapsed > config.prompt_ttl.as_secs() {
|
||||
(
|
||||
false,
|
||||
format!("possible prompt cache TTL expiry after {elapsed}s"),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
true,
|
||||
"cache read tokens dropped while prompt fingerprint remained stable".to_string(),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
(false, reasons.join(", "))
|
||||
};
|
||||
|
||||
Some(CacheBreakEvent {
|
||||
unexpected,
|
||||
reason,
|
||||
previous_cache_read_input_tokens: previous.cache_read_input_tokens,
|
||||
current_cache_read_input_tokens: current.cache_read_input_tokens,
|
||||
token_drop,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_usage_to_stats(
|
||||
stats: &mut PromptCacheStats,
|
||||
usage: &Usage,
|
||||
request_hash: &str,
|
||||
source: &str,
|
||||
) {
|
||||
stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens);
|
||||
stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens);
|
||||
stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens);
|
||||
stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens);
|
||||
stats.last_request_hash = Some(request_hash.to_string());
|
||||
stats.last_cache_source = Some(source.to_string());
|
||||
}
|
||||
|
||||
fn persist_state(inner: &PromptCacheInner) {
|
||||
let _ = ensure_cache_dirs(&inner.paths);
|
||||
let _ = write_json(&inner.paths.stats_path, &inner.stats);
|
||||
if let Some(previous) = &inner.previous {
|
||||
let _ = write_json(&inner.paths.session_state_path, previous);
|
||||
}
|
||||
}
|
||||
|
||||
fn write_completion_entry(
|
||||
paths: &PromptCachePaths,
|
||||
request_hash: &str,
|
||||
response: &MessageResponse,
|
||||
) {
|
||||
let _ = ensure_cache_dirs(paths);
|
||||
let entry = CompletionCacheEntry {
|
||||
cached_at_unix_secs: now_unix_secs(),
|
||||
fingerprint_version: current_fingerprint_version(),
|
||||
response: response.clone(),
|
||||
};
|
||||
let _ = write_json(&paths.completion_entry_path(request_hash), &entry);
|
||||
}
|
||||
|
||||
fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> {
|
||||
fs::create_dir_all(&paths.completion_dir)
|
||||
}
|
||||
|
||||
fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> {
|
||||
let json = serde_json::to_vec_pretty(value)
|
||||
.map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?;
|
||||
fs::write(path, json)
|
||||
}
|
||||
|
||||
fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> Option<T> {
|
||||
let bytes = fs::read(path).ok()?;
|
||||
serde_json::from_slice(&bytes).ok()
|
||||
}
|
||||
|
||||
fn request_hash_hex(request: &MessageRequest) -> String {
|
||||
format!(
|
||||
"{REQUEST_FINGERPRINT_PREFIX}-{:016x}",
|
||||
hash_serializable(request)
|
||||
)
|
||||
}
|
||||
|
||||
fn hash_serializable<T: Serialize>(value: &T) -> u64 {
|
||||
let json = serde_json::to_vec(value).unwrap_or_default();
|
||||
stable_hash_bytes(&json)
|
||||
}
|
||||
|
||||
fn sanitize_path_segment(value: &str) -> String {
|
||||
let sanitized: String = value
|
||||
.chars()
|
||||
.map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' })
|
||||
.collect();
|
||||
if sanitized.len() <= MAX_SANITIZED_LENGTH {
|
||||
return sanitized;
|
||||
}
|
||||
let suffix = format!("-{:x}", hash_string(value));
|
||||
format!(
|
||||
"{}{}",
|
||||
&sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())],
|
||||
suffix
|
||||
)
|
||||
}
|
||||
|
||||
fn hash_string(value: &str) -> u64 {
|
||||
stable_hash_bytes(value.as_bytes())
|
||||
}
|
||||
|
||||
fn base_cache_root() -> PathBuf {
|
||||
if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") {
|
||||
return PathBuf::from(config_home)
|
||||
.join("cache")
|
||||
.join("prompt-cache");
|
||||
}
|
||||
if let Some(home) = std::env::var_os("HOME") {
|
||||
return PathBuf::from(home)
|
||||
.join(".claude")
|
||||
.join("cache")
|
||||
.join("prompt-cache");
|
||||
}
|
||||
std::env::temp_dir().join("claude-prompt-cache")
|
||||
}
|
||||
|
||||
fn now_unix_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_or(0, |duration| duration.as_secs())
|
||||
}
|
||||
|
||||
const fn current_fingerprint_version() -> u32 {
|
||||
REQUEST_FINGERPRINT_VERSION
|
||||
}
|
||||
|
||||
fn stable_hash_bytes(bytes: &[u8]) -> u64 {
|
||||
let mut hash = FNV_OFFSET_BASIS;
|
||||
for byte in bytes {
|
||||
hash ^= u64::from(*byte);
|
||||
hash = hash.wrapping_mul(FNV_PRIME);
|
||||
}
|
||||
hash
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::{
|
||||
detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache,
|
||||
PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX,
|
||||
};
|
||||
use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage};
|
||||
|
||||
fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn path_builder_sanitizes_session_identifier() {
|
||||
let paths = PromptCachePaths::for_session("session:/with spaces");
|
||||
let session_dir = paths
|
||||
.session_dir
|
||||
.file_name()
|
||||
.and_then(|value| value.to_str())
|
||||
.expect("session dir name");
|
||||
assert_eq!(session_dir, "session--with-spaces");
|
||||
assert!(paths.completion_dir.ends_with("completions"));
|
||||
assert!(paths.stats_path.ends_with("stats.json"));
|
||||
assert!(paths.session_state_path.ends_with("session-state.json"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_fingerprint_drives_unexpected_break_detection() {
|
||||
let request = sample_request("same");
|
||||
let previous = TrackedPromptState::from_usage(
|
||||
&request,
|
||||
&Usage {
|
||||
input_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 6_000,
|
||||
output_tokens: 0,
|
||||
},
|
||||
);
|
||||
let current = TrackedPromptState::from_usage(
|
||||
&request,
|
||||
&Usage {
|
||||
input_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 1_000,
|
||||
output_tokens: 0,
|
||||
},
|
||||
);
|
||||
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t)
|
||||
.expect("break should be detected");
|
||||
assert!(event.unexpected);
|
||||
assert!(event.reason.contains("stable"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn changed_prompt_marks_break_as_expected() {
|
||||
let previous_request = sample_request("first");
|
||||
let current_request = sample_request("second");
|
||||
let previous = TrackedPromptState::from_usage(
|
||||
&previous_request,
|
||||
&Usage {
|
||||
input_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 6_000,
|
||||
output_tokens: 0,
|
||||
},
|
||||
);
|
||||
let current = TrackedPromptState::from_usage(
|
||||
¤t_request,
|
||||
&Usage {
|
||||
input_tokens: 0,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 1_000,
|
||||
output_tokens: 0,
|
||||
},
|
||||
);
|
||||
let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t)
|
||||
.expect("break should be detected");
|
||||
assert!(!event.unexpected);
|
||||
assert!(event.reason.contains("message payload changed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn completion_cache_round_trip_persists_recent_response() {
|
||||
let _guard = test_env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"prompt-cache-test-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
let cache = PromptCache::new("unit-test-session");
|
||||
let request = sample_request("cache me");
|
||||
let response = sample_response(42, 12, "cached");
|
||||
|
||||
assert!(cache.lookup_completion(&request).is_none());
|
||||
let record = cache.record_response(&request, &response);
|
||||
assert!(record.cache_break.is_none());
|
||||
|
||||
let cached = cache
|
||||
.lookup_completion(&request)
|
||||
.expect("cached response should load");
|
||||
assert_eq!(cached.content, response.content);
|
||||
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.completion_cache_hits, 1);
|
||||
assert_eq!(stats.completion_cache_misses, 1);
|
||||
assert_eq!(stats.completion_cache_writes, 1);
|
||||
|
||||
let persisted = read_json::<super::PromptCacheStats>(&cache.paths().stats_path)
|
||||
.expect("stats should persist");
|
||||
assert_eq!(persisted.completion_cache_hits, 1);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn distinct_requests_do_not_collide_in_completion_cache() {
|
||||
let _guard = test_env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"prompt-cache-distinct-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
let cache = PromptCache::new("distinct-request-session");
|
||||
let first_request = sample_request("first");
|
||||
let second_request = sample_request("second");
|
||||
|
||||
let response = sample_response(42, 12, "cached");
|
||||
let _ = cache.record_response(&first_request, &response);
|
||||
|
||||
assert!(cache.lookup_completion(&second_request).is_none());
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expired_completion_entries_are_not_reused() {
|
||||
let _guard = test_env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"prompt-cache-expired-{}-{}",
|
||||
std::process::id(),
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
let cache = PromptCache::with_config(PromptCacheConfig {
|
||||
session_id: "expired-session".to_string(),
|
||||
completion_ttl: Duration::ZERO,
|
||||
..PromptCacheConfig::default()
|
||||
});
|
||||
let request = sample_request("expire me");
|
||||
let response = sample_response(7, 3, "stale");
|
||||
|
||||
let _ = cache.record_response(&request, &response);
|
||||
|
||||
assert!(cache.lookup_completion(&request).is_none());
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.completion_cache_hits, 0);
|
||||
assert_eq!(stats.completion_cache_misses, 1);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_path_caps_long_values() {
|
||||
let long_value = "x".repeat(200);
|
||||
let sanitized = sanitize_path_segment(&long_value);
|
||||
assert!(sanitized.len() <= 80);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_hashes_are_versioned_and_stable() {
|
||||
let request = sample_request("stable");
|
||||
let first = request_hash_hex(&request);
|
||||
let second = request_hash_hex(&request);
|
||||
assert_eq!(first, second);
|
||||
assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX));
|
||||
}
|
||||
|
||||
fn sample_request(text: &str) -> MessageRequest {
|
||||
MessageRequest {
|
||||
model: "claude-3-7-sonnet-latest".to_string(),
|
||||
max_tokens: 64,
|
||||
messages: vec![InputMessage::user_text(text)],
|
||||
system: Some("system".to_string()),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_response(
|
||||
cache_read_input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
text: &str,
|
||||
) -> MessageResponse {
|
||||
MessageResponse {
|
||||
id: "msg_test".to_string(),
|
||||
kind: "message".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![OutputContentBlock::Text {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
model: "claude-3-7-sonnet-latest".to_string(),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
stop_sequence: None,
|
||||
usage: Usage {
|
||||
input_tokens: 10,
|
||||
cache_creation_input_tokens: 5,
|
||||
cache_read_input_tokens,
|
||||
output_tokens,
|
||||
},
|
||||
request_id: Some("req_test".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,11 @@
|
||||
use crate::error::ApiError;
|
||||
use crate::types::StreamEvent;
|
||||
use serde_json::Value;
|
||||
use reqwest::StatusCode;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct SseParser {
|
||||
buffer: Vec<u8>,
|
||||
provider: Option<String>,
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
impl SseParser {
|
||||
@ -14,12 +14,23 @@ impl SseParser {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Attach the provider name and model to this parser so that JSON
|
||||
/// deserialization failures within streamed frames carry enough context
|
||||
/// for callers to understand which upstream produced the unparseable
|
||||
/// payload.
|
||||
#[must_use]
|
||||
pub fn with_context(mut self, provider: impl Into<String>, model: impl Into<String>) -> Self {
|
||||
self.provider = Some(provider.into());
|
||||
self.model = Some(model.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<StreamEvent>, ApiError> {
|
||||
self.buffer.extend_from_slice(chunk);
|
||||
let mut events = Vec::new();
|
||||
|
||||
while let Some(frame) = self.next_frame() {
|
||||
if let Some(event) = parse_frame(&frame)? {
|
||||
if let Some(event) = self.parse_frame_with_context(&frame)? {
|
||||
events.push(event);
|
||||
}
|
||||
}
|
||||
@ -33,12 +44,18 @@ impl SseParser {
|
||||
}
|
||||
|
||||
let trailing = std::mem::take(&mut self.buffer);
|
||||
match parse_frame(&String::from_utf8_lossy(&trailing))? {
|
||||
match self.parse_frame_with_context(&String::from_utf8_lossy(&trailing))? {
|
||||
Some(event) => Ok(vec![event]),
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_frame_with_context(&self, frame: &str) -> Result<Option<StreamEvent>, ApiError> {
|
||||
let provider = self.provider.as_deref().unwrap_or("unknown");
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
parse_frame_with_provider(frame, provider, model)
|
||||
}
|
||||
|
||||
fn next_frame(&mut self) -> Option<String> {
|
||||
let separator = self
|
||||
.buffer
|
||||
@ -63,6 +80,14 @@ impl SseParser {
|
||||
}
|
||||
|
||||
pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
|
||||
parse_frame_with_provider(frame, "unknown", "unknown")
|
||||
}
|
||||
|
||||
pub(crate) fn parse_frame_with_provider(
|
||||
frame: &str,
|
||||
provider: &str,
|
||||
model: &str,
|
||||
) -> Result<Option<StreamEvent>, ApiError> {
|
||||
let trimmed = frame.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Ok(None);
|
||||
@ -97,75 +122,9 @@ pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if matches!(event_name, Some("error")) {
|
||||
return Err(parse_error_event(&payload));
|
||||
}
|
||||
|
||||
// Some "Anthropic-compatible" gateways put the event type in the SSE `event:` field,
|
||||
// and omit the `{ "type": ... }` discriminator from the JSON `data:` payload.
|
||||
// Our Rust enums are tagged with `#[serde(tag = "type")]`, so we synthesize it here.
|
||||
match serde_json::from_str::<StreamEvent>(&payload) {
|
||||
Ok(event) => Ok(Some(event)),
|
||||
Err(error) => {
|
||||
// Best-effort: if we have an SSE event name and the payload is a JSON object
|
||||
// without a `type` field, inject it and retry.
|
||||
let Some(event_name) = event_name else {
|
||||
return Err(ApiError::from(error));
|
||||
};
|
||||
let Ok(Value::Object(mut object)) = serde_json::from_str::<Value>(&payload) else {
|
||||
return Err(ApiError::from(error));
|
||||
};
|
||||
if object
|
||||
.get("type")
|
||||
.and_then(Value::as_str)
|
||||
.is_some_and(|value| value == "error")
|
||||
{
|
||||
return Err(parse_error_object(&object, payload));
|
||||
}
|
||||
if object.contains_key("type") {
|
||||
return Err(ApiError::from(error));
|
||||
}
|
||||
object.insert("type".to_string(), Value::String(event_name.to_string()));
|
||||
serde_json::from_value::<StreamEvent>(Value::Object(object))
|
||||
serde_json::from_str::<StreamEvent>(&payload)
|
||||
.map(Some)
|
||||
.map_err(ApiError::from)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_error_event(payload: &str) -> ApiError {
|
||||
match serde_json::from_str::<Value>(payload) {
|
||||
Ok(Value::Object(object)) => parse_error_object(&object, payload.to_string()),
|
||||
_ => ApiError::Api {
|
||||
status: StatusCode::BAD_GATEWAY,
|
||||
error_type: Some("stream_error".to_string()),
|
||||
message: Some(payload.to_string()),
|
||||
body: payload.to_string(),
|
||||
retryable: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_error_object(object: &serde_json::Map<String, Value>, body: String) -> ApiError {
|
||||
let nested = object.get("error").and_then(Value::as_object);
|
||||
let error_type = nested
|
||||
.and_then(|error| error.get("type"))
|
||||
.or_else(|| object.get("type"))
|
||||
.and_then(Value::as_str)
|
||||
.map(ToOwned::to_owned);
|
||||
let message = nested
|
||||
.and_then(|error| error.get("message"))
|
||||
.or_else(|| object.get("message"))
|
||||
.and_then(Value::as_str)
|
||||
.map(ToOwned::to_owned);
|
||||
|
||||
ApiError::Api {
|
||||
status: StatusCode::BAD_GATEWAY,
|
||||
error_type,
|
||||
message,
|
||||
body,
|
||||
retryable: false,
|
||||
}
|
||||
.map_err(|error| ApiError::json_deserialize(provider, model, &payload, error))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -263,26 +222,6 @@ mod tests {
|
||||
assert_eq!(event, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_event_name_when_payload_omits_type() {
|
||||
let frame = concat!("event: message_stop\n", "data: {}\n\n");
|
||||
let event = parse_frame(frame).expect("frame should parse");
|
||||
assert_eq!(event, Some(StreamEvent::MessageStop(crate::types::MessageStopEvent {})));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn surfaces_stream_error_events() {
|
||||
let frame = concat!(
|
||||
"event: error\n",
|
||||
"data: {\"error\":{\"type\":\"invalid_request_error\",\"message\":\"bad input\"}}\n\n"
|
||||
);
|
||||
let error = parse_frame(frame).expect_err("error frame should surface");
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"api returned 502 Bad Gateway (invalid_request_error): bad input"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_split_json_across_data_lines() {
|
||||
let frame = concat!(
|
||||
@ -364,4 +303,28 @@ mod tests {
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_message_delta_frame_with_empty_usage_when_parsed_then_usage_defaults_to_zero() {
|
||||
// given
|
||||
let frame = concat!(
|
||||
"event: message_delta\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{}}\n\n"
|
||||
);
|
||||
|
||||
// when
|
||||
let event = parse_frame(frame).expect("frame should parse");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
event,
|
||||
Some(StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
|
||||
delta: MessageDelta {
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: Usage::default(),
|
||||
}))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
use runtime::{pricing_for_model, TokenUsage, UsageCostEstimate};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
pub struct MessageRequest {
|
||||
pub model: String,
|
||||
pub max_tokens: u32,
|
||||
@ -14,6 +15,22 @@ pub struct MessageRequest {
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
pub stream: bool,
|
||||
/// OpenAI-compatible tuning parameters. Optional — omitted from payload when None.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
/// Reasoning effort level for OpenAI-compatible reasoning models (e.g. `o4-mini`).
|
||||
/// Accepted values: `"low"`, `"medium"`, `"high"`. Omitted when `None`.
|
||||
/// Silently ignored by backends that do not support it.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_effort: Option<String>,
|
||||
}
|
||||
|
||||
impl MessageRequest {
|
||||
@ -75,14 +92,6 @@ pub enum InputContentBlock {
|
||||
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
|
||||
is_error: bool,
|
||||
},
|
||||
Thinking {
|
||||
thinking: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
signature: Option<String>,
|
||||
},
|
||||
RedactedThinking {
|
||||
data: Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
@ -120,6 +129,7 @@ pub struct MessageResponse {
|
||||
pub stop_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
pub stop_sequence: Option<String>,
|
||||
#[serde(default)]
|
||||
pub usage: Usage,
|
||||
#[serde(default)]
|
||||
pub request_id: Option<String>,
|
||||
@ -154,20 +164,44 @@ pub enum OutputContentBlock {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
#[serde(default)]
|
||||
pub input_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub cache_creation_input_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub cache_read_input_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
impl Usage {
|
||||
#[must_use]
|
||||
pub const fn total_tokens(&self) -> u32 {
|
||||
self.input_tokens + self.output_tokens
|
||||
self.input_tokens
|
||||
+ self.output_tokens
|
||||
+ self.cache_creation_input_tokens
|
||||
+ self.cache_read_input_tokens
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn token_usage(&self) -> TokenUsage {
|
||||
TokenUsage {
|
||||
input_tokens: self.input_tokens,
|
||||
output_tokens: self.output_tokens,
|
||||
cache_creation_input_tokens: self.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: self.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn estimated_cost_usd(&self, model: &str) -> UsageCostEstimate {
|
||||
let usage = self.token_usage();
|
||||
pricing_for_model(model).map_or_else(
|
||||
|| usage.estimate_cost_usd(),
|
||||
|pricing| usage.estimate_cost_usd_with_pricing(pricing),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -179,6 +213,7 @@ pub struct MessageStartEvent {
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MessageDeltaEvent {
|
||||
pub delta: MessageDelta,
|
||||
#[serde(default)]
|
||||
pub usage: Usage,
|
||||
}
|
||||
|
||||
@ -229,3 +264,47 @@ pub enum StreamEvent {
|
||||
ContentBlockStop(ContentBlockStopEvent),
|
||||
MessageStop(MessageStopEvent),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use runtime::format_usd;
|
||||
|
||||
use super::{MessageResponse, Usage};
|
||||
|
||||
#[test]
|
||||
fn usage_total_tokens_includes_cache_tokens() {
|
||||
let usage = Usage {
|
||||
input_tokens: 10,
|
||||
cache_creation_input_tokens: 2,
|
||||
cache_read_input_tokens: 3,
|
||||
output_tokens: 4,
|
||||
};
|
||||
|
||||
assert_eq!(usage.total_tokens(), 19);
|
||||
assert_eq!(usage.token_usage().total_tokens(), 19);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn message_response_estimates_cost_from_model_usage() {
|
||||
let response = MessageResponse {
|
||||
id: "msg_cost".to_string(),
|
||||
kind: "message".to_string(),
|
||||
role: "assistant".to_string(),
|
||||
content: Vec::new(),
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
stop_reason: Some("end_turn".to_string()),
|
||||
stop_sequence: None,
|
||||
usage: Usage {
|
||||
input_tokens: 1_000_000,
|
||||
cache_creation_input_tokens: 100_000,
|
||||
cache_read_input_tokens: 200_000,
|
||||
output_tokens: 500_000,
|
||||
},
|
||||
request_id: None,
|
||||
};
|
||||
|
||||
let cost = response.usage.estimated_cost_usd(&response.model);
|
||||
assert_eq!(format_usd(cost.total_cost_usd()), "$54.6750");
|
||||
assert_eq!(response.total_tokens(), 1_800_000);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,17 +1,27 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex as StdMutex, OnceLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use api::{
|
||||
ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent,
|
||||
AnthropicClient, ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent,
|
||||
ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest,
|
||||
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
|
||||
OutputContentBlock, PromptCache, PromptCacheConfig, ProviderClient, StreamEvent, ToolChoice,
|
||||
ToolDefinition,
|
||||
};
|
||||
use serde_json::json;
|
||||
use telemetry::{ClientIdentity, MemoryTelemetrySink, SessionTracer, TelemetryEvent};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| StdMutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_posts_json_and_parses_response() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
@ -20,8 +30,8 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
"\"id\":\"msg_test\",",
|
||||
"\"type\":\"message\",",
|
||||
"\"role\":\"assistant\",",
|
||||
"\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claw\"}],",
|
||||
"\"model\":\"claude-sonnet-4-6\",",
|
||||
"\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
|
||||
"\"model\":\"claude-3-7-sonnet-latest\",",
|
||||
"\"stop_reason\":\"end_turn\",",
|
||||
"\"stop_sequence\":null,",
|
||||
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
|
||||
@ -45,10 +55,12 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
assert_eq!(response.id, "msg_test");
|
||||
assert_eq!(response.total_tokens(), 16);
|
||||
assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
|
||||
assert_eq!(response.usage.cache_creation_input_tokens, 0);
|
||||
assert_eq!(response.usage.cache_read_input_tokens, 0);
|
||||
assert_eq!(
|
||||
response.content,
|
||||
vec![OutputContentBlock::Text {
|
||||
text: "Hello from Claw".to_string(),
|
||||
text: "Hello from Claude".to_string(),
|
||||
}]
|
||||
);
|
||||
|
||||
@ -64,23 +76,258 @@ async fn send_message_posts_json_and_parses_response() {
|
||||
request.headers.get("authorization").map(String::as_str),
|
||||
Some("Bearer proxy-token")
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("anthropic-version").map(String::as_str),
|
||||
Some("2023-06-01")
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("user-agent").map(String::as_str),
|
||||
Some("claude-code/0.1.0")
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("anthropic-beta").map(String::as_str),
|
||||
Some("claude-code-20250219,prompt-caching-scope-2026-01-05")
|
||||
);
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_str(&request.body).expect("request body should be json");
|
||||
assert_eq!(
|
||||
body.get("model").and_then(serde_json::Value::as_str),
|
||||
Some("claude-sonnet-4-6")
|
||||
Some("claude-3-7-sonnet-latest")
|
||||
);
|
||||
assert!(body.get("stream").is_none());
|
||||
assert_eq!(body["tools"][0]["name"], json!("get_weather"));
|
||||
assert_eq!(body["tool_choice"]["type"], json!("auto"));
|
||||
assert!(
|
||||
body.get("betas").is_none(),
|
||||
"betas must travel via the anthropic-beta header, not the request body"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_blocks_oversized_requests_before_the_http_call() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response("200 OK", "application/json", "{}")],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
|
||||
let error = client
|
||||
.send_message(&MessageRequest {
|
||||
model: "claude-sonnet-4-6".to_string(),
|
||||
max_tokens: 64_000,
|
||||
messages: vec![InputMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![InputContentBlock::Text {
|
||||
text: "x".repeat(600_000),
|
||||
}],
|
||||
}],
|
||||
system: Some("Keep the answer short.".to_string()),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect_err("oversized request should fail local context-window preflight");
|
||||
|
||||
assert!(matches!(error, ApiError::ContextWindowExceeded { .. }));
|
||||
assert!(
|
||||
state.lock().await.is_empty(),
|
||||
"preflight failure should avoid any upstream HTTP request"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_applies_request_profile_and_records_telemetry() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response_with_headers(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
concat!(
|
||||
"{",
|
||||
"\"id\":\"msg_profile\",",
|
||||
"\"type\":\"message\",",
|
||||
"\"role\":\"assistant\",",
|
||||
"\"content\":[{\"type\":\"text\",\"text\":\"ok\"}],",
|
||||
"\"model\":\"claude-3-7-sonnet-latest\",",
|
||||
"\"stop_reason\":\"end_turn\",",
|
||||
"\"stop_sequence\":null,",
|
||||
"\"usage\":{\"input_tokens\":1,\"cache_creation_input_tokens\":2,\"cache_read_input_tokens\":3,\"output_tokens\":1}",
|
||||
"}"
|
||||
),
|
||||
&[("request-id", "req_profile_123")],
|
||||
)],
|
||||
)
|
||||
.await;
|
||||
let sink = Arc::new(MemoryTelemetrySink::default());
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_client_identity(ClientIdentity::new("claude-code", "9.9.9").with_runtime("rust-cli"))
|
||||
.with_beta("tools-2026-04-01")
|
||||
.with_extra_body_param("metadata", json!({"source": "clawd-code"}))
|
||||
.with_session_tracer(SessionTracer::new("session-telemetry", sink.clone()));
|
||||
|
||||
let response = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(response.request_id.as_deref(), Some("req_profile_123"));
|
||||
|
||||
let captured = state.lock().await;
|
||||
let request = captured.first().expect("server should capture request");
|
||||
assert_eq!(
|
||||
request.headers.get("anthropic-beta").map(String::as_str),
|
||||
Some("claude-code-20250219,prompt-caching-scope-2026-01-05,tools-2026-04-01")
|
||||
);
|
||||
assert_eq!(
|
||||
request.headers.get("user-agent").map(String::as_str),
|
||||
Some("claude-code/9.9.9")
|
||||
);
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_str(&request.body).expect("request body should be json");
|
||||
assert_eq!(body["metadata"]["source"], json!("clawd-code"));
|
||||
assert!(
|
||||
body.get("betas").is_none(),
|
||||
"betas must travel via the anthropic-beta header, not the request body"
|
||||
);
|
||||
|
||||
let events = sink.events();
|
||||
assert_eq!(events.len(), 6);
|
||||
assert!(matches!(
|
||||
&events[0],
|
||||
TelemetryEvent::HttpRequestStarted {
|
||||
session_id,
|
||||
attempt: 1,
|
||||
method,
|
||||
path,
|
||||
..
|
||||
} if session_id == "session-telemetry" && method == "POST" && path == "/v1/messages"
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[1],
|
||||
TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_started"
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[2],
|
||||
TelemetryEvent::HttpRequestSucceeded {
|
||||
request_id,
|
||||
status: 200,
|
||||
..
|
||||
} if request_id.as_deref() == Some("req_profile_123")
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[3],
|
||||
TelemetryEvent::SessionTrace(trace) if trace.name == "http_request_succeeded"
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[4],
|
||||
TelemetryEvent::Analytics(event)
|
||||
if event.namespace == "api"
|
||||
&& event.action == "message_usage"
|
||||
&& event.properties.get("request_id") == Some(&json!("req_profile_123"))
|
||||
&& event.properties.get("total_tokens") == Some(&json!(7))
|
||||
&& event.properties.get("estimated_cost_usd") == Some(&json!("$0.0001"))
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[5],
|
||||
TelemetryEvent::SessionTrace(trace) if trace.name == "analytics"
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_parses_prompt_cache_token_usage_from_response() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let body = concat!(
|
||||
"{",
|
||||
"\"id\":\"msg_cache_tokens\",",
|
||||
"\"type\":\"message\",",
|
||||
"\"role\":\"assistant\",",
|
||||
"\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],",
|
||||
"\"model\":\"claude-3-7-sonnet-latest\",",
|
||||
"\"stop_reason\":\"end_turn\",",
|
||||
"\"stop_sequence\":null,",
|
||||
"\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}",
|
||||
"}"
|
||||
);
|
||||
let server = spawn_server(
|
||||
state,
|
||||
vec![http_response("200 OK", "application/json", body)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
|
||||
let response = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("request should succeed");
|
||||
|
||||
assert_eq!(response.usage.input_tokens, 12);
|
||||
assert_eq!(response.usage.cache_creation_input_tokens, 321);
|
||||
assert_eq!(response.usage.cache_read_input_tokens, 654);
|
||||
assert_eq!(response.usage.output_tokens, 4);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn given_empty_usage_object_when_send_message_parses_response_then_usage_defaults_to_zero() {
|
||||
// given
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let body = concat!(
|
||||
"{",
|
||||
"\"id\":\"msg_empty_usage\",",
|
||||
"\"type\":\"message\",",
|
||||
"\"role\":\"assistant\",",
|
||||
"\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
|
||||
"\"model\":\"claude-3-7-sonnet-latest\",",
|
||||
"\"stop_reason\":\"end_turn\",",
|
||||
"\"stop_sequence\":null,",
|
||||
"\"usage\":{}",
|
||||
"}"
|
||||
);
|
||||
let server = spawn_server(
|
||||
state,
|
||||
vec![http_response("200 OK", "application/json", body)],
|
||||
)
|
||||
.await;
|
||||
let client = AnthropicClient::new("test-key").with_base_url(server.base_url());
|
||||
|
||||
// when
|
||||
let response = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("response with empty usage object should still parse");
|
||||
|
||||
// then
|
||||
assert_eq!(response.id, "msg_empty_usage");
|
||||
assert_eq!(response.total_tokens(), 0);
|
||||
assert_eq!(response.usage.input_tokens, 0);
|
||||
assert_eq!(response.usage.cache_creation_input_tokens, 0);
|
||||
assert_eq!(response.usage.cache_read_input_tokens, 0);
|
||||
assert_eq!(response.usage.output_tokens, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
let _guard = env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"api-stream-cache-{}-{}",
|
||||
std::process::id(),
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let sse = concat!(
|
||||
"event: message_start\n",
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
|
||||
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n",
|
||||
"event: content_block_start\n",
|
||||
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
|
||||
"event: content_block_delta\n",
|
||||
@ -88,7 +335,7 @@ async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
"event: content_block_stop\n",
|
||||
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
|
||||
"event: message_delta\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n",
|
||||
"event: message_stop\n",
|
||||
"data: {\"type\":\"message_stop\"}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
@ -106,7 +353,8 @@ async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
|
||||
let client = ApiClient::new("test-key")
|
||||
.with_auth_token(Some("proxy-token".to_string()))
|
||||
.with_base_url(server.base_url());
|
||||
.with_base_url(server.base_url())
|
||||
.with_prompt_cache(PromptCache::new("stream-session"));
|
||||
let mut stream = client
|
||||
.stream_message(&sample_request(false))
|
||||
.await
|
||||
@ -160,6 +408,20 @@ async fn stream_message_parses_sse_events_with_tool_use() {
|
||||
let captured = state.lock().await;
|
||||
let request = captured.first().expect("server should capture request");
|
||||
assert!(request.body.contains("\"stream\":true"));
|
||||
|
||||
let cache_stats = client
|
||||
.prompt_cache_stats()
|
||||
.expect("prompt cache stats should exist");
|
||||
assert_eq!(cache_stats.tracked_requests, 1);
|
||||
assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34));
|
||||
assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55));
|
||||
assert_eq!(
|
||||
cache_stats.last_cache_source.as_deref(),
|
||||
Some("api-response")
|
||||
);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -176,7 +438,7 @@ async fn retries_retryable_failures_before_succeeding() {
|
||||
http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -196,28 +458,28 @@ async fn retries_retryable_failures_before_succeeding() {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_client_dispatches_api_requests() {
|
||||
async fn provider_client_dispatches_anthropic_requests() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
"{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = ProviderClient::from_model_with_default_auth(
|
||||
let client = ProviderClient::from_model_with_anthropic_auth(
|
||||
"claude-sonnet-4-6",
|
||||
Some(AuthSource::ApiKey("test-key".to_string())),
|
||||
)
|
||||
.expect("api provider client should be constructed");
|
||||
.expect("anthropic provider client should be constructed");
|
||||
let client = match client {
|
||||
ProviderClient::ClawApi(client) => {
|
||||
ProviderClient::ClawApi(client.with_base_url(server.base_url()))
|
||||
ProviderClient::Anthropic(client) => {
|
||||
ProviderClient::Anthropic(client.with_base_url(server.base_url()))
|
||||
}
|
||||
other => panic!("expected default provider, got {other:?}"),
|
||||
other => panic!("expected anthropic provider, got {other:?}"),
|
||||
};
|
||||
|
||||
let response = client
|
||||
@ -284,13 +546,194 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn retries_multiple_retryable_failures_with_exponential_backoff_and_jitter() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![
|
||||
http_response(
|
||||
"429 Too Many Requests",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
|
||||
),
|
||||
http_response(
|
||||
"500 Internal Server Error",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"api_error\",\"message\":\"boom\"}}",
|
||||
),
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
|
||||
),
|
||||
http_response(
|
||||
"429 Too Many Requests",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down again\"}}",
|
||||
),
|
||||
http_response(
|
||||
"503 Service Unavailable",
|
||||
"application/json",
|
||||
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
|
||||
),
|
||||
http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_exp_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered after 5\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = ApiClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_retry_policy(8, Duration::from_millis(1), Duration::from_millis(4));
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
let response = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("8-retry policy should absorb 5 retryable failures");
|
||||
|
||||
let elapsed = started_at.elapsed();
|
||||
assert_eq!(response.total_tokens(), 5);
|
||||
assert_eq!(
|
||||
state.lock().await.len(),
|
||||
6,
|
||||
"client should issue 1 original + 5 retry requests before the 200"
|
||||
);
|
||||
// Jittered sleeps are bounded by 2 * max_backoff per retry (base + jitter),
|
||||
// so 5 sleeps fit comfortably below this upper bound with generous slack.
|
||||
assert!(
|
||||
elapsed < Duration::from_secs(5),
|
||||
"retries should complete promptly, took {elapsed:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn send_message_reuses_recent_completion_cache_entries() {
|
||||
let _guard = env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"api-prompt-cache-{}-{}",
|
||||
std::process::id(),
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}",
|
||||
)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_prompt_cache(PromptCache::new("integration-session"));
|
||||
|
||||
let first = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("first request should succeed");
|
||||
let second = client
|
||||
.send_message(&sample_request(false))
|
||||
.await
|
||||
.expect("second request should reuse cache");
|
||||
|
||||
assert_eq!(first.content, second.content);
|
||||
assert_eq!(state.lock().await.len(), 1);
|
||||
|
||||
let cache_stats = client
|
||||
.prompt_cache_stats()
|
||||
.expect("prompt cache stats should exist");
|
||||
assert_eq!(cache_stats.completion_cache_hits, 1);
|
||||
assert_eq!(cache_stats.completion_cache_misses, 1);
|
||||
assert_eq!(cache_stats.completion_cache_writes, 1);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
async fn send_message_tracks_unexpected_prompt_cache_breaks() {
|
||||
let _guard = env_lock();
|
||||
let temp_root = std::env::temp_dir().join(format!(
|
||||
"api-prompt-break-{}-{}",
|
||||
std::process::id(),
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("time")
|
||||
.as_nanos()
|
||||
));
|
||||
std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root);
|
||||
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state,
|
||||
vec![
|
||||
http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}",
|
||||
),
|
||||
http_response(
|
||||
"200 OK",
|
||||
"application/json",
|
||||
"{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}",
|
||||
),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let request = sample_request(false);
|
||||
let client = AnthropicClient::new("test-key")
|
||||
.with_base_url(server.base_url())
|
||||
.with_prompt_cache(PromptCache::with_config(PromptCacheConfig {
|
||||
session_id: "break-session".to_string(),
|
||||
completion_ttl: Duration::from_secs(0),
|
||||
..PromptCacheConfig::default()
|
||||
}));
|
||||
|
||||
client
|
||||
.send_message(&request)
|
||||
.await
|
||||
.expect("first response should succeed");
|
||||
client
|
||||
.send_message(&request)
|
||||
.await
|
||||
.expect("second response should succeed");
|
||||
|
||||
let cache_stats = client
|
||||
.prompt_cache_stats()
|
||||
.expect("prompt cache stats should exist");
|
||||
assert_eq!(cache_stats.unexpected_cache_breaks, 1);
|
||||
assert_eq!(
|
||||
cache_stats.last_break_reason.as_deref(),
|
||||
Some("cache read tokens dropped while prompt fingerprint remained stable")
|
||||
);
|
||||
|
||||
std::fs::remove_dir_all(temp_root).expect("cleanup temp root");
|
||||
std::env::remove_var("CLAUDE_CONFIG_HOME");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "requires ANTHROPIC_API_KEY and network access"]
|
||||
async fn live_stream_smoke_test() {
|
||||
let client = ApiClient::from_env().expect("ANTHROPIC_API_KEY must be set");
|
||||
let mut stream = client
|
||||
.stream_message(&MessageRequest {
|
||||
model: std::env::var("CLAW_MODEL").unwrap_or_else(|_| "claude-sonnet-4-6".to_string()),
|
||||
model: std::env::var("ANTHROPIC_MODEL")
|
||||
.unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
|
||||
max_tokens: 32,
|
||||
messages: vec![InputMessage::user_text(
|
||||
"Reply with exactly: hello from rust",
|
||||
@ -299,6 +742,7 @@ async fn live_stream_smoke_test() {
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect("live stream should start");
|
||||
@ -450,7 +894,7 @@ fn http_response_with_headers(
|
||||
|
||||
fn sample_request(stream: bool) -> MessageRequest {
|
||||
MessageRequest {
|
||||
model: "claude-sonnet-4-6".to_string(),
|
||||
model: "claude-3-7-sonnet-latest".to_string(),
|
||||
max_tokens: 64,
|
||||
messages: vec![InputMessage {
|
||||
role: "user".to_string(),
|
||||
@ -479,5 +923,6 @@ fn sample_request(stream: bool) -> MessageRequest {
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
@ -4,9 +4,10 @@ use std::sync::Arc;
|
||||
use std::sync::{Mutex as StdMutex, OnceLock};
|
||||
|
||||
use api::{
|
||||
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
||||
InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
|
||||
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
|
||||
ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
|
||||
ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest,
|
||||
OpenAiCompatClient, OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent,
|
||||
ToolChoice, ToolDefinition,
|
||||
};
|
||||
use serde_json::json;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
@ -62,6 +63,43 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() {
|
||||
assert_eq!(body["tools"][0]["type"], json!("function"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_blocks_oversized_xai_requests_before_the_http_call() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response("200 OK", "application/json", "{}")],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai())
|
||||
.with_base_url(server.base_url());
|
||||
let error = client
|
||||
.send_message(&MessageRequest {
|
||||
model: "grok-3".to_string(),
|
||||
max_tokens: 64_000,
|
||||
messages: vec![InputMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![InputContentBlock::Text {
|
||||
text: "x".repeat(300_000),
|
||||
}],
|
||||
}],
|
||||
system: Some("Keep the answer short.".to_string()),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
.expect_err("oversized request should fail local context-window preflight");
|
||||
|
||||
assert!(matches!(error, ApiError::ContextWindowExceeded { .. }));
|
||||
assert!(
|
||||
state.lock().await.is_empty(),
|
||||
"preflight failure should avoid any upstream HTTP request"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_message_accepts_full_chat_completions_endpoint_override() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
@ -195,6 +233,83 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() {
|
||||
assert!(request.body.contains("\"stream\":true"));
|
||||
}
|
||||
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
#[tokio::test]
|
||||
async fn openai_streaming_requests_opt_into_usage_chunks() {
|
||||
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||
let sse = concat!(
|
||||
"data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n",
|
||||
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
|
||||
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n",
|
||||
"data: [DONE]\n\n"
|
||||
);
|
||||
let server = spawn_server(
|
||||
state.clone(),
|
||||
vec![http_response_with_headers(
|
||||
"200 OK",
|
||||
"text/event-stream",
|
||||
sse,
|
||||
&[("x-request-id", "req_openai_stream")],
|
||||
)],
|
||||
)
|
||||
.await;
|
||||
|
||||
let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai())
|
||||
.with_base_url(server.base_url());
|
||||
let mut stream = client
|
||||
.stream_message(&sample_request(false))
|
||||
.await
|
||||
.expect("stream should start");
|
||||
|
||||
assert_eq!(stream.request_id(), Some("req_openai_stream"));
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(event) = stream.next_event().await.expect("event should parse") {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
|
||||
assert!(matches!(
|
||||
events[1],
|
||||
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||
content_block: OutputContentBlock::Text { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
assert!(matches!(
|
||||
events[2],
|
||||
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
||||
delta: ContentBlockDelta::TextDelta { .. },
|
||||
..
|
||||
})
|
||||
));
|
||||
assert!(matches!(
|
||||
events[3],
|
||||
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
|
||||
));
|
||||
assert!(matches!(
|
||||
events[4],
|
||||
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
|
||||
));
|
||||
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
|
||||
|
||||
match &events[4] {
|
||||
StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
|
||||
assert_eq!(usage.input_tokens, 9);
|
||||
assert_eq!(usage.output_tokens, 4);
|
||||
}
|
||||
other => panic!("expected message delta, got {other:?}"),
|
||||
}
|
||||
|
||||
let captured = state.lock().await;
|
||||
let request = captured.first().expect("captured request");
|
||||
assert_eq!(request.path, "/chat/completions");
|
||||
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
|
||||
assert_eq!(body["stream"], json!(true));
|
||||
assert_eq!(body["stream_options"], json!({"include_usage": true}));
|
||||
}
|
||||
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
#[tokio::test]
|
||||
async fn provider_client_dispatches_xai_requests_from_env() {
|
||||
let _lock = env_lock();
|
||||
@ -382,6 +497,7 @@ fn sample_request(stream: bool) -> MessageRequest {
|
||||
}]),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
stream,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
@ -389,7 +505,7 @@ fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<StdMutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| StdMutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
struct ScopedEnvVar {
|
||||
|
||||
@ -22,7 +22,9 @@ fn provider_client_reports_missing_xai_credentials_for_grok_models() {
|
||||
.expect_err("grok requests without XAI_API_KEY should fail fast");
|
||||
|
||||
match error {
|
||||
ApiError::MissingCredentials { provider, env_vars } => {
|
||||
ApiError::MissingCredentials {
|
||||
provider, env_vars, ..
|
||||
} => {
|
||||
assert_eq!(provider, "xAI");
|
||||
assert_eq!(env_vars, &["XAI_API_KEY"]);
|
||||
}
|
||||
@ -31,18 +33,18 @@ fn provider_client_reports_missing_xai_credentials_for_grok_models() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_client_uses_explicit_auth_without_env_lookup() {
|
||||
fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() {
|
||||
let _lock = env_lock();
|
||||
let _api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None);
|
||||
let _auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None);
|
||||
let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None);
|
||||
let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None);
|
||||
|
||||
let client = ProviderClient::from_model_with_default_auth(
|
||||
let client = ProviderClient::from_model_with_anthropic_auth(
|
||||
"claude-sonnet-4-6",
|
||||
Some(AuthSource::ApiKey("claw-test-key".to_string())),
|
||||
Some(AuthSource::ApiKey("anthropic-test-key".to_string())),
|
||||
)
|
||||
.expect("explicit auth should avoid env lookup");
|
||||
.expect("explicit anthropic auth should avoid env lookup");
|
||||
|
||||
assert_eq!(client.provider_kind(), ProviderKind::ClawApi);
|
||||
assert_eq!(client.provider_kind(), ProviderKind::Anthropic);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -57,7 +59,7 @@ fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
struct EnvVarGuard {
|
||||
|
||||
173
crates/api/tests/proxy_integration.rs
Normal file
173
crates/api/tests/proxy_integration.rs
Normal file
@ -0,0 +1,173 @@
|
||||
use std::ffi::OsString;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use api::{build_http_client_with, ProxyConfig};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: &'static str,
|
||||
original: Option<OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &'static str, value: Option<&str>) -> Self {
|
||||
let original = std::env::var_os(key);
|
||||
match value {
|
||||
Some(value) => std::env::set_var(key, value),
|
||||
None => std::env::remove_var(key),
|
||||
}
|
||||
Self { key, original }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
match &self.original {
|
||||
Some(value) => std::env::set_var(self.key, value),
|
||||
None => std::env::remove_var(self.key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_reads_uppercase_proxy_vars() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128"));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129"));
|
||||
let _no = EnvVarGuard::set("NO_PROXY", Some("localhost,127.0.0.1"));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert_eq!(config.http_proxy.as_deref(), Some("http://proxy.corp:3128"));
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://secure.corp:3129")
|
||||
);
|
||||
assert_eq!(config.no_proxy.as_deref(), Some("localhost,127.0.0.1"));
|
||||
assert!(config.proxy_url.is_none());
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_reads_lowercase_proxy_vars() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", None);
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
|
||||
let _no = EnvVarGuard::set("NO_PROXY", None);
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128"));
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", Some("http://lower-secure.corp:3129"));
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", Some(".internal"));
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert_eq!(config.http_proxy.as_deref(), Some("http://lower.corp:3128"));
|
||||
assert_eq!(
|
||||
config.https_proxy.as_deref(),
|
||||
Some("http://lower-secure.corp:3129")
|
||||
);
|
||||
assert_eq!(config.no_proxy.as_deref(), Some(".internal"));
|
||||
assert!(!config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_is_empty_when_no_vars_set() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", None);
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
|
||||
let _no = EnvVarGuard::set("NO_PROXY", None);
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert!(config.is_empty());
|
||||
assert!(config.http_proxy.is_none());
|
||||
assert!(config.https_proxy.is_none());
|
||||
assert!(config.no_proxy.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_treats_empty_values_as_unset() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", Some(""));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", Some(""));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", Some(""));
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", Some(""));
|
||||
let _no = EnvVarGuard::set("NO_PROXY", Some(""));
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", Some(""));
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert!(config.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_client_with_env_proxy_config_succeeds() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http = EnvVarGuard::set("HTTP_PROXY", Some("http://proxy.corp:3128"));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", Some("http://secure.corp:3129"));
|
||||
let _no = EnvVarGuard::set("NO_PROXY", Some("localhost"));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_client_with_proxy_url_config_succeeds() {
|
||||
// given
|
||||
let config = ProxyConfig::from_proxy_url("http://unified.corp:3128");
|
||||
|
||||
// when
|
||||
let result = build_http_client_with(&config);
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_config_from_env_prefers_uppercase_over_lowercase() {
|
||||
// given
|
||||
let _lock = env_lock();
|
||||
let _http_upper = EnvVarGuard::set("HTTP_PROXY", Some("http://upper.corp:3128"));
|
||||
let _http_lower = EnvVarGuard::set("http_proxy", Some("http://lower.corp:3128"));
|
||||
let _https = EnvVarGuard::set("HTTPS_PROXY", None);
|
||||
let _https_lower = EnvVarGuard::set("https_proxy", None);
|
||||
let _no = EnvVarGuard::set("NO_PROXY", None);
|
||||
let _no_lower = EnvVarGuard::set("no_proxy", None);
|
||||
|
||||
// when
|
||||
let config = ProxyConfig::from_env();
|
||||
|
||||
// then
|
||||
assert_eq!(config.http_proxy.as_deref(), Some("http://upper.corp:3128"));
|
||||
}
|
||||
@ -16,7 +16,7 @@ use std::thread;
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use api::{
|
||||
resolve_startup_auth_source, AuthSource, ClawApiClient, ContentBlockDelta, InputContentBlock,
|
||||
resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock,
|
||||
InputMessage, MessageRequest, MessageResponse, OutputContentBlock,
|
||||
StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock,
|
||||
};
|
||||
@ -329,7 +329,7 @@ fn join_optional_args(args: &[String]) -> Option<String> {
|
||||
|
||||
fn parse_direct_slash_cli_action(rest: &[String]) -> Result<CliAction, String> {
|
||||
let raw = rest.join(" ");
|
||||
match SlashCommand::parse(&raw) {
|
||||
match SlashCommand::parse(&raw).map_err(|e| e.to_string())? {
|
||||
Some(SlashCommand::Help) => Ok(CliAction::Help),
|
||||
Some(SlashCommand::Agents { args }) => Ok(CliAction::Agents { args }),
|
||||
Some(SlashCommand::Skills { args }) => Ok(CliAction::Skills { args }),
|
||||
@ -484,7 +484,7 @@ fn dump_manifests() {
|
||||
}
|
||||
|
||||
fn print_bootstrap_plan() {
|
||||
for phase in runtime::BootstrapPlan::claw_default().phases() {
|
||||
for phase in runtime::BootstrapPlan::claude_code_default().phases() {
|
||||
println!("- {phase:?}");
|
||||
}
|
||||
}
|
||||
@ -541,7 +541,7 @@ fn run_login() -> Result<(), Box<dyn std::error::Error>> {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into());
|
||||
}
|
||||
|
||||
let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(api::read_base_url());
|
||||
let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(api::read_base_url());
|
||||
let exchange_request =
|
||||
OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri);
|
||||
let runtime = tokio::runtime::Runtime::new()?;
|
||||
@ -650,7 +650,7 @@ fn resume_session(session_path: &Path, commands: &[String]) {
|
||||
|
||||
let mut session = session;
|
||||
for raw_command in commands {
|
||||
let Some(command) = SlashCommand::parse(raw_command) else {
|
||||
let Ok(Some(command)) = SlashCommand::parse(raw_command) else {
|
||||
eprintln!("unsupported resumed command: {raw_command}");
|
||||
std::process::exit(2);
|
||||
};
|
||||
@ -987,8 +987,6 @@ fn run_resume_command(
|
||||
}
|
||||
SlashCommand::Bughunter { .. }
|
||||
| SlashCommand::Branch { .. }
|
||||
| SlashCommand::Worktree { .. }
|
||||
| SlashCommand::CommitPushPr { .. }
|
||||
| SlashCommand::Commit
|
||||
| SlashCommand::Pr { .. }
|
||||
| SlashCommand::Issue { .. }
|
||||
@ -1000,7 +998,7 @@ fn run_resume_command(
|
||||
| SlashCommand::Permissions { .. }
|
||||
| SlashCommand::Session { .. }
|
||||
| SlashCommand::Plugins { .. }
|
||||
| SlashCommand::Unknown(_) => Err("unsupported resumed slash command".into()),
|
||||
| _ => Err("unsupported resumed slash command".into()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1024,7 +1022,7 @@ fn run_repl(
|
||||
cli.persist_session()?;
|
||||
break;
|
||||
}
|
||||
if let Some(command) = SlashCommand::parse(trimmed) {
|
||||
if let Ok(Some(command)) = SlashCommand::parse(trimmed) {
|
||||
if cli.handle_repl_command(command)? {
|
||||
cli.persist_session()?;
|
||||
}
|
||||
@ -1336,24 +1334,14 @@ impl LiveCli {
|
||||
);
|
||||
false
|
||||
}
|
||||
SlashCommand::Worktree { .. } => {
|
||||
eprintln!(
|
||||
"{}",
|
||||
render_mode_unavailable("worktree", "git worktree commands")
|
||||
);
|
||||
false
|
||||
}
|
||||
SlashCommand::CommitPushPr { .. } => {
|
||||
eprintln!(
|
||||
"{}",
|
||||
render_mode_unavailable("commit-push-pr", "commit + push + PR automation")
|
||||
);
|
||||
false
|
||||
}
|
||||
SlashCommand::Unknown(name) => {
|
||||
eprintln!("{}", render_unknown_repl_command(&name));
|
||||
false
|
||||
}
|
||||
_ => {
|
||||
eprintln!("command not available in this mode");
|
||||
false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -2505,12 +2493,6 @@ fn render_export_text(session: &Session) -> String {
|
||||
for block in &message.blocks {
|
||||
match block {
|
||||
ContentBlock::Text { text } => lines.push(text.clone()),
|
||||
ContentBlock::Thinking { thinking, .. } => {
|
||||
lines.push(format!("[thinking] {thinking}"));
|
||||
}
|
||||
ContentBlock::RedactedThinking { .. } => {
|
||||
lines.push("[thinking] <redacted>".to_string());
|
||||
}
|
||||
ContentBlock::ToolUse { id, name, input } => {
|
||||
lines.push(format!("[tool_use id={id} name={name}] {input}"));
|
||||
}
|
||||
@ -2995,7 +2977,7 @@ fn build_runtime(
|
||||
CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()),
|
||||
permission_policy(permission_mode, &tool_registry),
|
||||
system_prompt,
|
||||
feature_config,
|
||||
&feature_config,
|
||||
))
|
||||
}
|
||||
|
||||
@ -3047,7 +3029,7 @@ impl runtime::PermissionPrompter for CliPermissionPrompter {
|
||||
|
||||
struct DefaultRuntimeClient {
|
||||
runtime: tokio::runtime::Runtime,
|
||||
client: ClawApiClient,
|
||||
client: AnthropicClient,
|
||||
model: String,
|
||||
enable_tools: bool,
|
||||
emit_output: bool,
|
||||
@ -3067,7 +3049,7 @@ impl DefaultRuntimeClient {
|
||||
) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
Ok(Self {
|
||||
runtime: tokio::runtime::Runtime::new()?,
|
||||
client: ClawApiClient::from_auth(resolve_cli_auth_source()?)
|
||||
client: AnthropicClient::from_auth(resolve_cli_auth_source()?)
|
||||
.with_base_url(api::read_base_url()),
|
||||
model,
|
||||
enable_tools,
|
||||
@ -3105,6 +3087,12 @@ impl ApiClient for DefaultRuntimeClient {
|
||||
.then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())),
|
||||
tool_choice: self.enable_tools.then_some(ToolChoice::Auto),
|
||||
stream: true,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
stop: None,
|
||||
reasoning_effort: None,
|
||||
};
|
||||
|
||||
self.runtime.block_on(async {
|
||||
@ -3173,7 +3161,6 @@ impl ApiClient for DefaultRuntimeClient {
|
||||
.and_then(|()| out.flush())
|
||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||
}
|
||||
events.push(AssistantEvent::ThinkingDelta(thinking));
|
||||
}
|
||||
}
|
||||
ContentBlockDelta::SignatureDelta { .. } => {}
|
||||
@ -3254,9 +3241,7 @@ fn final_assistant_text(summary: &runtime::TurnSummary) -> String {
|
||||
.iter()
|
||||
.filter_map(|block| match block {
|
||||
ContentBlock::Text { text } => Some(text.as_str()),
|
||||
ContentBlock::Thinking { thinking, .. } => Some(thinking.as_str()),
|
||||
ContentBlock::RedactedThinking { .. }
|
||||
| ContentBlock::ToolUse { .. }
|
||||
ContentBlock::ToolUse { .. }
|
||||
| ContentBlock::ToolResult { .. } => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
@ -3276,9 +3261,7 @@ fn collect_tool_uses(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> {
|
||||
"name": name,
|
||||
"input": input,
|
||||
})),
|
||||
ContentBlock::Thinking { .. }
|
||||
| ContentBlock::RedactedThinking { .. }
|
||||
| ContentBlock::Text { .. }
|
||||
ContentBlock::Text { .. }
|
||||
| ContentBlock::ToolResult { .. } => None,
|
||||
})
|
||||
.collect()
|
||||
@ -3301,9 +3284,7 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value
|
||||
"output": output,
|
||||
"is_error": is_error,
|
||||
})),
|
||||
ContentBlock::Thinking { .. }
|
||||
| ContentBlock::RedactedThinking { .. }
|
||||
| ContentBlock::Text { .. }
|
||||
ContentBlock::Text { .. }
|
||||
| ContentBlock::ToolUse { .. } => None,
|
||||
})
|
||||
.collect()
|
||||
@ -3851,7 +3832,6 @@ fn push_output_block(
|
||||
write!(out, "\x1b[2m{thinking}\x1b[0m")
|
||||
.and_then(|()| out.flush())
|
||||
.map_err(|error| RuntimeError::new(error.to_string()))?;
|
||||
events.push(AssistantEvent::ThinkingDelta(thinking));
|
||||
}
|
||||
}
|
||||
OutputContentBlock::RedactedThinking { .. } => {}
|
||||
@ -3942,7 +3922,7 @@ impl ToolExecutor for CliToolExecutor {
|
||||
}
|
||||
|
||||
fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy {
|
||||
tool_registry.permission_specs(None).into_iter().fold(
|
||||
tool_registry.permission_specs(None).unwrap_or_default().into_iter().fold(
|
||||
PermissionPolicy::new(mode),
|
||||
|policy, (name, required_permission)| {
|
||||
policy.with_tool_requirement(name, required_permission)
|
||||
@ -3963,16 +3943,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
|
||||
.iter()
|
||||
.map(|block| match block {
|
||||
ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() },
|
||||
ContentBlock::Thinking {
|
||||
thinking,
|
||||
signature,
|
||||
} => InputContentBlock::Thinking {
|
||||
thinking: thinking.clone(),
|
||||
signature: signature.clone(),
|
||||
},
|
||||
ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking {
|
||||
data: serde_json::from_str(&data.render()).unwrap_or(serde_json::Value::Null),
|
||||
},
|
||||
ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
@ -4735,39 +4705,39 @@ mod tests {
|
||||
#[test]
|
||||
fn clear_command_requires_explicit_confirmation_flag() {
|
||||
assert_eq!(
|
||||
SlashCommand::parse("/clear"),
|
||||
Some(SlashCommand::Clear { confirm: false })
|
||||
SlashCommand::parse("/clear").map_err(|e| e.to_string()),
|
||||
Ok(Some(SlashCommand::Clear { confirm: false }))
|
||||
);
|
||||
assert_eq!(
|
||||
SlashCommand::parse("/clear --confirm"),
|
||||
Some(SlashCommand::Clear { confirm: true })
|
||||
SlashCommand::parse("/clear --confirm").map_err(|e| e.to_string()),
|
||||
Ok(Some(SlashCommand::Clear { confirm: true }))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_resume_and_config_slash_commands() {
|
||||
assert_eq!(
|
||||
SlashCommand::parse("/resume saved-session.json"),
|
||||
Some(SlashCommand::Resume {
|
||||
SlashCommand::parse("/resume saved-session.json").map_err(|e| e.to_string()),
|
||||
Ok(Some(SlashCommand::Resume {
|
||||
session_path: Some("saved-session.json".to_string())
|
||||
})
|
||||
}))
|
||||
);
|
||||
assert_eq!(
|
||||
SlashCommand::parse("/clear --confirm"),
|
||||
Some(SlashCommand::Clear { confirm: true })
|
||||
SlashCommand::parse("/clear --confirm").map_err(|e| e.to_string()),
|
||||
Ok(Some(SlashCommand::Clear { confirm: true }))
|
||||
);
|
||||
assert_eq!(
|
||||
SlashCommand::parse("/config"),
|
||||
Some(SlashCommand::Config { section: None })
|
||||
SlashCommand::parse("/config").map_err(|e| e.to_string()),
|
||||
Ok(Some(SlashCommand::Config { section: None }))
|
||||
);
|
||||
assert_eq!(
|
||||
SlashCommand::parse("/config env"),
|
||||
Some(SlashCommand::Config {
|
||||
SlashCommand::parse("/config env").map_err(|e| e.to_string()),
|
||||
Ok(Some(SlashCommand::Config {
|
||||
section: Some("env".to_string())
|
||||
})
|
||||
}))
|
||||
);
|
||||
assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory));
|
||||
assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init));
|
||||
assert_eq!(SlashCommand::parse("/memory").map_err(|e| e.to_string()), Ok(Some(SlashCommand::Memory)));
|
||||
assert_eq!(SlashCommand::parse("/init").map_err(|e| e.to_string()), Ok(Some(SlashCommand::Init)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -70,16 +70,12 @@ fn upstream_repo_candidates(primary_repo_root: &Path) -> Vec<PathBuf> {
|
||||
}
|
||||
|
||||
for ancestor in primary_repo_root.ancestors().take(4) {
|
||||
candidates.push(ancestor.join("claude-code"));
|
||||
candidates.push(ancestor.join("claw-code"));
|
||||
candidates.push(ancestor.join("clawd-code"));
|
||||
}
|
||||
|
||||
candidates.push(
|
||||
primary_repo_root
|
||||
.join("reference-source")
|
||||
.join("claude-code"),
|
||||
);
|
||||
candidates.push(primary_repo_root.join("vendor").join("claude-code"));
|
||||
candidates.push(primary_repo_root.join("reference-source").join("claw-code"));
|
||||
candidates.push(primary_repo_root.join("vendor").join("claw-code"));
|
||||
|
||||
let mut deduped = Vec::new();
|
||||
for candidate in candidates {
|
||||
|
||||
@ -41,6 +41,7 @@ mod tests {
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn write_mock_server_script(root: &std::path::Path) -> PathBuf {
|
||||
let script_path = root.join("mock_lsp_server.py");
|
||||
fs::write(
|
||||
|
||||
18
crates/mock-anthropic-service/Cargo.toml
Normal file
18
crates/mock-anthropic-service/Cargo.toml
Normal file
@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "mock-anthropic-service"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "mock-anthropic-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
api = { path = "../api" }
|
||||
serde_json.workspace = true
|
||||
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "signal", "sync"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
1123
crates/mock-anthropic-service/src/lib.rs
Normal file
1123
crates/mock-anthropic-service/src/lib.rs
Normal file
File diff suppressed because it is too large
Load Diff
34
crates/mock-anthropic-service/src/main.rs
Normal file
34
crates/mock-anthropic-service/src/main.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use std::env;
|
||||
|
||||
use mock_anthropic_service::MockAnthropicService;
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let mut bind_addr = String::from("127.0.0.1:0");
|
||||
let mut args = env::args().skip(1);
|
||||
while let Some(arg) = args.next() {
|
||||
match arg.as_str() {
|
||||
"--bind" => {
|
||||
bind_addr = args
|
||||
.next()
|
||||
.ok_or_else(|| "missing value for --bind".to_string())?;
|
||||
}
|
||||
flag if flag.starts_with("--bind=") => {
|
||||
bind_addr = flag[7..].to_string();
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
println!("Usage: mock-anthropic-service [--bind HOST:PORT]");
|
||||
return Ok(());
|
||||
}
|
||||
other => {
|
||||
return Err(format!("unsupported argument: {other}").into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let server = MockAnthropicService::spawn_on(&bind_addr).await?;
|
||||
println!("MOCK_ANTHROPIC_BASE_URL={}", server.base_url());
|
||||
tokio::signal::ctrl_c().await?;
|
||||
drop(server);
|
||||
Ok(())
|
||||
}
|
||||
@ -1,6 +1,4 @@
|
||||
use std::ffi::OsStr;
|
||||
#[cfg(not(windows))]
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
use serde_json::json;
|
||||
@ -11,6 +9,7 @@ use crate::{PluginError, PluginHooks, PluginRegistry};
|
||||
pub enum HookEvent {
|
||||
PreToolUse,
|
||||
PostToolUse,
|
||||
PostToolUseFailure,
|
||||
}
|
||||
|
||||
impl HookEvent {
|
||||
@ -18,6 +17,7 @@ impl HookEvent {
|
||||
match self {
|
||||
Self::PreToolUse => "PreToolUse",
|
||||
Self::PostToolUse => "PostToolUse",
|
||||
Self::PostToolUseFailure => "PostToolUseFailure",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -25,6 +25,7 @@ impl HookEvent {
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct HookRunResult {
|
||||
denied: bool,
|
||||
failed: bool,
|
||||
messages: Vec<String>,
|
||||
}
|
||||
|
||||
@ -33,6 +34,7 @@ impl HookRunResult {
|
||||
pub fn allow(messages: Vec<String>) -> Self {
|
||||
Self {
|
||||
denied: false,
|
||||
failed: false,
|
||||
messages,
|
||||
}
|
||||
}
|
||||
@ -42,6 +44,11 @@ impl HookRunResult {
|
||||
self.denied
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_failed(&self) -> bool {
|
||||
self.failed
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn messages(&self) -> &[String] {
|
||||
&self.messages
|
||||
@ -65,7 +72,7 @@ impl HookRunner {
|
||||
|
||||
#[must_use]
|
||||
pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult {
|
||||
self.run_commands(
|
||||
Self::run_commands(
|
||||
HookEvent::PreToolUse,
|
||||
&self.hooks.pre_tool_use,
|
||||
tool_name,
|
||||
@ -83,7 +90,7 @@ impl HookRunner {
|
||||
tool_output: &str,
|
||||
is_error: bool,
|
||||
) -> HookRunResult {
|
||||
self.run_commands(
|
||||
Self::run_commands(
|
||||
HookEvent::PostToolUse,
|
||||
&self.hooks.post_tool_use,
|
||||
tool_name,
|
||||
@ -93,8 +100,24 @@ impl HookRunner {
|
||||
)
|
||||
}
|
||||
|
||||
fn run_commands(
|
||||
#[must_use]
|
||||
pub fn run_post_tool_use_failure(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
tool_input: &str,
|
||||
tool_error: &str,
|
||||
) -> HookRunResult {
|
||||
Self::run_commands(
|
||||
HookEvent::PostToolUseFailure,
|
||||
&self.hooks.post_tool_use_failure,
|
||||
tool_name,
|
||||
tool_input,
|
||||
Some(tool_error),
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
fn run_commands(
|
||||
event: HookEvent,
|
||||
commands: &[String],
|
||||
tool_name: &str,
|
||||
@ -106,20 +129,12 @@ impl HookRunner {
|
||||
return HookRunResult::allow(Vec::new());
|
||||
}
|
||||
|
||||
let payload = json!({
|
||||
"hook_event_name": event.as_str(),
|
||||
"tool_name": tool_name,
|
||||
"tool_input": parse_tool_input(tool_input),
|
||||
"tool_input_json": tool_input,
|
||||
"tool_output": tool_output,
|
||||
"tool_result_is_error": is_error,
|
||||
})
|
||||
.to_string();
|
||||
let payload = hook_payload(event, tool_name, tool_input, tool_output, is_error).to_string();
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
for command in commands {
|
||||
match self.run_command(
|
||||
match Self::run_command(
|
||||
command,
|
||||
event,
|
||||
tool_name,
|
||||
@ -139,19 +154,26 @@ impl HookRunner {
|
||||
}));
|
||||
return HookRunResult {
|
||||
denied: true,
|
||||
failed: false,
|
||||
messages,
|
||||
};
|
||||
}
|
||||
HookCommandOutcome::Failed { message } => {
|
||||
messages.push(message);
|
||||
return HookRunResult {
|
||||
denied: false,
|
||||
failed: true,
|
||||
messages,
|
||||
};
|
||||
}
|
||||
HookCommandOutcome::Warn { message } => messages.push(message),
|
||||
}
|
||||
}
|
||||
|
||||
HookRunResult::allow(messages)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments, clippy::unused_self)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_command(
|
||||
&self,
|
||||
command: &str,
|
||||
event: HookEvent,
|
||||
tool_name: &str,
|
||||
@ -180,7 +202,7 @@ impl HookRunner {
|
||||
match output.status.code() {
|
||||
Some(0) => HookCommandOutcome::Allow { message },
|
||||
Some(2) => HookCommandOutcome::Deny { message },
|
||||
Some(code) => HookCommandOutcome::Warn {
|
||||
Some(code) => HookCommandOutcome::Failed {
|
||||
message: format_hook_warning(
|
||||
command,
|
||||
code,
|
||||
@ -188,7 +210,7 @@ impl HookRunner {
|
||||
stderr.as_str(),
|
||||
),
|
||||
},
|
||||
None => HookCommandOutcome::Warn {
|
||||
None => HookCommandOutcome::Failed {
|
||||
message: format!(
|
||||
"{} hook `{command}` terminated by signal while handling `{tool_name}`",
|
||||
event.as_str()
|
||||
@ -196,7 +218,7 @@ impl HookRunner {
|
||||
},
|
||||
}
|
||||
}
|
||||
Err(error) => HookCommandOutcome::Warn {
|
||||
Err(error) => HookCommandOutcome::Failed {
|
||||
message: format!(
|
||||
"{} hook `{command}` failed to start for `{tool_name}`: {error}",
|
||||
event.as_str()
|
||||
@ -209,7 +231,34 @@ impl HookRunner {
|
||||
enum HookCommandOutcome {
|
||||
Allow { message: Option<String> },
|
||||
Deny { message: Option<String> },
|
||||
Warn { message: String },
|
||||
Failed { message: String },
|
||||
}
|
||||
|
||||
fn hook_payload(
|
||||
event: HookEvent,
|
||||
tool_name: &str,
|
||||
tool_input: &str,
|
||||
tool_output: Option<&str>,
|
||||
is_error: bool,
|
||||
) -> serde_json::Value {
|
||||
match event {
|
||||
HookEvent::PostToolUseFailure => json!({
|
||||
"hook_event_name": event.as_str(),
|
||||
"tool_name": tool_name,
|
||||
"tool_input": parse_tool_input(tool_input),
|
||||
"tool_input_json": tool_input,
|
||||
"tool_error": tool_output,
|
||||
"tool_result_is_error": true,
|
||||
}),
|
||||
_ => json!({
|
||||
"hook_event_name": event.as_str(),
|
||||
"tool_name": tool_name,
|
||||
"tool_input": parse_tool_input(tool_input),
|
||||
"tool_input_json": tool_input,
|
||||
"tool_output": tool_output,
|
||||
"tool_result_is_error": is_error,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_tool_input(tool_input: &str) -> serde_json::Value {
|
||||
@ -217,8 +266,7 @@ fn parse_tool_input(tool_input: &str) -> serde_json::Value {
|
||||
}
|
||||
|
||||
fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String {
|
||||
let mut message =
|
||||
format!("Hook `{command}` exited with status {code}; allowing tool execution to continue");
|
||||
let mut message = format!("Hook `{command}` exited with status {code}");
|
||||
if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) {
|
||||
message.push_str(": ");
|
||||
message.push_str(stdout);
|
||||
@ -288,7 +336,28 @@ impl CommandWithStdin {
|
||||
let mut child = self.command.spawn()?;
|
||||
if let Some(mut child_stdin) = child.stdin.take() {
|
||||
use std::io::Write as _;
|
||||
child_stdin.write_all(stdin)?;
|
||||
// Tolerate BrokenPipe: a hook script that runs to completion
|
||||
// (or exits early without reading stdin) closes its stdin
|
||||
// before the parent finishes writing the JSON payload, and
|
||||
// the kernel raises EPIPE on the parent's write_all. That is
|
||||
// not a hook failure — the child still exited cleanly and we
|
||||
// still need to wait_with_output() to capture stdout/stderr
|
||||
// and the real exit code. Other write errors (e.g. EIO,
|
||||
// permission, OOM) still propagate.
|
||||
//
|
||||
// This was the root cause of the Linux CI flake on
|
||||
// hooks::tests::collects_and_runs_hooks_from_enabled_plugins
|
||||
// (ROADMAP #25, runs 24120271422 / 24120538408 / 24121392171
|
||||
// / 24121776826): the test hook scripts run in microseconds
|
||||
// and the parent's stdin write races against child exit.
|
||||
// macOS pipes happen to buffer the small payload before the
|
||||
// child exits; Linux pipes do not, so the race shows up
|
||||
// deterministically on ubuntu runners.
|
||||
match child_stdin.write_all(stdin) {
|
||||
Ok(()) => {}
|
||||
Err(error) if error.kind() == std::io::ErrorKind::BrokenPipe => {}
|
||||
Err(error) => return Err(error),
|
||||
}
|
||||
}
|
||||
child.wait_with_output()
|
||||
}
|
||||
@ -310,23 +379,55 @@ mod tests {
|
||||
std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}"))
|
||||
}
|
||||
|
||||
fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) {
|
||||
fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir");
|
||||
fn make_executable(path: &Path) {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = fs::Permissions::from_mode(0o755);
|
||||
fs::set_permissions(path, perms)
|
||||
.unwrap_or_else(|e| panic!("chmod +x {}: {e}", path.display()));
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
let _ = path;
|
||||
}
|
||||
|
||||
fn write_hook_plugin(
|
||||
root: &Path,
|
||||
name: &str,
|
||||
pre_message: &str,
|
||||
post_message: &str,
|
||||
failure_message: &str,
|
||||
) {
|
||||
fs::create_dir_all(root.join(".claude-plugin")).expect("manifest dir");
|
||||
fs::create_dir_all(root.join("hooks")).expect("hooks dir");
|
||||
|
||||
let pre_path = root.join("hooks").join("pre.sh");
|
||||
fs::write(
|
||||
root.join("hooks").join("pre.sh"),
|
||||
&pre_path,
|
||||
format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"),
|
||||
)
|
||||
.expect("write pre hook");
|
||||
make_executable(&pre_path);
|
||||
|
||||
let post_path = root.join("hooks").join("post.sh");
|
||||
fs::write(
|
||||
root.join("hooks").join("post.sh"),
|
||||
&post_path,
|
||||
format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"),
|
||||
)
|
||||
.expect("write post hook");
|
||||
make_executable(&post_path);
|
||||
|
||||
let failure_path = root.join("hooks").join("failure.sh");
|
||||
fs::write(
|
||||
root.join(".claw-plugin").join("plugin.json"),
|
||||
&failure_path,
|
||||
format!("#!/bin/sh\nprintf '%s\\n' '{failure_message}'\n"),
|
||||
)
|
||||
.expect("write failure hook");
|
||||
make_executable(&failure_path);
|
||||
fs::write(
|
||||
root.join(".claude-plugin").join("plugin.json"),
|
||||
format!(
|
||||
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}"
|
||||
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"],\n \"PostToolUseFailure\": [\"./hooks/failure.sh\"]\n }}\n}}"
|
||||
),
|
||||
)
|
||||
.expect("write plugin manifest");
|
||||
@ -334,6 +435,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn collects_and_runs_hooks_from_enabled_plugins() {
|
||||
// given
|
||||
let config_home = temp_dir("config");
|
||||
let first_source_root = temp_dir("source-a");
|
||||
let second_source_root = temp_dir("source-b");
|
||||
@ -342,12 +444,14 @@ mod tests {
|
||||
"first",
|
||||
"plugin pre one",
|
||||
"plugin post one",
|
||||
"plugin failure one",
|
||||
);
|
||||
write_hook_plugin(
|
||||
&second_source_root,
|
||||
"second",
|
||||
"plugin pre two",
|
||||
"plugin post two",
|
||||
"plugin failure two",
|
||||
);
|
||||
|
||||
let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
|
||||
@ -359,8 +463,10 @@ mod tests {
|
||||
.expect("second plugin install should succeed");
|
||||
let registry = manager.plugin_registry().expect("registry should build");
|
||||
|
||||
// when
|
||||
let runner = HookRunner::from_registry(®istry).expect("plugin hooks should load");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#),
|
||||
HookRunResult::allow(vec![
|
||||
@ -375,6 +481,13 @@ mod tests {
|
||||
"plugin post two".to_string(),
|
||||
])
|
||||
);
|
||||
assert_eq!(
|
||||
runner.run_post_tool_use_failure("Read", r#"{"path":"README.md"}"#, "tool failed",),
|
||||
HookRunResult::allow(vec![
|
||||
"plugin failure one".to_string(),
|
||||
"plugin failure two".to_string(),
|
||||
])
|
||||
);
|
||||
|
||||
let _ = fs::remove_dir_all(config_home);
|
||||
let _ = fs::remove_dir_all(first_source_root);
|
||||
@ -383,14 +496,68 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn pre_tool_use_denies_when_plugin_hook_exits_two() {
|
||||
// given
|
||||
let runner = HookRunner::new(crate::PluginHooks {
|
||||
pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()],
|
||||
post_tool_use: Vec::new(),
|
||||
post_tool_use_failure: Vec::new(),
|
||||
});
|
||||
|
||||
// when
|
||||
let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
|
||||
|
||||
// then
|
||||
assert!(result.is_denied());
|
||||
assert_eq!(result.messages(), &["blocked by plugin".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn propagates_plugin_hook_failures() {
|
||||
// given
|
||||
let runner = HookRunner::new(crate::PluginHooks {
|
||||
pre_tool_use: vec![
|
||||
"printf 'broken plugin hook'; exit 1".to_string(),
|
||||
"printf 'later plugin hook'".to_string(),
|
||||
],
|
||||
post_tool_use: Vec::new(),
|
||||
post_tool_use_failure: Vec::new(),
|
||||
});
|
||||
|
||||
// when
|
||||
let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#);
|
||||
|
||||
// then
|
||||
assert!(result.is_failed());
|
||||
assert!(result
|
||||
.messages()
|
||||
.iter()
|
||||
.any(|message| message.contains("broken plugin hook")));
|
||||
assert!(!result
|
||||
.messages()
|
||||
.iter()
|
||||
.any(|message| message == "later plugin hook"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(unix)]
|
||||
fn generated_hook_scripts_are_executable() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
// given
|
||||
let root = temp_dir("exec-guard");
|
||||
write_hook_plugin(&root, "exec-check", "pre", "post", "fail");
|
||||
|
||||
// then
|
||||
for script in ["pre.sh", "post.sh", "failure.sh"] {
|
||||
let path = root.join("hooks").join(script);
|
||||
let mode = fs::metadata(&path)
|
||||
.unwrap_or_else(|e| panic!("{script} metadata: {e}"))
|
||||
.permissions()
|
||||
.mode();
|
||||
assert!(
|
||||
mode & 0o111 != 0,
|
||||
"{script} must have at least one execute bit set, got mode {mode:#o}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,7 +18,7 @@ const BUNDLED_MARKETPLACE: &str = "bundled";
|
||||
const SETTINGS_FILE_NAME: &str = "settings.json";
|
||||
const REGISTRY_FILE_NAME: &str = "installed.json";
|
||||
const MANIFEST_FILE_NAME: &str = "plugin.json";
|
||||
const MANIFEST_RELATIVE_PATH: &str = ".claw-plugin/plugin.json";
|
||||
const MANIFEST_RELATIVE_PATH: &str = ".claude-plugin/plugin.json";
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
@ -67,12 +67,16 @@ pub struct PluginHooks {
|
||||
pub pre_tool_use: Vec<String>,
|
||||
#[serde(rename = "PostToolUse", default)]
|
||||
pub post_tool_use: Vec<String>,
|
||||
#[serde(rename = "PostToolUseFailure", default)]
|
||||
pub post_tool_use_failure: Vec<String>,
|
||||
}
|
||||
|
||||
impl PluginHooks {
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.pre_tool_use.is_empty() && self.post_tool_use.is_empty()
|
||||
self.pre_tool_use.is_empty()
|
||||
&& self.post_tool_use.is_empty()
|
||||
&& self.post_tool_use_failure.is_empty()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@ -85,6 +89,9 @@ impl PluginHooks {
|
||||
.post_tool_use
|
||||
.extend(other.post_tool_use.iter().cloned());
|
||||
merged
|
||||
.post_tool_use_failure
|
||||
.extend(other.post_tool_use_failure.iter().cloned());
|
||||
merged
|
||||
}
|
||||
}
|
||||
|
||||
@ -302,14 +309,14 @@ impl PluginTool {
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.env("CLAW_PLUGIN_ID", &self.plugin_id)
|
||||
.env("CLAW_PLUGIN_NAME", &self.plugin_name)
|
||||
.env("CLAW_TOOL_NAME", &self.definition.name)
|
||||
.env("CLAW_TOOL_INPUT", &input_json);
|
||||
.env("CLAWD_PLUGIN_ID", &self.plugin_id)
|
||||
.env("CLAWD_PLUGIN_NAME", &self.plugin_name)
|
||||
.env("CLAWD_TOOL_NAME", &self.definition.name)
|
||||
.env("CLAWD_TOOL_INPUT", &input_json);
|
||||
if let Some(root) = &self.root {
|
||||
process
|
||||
.current_dir(root)
|
||||
.env("CLAW_PLUGIN_ROOT", root.display().to_string());
|
||||
.env("CLAWD_PLUGIN_ROOT", root.display().to_string());
|
||||
}
|
||||
|
||||
let mut child = process.spawn()?;
|
||||
@ -648,6 +655,106 @@ pub struct PluginSummary {
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PluginLoadFailure {
|
||||
pub plugin_root: PathBuf,
|
||||
pub kind: PluginKind,
|
||||
pub source: String,
|
||||
error: Box<PluginError>,
|
||||
}
|
||||
|
||||
impl PluginLoadFailure {
|
||||
#[must_use]
|
||||
pub fn new(plugin_root: PathBuf, kind: PluginKind, source: String, error: PluginError) -> Self {
|
||||
Self {
|
||||
plugin_root,
|
||||
kind,
|
||||
source,
|
||||
error: Box::new(error),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn error(&self) -> &PluginError {
|
||||
self.error.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for PluginLoadFailure {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"failed to load {} plugin from `{}` (source: {}): {}",
|
||||
self.kind,
|
||||
self.plugin_root.display(),
|
||||
self.source,
|
||||
self.error()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PluginRegistryReport {
|
||||
registry: PluginRegistry,
|
||||
failures: Vec<PluginLoadFailure>,
|
||||
}
|
||||
|
||||
impl PluginRegistryReport {
|
||||
#[must_use]
|
||||
pub fn new(registry: PluginRegistry, failures: Vec<PluginLoadFailure>) -> Self {
|
||||
Self { registry, failures }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn registry(&self) -> &PluginRegistry {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn failures(&self) -> &[PluginLoadFailure] {
|
||||
&self.failures
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn has_failures(&self) -> bool {
|
||||
!self.failures.is_empty()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn summaries(&self) -> Vec<PluginSummary> {
|
||||
self.registry.summaries()
|
||||
}
|
||||
|
||||
pub fn into_registry(self) -> Result<PluginRegistry, PluginError> {
|
||||
if self.failures.is_empty() {
|
||||
Ok(self.registry)
|
||||
} else {
|
||||
Err(PluginError::LoadFailures(self.failures))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct PluginDiscovery {
|
||||
plugins: Vec<PluginDefinition>,
|
||||
failures: Vec<PluginLoadFailure>,
|
||||
}
|
||||
|
||||
impl PluginDiscovery {
|
||||
fn push_plugin(&mut self, plugin: PluginDefinition) {
|
||||
self.plugins.push(plugin);
|
||||
}
|
||||
|
||||
fn push_failure(&mut self, failure: PluginLoadFailure) {
|
||||
self.failures.push(failure);
|
||||
}
|
||||
|
||||
fn extend(&mut self, other: Self) {
|
||||
self.plugins.extend(other.plugins);
|
||||
self.failures.extend(other.failures);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct PluginRegistry {
|
||||
plugins: Vec<RegisteredPlugin>,
|
||||
@ -802,6 +909,10 @@ pub enum PluginManifestValidationError {
|
||||
kind: &'static str,
|
||||
path: PathBuf,
|
||||
},
|
||||
PathIsDirectory {
|
||||
kind: &'static str,
|
||||
path: PathBuf,
|
||||
},
|
||||
InvalidToolInputSchema {
|
||||
tool_name: String,
|
||||
},
|
||||
@ -809,6 +920,9 @@ pub enum PluginManifestValidationError {
|
||||
tool_name: String,
|
||||
permission: String,
|
||||
},
|
||||
UnsupportedManifestContract {
|
||||
detail: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for PluginManifestValidationError {
|
||||
@ -838,6 +952,9 @@ impl Display for PluginManifestValidationError {
|
||||
Self::MissingPath { kind, path } => {
|
||||
write!(f, "{kind} path `{}` does not exist", path.display())
|
||||
}
|
||||
Self::PathIsDirectory { kind, path } => {
|
||||
write!(f, "{kind} path `{}` must point to a file", path.display())
|
||||
}
|
||||
Self::InvalidToolInputSchema { tool_name } => {
|
||||
write!(
|
||||
f,
|
||||
@ -851,6 +968,7 @@ impl Display for PluginManifestValidationError {
|
||||
f,
|
||||
"plugin tool `{tool_name}` requiredPermission `{permission}` must be read-only, workspace-write, or danger-full-access"
|
||||
),
|
||||
Self::UnsupportedManifestContract { detail } => f.write_str(detail),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -860,6 +978,7 @@ pub enum PluginError {
|
||||
Io(std::io::Error),
|
||||
Json(serde_json::Error),
|
||||
ManifestValidation(Vec<PluginManifestValidationError>),
|
||||
LoadFailures(Vec<PluginLoadFailure>),
|
||||
InvalidManifest(String),
|
||||
NotFound(String),
|
||||
CommandFailed(String),
|
||||
@ -879,6 +998,15 @@ impl Display for PluginError {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::LoadFailures(failures) => {
|
||||
for (index, failure) in failures.iter().enumerate() {
|
||||
if index > 0 {
|
||||
write!(f, "; ")?;
|
||||
}
|
||||
write!(f, "{failure}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::InvalidManifest(message)
|
||||
| Self::NotFound(message)
|
||||
| Self::CommandFailed(message) => write!(f, "{message}"),
|
||||
@ -935,15 +1063,23 @@ impl PluginManager {
|
||||
}
|
||||
|
||||
pub fn plugin_registry(&self) -> Result<PluginRegistry, PluginError> {
|
||||
Ok(PluginRegistry::new(
|
||||
self.discover_plugins()?
|
||||
.into_iter()
|
||||
.map(|plugin| {
|
||||
let enabled = self.is_enabled(plugin.metadata());
|
||||
RegisteredPlugin::new(plugin, enabled)
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
self.plugin_registry_report()?.into_registry()
|
||||
}
|
||||
|
||||
pub fn plugin_registry_report(&self) -> Result<PluginRegistryReport, PluginError> {
|
||||
self.sync_bundled_plugins()?;
|
||||
|
||||
let mut discovery = PluginDiscovery::default();
|
||||
discovery.plugins.extend(builtin_plugins());
|
||||
|
||||
let installed = self.discover_installed_plugins_with_failures()?;
|
||||
discovery.extend(installed);
|
||||
|
||||
let external =
|
||||
self.discover_external_directory_plugins_with_failures(&discovery.plugins)?;
|
||||
discovery.extend(external);
|
||||
|
||||
Ok(self.build_registry_report(discovery))
|
||||
}
|
||||
|
||||
pub fn list_plugins(&self) -> Result<Vec<PluginSummary>, PluginError> {
|
||||
@ -955,11 +1091,12 @@ impl PluginManager {
|
||||
}
|
||||
|
||||
pub fn discover_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> {
|
||||
self.sync_bundled_plugins()?;
|
||||
let mut plugins = builtin_plugins();
|
||||
plugins.extend(self.discover_installed_plugins()?);
|
||||
plugins.extend(self.discover_external_directory_plugins(&plugins)?);
|
||||
Ok(plugins)
|
||||
Ok(self
|
||||
.plugin_registry()?
|
||||
.plugins
|
||||
.into_iter()
|
||||
.map(|plugin| plugin.definition)
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub fn aggregated_hooks(&self) -> Result<PluginHooks, PluginError> {
|
||||
@ -1094,9 +1231,9 @@ impl PluginManager {
|
||||
})
|
||||
}
|
||||
|
||||
fn discover_installed_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> {
|
||||
fn discover_installed_plugins_with_failures(&self) -> Result<PluginDiscovery, PluginError> {
|
||||
let mut registry = self.load_registry()?;
|
||||
let mut plugins = Vec::new();
|
||||
let mut discovery = PluginDiscovery::default();
|
||||
let mut seen_ids = BTreeSet::<String>::new();
|
||||
let mut seen_paths = BTreeSet::<PathBuf>::new();
|
||||
let mut stale_registry_ids = Vec::new();
|
||||
@ -1111,10 +1248,21 @@ impl PluginManager {
|
||||
|| install_path.display().to_string(),
|
||||
|record| describe_install_source(&record.source),
|
||||
);
|
||||
let plugin = load_plugin_definition(&install_path, kind, source, kind.marketplace())?;
|
||||
match load_plugin_definition(&install_path, kind, source.clone(), kind.marketplace()) {
|
||||
Ok(plugin) => {
|
||||
if seen_ids.insert(plugin.metadata().id.clone()) {
|
||||
seen_paths.insert(install_path);
|
||||
plugins.push(plugin);
|
||||
discovery.push_plugin(plugin);
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
discovery.push_failure(PluginLoadFailure::new(
|
||||
install_path,
|
||||
kind,
|
||||
source,
|
||||
error,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1127,15 +1275,27 @@ impl PluginManager {
|
||||
stale_registry_ids.push(record.id.clone());
|
||||
continue;
|
||||
}
|
||||
let plugin = load_plugin_definition(
|
||||
let source = describe_install_source(&record.source);
|
||||
match load_plugin_definition(
|
||||
&record.install_path,
|
||||
record.kind,
|
||||
describe_install_source(&record.source),
|
||||
source.clone(),
|
||||
record.kind.marketplace(),
|
||||
)?;
|
||||
) {
|
||||
Ok(plugin) => {
|
||||
if seen_ids.insert(plugin.metadata().id.clone()) {
|
||||
seen_paths.insert(record.install_path.clone());
|
||||
plugins.push(plugin);
|
||||
discovery.push_plugin(plugin);
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
discovery.push_failure(PluginLoadFailure::new(
|
||||
record.install_path.clone(),
|
||||
record.kind,
|
||||
source,
|
||||
error,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1146,47 +1306,51 @@ impl PluginManager {
|
||||
self.store_registry(®istry)?;
|
||||
}
|
||||
|
||||
Ok(plugins)
|
||||
Ok(discovery)
|
||||
}
|
||||
|
||||
fn discover_external_directory_plugins(
|
||||
fn discover_external_directory_plugins_with_failures(
|
||||
&self,
|
||||
existing_plugins: &[PluginDefinition],
|
||||
) -> Result<Vec<PluginDefinition>, PluginError> {
|
||||
let mut plugins = Vec::new();
|
||||
) -> Result<PluginDiscovery, PluginError> {
|
||||
let mut discovery = PluginDiscovery::default();
|
||||
|
||||
for directory in &self.config.external_dirs {
|
||||
for root in discover_plugin_dirs(directory)? {
|
||||
let plugin = load_plugin_definition(
|
||||
let source = root.display().to_string();
|
||||
match load_plugin_definition(
|
||||
&root,
|
||||
PluginKind::External,
|
||||
root.display().to_string(),
|
||||
source.clone(),
|
||||
EXTERNAL_MARKETPLACE,
|
||||
)?;
|
||||
) {
|
||||
Ok(plugin) => {
|
||||
if existing_plugins
|
||||
.iter()
|
||||
.chain(plugins.iter())
|
||||
.chain(discovery.plugins.iter())
|
||||
.all(|existing| existing.metadata().id != plugin.metadata().id)
|
||||
{
|
||||
plugins.push(plugin);
|
||||
discovery.push_plugin(plugin);
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
discovery.push_failure(PluginLoadFailure::new(
|
||||
root,
|
||||
PluginKind::External,
|
||||
source,
|
||||
error,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(plugins)
|
||||
Ok(discovery)
|
||||
}
|
||||
|
||||
fn installed_plugin_registry(&self) -> Result<PluginRegistry, PluginError> {
|
||||
pub fn installed_plugin_registry_report(&self) -> Result<PluginRegistryReport, PluginError> {
|
||||
self.sync_bundled_plugins()?;
|
||||
Ok(PluginRegistry::new(
|
||||
self.discover_installed_plugins()?
|
||||
.into_iter()
|
||||
.map(|plugin| {
|
||||
let enabled = self.is_enabled(plugin.metadata());
|
||||
RegisteredPlugin::new(plugin, enabled)
|
||||
})
|
||||
.collect(),
|
||||
))
|
||||
Ok(self.build_registry_report(self.discover_installed_plugins_with_failures()?))
|
||||
}
|
||||
|
||||
fn sync_bundled_plugins(&self) -> Result<(), PluginError> {
|
||||
@ -1332,6 +1496,26 @@ impl PluginManager {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn installed_plugin_registry(&self) -> Result<PluginRegistry, PluginError> {
|
||||
self.installed_plugin_registry_report()?.into_registry()
|
||||
}
|
||||
|
||||
fn build_registry_report(&self, discovery: PluginDiscovery) -> PluginRegistryReport {
|
||||
PluginRegistryReport::new(
|
||||
PluginRegistry::new(
|
||||
discovery
|
||||
.plugins
|
||||
.into_iter()
|
||||
.map(|plugin| {
|
||||
let enabled = self.is_enabled(plugin.metadata());
|
||||
RegisteredPlugin::new(plugin, enabled)
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
discovery.failures,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
@ -1414,10 +1598,73 @@ fn load_manifest_from_path(
|
||||
manifest_path.display()
|
||||
))
|
||||
})?;
|
||||
let raw_manifest: RawPluginManifest = serde_json::from_str(&contents)?;
|
||||
let raw_json: Value = serde_json::from_str(&contents)?;
|
||||
let compatibility_errors = detect_claude_code_manifest_contract_gaps(&raw_json);
|
||||
if !compatibility_errors.is_empty() {
|
||||
return Err(PluginError::ManifestValidation(compatibility_errors));
|
||||
}
|
||||
let raw_manifest: RawPluginManifest = serde_json::from_value(raw_json)?;
|
||||
build_plugin_manifest(root, raw_manifest)
|
||||
}
|
||||
|
||||
fn detect_claude_code_manifest_contract_gaps(
|
||||
raw_manifest: &Value,
|
||||
) -> Vec<PluginManifestValidationError> {
|
||||
let Some(root) = raw_manifest.as_object() else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let mut errors = Vec::new();
|
||||
|
||||
for (field, detail) in [
|
||||
(
|
||||
"skills",
|
||||
"plugin manifest field `skills` uses the Claude Code plugin contract; `claw` does not load plugin-managed skills and instead discovers skills from local roots such as `.claw/skills`, `.omc/skills`, `.agents/skills`, `~/.omc/skills`, and `~/.claude/skills/omc-learned`.",
|
||||
),
|
||||
(
|
||||
"mcpServers",
|
||||
"plugin manifest field `mcpServers` uses the Claude Code plugin contract; `claw` does not import MCP servers from plugin manifests.",
|
||||
),
|
||||
(
|
||||
"agents",
|
||||
"plugin manifest field `agents` uses the Claude Code plugin contract; `claw` does not load plugin-managed agent markdown catalogs from plugin manifests.",
|
||||
),
|
||||
] {
|
||||
if root.contains_key(field) {
|
||||
errors.push(PluginManifestValidationError::UnsupportedManifestContract {
|
||||
detail: detail.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if root
|
||||
.get("commands")
|
||||
.and_then(Value::as_array)
|
||||
.is_some_and(|commands| commands.iter().any(Value::is_string))
|
||||
{
|
||||
errors.push(PluginManifestValidationError::UnsupportedManifestContract {
|
||||
detail: "plugin manifest field `commands` uses Claude Code-style directory globs; `claw` slash dispatch is still built-in and does not load plugin slash command markdown files.".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(hooks) = root.get("hooks").and_then(Value::as_object) {
|
||||
for hook_name in hooks.keys() {
|
||||
if !matches!(
|
||||
hook_name.as_str(),
|
||||
"PreToolUse" | "PostToolUse" | "PostToolUseFailure"
|
||||
) {
|
||||
errors.push(PluginManifestValidationError::UnsupportedManifestContract {
|
||||
detail: format!(
|
||||
"plugin hook `{hook_name}` uses the Claude Code lifecycle contract; `claw` plugins currently support only PreToolUse, PostToolUse, and PostToolUseFailure."
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errors
|
||||
}
|
||||
|
||||
fn plugin_manifest_path(root: &Path) -> Result<PathBuf, PluginError> {
|
||||
let direct_path = root.join(MANIFEST_FILE_NAME);
|
||||
if direct_path.exists() {
|
||||
@ -1449,6 +1696,12 @@ fn build_plugin_manifest(
|
||||
let permissions = build_manifest_permissions(&raw.permissions, &mut errors);
|
||||
validate_command_entries(root, raw.hooks.pre_tool_use.iter(), "hook", &mut errors);
|
||||
validate_command_entries(root, raw.hooks.post_tool_use.iter(), "hook", &mut errors);
|
||||
validate_command_entries(
|
||||
root,
|
||||
raw.hooks.post_tool_use_failure.iter(),
|
||||
"hook",
|
||||
&mut errors,
|
||||
);
|
||||
validate_command_entries(
|
||||
root,
|
||||
raw.lifecycle.init.iter(),
|
||||
@ -1676,6 +1929,8 @@ fn validate_command_entry(
|
||||
};
|
||||
if !path.exists() {
|
||||
errors.push(PluginManifestValidationError::MissingPath { kind, path });
|
||||
} else if !path.is_file() {
|
||||
errors.push(PluginManifestValidationError::PathIsDirectory { kind, path });
|
||||
}
|
||||
}
|
||||
|
||||
@ -1691,6 +1946,11 @@ fn resolve_hooks(root: &Path, hooks: &PluginHooks) -> PluginHooks {
|
||||
.iter()
|
||||
.map(|entry| resolve_hook_entry(root, entry))
|
||||
.collect(),
|
||||
post_tool_use_failure: hooks
|
||||
.post_tool_use_failure
|
||||
.iter()
|
||||
.map(|entry| resolve_hook_entry(root, entry))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1739,7 +1999,12 @@ fn validate_hook_paths(root: Option<&Path>, hooks: &PluginHooks) -> Result<(), P
|
||||
let Some(root) = root else {
|
||||
return Ok(());
|
||||
};
|
||||
for entry in hooks.pre_tool_use.iter().chain(hooks.post_tool_use.iter()) {
|
||||
for entry in hooks
|
||||
.pre_tool_use
|
||||
.iter()
|
||||
.chain(hooks.post_tool_use.iter())
|
||||
.chain(hooks.post_tool_use_failure.iter())
|
||||
{
|
||||
validate_command_path(root, entry, "hook")?;
|
||||
}
|
||||
Ok(())
|
||||
@ -1783,6 +2048,12 @@ fn validate_command_path(root: &Path, entry: &str, kind: &str) -> Result<(), Plu
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
if !path.is_file() {
|
||||
return Err(PluginError::InvalidManifest(format!(
|
||||
"{kind} path `{}` must point to a file",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -2094,6 +2365,30 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
fn write_directory_path_plugin(root: &Path, name: &str) {
|
||||
fs::create_dir_all(root.join("hooks").join("pre-dir")).expect("hook dir");
|
||||
fs::create_dir_all(root.join("tools").join("tool-dir")).expect("tool dir");
|
||||
fs::create_dir_all(root.join("commands").join("sync-dir")).expect("command dir");
|
||||
fs::create_dir_all(root.join("lifecycle").join("init-dir")).expect("lifecycle dir");
|
||||
write_file(
|
||||
root.join(MANIFEST_FILE_NAME).as_path(),
|
||||
format!(
|
||||
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"directory path plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre-dir\"]\n }},\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init-dir\"]\n }},\n \"tools\": [\n {{\n \"name\": \"dir_tool\",\n \"description\": \"Directory tool\",\n \"inputSchema\": {{\"type\": \"object\"}},\n \"command\": \"./tools/tool-dir\"\n }}\n ],\n \"commands\": [\n {{\n \"name\": \"sync\",\n \"description\": \"Directory command\",\n \"command\": \"./commands/sync-dir\"\n }}\n ]\n}}"
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
fn write_broken_failure_hook_plugin(root: &Path, name: &str) {
|
||||
write_file(
|
||||
root.join(MANIFEST_RELATIVE_PATH).as_path(),
|
||||
format!(
|
||||
"{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"broken plugin\",\n \"hooks\": {{\n \"PostToolUseFailure\": [\"./hooks/missing-failure.sh\"]\n }}\n}}"
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
fn write_lifecycle_plugin(root: &Path, name: &str, version: &str) -> PathBuf {
|
||||
let log_path = root.join("lifecycle.log");
|
||||
write_file(
|
||||
@ -2122,7 +2417,7 @@ mod tests {
|
||||
let script_path = root.join("tools").join("echo-json.sh");
|
||||
write_file(
|
||||
&script_path,
|
||||
"#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAW_PLUGIN_ID\" \"$CLAW_TOOL_NAME\" \"$INPUT\"\n",
|
||||
"#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_PLUGIN_ID\" \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n",
|
||||
);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
@ -2289,6 +2584,37 @@ mod tests {
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_plugin_from_directory_rejects_claude_code_manifest_contracts_with_guidance() {
|
||||
let root = temp_dir("manifest-claude-code-contract");
|
||||
write_file(
|
||||
root.join(MANIFEST_FILE_NAME).as_path(),
|
||||
r#"{
|
||||
"name": "oh-my-claudecode",
|
||||
"version": "4.10.2",
|
||||
"description": "Claude Code plugin manifest",
|
||||
"hooks": {
|
||||
"SessionStart": ["scripts/session-start.mjs"]
|
||||
},
|
||||
"agents": ["agents/*.md"],
|
||||
"commands": ["commands/**/*.md"],
|
||||
"skills": "./skills/",
|
||||
"mcpServers": "./.mcp.json"
|
||||
}"#,
|
||||
);
|
||||
|
||||
let error = load_plugin_from_directory(&root)
|
||||
.expect_err("Claude Code plugin manifest should fail with guidance");
|
||||
let rendered = error.to_string();
|
||||
assert!(rendered.contains("field `skills` uses the Claude Code plugin contract"));
|
||||
assert!(rendered.contains("field `mcpServers` uses the Claude Code plugin contract"));
|
||||
assert!(rendered.contains("field `agents` uses the Claude Code plugin contract"));
|
||||
assert!(rendered.contains("field `commands` uses Claude Code-style directory globs"));
|
||||
assert!(rendered.contains("hook `SessionStart` uses the Claude Code lifecycle contract"));
|
||||
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_plugin_from_directory_rejects_missing_tool_or_command_paths() {
|
||||
let root = temp_dir("manifest-paths");
|
||||
@ -2315,6 +2641,90 @@ mod tests {
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_plugin_from_directory_rejects_missing_lifecycle_paths() {
|
||||
// given
|
||||
let root = temp_dir("manifest-lifecycle-paths");
|
||||
write_file(
|
||||
root.join(MANIFEST_FILE_NAME).as_path(),
|
||||
r#"{
|
||||
"name": "missing-lifecycle-paths",
|
||||
"version": "1.0.0",
|
||||
"description": "Missing lifecycle path validation",
|
||||
"lifecycle": {
|
||||
"Init": ["./lifecycle/init.sh"],
|
||||
"Shutdown": ["./lifecycle/shutdown.sh"]
|
||||
}
|
||||
}"#,
|
||||
);
|
||||
|
||||
// when
|
||||
let error =
|
||||
load_plugin_from_directory(&root).expect_err("missing lifecycle paths should fail");
|
||||
|
||||
// then
|
||||
match error {
|
||||
PluginError::ManifestValidation(errors) => {
|
||||
assert!(errors.iter().any(|error| matches!(
|
||||
error,
|
||||
PluginManifestValidationError::MissingPath { kind, path }
|
||||
if *kind == "lifecycle command"
|
||||
&& path.ends_with(Path::new("lifecycle/init.sh"))
|
||||
)));
|
||||
assert!(errors.iter().any(|error| matches!(
|
||||
error,
|
||||
PluginManifestValidationError::MissingPath { kind, path }
|
||||
if *kind == "lifecycle command"
|
||||
&& path.ends_with(Path::new("lifecycle/shutdown.sh"))
|
||||
)));
|
||||
}
|
||||
other => panic!("expected manifest validation errors, got {other}"),
|
||||
}
|
||||
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_plugin_from_directory_rejects_directory_command_paths() {
|
||||
// given
|
||||
let root = temp_dir("manifest-directory-paths");
|
||||
write_directory_path_plugin(&root, "directory-paths");
|
||||
|
||||
// when
|
||||
let error =
|
||||
load_plugin_from_directory(&root).expect_err("directory command paths should fail");
|
||||
|
||||
// then
|
||||
match error {
|
||||
PluginError::ManifestValidation(errors) => {
|
||||
assert!(errors.iter().any(|error| matches!(
|
||||
error,
|
||||
PluginManifestValidationError::PathIsDirectory { kind, path }
|
||||
if *kind == "hook" && path.ends_with(Path::new("hooks/pre-dir"))
|
||||
)));
|
||||
assert!(errors.iter().any(|error| matches!(
|
||||
error,
|
||||
PluginManifestValidationError::PathIsDirectory { kind, path }
|
||||
if *kind == "lifecycle command"
|
||||
&& path.ends_with(Path::new("lifecycle/init-dir"))
|
||||
)));
|
||||
assert!(errors.iter().any(|error| matches!(
|
||||
error,
|
||||
PluginManifestValidationError::PathIsDirectory { kind, path }
|
||||
if *kind == "tool" && path.ends_with(Path::new("tools/tool-dir"))
|
||||
)));
|
||||
assert!(errors.iter().any(|error| matches!(
|
||||
error,
|
||||
PluginManifestValidationError::PathIsDirectory { kind, path }
|
||||
if *kind == "command" && path.ends_with(Path::new("commands/sync-dir"))
|
||||
)));
|
||||
}
|
||||
other => panic!("expected manifest validation errors, got {other}"),
|
||||
}
|
||||
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_plugin_from_directory_rejects_invalid_permissions() {
|
||||
let root = temp_dir("manifest-invalid-permissions");
|
||||
@ -2806,16 +3216,95 @@ mod tests {
|
||||
let _ = fs::remove_dir_all(source_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plugin_registry_report_collects_load_failures_without_dropping_valid_plugins() {
|
||||
// given
|
||||
let config_home = temp_dir("report-home");
|
||||
let external_root = temp_dir("report-external");
|
||||
write_external_plugin(&external_root.join("valid"), "valid-report", "1.0.0");
|
||||
write_broken_plugin(&external_root.join("broken"), "broken-report");
|
||||
|
||||
let mut config = PluginManagerConfig::new(&config_home);
|
||||
config.external_dirs = vec![external_root.clone()];
|
||||
let manager = PluginManager::new(config);
|
||||
|
||||
// when
|
||||
let report = manager
|
||||
.plugin_registry_report()
|
||||
.expect("report should tolerate invalid external plugins");
|
||||
|
||||
// then
|
||||
assert!(report.registry().contains("valid-report@external"));
|
||||
assert_eq!(report.failures().len(), 1);
|
||||
assert_eq!(report.failures()[0].kind, PluginKind::External);
|
||||
assert!(report.failures()[0]
|
||||
.plugin_root
|
||||
.ends_with(Path::new("broken")));
|
||||
assert!(report.failures()[0]
|
||||
.error()
|
||||
.to_string()
|
||||
.contains("does not exist"));
|
||||
|
||||
let error = manager
|
||||
.plugin_registry()
|
||||
.expect_err("strict registry should surface load failures");
|
||||
match error {
|
||||
PluginError::LoadFailures(failures) => {
|
||||
assert_eq!(failures.len(), 1);
|
||||
assert!(failures[0].plugin_root.ends_with(Path::new("broken")));
|
||||
}
|
||||
other => panic!("expected load failures, got {other}"),
|
||||
}
|
||||
|
||||
let _ = fs::remove_dir_all(config_home);
|
||||
let _ = fs::remove_dir_all(external_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn installed_plugin_registry_report_collects_load_failures_from_install_root() {
|
||||
// given
|
||||
let config_home = temp_dir("installed-report-home");
|
||||
let bundled_root = temp_dir("installed-report-bundled");
|
||||
let install_root = config_home.join("plugins").join("installed");
|
||||
write_external_plugin(&install_root.join("valid"), "installed-valid", "1.0.0");
|
||||
write_broken_plugin(&install_root.join("broken"), "installed-broken");
|
||||
|
||||
let mut config = PluginManagerConfig::new(&config_home);
|
||||
config.bundled_root = Some(bundled_root.clone());
|
||||
config.install_root = Some(install_root);
|
||||
let manager = PluginManager::new(config);
|
||||
|
||||
// when
|
||||
let report = manager
|
||||
.installed_plugin_registry_report()
|
||||
.expect("installed report should tolerate invalid installed plugins");
|
||||
|
||||
// then
|
||||
assert!(report.registry().contains("installed-valid@external"));
|
||||
assert_eq!(report.failures().len(), 1);
|
||||
assert!(report.failures()[0]
|
||||
.plugin_root
|
||||
.ends_with(Path::new("broken")));
|
||||
|
||||
let _ = fs::remove_dir_all(config_home);
|
||||
let _ = fs::remove_dir_all(bundled_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_plugin_sources_with_missing_hook_paths() {
|
||||
// given
|
||||
let config_home = temp_dir("broken-home");
|
||||
let source_root = temp_dir("broken-source");
|
||||
write_broken_plugin(&source_root, "broken");
|
||||
|
||||
let manager = PluginManager::new(PluginManagerConfig::new(&config_home));
|
||||
|
||||
// when
|
||||
let error = manager
|
||||
.validate_plugin_source(source_root.to_str().expect("utf8 path"))
|
||||
.expect_err("missing hook file should fail validation");
|
||||
|
||||
// then
|
||||
assert!(error.to_string().contains("does not exist"));
|
||||
|
||||
let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
|
||||
@ -2828,6 +3317,33 @@ mod tests {
|
||||
let _ = fs::remove_dir_all(source_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_plugin_sources_with_missing_failure_hook_paths() {
|
||||
// given
|
||||
let config_home = temp_dir("broken-failure-home");
|
||||
let source_root = temp_dir("broken-failure-source");
|
||||
write_broken_failure_hook_plugin(&source_root, "broken-failure");
|
||||
|
||||
let manager = PluginManager::new(PluginManagerConfig::new(&config_home));
|
||||
|
||||
// when
|
||||
let error = manager
|
||||
.validate_plugin_source(source_root.to_str().expect("utf8 path"))
|
||||
.expect_err("missing failure hook file should fail validation");
|
||||
|
||||
// then
|
||||
assert!(error.to_string().contains("does not exist"));
|
||||
|
||||
let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home));
|
||||
let install_error = manager
|
||||
.install(source_root.to_str().expect("utf8 path"))
|
||||
.expect_err("install should reject invalid failure hook paths");
|
||||
assert!(install_error.to_string().contains("does not exist"));
|
||||
|
||||
let _ = fs::remove_dir_all(config_home);
|
||||
let _ = fs::remove_dir_all(source_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plugin_registry_runs_initialize_and_shutdown_for_enabled_plugins() {
|
||||
let config_home = temp_dir("lifecycle-home");
|
||||
|
||||
@ -7,13 +7,14 @@ publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
sha2 = "0.10"
|
||||
which = "7"
|
||||
glob = "0.3"
|
||||
lsp = { path = "../lsp" }
|
||||
plugins = { path = "../plugins" }
|
||||
regex = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json.workspace = true
|
||||
tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
|
||||
telemetry = { path = "../telemetry" }
|
||||
tokio = { version = "1", features = ["io-std", "io-util", "macros", "process", "rt", "rt-multi-thread", "time"] }
|
||||
walkdir = "2"
|
||||
|
||||
[lints]
|
||||
|
||||
@ -14,6 +14,7 @@ use crate::sandbox::{
|
||||
};
|
||||
use crate::ConfigLoader;
|
||||
|
||||
/// Input schema for the built-in bash execution tool.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct BashCommandInput {
|
||||
pub command: String,
|
||||
@ -33,6 +34,7 @@ pub struct BashCommandInput {
|
||||
pub allowed_mounts: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Output returned from a bash tool invocation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct BashCommandOutput {
|
||||
pub stdout: String,
|
||||
@ -64,6 +66,7 @@ pub struct BashCommandOutput {
|
||||
pub sandbox_status: Option<SandboxStatus>,
|
||||
}
|
||||
|
||||
/// Executes a shell command with the requested sandbox settings.
|
||||
pub fn execute_bash(input: BashCommandInput) -> io::Result<BashCommandOutput> {
|
||||
let cwd = env::current_dir()?;
|
||||
let sandbox_status = sandbox_status_for_input(&input, &cwd);
|
||||
@ -134,8 +137,8 @@ async fn execute_bash_async(
|
||||
};
|
||||
|
||||
let (output, interrupted) = output_result;
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
|
||||
let stdout = truncate_output(&String::from_utf8_lossy(&output.stdout));
|
||||
let stderr = truncate_output(&String::from_utf8_lossy(&output.stderr));
|
||||
let no_output_expected = Some(stdout.trim().is_empty() && stderr.trim().is_empty());
|
||||
let return_code_interpretation = output.status.code().and_then(|code| {
|
||||
if code == 0 {
|
||||
@ -197,36 +200,31 @@ fn prepare_command(
|
||||
return prepared;
|
||||
}
|
||||
|
||||
let mut prepared = if cfg!(target_os = "windows") && !sh_exists() {
|
||||
let prepared = if cfg!(target_os = "windows") && !sh_exists() {
|
||||
let mut p = Command::new("cmd");
|
||||
p.arg("/C").arg(command);
|
||||
p.arg("/C").arg(command).current_dir(cwd);
|
||||
p
|
||||
} else {
|
||||
let mut p = Command::new("sh");
|
||||
p.arg("-lc").arg(command);
|
||||
p.arg("-lc").arg(command).current_dir(cwd);
|
||||
if sandbox_status.filesystem_active {
|
||||
p.env("HOME", cwd.join(".sandbox-home"));
|
||||
p.env("TMPDIR", cwd.join(".sandbox-tmp"));
|
||||
}
|
||||
p
|
||||
};
|
||||
prepared.current_dir(cwd);
|
||||
if sandbox_status.filesystem_active {
|
||||
prepared.env("HOME", cwd.join(".sandbox-home"));
|
||||
prepared.env("TMPDIR", cwd.join(".sandbox-tmp"));
|
||||
}
|
||||
prepared
|
||||
}
|
||||
|
||||
fn sh_exists() -> bool {
|
||||
env::var_os("PATH").is_some_and(|paths| {
|
||||
env::split_paths(&paths).any(|path| {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
path.join("sh.exe").exists() || path.join("sh.bat").exists() || path.join("sh").exists()
|
||||
which::which("sh").is_ok()
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
path.join("sh").exists()
|
||||
true
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_tokio_command(
|
||||
@ -247,20 +245,19 @@ fn prepare_tokio_command(
|
||||
return prepared;
|
||||
}
|
||||
|
||||
let mut prepared = if cfg!(target_os = "windows") && !sh_exists() {
|
||||
let prepared = if cfg!(target_os = "windows") && !sh_exists() {
|
||||
let mut p = TokioCommand::new("cmd");
|
||||
p.arg("/C").arg(command);
|
||||
p.arg("/C").arg(command).current_dir(cwd);
|
||||
p
|
||||
} else {
|
||||
let mut p = TokioCommand::new("sh");
|
||||
p.arg("-lc").arg(command);
|
||||
p.arg("-lc").arg(command).current_dir(cwd);
|
||||
if sandbox_status.filesystem_active {
|
||||
p.env("HOME", cwd.join(".sandbox-home"));
|
||||
p.env("TMPDIR", cwd.join(".sandbox-tmp"));
|
||||
}
|
||||
p
|
||||
};
|
||||
prepared.current_dir(cwd);
|
||||
if sandbox_status.filesystem_active {
|
||||
prepared.env("HOME", cwd.join(".sandbox-home"));
|
||||
prepared.env("TMPDIR", cwd.join(".sandbox-tmp"));
|
||||
}
|
||||
prepared
|
||||
}
|
||||
|
||||
@ -312,3 +309,53 @@ mod tests {
|
||||
assert!(!output.sandbox_status.expect("sandbox status").enabled);
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum output bytes before truncation (16 KiB, matching upstream).
|
||||
const MAX_OUTPUT_BYTES: usize = 16_384;
|
||||
|
||||
/// Truncate output to `MAX_OUTPUT_BYTES`, appending a marker when trimmed.
|
||||
fn truncate_output(s: &str) -> String {
|
||||
if s.len() <= MAX_OUTPUT_BYTES {
|
||||
return s.to_string();
|
||||
}
|
||||
// Find the last valid UTF-8 boundary at or before MAX_OUTPUT_BYTES
|
||||
let mut end = MAX_OUTPUT_BYTES;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
let mut truncated = s[..end].to_string();
|
||||
truncated.push_str("\n\n[output truncated — exceeded 16384 bytes]");
|
||||
truncated
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod truncation_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn short_output_unchanged() {
|
||||
let s = "hello world";
|
||||
assert_eq!(truncate_output(s), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn long_output_truncated() {
|
||||
let s = "x".repeat(20_000);
|
||||
let result = truncate_output(&s);
|
||||
assert!(result.len() < 20_000);
|
||||
assert!(result.ends_with("[output truncated — exceeded 16384 bytes]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exact_boundary_unchanged() {
|
||||
let s = "a".repeat(MAX_OUTPUT_BYTES);
|
||||
assert_eq!(truncate_output(&s), s);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_over_boundary_truncated() {
|
||||
let s = "a".repeat(MAX_OUTPUT_BYTES + 1);
|
||||
let result = truncate_output(&s);
|
||||
assert!(result.contains("[output truncated"));
|
||||
}
|
||||
}
|
||||
|
||||
1004
crates/runtime/src/bash_validation.rs
Normal file
1004
crates/runtime/src/bash_validation.rs
Normal file
File diff suppressed because it is too large
Load Diff
@ -21,7 +21,7 @@ pub struct BootstrapPlan {
|
||||
|
||||
impl BootstrapPlan {
|
||||
#[must_use]
|
||||
pub fn claw_default() -> Self {
|
||||
pub fn claude_code_default() -> Self {
|
||||
Self::from_phases(vec![
|
||||
BootstrapPhase::CliEntry,
|
||||
BootstrapPhase::FastPathVersion,
|
||||
@ -54,3 +54,58 @@ impl BootstrapPlan {
|
||||
&self.phases
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{BootstrapPhase, BootstrapPlan};
|
||||
|
||||
#[test]
|
||||
fn from_phases_deduplicates_while_preserving_order() {
|
||||
// given
|
||||
let phases = vec![
|
||||
BootstrapPhase::CliEntry,
|
||||
BootstrapPhase::FastPathVersion,
|
||||
BootstrapPhase::CliEntry,
|
||||
BootstrapPhase::MainRuntime,
|
||||
BootstrapPhase::FastPathVersion,
|
||||
];
|
||||
|
||||
// when
|
||||
let plan = BootstrapPlan::from_phases(phases);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
plan.phases(),
|
||||
&[
|
||||
BootstrapPhase::CliEntry,
|
||||
BootstrapPhase::FastPathVersion,
|
||||
BootstrapPhase::MainRuntime,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn claude_code_default_covers_each_phase_once() {
|
||||
// given
|
||||
let expected = [
|
||||
BootstrapPhase::CliEntry,
|
||||
BootstrapPhase::FastPathVersion,
|
||||
BootstrapPhase::StartupProfiler,
|
||||
BootstrapPhase::SystemPromptFastPath,
|
||||
BootstrapPhase::ChromeMcpFastPath,
|
||||
BootstrapPhase::DaemonWorkerFastPath,
|
||||
BootstrapPhase::BridgeFastPath,
|
||||
BootstrapPhase::DaemonFastPath,
|
||||
BootstrapPhase::BackgroundSessionFastPath,
|
||||
BootstrapPhase::TemplateFastPath,
|
||||
BootstrapPhase::EnvironmentRunnerFastPath,
|
||||
BootstrapPhase::MainRuntime,
|
||||
];
|
||||
|
||||
// when
|
||||
let plan = BootstrapPlan::claude_code_default();
|
||||
|
||||
// then
|
||||
assert_eq!(plan.phases(), &expected);
|
||||
}
|
||||
}
|
||||
|
||||
144
crates/runtime/src/branch_lock.rs
Normal file
144
crates/runtime/src/branch_lock.rs
Normal file
@ -0,0 +1,144 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct BranchLockIntent {
|
||||
#[serde(rename = "laneId")]
|
||||
pub lane_id: String,
|
||||
pub branch: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub worktree: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub modules: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct BranchLockCollision {
|
||||
pub branch: String,
|
||||
pub module: String,
|
||||
#[serde(rename = "laneIds")]
|
||||
pub lane_ids: Vec<String>,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn detect_branch_lock_collisions(intents: &[BranchLockIntent]) -> Vec<BranchLockCollision> {
|
||||
let mut collisions = Vec::new();
|
||||
|
||||
for (index, left) in intents.iter().enumerate() {
|
||||
for right in &intents[index + 1..] {
|
||||
if left.branch != right.branch {
|
||||
continue;
|
||||
}
|
||||
for module in overlapping_modules(&left.modules, &right.modules) {
|
||||
collisions.push(BranchLockCollision {
|
||||
branch: left.branch.clone(),
|
||||
module,
|
||||
lane_ids: vec![left.lane_id.clone(), right.lane_id.clone()],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
collisions.sort_by(|a, b| {
|
||||
a.branch
|
||||
.cmp(&b.branch)
|
||||
.then(a.module.cmp(&b.module))
|
||||
.then(a.lane_ids.cmp(&b.lane_ids))
|
||||
});
|
||||
collisions.dedup();
|
||||
collisions
|
||||
}
|
||||
|
||||
fn overlapping_modules(left: &[String], right: &[String]) -> Vec<String> {
|
||||
let mut overlaps = Vec::new();
|
||||
for left_module in left {
|
||||
for right_module in right {
|
||||
if modules_overlap(left_module, right_module) {
|
||||
overlaps.push(shared_scope(left_module, right_module));
|
||||
}
|
||||
}
|
||||
}
|
||||
overlaps.sort();
|
||||
overlaps.dedup();
|
||||
overlaps
|
||||
}
|
||||
|
||||
fn modules_overlap(left: &str, right: &str) -> bool {
|
||||
left == right
|
||||
|| left.starts_with(&format!("{right}/"))
|
||||
|| right.starts_with(&format!("{left}/"))
|
||||
}
|
||||
|
||||
fn shared_scope(left: &str, right: &str) -> String {
|
||||
if left.starts_with(&format!("{right}/")) || left == right {
|
||||
right.to_string()
|
||||
} else {
|
||||
left.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{detect_branch_lock_collisions, BranchLockIntent};
|
||||
|
||||
#[test]
|
||||
fn detects_same_branch_same_module_collisions() {
|
||||
let collisions = detect_branch_lock_collisions(&[
|
||||
BranchLockIntent {
|
||||
lane_id: "lane-a".to_string(),
|
||||
branch: "feature/lock".to_string(),
|
||||
worktree: Some("wt-a".to_string()),
|
||||
modules: vec!["runtime/mcp".to_string()],
|
||||
},
|
||||
BranchLockIntent {
|
||||
lane_id: "lane-b".to_string(),
|
||||
branch: "feature/lock".to_string(),
|
||||
worktree: Some("wt-b".to_string()),
|
||||
modules: vec!["runtime/mcp".to_string()],
|
||||
},
|
||||
]);
|
||||
|
||||
assert_eq!(collisions.len(), 1);
|
||||
assert_eq!(collisions[0].branch, "feature/lock");
|
||||
assert_eq!(collisions[0].module, "runtime/mcp");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_nested_module_scope_collisions() {
|
||||
let collisions = detect_branch_lock_collisions(&[
|
||||
BranchLockIntent {
|
||||
lane_id: "lane-a".to_string(),
|
||||
branch: "feature/lock".to_string(),
|
||||
worktree: None,
|
||||
modules: vec!["runtime".to_string()],
|
||||
},
|
||||
BranchLockIntent {
|
||||
lane_id: "lane-b".to_string(),
|
||||
branch: "feature/lock".to_string(),
|
||||
worktree: None,
|
||||
modules: vec!["runtime/mcp".to_string()],
|
||||
},
|
||||
]);
|
||||
|
||||
assert_eq!(collisions[0].module, "runtime");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_different_branches() {
|
||||
let collisions = detect_branch_lock_collisions(&[
|
||||
BranchLockIntent {
|
||||
lane_id: "lane-a".to_string(),
|
||||
branch: "feature/a".to_string(),
|
||||
worktree: None,
|
||||
modules: vec!["runtime/mcp".to_string()],
|
||||
},
|
||||
BranchLockIntent {
|
||||
lane_id: "lane-b".to_string(),
|
||||
branch: "feature/b".to_string(),
|
||||
worktree: None,
|
||||
modules: vec!["runtime/mcp".to_string()],
|
||||
},
|
||||
]);
|
||||
|
||||
assert!(collisions.is_empty());
|
||||
}
|
||||
}
|
||||
@ -5,6 +5,7 @@ const COMPACT_CONTINUATION_PREAMBLE: &str =
|
||||
const COMPACT_RECENT_MESSAGES_NOTE: &str = "Recent messages are preserved verbatim.";
|
||||
const COMPACT_DIRECT_RESUME_INSTRUCTION: &str = "Continue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text.";
|
||||
|
||||
/// Thresholds controlling when and how a session is compacted.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct CompactionConfig {
|
||||
pub preserve_recent_messages: usize,
|
||||
@ -20,6 +21,7 @@ impl Default for CompactionConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of compacting a session into a summary plus preserved tail messages.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CompactionResult {
|
||||
pub summary: String,
|
||||
@ -28,11 +30,13 @@ pub struct CompactionResult {
|
||||
pub removed_message_count: usize,
|
||||
}
|
||||
|
||||
/// Roughly estimates the token footprint of the current session transcript.
|
||||
#[must_use]
|
||||
pub fn estimate_session_tokens(session: &Session) -> usize {
|
||||
session.messages.iter().map(estimate_message_tokens).sum()
|
||||
}
|
||||
|
||||
/// Returns `true` when the session exceeds the configured compaction budget.
|
||||
#[must_use]
|
||||
pub fn should_compact(session: &Session, config: CompactionConfig) -> bool {
|
||||
let start = compacted_summary_prefix_len(session);
|
||||
@ -46,6 +50,7 @@ pub fn should_compact(session: &Session, config: CompactionConfig) -> bool {
|
||||
>= config.max_estimated_tokens
|
||||
}
|
||||
|
||||
/// Normalizes a compaction summary into user-facing continuation text.
|
||||
#[must_use]
|
||||
pub fn format_compact_summary(summary: &str) -> String {
|
||||
let without_analysis = strip_tag_block(summary, "analysis");
|
||||
@ -61,6 +66,7 @@ pub fn format_compact_summary(summary: &str) -> String {
|
||||
collapse_blank_lines(&formatted).trim().to_string()
|
||||
}
|
||||
|
||||
/// Builds the synthetic system message used after session compaction.
|
||||
#[must_use]
|
||||
pub fn get_compact_continuation_message(
|
||||
summary: &str,
|
||||
@ -85,6 +91,7 @@ pub fn get_compact_continuation_message(
|
||||
base
|
||||
}
|
||||
|
||||
/// Compacts a session by summarizing older messages and preserving the recent tail.
|
||||
#[must_use]
|
||||
pub fn compact_session(session: &Session, config: CompactionConfig) -> CompactionResult {
|
||||
if !should_compact(session, config) {
|
||||
@ -101,10 +108,55 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
||||
.first()
|
||||
.and_then(extract_existing_compacted_summary);
|
||||
let compacted_prefix_len = usize::from(existing_summary.is_some());
|
||||
let keep_from = session
|
||||
let raw_keep_from = session
|
||||
.messages
|
||||
.len()
|
||||
.saturating_sub(config.preserve_recent_messages);
|
||||
// Ensure we do not split a tool-use / tool-result pair at the compaction
|
||||
// boundary. If the first preserved message is a user message whose first
|
||||
// block is a ToolResult, the assistant message with the matching ToolUse
|
||||
// was slated for removal — that produces an orphaned tool role message on
|
||||
// the OpenAI-compat path (400: tool message must follow assistant with
|
||||
// tool_calls). Walk the boundary back until we start at a safe point.
|
||||
let keep_from = {
|
||||
let mut k = raw_keep_from;
|
||||
// If the first preserved message is a tool-result turn, ensure its
|
||||
// paired assistant tool-use turn is preserved too. Without this fix,
|
||||
// the OpenAI-compat adapter sends an orphaned 'tool' role message
|
||||
// with no preceding assistant 'tool_calls', which providers reject
|
||||
// with a 400. We walk back only if the immediately preceding message
|
||||
// is NOT an assistant message that contains a ToolUse block (i.e. the
|
||||
// pair is actually broken at the boundary).
|
||||
loop {
|
||||
if k == 0 || k <= compacted_prefix_len {
|
||||
break;
|
||||
}
|
||||
let first_preserved = &session.messages[k];
|
||||
let starts_with_tool_result = first_preserved
|
||||
.blocks
|
||||
.first()
|
||||
.map(|b| matches!(b, ContentBlock::ToolResult { .. }))
|
||||
.unwrap_or(false);
|
||||
if !starts_with_tool_result {
|
||||
break;
|
||||
}
|
||||
// Check the message just before the current boundary.
|
||||
let preceding = &session.messages[k - 1];
|
||||
let preceding_has_tool_use = preceding
|
||||
.blocks
|
||||
.iter()
|
||||
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
|
||||
if preceding_has_tool_use {
|
||||
// Pair is intact — walk back one more to include the assistant turn.
|
||||
k = k.saturating_sub(1);
|
||||
break;
|
||||
}
|
||||
// Preceding message has no ToolUse but we have a ToolResult —
|
||||
// this is already an orphaned pair; walk back to try to fix it.
|
||||
k = k.saturating_sub(1);
|
||||
}
|
||||
k
|
||||
};
|
||||
let removed = &session.messages[compacted_prefix_len..keep_from];
|
||||
let preserved = session.messages[keep_from..].to_vec();
|
||||
let summary =
|
||||
@ -119,13 +171,14 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio
|
||||
}];
|
||||
compacted_messages.extend(preserved);
|
||||
|
||||
let mut compacted_session = session.clone();
|
||||
compacted_session.messages = compacted_messages;
|
||||
compacted_session.record_compaction(summary.clone(), removed.len());
|
||||
|
||||
CompactionResult {
|
||||
summary,
|
||||
formatted_summary,
|
||||
compacted_session: Session {
|
||||
version: session.version,
|
||||
messages: compacted_messages,
|
||||
},
|
||||
compacted_session,
|
||||
removed_message_count: removed.len(),
|
||||
}
|
||||
}
|
||||
@ -160,9 +213,7 @@ fn summarize_messages(messages: &[ConversationMessage]) -> String {
|
||||
.filter_map(|block| match block {
|
||||
ContentBlock::ToolUse { name, .. } => Some(name.as_str()),
|
||||
ContentBlock::ToolResult { tool_name, .. } => Some(tool_name.as_str()),
|
||||
ContentBlock::Text { .. }
|
||||
| ContentBlock::Thinking { .. }
|
||||
| ContentBlock::RedactedThinking { .. } => None,
|
||||
ContentBlock::Text { .. } => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
tool_names.sort_unstable();
|
||||
@ -277,8 +328,6 @@ fn summarize_block(block: &ContentBlock) -> String {
|
||||
"tool_result {tool_name}: {}{output}",
|
||||
if *is_error { "error " } else { "" }
|
||||
),
|
||||
ContentBlock::Thinking { thinking, .. } => format!("thinking: {thinking}"),
|
||||
ContentBlock::RedactedThinking { .. } => "thinking: <redacted>".to_string(),
|
||||
};
|
||||
truncate_summary(&raw, 160)
|
||||
}
|
||||
@ -328,8 +377,6 @@ fn collect_key_files(messages: &[ConversationMessage]) -> Vec<String> {
|
||||
.flat_map(|message| message.blocks.iter())
|
||||
.map(|block| match block {
|
||||
ContentBlock::Text { text } => text.as_str(),
|
||||
ContentBlock::Thinking { thinking, .. } => thinking.as_str(),
|
||||
ContentBlock::RedactedThinking { .. } => "",
|
||||
ContentBlock::ToolUse { input, .. } => input.as_str(),
|
||||
ContentBlock::ToolResult { output, .. } => output.as_str(),
|
||||
})
|
||||
@ -354,9 +401,7 @@ fn first_text_block(message: &ConversationMessage) -> Option<&str> {
|
||||
ContentBlock::Text { text } if !text.trim().is_empty() => Some(text.as_str()),
|
||||
ContentBlock::ToolUse { .. }
|
||||
| ContentBlock::ToolResult { .. }
|
||||
| ContentBlock::Text { .. }
|
||||
| ContentBlock::Thinking { .. }
|
||||
| ContentBlock::RedactedThinking { .. } => None,
|
||||
| ContentBlock::Text { .. } => None,
|
||||
})
|
||||
}
|
||||
|
||||
@ -402,8 +447,6 @@ fn estimate_message_tokens(message: &ConversationMessage) -> usize {
|
||||
.iter()
|
||||
.map(|block| match block {
|
||||
ContentBlock::Text { text } => text.len() / 4 + 1,
|
||||
ContentBlock::Thinking { thinking, .. } => thinking.len() / 4 + 1,
|
||||
ContentBlock::RedactedThinking { .. } => 1,
|
||||
ContentBlock::ToolUse { name, input, .. } => (name.len() + input.len()) / 4 + 1,
|
||||
ContentBlock::ToolResult {
|
||||
tool_name, output, ..
|
||||
@ -512,7 +555,7 @@ fn extract_summary_timeline(summary: &str) -> Vec<String> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
collect_key_files, compact_session, estimate_session_tokens, format_compact_summary,
|
||||
collect_key_files, compact_session, format_compact_summary,
|
||||
get_compact_continuation_message, infer_pending_work, should_compact, CompactionConfig,
|
||||
};
|
||||
use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session};
|
||||
@ -525,10 +568,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn leaves_small_sessions_unchanged() {
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![ConversationMessage::user_text("hello")],
|
||||
};
|
||||
let mut session = Session::new();
|
||||
session.messages = vec![ConversationMessage::user_text("hello")];
|
||||
|
||||
let result = compact_session(&session, CompactionConfig::default());
|
||||
assert_eq!(result.removed_message_count, 0);
|
||||
@ -539,9 +580,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn compacts_older_messages_into_a_system_summary() {
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![
|
||||
let mut session = Session::new();
|
||||
session.messages = vec![
|
||||
ConversationMessage::user_text("one ".repeat(200)),
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "two ".repeat(200),
|
||||
@ -554,8 +594,7 @@ mod tests {
|
||||
}],
|
||||
usage: None,
|
||||
},
|
||||
],
|
||||
};
|
||||
];
|
||||
|
||||
let result = compact_session(
|
||||
&session,
|
||||
@ -565,7 +604,14 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
assert_eq!(result.removed_message_count, 2);
|
||||
// With the tool-use/tool-result boundary fix, the compaction preserves
|
||||
// one extra message to avoid an orphaned tool result at the boundary.
|
||||
// messages[1] (assistant) must be kept along with messages[2] (tool result).
|
||||
assert!(
|
||||
result.removed_message_count <= 2,
|
||||
"expected at most 2 removed, got {}",
|
||||
result.removed_message_count
|
||||
);
|
||||
assert_eq!(
|
||||
result.compacted_session.messages[0].role,
|
||||
MessageRole::System
|
||||
@ -583,28 +629,29 @@ mod tests {
|
||||
max_estimated_tokens: 1,
|
||||
}
|
||||
));
|
||||
// Note: with the tool-use/tool-result boundary guard the compacted session
|
||||
// may preserve one extra message at the boundary, so token reduction is
|
||||
// not guaranteed for small sessions. The invariant that matters is that
|
||||
// the removed_message_count is non-zero (something was compacted).
|
||||
assert!(
|
||||
estimate_session_tokens(&result.compacted_session) < estimate_session_tokens(&session)
|
||||
result.removed_message_count > 0,
|
||||
"compaction must remove at least one message"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keeps_previous_compacted_context_when_compacting_again() {
|
||||
let initial_session = Session {
|
||||
version: 1,
|
||||
messages: vec![
|
||||
let mut initial_session = Session::new();
|
||||
initial_session.messages = vec![
|
||||
ConversationMessage::user_text("Investigate rust/crates/runtime/src/compact.rs"),
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "I will inspect the compact flow.".to_string(),
|
||||
}]),
|
||||
ConversationMessage::user_text(
|
||||
"Also update rust/crates/runtime/src/conversation.rs",
|
||||
),
|
||||
ConversationMessage::user_text("Also update rust/crates/runtime/src/conversation.rs"),
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "Next: preserve prior summary context during auto compact.".to_string(),
|
||||
}]),
|
||||
],
|
||||
};
|
||||
];
|
||||
let config = CompactionConfig {
|
||||
preserve_recent_messages: 2,
|
||||
max_estimated_tokens: 1,
|
||||
@ -619,13 +666,9 @@ mod tests {
|
||||
}]),
|
||||
]);
|
||||
|
||||
let second = compact_session(
|
||||
&Session {
|
||||
version: 1,
|
||||
messages: follow_up_messages,
|
||||
},
|
||||
config,
|
||||
);
|
||||
let mut second_session = Session::new();
|
||||
second_session.messages = follow_up_messages;
|
||||
let second = compact_session(&second_session, config);
|
||||
|
||||
assert!(second
|
||||
.formatted_summary
|
||||
@ -654,9 +697,8 @@ mod tests {
|
||||
#[test]
|
||||
fn ignores_existing_compacted_summary_when_deciding_to_recompact() {
|
||||
let summary = "<summary>Conversation summary:\n- Scope: earlier work preserved.\n- Key timeline:\n - user: large preserved context\n</summary>";
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![
|
||||
let mut session = Session::new();
|
||||
session.messages = vec![
|
||||
ConversationMessage {
|
||||
role: MessageRole::System,
|
||||
blocks: vec![ContentBlock::Text {
|
||||
@ -668,8 +710,7 @@ mod tests {
|
||||
ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "recent".to_string(),
|
||||
}]),
|
||||
],
|
||||
};
|
||||
];
|
||||
|
||||
assert!(!should_compact(
|
||||
&session,
|
||||
@ -692,10 +733,84 @@ mod tests {
|
||||
#[test]
|
||||
fn extracts_key_files_from_message_content() {
|
||||
let files = collect_key_files(&[ConversationMessage::user_text(
|
||||
"Update rust/crates/runtime/src/compact.rs and rust/crates/tools/src/lib.rs next.",
|
||||
"Update rust/crates/runtime/src/compact.rs and rust/crates/rusty-claude-cli/src/main.rs next.",
|
||||
)]);
|
||||
assert!(files.contains(&"rust/crates/runtime/src/compact.rs".to_string()));
|
||||
assert!(files.contains(&"rust/crates/tools/src/lib.rs".to_string()));
|
||||
assert!(files.contains(&"rust/crates/rusty-claude-cli/src/main.rs".to_string()));
|
||||
}
|
||||
|
||||
/// Regression: compaction must not split an assistant(ToolUse) /
|
||||
/// user(ToolResult) pair at the boundary. An orphaned tool-result message
|
||||
/// without the preceding assistant tool_calls causes a 400 on the
|
||||
/// OpenAI-compat path (gaebal-gajae repro 2026-04-09).
|
||||
#[test]
|
||||
fn compaction_does_not_split_tool_use_tool_result_pair() {
|
||||
use crate::session::{ContentBlock, Session};
|
||||
|
||||
let tool_id = "call_abc";
|
||||
let mut session = Session::default();
|
||||
// Turn 1: user prompt
|
||||
session
|
||||
.push_message(ConversationMessage::user_text("Search for files"))
|
||||
.unwrap();
|
||||
// Turn 2: assistant calls a tool
|
||||
session
|
||||
.push_message(ConversationMessage::assistant(vec![
|
||||
ContentBlock::ToolUse {
|
||||
id: tool_id.to_string(),
|
||||
name: "search".to_string(),
|
||||
input: "{\"q\":\"*.rs\"}".to_string(),
|
||||
},
|
||||
]))
|
||||
.unwrap();
|
||||
// Turn 3: tool result
|
||||
session
|
||||
.push_message(ConversationMessage::tool_result(
|
||||
tool_id,
|
||||
"search",
|
||||
"found 5 files",
|
||||
false,
|
||||
))
|
||||
.unwrap();
|
||||
// Turn 4: assistant final response
|
||||
session
|
||||
.push_message(ConversationMessage::assistant(vec![ContentBlock::Text {
|
||||
text: "Done.".to_string(),
|
||||
}]))
|
||||
.unwrap();
|
||||
|
||||
// Compact preserving only 1 recent message — without the fix this
|
||||
// would cut the boundary so that the tool result (turn 3) is first,
|
||||
// without its preceding assistant tool_calls (turn 2).
|
||||
let config = CompactionConfig {
|
||||
preserve_recent_messages: 1,
|
||||
..CompactionConfig::default()
|
||||
};
|
||||
let result = compact_session(&session, config);
|
||||
// After compaction, no two consecutive messages should have the pattern
|
||||
// tool_result immediately following a non-assistant message (i.e. an
|
||||
// orphaned tool result without a preceding assistant ToolUse).
|
||||
let messages = &result.compacted_session.messages;
|
||||
for i in 1..messages.len() {
|
||||
let curr_is_tool_result = messages[i]
|
||||
.blocks
|
||||
.first()
|
||||
.map(|b| matches!(b, ContentBlock::ToolResult { .. }))
|
||||
.unwrap_or(false);
|
||||
if curr_is_tool_result {
|
||||
let prev_has_tool_use = messages[i - 1]
|
||||
.blocks
|
||||
.iter()
|
||||
.any(|b| matches!(b, ContentBlock::ToolUse { .. }));
|
||||
assert!(
|
||||
prev_has_tool_use,
|
||||
"message[{}] is a ToolResult but message[{}] has no ToolUse: {:?}",
|
||||
i,
|
||||
i - 1,
|
||||
&messages[i - 1].blocks
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
901
crates/runtime/src/config_validate.rs
Normal file
901
crates/runtime/src/config_validate.rs
Normal file
@ -0,0 +1,901 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::config::ConfigError;
|
||||
use crate::json::JsonValue;
|
||||
|
||||
/// Diagnostic emitted when a config file contains a suspect field.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ConfigDiagnostic {
|
||||
pub path: String,
|
||||
pub field: String,
|
||||
pub line: Option<usize>,
|
||||
pub kind: DiagnosticKind,
|
||||
}
|
||||
|
||||
/// Classification of the diagnostic.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum DiagnosticKind {
|
||||
UnknownKey {
|
||||
suggestion: Option<String>,
|
||||
},
|
||||
WrongType {
|
||||
expected: &'static str,
|
||||
got: &'static str,
|
||||
},
|
||||
Deprecated {
|
||||
replacement: &'static str,
|
||||
},
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConfigDiagnostic {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let location = self
|
||||
.line
|
||||
.map_or_else(String::new, |line| format!(" (line {line})"));
|
||||
match &self.kind {
|
||||
DiagnosticKind::UnknownKey { suggestion: None } => {
|
||||
write!(f, "{}: unknown key \"{}\"{location}", self.path, self.field)
|
||||
}
|
||||
DiagnosticKind::UnknownKey {
|
||||
suggestion: Some(hint),
|
||||
} => {
|
||||
write!(
|
||||
f,
|
||||
"{}: unknown key \"{}\"{location}. Did you mean \"{}\"?",
|
||||
self.path, self.field, hint
|
||||
)
|
||||
}
|
||||
DiagnosticKind::WrongType { expected, got } => {
|
||||
write!(
|
||||
f,
|
||||
"{}: field \"{}\" must be {expected}, got {got}{location}",
|
||||
self.path, self.field
|
||||
)
|
||||
}
|
||||
DiagnosticKind::Deprecated { replacement } => {
|
||||
write!(
|
||||
f,
|
||||
"{}: field \"{}\" is deprecated{location}. Use \"{replacement}\" instead",
|
||||
self.path, self.field
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of validating a single config file.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ValidationResult {
|
||||
pub errors: Vec<ConfigDiagnostic>,
|
||||
pub warnings: Vec<ConfigDiagnostic>,
|
||||
}
|
||||
|
||||
impl ValidationResult {
|
||||
#[must_use]
|
||||
pub fn is_ok(&self) -> bool {
|
||||
self.errors.is_empty()
|
||||
}
|
||||
|
||||
fn merge(&mut self, other: Self) {
|
||||
self.errors.extend(other.errors);
|
||||
self.warnings.extend(other.warnings);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- known-key schema ----
|
||||
|
||||
/// Expected type for a config field.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum FieldType {
|
||||
String,
|
||||
Bool,
|
||||
Object,
|
||||
StringArray,
|
||||
Number,
|
||||
}
|
||||
|
||||
impl FieldType {
|
||||
fn label(self) -> &'static str {
|
||||
match self {
|
||||
Self::String => "a string",
|
||||
Self::Bool => "a boolean",
|
||||
Self::Object => "an object",
|
||||
Self::StringArray => "an array of strings",
|
||||
Self::Number => "a number",
|
||||
}
|
||||
}
|
||||
|
||||
fn matches(self, value: &JsonValue) -> bool {
|
||||
match self {
|
||||
Self::String => value.as_str().is_some(),
|
||||
Self::Bool => value.as_bool().is_some(),
|
||||
Self::Object => value.as_object().is_some(),
|
||||
Self::StringArray => value
|
||||
.as_array()
|
||||
.is_some_and(|arr| arr.iter().all(|v| v.as_str().is_some())),
|
||||
Self::Number => value.as_i64().is_some(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn json_type_label(value: &JsonValue) -> &'static str {
|
||||
match value {
|
||||
JsonValue::Null => "null",
|
||||
JsonValue::Bool(_) => "a boolean",
|
||||
JsonValue::Number(_) => "a number",
|
||||
JsonValue::String(_) => "a string",
|
||||
JsonValue::Array(_) => "an array",
|
||||
JsonValue::Object(_) => "an object",
|
||||
}
|
||||
}
|
||||
|
||||
struct FieldSpec {
|
||||
name: &'static str,
|
||||
expected: FieldType,
|
||||
}
|
||||
|
||||
struct DeprecatedField {
|
||||
name: &'static str,
|
||||
replacement: &'static str,
|
||||
}
|
||||
|
||||
const TOP_LEVEL_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "$schema",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "model",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "hooks",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "permissions",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "permissionMode",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "mcpServers",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "oauth",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "enabledPlugins",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "plugins",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "sandbox",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "env",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "aliases",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "providerFallbacks",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "trustedRoots",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const HOOKS_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "PreToolUse",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "PostToolUse",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "PostToolUseFailure",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const PERMISSIONS_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "defaultMode",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "allow",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "deny",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "ask",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const PLUGINS_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "enabled",
|
||||
expected: FieldType::Object,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "externalDirectories",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "installRoot",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "registryPath",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "bundledRoot",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "maxOutputTokens",
|
||||
expected: FieldType::Number,
|
||||
},
|
||||
];
|
||||
|
||||
const SANDBOX_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "enabled",
|
||||
expected: FieldType::Bool,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "namespaceRestrictions",
|
||||
expected: FieldType::Bool,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "networkIsolation",
|
||||
expected: FieldType::Bool,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "filesystemMode",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "allowedMounts",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const OAUTH_FIELDS: &[FieldSpec] = &[
|
||||
FieldSpec {
|
||||
name: "clientId",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "authorizeUrl",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "tokenUrl",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "callbackPort",
|
||||
expected: FieldType::Number,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "manualRedirectUrl",
|
||||
expected: FieldType::String,
|
||||
},
|
||||
FieldSpec {
|
||||
name: "scopes",
|
||||
expected: FieldType::StringArray,
|
||||
},
|
||||
];
|
||||
|
||||
const DEPRECATED_FIELDS: &[DeprecatedField] = &[
|
||||
DeprecatedField {
|
||||
name: "permissionMode",
|
||||
replacement: "permissions.defaultMode",
|
||||
},
|
||||
DeprecatedField {
|
||||
name: "enabledPlugins",
|
||||
replacement: "plugins.enabled",
|
||||
},
|
||||
];
|
||||
|
||||
// ---- line-number resolution ----
|
||||
|
||||
/// Find the 1-based line number where a JSON key first appears in the raw source.
|
||||
fn find_key_line(source: &str, key: &str) -> Option<usize> {
|
||||
// Search for `"key"` followed by optional whitespace and a colon.
|
||||
let needle = format!("\"{key}\"");
|
||||
let mut search_start = 0;
|
||||
while let Some(offset) = source[search_start..].find(&needle) {
|
||||
let absolute = search_start + offset;
|
||||
let after = absolute + needle.len();
|
||||
// Verify the next non-whitespace char is `:` to confirm this is a key, not a value.
|
||||
if source[after..].chars().find(|ch| !ch.is_ascii_whitespace()) == Some(':') {
|
||||
return Some(source[..absolute].chars().filter(|&ch| ch == '\n').count() + 1);
|
||||
}
|
||||
search_start = after;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// ---- core validation ----
|
||||
|
||||
fn validate_object_keys(
|
||||
object: &BTreeMap<String, JsonValue>,
|
||||
known_fields: &[FieldSpec],
|
||||
prefix: &str,
|
||||
source: &str,
|
||||
path_display: &str,
|
||||
) -> ValidationResult {
|
||||
let mut result = ValidationResult {
|
||||
errors: Vec::new(),
|
||||
warnings: Vec::new(),
|
||||
};
|
||||
|
||||
let known_names: Vec<&str> = known_fields.iter().map(|f| f.name).collect();
|
||||
|
||||
for (key, value) in object {
|
||||
let field_path = if prefix.is_empty() {
|
||||
key.clone()
|
||||
} else {
|
||||
format!("{prefix}.{key}")
|
||||
};
|
||||
|
||||
if let Some(spec) = known_fields.iter().find(|f| f.name == key) {
|
||||
// Type check.
|
||||
if !spec.expected.matches(value) {
|
||||
result.errors.push(ConfigDiagnostic {
|
||||
path: path_display.to_string(),
|
||||
field: field_path,
|
||||
line: find_key_line(source, key),
|
||||
kind: DiagnosticKind::WrongType {
|
||||
expected: spec.expected.label(),
|
||||
got: json_type_label(value),
|
||||
},
|
||||
});
|
||||
}
|
||||
} else if DEPRECATED_FIELDS.iter().any(|d| d.name == key) {
|
||||
// Deprecated key — handled separately, not an unknown-key error.
|
||||
} else {
|
||||
// Unknown key.
|
||||
let suggestion = suggest_field(key, &known_names);
|
||||
result.errors.push(ConfigDiagnostic {
|
||||
path: path_display.to_string(),
|
||||
field: field_path,
|
||||
line: find_key_line(source, key),
|
||||
kind: DiagnosticKind::UnknownKey { suggestion },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn suggest_field(input: &str, candidates: &[&str]) -> Option<String> {
|
||||
let input_lower = input.to_ascii_lowercase();
|
||||
candidates
|
||||
.iter()
|
||||
.filter_map(|candidate| {
|
||||
let distance = simple_edit_distance(&input_lower, &candidate.to_ascii_lowercase());
|
||||
(distance <= 3).then_some((distance, *candidate))
|
||||
})
|
||||
.min_by_key(|(distance, _)| *distance)
|
||||
.map(|(_, name)| name.to_string())
|
||||
}
|
||||
|
||||
fn simple_edit_distance(left: &str, right: &str) -> usize {
|
||||
if left.is_empty() {
|
||||
return right.len();
|
||||
}
|
||||
if right.is_empty() {
|
||||
return left.len();
|
||||
}
|
||||
let right_chars: Vec<char> = right.chars().collect();
|
||||
let mut previous: Vec<usize> = (0..=right_chars.len()).collect();
|
||||
let mut current = vec![0; right_chars.len() + 1];
|
||||
|
||||
for (left_index, left_char) in left.chars().enumerate() {
|
||||
current[0] = left_index + 1;
|
||||
for (right_index, right_char) in right_chars.iter().enumerate() {
|
||||
let cost = usize::from(left_char != *right_char);
|
||||
current[right_index + 1] = (previous[right_index + 1] + 1)
|
||||
.min(current[right_index] + 1)
|
||||
.min(previous[right_index] + cost);
|
||||
}
|
||||
previous.clone_from(¤t);
|
||||
}
|
||||
|
||||
previous[right_chars.len()]
|
||||
}
|
||||
|
||||
/// Validate a parsed config file's keys and types against the known schema.
|
||||
///
|
||||
/// Returns diagnostics (errors and deprecation warnings) without blocking the load.
|
||||
pub fn validate_config_file(
|
||||
object: &BTreeMap<String, JsonValue>,
|
||||
source: &str,
|
||||
file_path: &Path,
|
||||
) -> ValidationResult {
|
||||
let path_display = file_path.display().to_string();
|
||||
let mut result = validate_object_keys(object, TOP_LEVEL_FIELDS, "", source, &path_display);
|
||||
|
||||
// Check deprecated fields.
|
||||
for deprecated in DEPRECATED_FIELDS {
|
||||
if object.contains_key(deprecated.name) {
|
||||
result.warnings.push(ConfigDiagnostic {
|
||||
path: path_display.clone(),
|
||||
field: deprecated.name.to_string(),
|
||||
line: find_key_line(source, deprecated.name),
|
||||
kind: DiagnosticKind::Deprecated {
|
||||
replacement: deprecated.replacement,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Validate known nested objects.
|
||||
if let Some(hooks) = object.get("hooks").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
hooks,
|
||||
HOOKS_FIELDS,
|
||||
"hooks",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(permissions) = object.get("permissions").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
permissions,
|
||||
PERMISSIONS_FIELDS,
|
||||
"permissions",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(plugins) = object.get("plugins").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
plugins,
|
||||
PLUGINS_FIELDS,
|
||||
"plugins",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(sandbox) = object.get("sandbox").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
sandbox,
|
||||
SANDBOX_FIELDS,
|
||||
"sandbox",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
if let Some(oauth) = object.get("oauth").and_then(JsonValue::as_object) {
|
||||
result.merge(validate_object_keys(
|
||||
oauth,
|
||||
OAUTH_FIELDS,
|
||||
"oauth",
|
||||
source,
|
||||
&path_display,
|
||||
));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Check whether a file path uses an unsupported config format (e.g. TOML).
|
||||
pub fn check_unsupported_format(file_path: &Path) -> Result<(), ConfigError> {
|
||||
if let Some(ext) = file_path.extension().and_then(|e| e.to_str()) {
|
||||
if ext.eq_ignore_ascii_case("toml") {
|
||||
return Err(ConfigError::Parse(format!(
|
||||
"{}: TOML config files are not supported. Use JSON (settings.json) instead",
|
||||
file_path.display()
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Format all diagnostics into a human-readable report.
|
||||
#[must_use]
|
||||
pub fn format_diagnostics(result: &ValidationResult) -> String {
|
||||
let mut lines = Vec::new();
|
||||
for warning in &result.warnings {
|
||||
lines.push(format!("warning: {warning}"));
|
||||
}
|
||||
for error in &result.errors {
|
||||
lines.push(format!("error: {error}"));
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn test_path() -> PathBuf {
|
||||
PathBuf::from("/test/settings.json")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_unknown_top_level_key() {
|
||||
// given
|
||||
let source = r#"{"model": "opus", "unknownField": true}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "unknownField");
|
||||
assert!(matches!(
|
||||
result.errors[0].kind,
|
||||
DiagnosticKind::UnknownKey { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_wrong_type_for_model() {
|
||||
// given
|
||||
let source = r#"{"model": 123}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "model");
|
||||
assert!(matches!(
|
||||
result.errors[0].kind,
|
||||
DiagnosticKind::WrongType {
|
||||
expected: "a string",
|
||||
got: "a number"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_deprecated_permission_mode() {
|
||||
// given
|
||||
let source = r#"{"permissionMode": "plan"}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.warnings.len(), 1);
|
||||
assert_eq!(result.warnings[0].field, "permissionMode");
|
||||
assert!(matches!(
|
||||
result.warnings[0].kind,
|
||||
DiagnosticKind::Deprecated {
|
||||
replacement: "permissions.defaultMode"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_deprecated_enabled_plugins() {
|
||||
// given
|
||||
let source = r#"{"enabledPlugins": {"tool-guard@builtin": true}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.warnings.len(), 1);
|
||||
assert_eq!(result.warnings[0].field, "enabledPlugins");
|
||||
assert!(matches!(
|
||||
result.warnings[0].kind,
|
||||
DiagnosticKind::Deprecated {
|
||||
replacement: "plugins.enabled"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reports_line_number_for_unknown_key() {
|
||||
// given
|
||||
let source = "{\n \"model\": \"opus\",\n \"badKey\": true\n}";
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].line, Some(3));
|
||||
assert_eq!(result.errors[0].field, "badKey");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reports_line_number_for_wrong_type() {
|
||||
// given
|
||||
let source = "{\n \"model\": 42\n}";
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].line, Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_hooks_keys() {
|
||||
// given
|
||||
let source = r#"{"hooks": {"PreToolUse": ["cmd"], "BadHook": ["x"]}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "hooks.BadHook");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_permissions_keys() {
|
||||
// given
|
||||
let source = r#"{"permissions": {"allow": ["Read"], "denyAll": true}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "permissions.denyAll");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_sandbox_keys() {
|
||||
// given
|
||||
let source = r#"{"sandbox": {"enabled": true, "containerMode": "strict"}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "sandbox.containerMode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_plugins_keys() {
|
||||
// given
|
||||
let source = r#"{"plugins": {"installRoot": "/tmp", "autoUpdate": true}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "plugins.autoUpdate");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_nested_oauth_keys() {
|
||||
// given
|
||||
let source = r#"{"oauth": {"clientId": "abc", "secret": "hidden"}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "oauth.secret");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_config_produces_no_diagnostics() {
|
||||
// given
|
||||
let source = r#"{
|
||||
"model": "opus",
|
||||
"hooks": {"PreToolUse": ["guard"]},
|
||||
"permissions": {"defaultMode": "plan", "allow": ["Read"]},
|
||||
"mcpServers": {},
|
||||
"sandbox": {"enabled": false}
|
||||
}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert!(result.is_ok());
|
||||
assert!(result.warnings.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn suggests_close_field_name() {
|
||||
// given
|
||||
let source = r#"{"modle": "opus"}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
match &result.errors[0].kind {
|
||||
DiagnosticKind::UnknownKey {
|
||||
suggestion: Some(s),
|
||||
} => assert_eq!(s, "model"),
|
||||
other => panic!("expected suggestion, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_diagnostics_includes_all_entries() {
|
||||
// given
|
||||
let source = r#"{"permissionMode": "plan", "badKey": 1}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// when
|
||||
let output = format_diagnostics(&result);
|
||||
|
||||
// then
|
||||
assert!(output.contains("warning:"));
|
||||
assert!(output.contains("error:"));
|
||||
assert!(output.contains("badKey"));
|
||||
assert!(output.contains("permissionMode"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_unsupported_format_rejects_toml() {
|
||||
// given
|
||||
let path = PathBuf::from("/home/.claw/settings.toml");
|
||||
|
||||
// when
|
||||
let result = check_unsupported_format(&path);
|
||||
|
||||
// then
|
||||
assert!(result.is_err());
|
||||
let message = result.unwrap_err().to_string();
|
||||
assert!(message.contains("TOML"));
|
||||
assert!(message.contains("settings.toml"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_unsupported_format_allows_json() {
|
||||
// given
|
||||
let path = PathBuf::from("/home/.claw/settings.json");
|
||||
|
||||
// when / then
|
||||
assert!(check_unsupported_format(&path).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_type_in_nested_sandbox_field() {
|
||||
// given
|
||||
let source = r#"{"sandbox": {"enabled": "yes"}}"#;
|
||||
let parsed = JsonValue::parse(source).expect("valid json");
|
||||
let object = parsed.as_object().expect("object");
|
||||
|
||||
// when
|
||||
let result = validate_config_file(object, source, &test_path());
|
||||
|
||||
// then
|
||||
assert_eq!(result.errors.len(), 1);
|
||||
assert_eq!(result.errors[0].field, "sandbox.enabled");
|
||||
assert!(matches!(
|
||||
result.errors[0].kind,
|
||||
DiagnosticKind::WrongType {
|
||||
expected: "a boolean",
|
||||
got: "a string"
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format_unknown_key_with_line() {
|
||||
// given
|
||||
let diag = ConfigDiagnostic {
|
||||
path: "/test/settings.json".to_string(),
|
||||
field: "badKey".to_string(),
|
||||
line: Some(5),
|
||||
kind: DiagnosticKind::UnknownKey { suggestion: None },
|
||||
};
|
||||
|
||||
// when
|
||||
let output = diag.to_string();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
output,
|
||||
r#"/test/settings.json: unknown key "badKey" (line 5)"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format_wrong_type_with_line() {
|
||||
// given
|
||||
let diag = ConfigDiagnostic {
|
||||
path: "/test/settings.json".to_string(),
|
||||
field: "model".to_string(),
|
||||
line: Some(2),
|
||||
kind: DiagnosticKind::WrongType {
|
||||
expected: "a string",
|
||||
got: "a number",
|
||||
},
|
||||
};
|
||||
|
||||
// when
|
||||
let output = diag.to_string();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
output,
|
||||
r#"/test/settings.json: field "model" must be a string, got a number (line 2)"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_format_deprecated_with_line() {
|
||||
// given
|
||||
let diag = ConfigDiagnostic {
|
||||
path: "/test/settings.json".to_string(),
|
||||
field: "permissionMode".to_string(),
|
||||
line: Some(3),
|
||||
kind: DiagnosticKind::Deprecated {
|
||||
replacement: "permissions.defaultMode",
|
||||
},
|
||||
};
|
||||
|
||||
// when
|
||||
let output = diag.to_string();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
output,
|
||||
r#"/test/settings.json: field "permissionMode" is deprecated (line 3). Use "permissions.defaultMode" instead"#
|
||||
);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -9,6 +9,41 @@ use regex::RegexBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use walkdir::WalkDir;
|
||||
|
||||
/// Maximum file size that can be read (10 MB).
|
||||
const MAX_READ_SIZE: u64 = 10 * 1024 * 1024;
|
||||
|
||||
/// Maximum file size that can be written (10 MB).
|
||||
const MAX_WRITE_SIZE: usize = 10 * 1024 * 1024;
|
||||
|
||||
/// Check whether a file appears to contain binary content by examining
|
||||
/// the first chunk for NUL bytes.
|
||||
fn is_binary_file(path: &Path) -> io::Result<bool> {
|
||||
use std::io::Read;
|
||||
let mut file = fs::File::open(path)?;
|
||||
let mut buffer = [0u8; 8192];
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
Ok(buffer[..bytes_read].contains(&0))
|
||||
}
|
||||
|
||||
/// Validate that a resolved path stays within the given workspace root.
|
||||
/// Returns the canonical path on success, or an error if the path escapes
|
||||
/// the workspace boundary (e.g. via `../` traversal or symlink).
|
||||
#[allow(dead_code)]
|
||||
fn validate_workspace_boundary(resolved: &Path, workspace_root: &Path) -> io::Result<()> {
|
||||
if !resolved.starts_with(workspace_root) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
format!(
|
||||
"path {} escapes workspace boundary {}",
|
||||
resolved.display(),
|
||||
workspace_root.display()
|
||||
),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Text payload returned by file-reading operations.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct TextFilePayload {
|
||||
#[serde(rename = "filePath")]
|
||||
@ -22,6 +57,7 @@ pub struct TextFilePayload {
|
||||
pub total_lines: usize,
|
||||
}
|
||||
|
||||
/// Output envelope for the `read_file` tool.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct ReadFileOutput {
|
||||
#[serde(rename = "type")]
|
||||
@ -29,6 +65,7 @@ pub struct ReadFileOutput {
|
||||
pub file: TextFilePayload,
|
||||
}
|
||||
|
||||
/// Structured patch hunk emitted by write and edit operations.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct StructuredPatchHunk {
|
||||
#[serde(rename = "oldStart")]
|
||||
@ -42,6 +79,7 @@ pub struct StructuredPatchHunk {
|
||||
pub lines: Vec<String>,
|
||||
}
|
||||
|
||||
/// Output envelope for full-file write operations.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct WriteFileOutput {
|
||||
#[serde(rename = "type")]
|
||||
@ -57,6 +95,7 @@ pub struct WriteFileOutput {
|
||||
pub git_diff: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Output envelope for targeted string-replacement edits.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct EditFileOutput {
|
||||
#[serde(rename = "filePath")]
|
||||
@ -77,6 +116,7 @@ pub struct EditFileOutput {
|
||||
pub git_diff: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Result of a glob-based filename search.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GlobSearchOutput {
|
||||
#[serde(rename = "durationMs")]
|
||||
@ -87,6 +127,7 @@ pub struct GlobSearchOutput {
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
/// Parameters accepted by the grep-style search tool.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct GrepSearchInput {
|
||||
pub pattern: String,
|
||||
@ -112,6 +153,7 @@ pub struct GrepSearchInput {
|
||||
pub multiline: Option<bool>,
|
||||
}
|
||||
|
||||
/// Result payload returned by the grep-style search tool.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct GrepSearchOutput {
|
||||
pub mode: Option<String>,
|
||||
@ -129,12 +171,35 @@ pub struct GrepSearchOutput {
|
||||
pub applied_offset: Option<usize>,
|
||||
}
|
||||
|
||||
/// Reads a text file and returns a line-windowed payload.
|
||||
pub fn read_file(
|
||||
path: &str,
|
||||
offset: Option<usize>,
|
||||
limit: Option<usize>,
|
||||
) -> io::Result<ReadFileOutput> {
|
||||
let absolute_path = normalize_path(path)?;
|
||||
|
||||
// Check file size before reading
|
||||
let metadata = fs::metadata(&absolute_path)?;
|
||||
if metadata.len() > MAX_READ_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"file is too large ({} bytes, max {} bytes)",
|
||||
metadata.len(),
|
||||
MAX_READ_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Detect binary files
|
||||
if is_binary_file(&absolute_path)? {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"file appears to be binary",
|
||||
));
|
||||
}
|
||||
|
||||
let content = fs::read_to_string(&absolute_path)?;
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
let start_index = offset.unwrap_or(0).min(lines.len());
|
||||
@ -155,7 +220,19 @@ pub fn read_file(
|
||||
})
|
||||
}
|
||||
|
||||
/// Replaces a file's contents and returns patch metadata.
|
||||
pub fn write_file(path: &str, content: &str) -> io::Result<WriteFileOutput> {
|
||||
if content.len() > MAX_WRITE_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"content is too large ({} bytes, max {} bytes)",
|
||||
content.len(),
|
||||
MAX_WRITE_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let absolute_path = normalize_path_allow_missing(path)?;
|
||||
let original_file = fs::read_to_string(&absolute_path).ok();
|
||||
if let Some(parent) = absolute_path.parent() {
|
||||
@ -177,6 +254,7 @@ pub fn write_file(path: &str, content: &str) -> io::Result<WriteFileOutput> {
|
||||
})
|
||||
}
|
||||
|
||||
/// Performs an in-file string replacement and returns patch metadata.
|
||||
pub fn edit_file(
|
||||
path: &str,
|
||||
old_string: &str,
|
||||
@ -217,6 +295,7 @@ pub fn edit_file(
|
||||
})
|
||||
}
|
||||
|
||||
/// Expands a glob pattern and returns matching filenames.
|
||||
pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result<GlobSearchOutput> {
|
||||
let started = Instant::now();
|
||||
let base_dir = path
|
||||
@ -229,14 +308,22 @@ pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result<GlobSearchOu
|
||||
base_dir.join(pattern).to_string_lossy().into_owned()
|
||||
};
|
||||
|
||||
// The `glob` crate does not support brace expansion ({a,b,c}).
|
||||
// Expand braces into multiple patterns so patterns like
|
||||
// `Assets/**/*.{cs,uxml,uss}` work correctly.
|
||||
let expanded = expand_braces(&search_pattern);
|
||||
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
let mut matches = Vec::new();
|
||||
let entries = glob::glob(&search_pattern)
|
||||
for pat in &expanded {
|
||||
let entries = glob::glob(pat)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error.to_string()))?;
|
||||
for entry in entries.flatten() {
|
||||
if entry.is_file() {
|
||||
if entry.is_file() && seen.insert(entry.clone()) {
|
||||
matches.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
matches.sort_by_key(|path| {
|
||||
fs::metadata(path)
|
||||
@ -260,6 +347,7 @@ pub fn glob_search(pattern: &str, path: Option<&str>) -> io::Result<GlobSearchOu
|
||||
})
|
||||
}
|
||||
|
||||
/// Runs a regex search over workspace files with optional context lines.
|
||||
pub fn grep_search(input: &GrepSearchInput) -> io::Result<GrepSearchOutput> {
|
||||
let base_path = input
|
||||
.path
|
||||
@ -477,18 +565,105 @@ fn normalize_path_allow_missing(path: &str) -> io::Result<PathBuf> {
|
||||
Ok(candidate)
|
||||
}
|
||||
|
||||
/// Read a file with workspace boundary enforcement.
|
||||
#[allow(dead_code)]
|
||||
pub fn read_file_in_workspace(
|
||||
path: &str,
|
||||
offset: Option<usize>,
|
||||
limit: Option<usize>,
|
||||
workspace_root: &Path,
|
||||
) -> io::Result<ReadFileOutput> {
|
||||
let absolute_path = normalize_path(path)?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
validate_workspace_boundary(&absolute_path, &canonical_root)?;
|
||||
read_file(path, offset, limit)
|
||||
}
|
||||
|
||||
/// Write a file with workspace boundary enforcement.
|
||||
#[allow(dead_code)]
|
||||
pub fn write_file_in_workspace(
|
||||
path: &str,
|
||||
content: &str,
|
||||
workspace_root: &Path,
|
||||
) -> io::Result<WriteFileOutput> {
|
||||
let absolute_path = normalize_path_allow_missing(path)?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
validate_workspace_boundary(&absolute_path, &canonical_root)?;
|
||||
write_file(path, content)
|
||||
}
|
||||
|
||||
/// Edit a file with workspace boundary enforcement.
|
||||
#[allow(dead_code)]
|
||||
pub fn edit_file_in_workspace(
|
||||
path: &str,
|
||||
old_string: &str,
|
||||
new_string: &str,
|
||||
replace_all: bool,
|
||||
workspace_root: &Path,
|
||||
) -> io::Result<EditFileOutput> {
|
||||
let absolute_path = normalize_path(path)?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
validate_workspace_boundary(&absolute_path, &canonical_root)?;
|
||||
edit_file(path, old_string, new_string, replace_all)
|
||||
}
|
||||
|
||||
/// Check whether a path is a symlink that resolves outside the workspace.
|
||||
#[allow(dead_code)]
|
||||
pub fn is_symlink_escape(path: &Path, workspace_root: &Path) -> io::Result<bool> {
|
||||
let metadata = fs::symlink_metadata(path)?;
|
||||
if !metadata.is_symlink() {
|
||||
return Ok(false);
|
||||
}
|
||||
let resolved = path.canonicalize()?;
|
||||
let canonical_root = workspace_root
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| workspace_root.to_path_buf());
|
||||
Ok(!resolved.starts_with(&canonical_root))
|
||||
}
|
||||
|
||||
/// Expand shell-style brace groups in a glob pattern.
|
||||
///
|
||||
/// Handles one level of braces: `foo.{a,b,c}` → `["foo.a", "foo.b", "foo.c"]`.
|
||||
/// Nested braces are not expanded (uncommon in practice).
|
||||
/// Patterns without braces pass through unchanged.
|
||||
fn expand_braces(pattern: &str) -> Vec<String> {
|
||||
let Some(open) = pattern.find('{') else {
|
||||
return vec![pattern.to_owned()];
|
||||
};
|
||||
let Some(close) = pattern[open..].find('}').map(|i| open + i) else {
|
||||
// Unmatched brace — treat as literal.
|
||||
return vec![pattern.to_owned()];
|
||||
};
|
||||
let prefix = &pattern[..open];
|
||||
let suffix = &pattern[close + 1..];
|
||||
let alternatives = &pattern[open + 1..close];
|
||||
alternatives
|
||||
.split(',')
|
||||
.flat_map(|alt| expand_braces(&format!("{prefix}{alt}{suffix}")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::{edit_file, glob_search, grep_search, read_file, write_file, GrepSearchInput};
|
||||
use super::{
|
||||
edit_file, expand_braces, glob_search, grep_search, is_symlink_escape, read_file,
|
||||
read_file_in_workspace, write_file, GrepSearchInput, MAX_WRITE_SIZE,
|
||||
};
|
||||
|
||||
fn temp_path(name: &str) -> std::path::PathBuf {
|
||||
let unique = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should move forward")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("claw-native-{name}-{unique}"))
|
||||
std::env::temp_dir().join(format!("clawd-native-{name}-{unique}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -513,6 +688,73 @@ mod tests {
|
||||
assert!(output.replace_all);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_binary_files() {
|
||||
let path = temp_path("binary-test.bin");
|
||||
std::fs::write(&path, b"\x00\x01\x02\x03binary content").expect("write should succeed");
|
||||
let result = read_file(path.to_string_lossy().as_ref(), None, None);
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert!(error.to_string().contains("binary"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_oversized_writes() {
|
||||
let path = temp_path("oversize-write.txt");
|
||||
let huge = "x".repeat(MAX_WRITE_SIZE + 1);
|
||||
let result = write_file(path.to_string_lossy().as_ref(), &huge);
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.kind(), std::io::ErrorKind::InvalidData);
|
||||
assert!(error.to_string().contains("too large"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforces_workspace_boundary() {
|
||||
let workspace = temp_path("workspace-boundary");
|
||||
std::fs::create_dir_all(&workspace).expect("workspace dir should be created");
|
||||
let inside = workspace.join("inside.txt");
|
||||
write_file(inside.to_string_lossy().as_ref(), "safe content")
|
||||
.expect("write inside workspace should succeed");
|
||||
|
||||
// Reading inside workspace should succeed
|
||||
let result =
|
||||
read_file_in_workspace(inside.to_string_lossy().as_ref(), None, None, &workspace);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Reading outside workspace should fail
|
||||
let outside = temp_path("outside-boundary.txt");
|
||||
write_file(outside.to_string_lossy().as_ref(), "unsafe content")
|
||||
.expect("write outside should succeed");
|
||||
let result =
|
||||
read_file_in_workspace(outside.to_string_lossy().as_ref(), None, None, &workspace);
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err();
|
||||
assert_eq!(error.kind(), std::io::ErrorKind::PermissionDenied);
|
||||
assert!(error.to_string().contains("escapes workspace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_symlink_escape() {
|
||||
let workspace = temp_path("symlink-workspace");
|
||||
std::fs::create_dir_all(&workspace).expect("workspace dir should be created");
|
||||
let outside = temp_path("symlink-target.txt");
|
||||
std::fs::write(&outside, "target content").expect("target should write");
|
||||
|
||||
let _link_path = workspace.join("escape-link.txt");
|
||||
#[cfg(unix)]
|
||||
{
|
||||
std::os::unix::fs::symlink(&outside, &link_path).expect("symlink should create");
|
||||
assert!(is_symlink_escape(&link_path, &workspace).expect("check should succeed"));
|
||||
}
|
||||
|
||||
// Non-symlink file should not be an escape
|
||||
let normal = workspace.join("normal.txt");
|
||||
std::fs::write(&normal, "normal content").expect("normal file should write");
|
||||
assert!(!is_symlink_escape(&normal, &workspace).expect("check should succeed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn globs_and_greps_directory() {
|
||||
let dir = temp_path("search-dir");
|
||||
@ -547,4 +789,51 @@ mod tests {
|
||||
.expect("grep should succeed");
|
||||
assert!(grep_output.content.unwrap_or_default().contains("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_braces_no_braces() {
|
||||
assert_eq!(expand_braces("*.rs"), vec!["*.rs"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_braces_single_group() {
|
||||
let mut result = expand_braces("Assets/**/*.{cs,uxml,uss}");
|
||||
result.sort();
|
||||
assert_eq!(
|
||||
result,
|
||||
vec!["Assets/**/*.cs", "Assets/**/*.uss", "Assets/**/*.uxml",]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_braces_nested() {
|
||||
let mut result = expand_braces("src/{a,b}.{rs,toml}");
|
||||
result.sort();
|
||||
assert_eq!(
|
||||
result,
|
||||
vec!["src/a.rs", "src/a.toml", "src/b.rs", "src/b.toml"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expand_braces_unmatched() {
|
||||
assert_eq!(expand_braces("foo.{bar"), vec!["foo.{bar"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn glob_search_with_braces_finds_files() {
|
||||
let dir = temp_path("glob-braces");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
std::fs::write(dir.join("a.rs"), "fn main() {}").unwrap();
|
||||
std::fs::write(dir.join("b.toml"), "[package]").unwrap();
|
||||
std::fs::write(dir.join("c.txt"), "hello").unwrap();
|
||||
|
||||
let result =
|
||||
glob_search("*.{rs,toml}", Some(dir.to_str().unwrap())).expect("glob should succeed");
|
||||
assert_eq!(
|
||||
result.num_files, 2,
|
||||
"should match .rs and .toml but not .txt"
|
||||
);
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
}
|
||||
|
||||
324
crates/runtime/src/git_context.rs
Normal file
324
crates/runtime/src/git_context.rs
Normal file
@ -0,0 +1,324 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
/// A single git commit entry from the log.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct GitCommitEntry {
|
||||
pub hash: String,
|
||||
pub subject: String,
|
||||
}
|
||||
|
||||
/// Git-aware context gathered at startup for injection into the system prompt.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct GitContext {
|
||||
pub branch: Option<String>,
|
||||
pub recent_commits: Vec<GitCommitEntry>,
|
||||
pub staged_files: Vec<String>,
|
||||
}
|
||||
|
||||
const MAX_RECENT_COMMITS: usize = 5;
|
||||
|
||||
impl GitContext {
|
||||
/// Detect the git context from the given working directory.
|
||||
///
|
||||
/// Returns `None` when the directory is not inside a git repository.
|
||||
#[must_use]
|
||||
pub fn detect(cwd: &Path) -> Option<Self> {
|
||||
// Quick gate: is this a git repo at all?
|
||||
let rev_parse = Command::new("git")
|
||||
.args(["rev-parse", "--is-inside-work-tree"])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !rev_parse.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
branch: read_branch(cwd),
|
||||
recent_commits: read_recent_commits(cwd),
|
||||
staged_files: read_staged_files(cwd),
|
||||
})
|
||||
}
|
||||
|
||||
/// Render a human-readable summary suitable for system-prompt injection.
|
||||
#[must_use]
|
||||
pub fn render(&self) -> String {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
if let Some(branch) = &self.branch {
|
||||
lines.push(format!("Git branch: {branch}"));
|
||||
}
|
||||
|
||||
if !self.recent_commits.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push("Recent commits:".to_string());
|
||||
for entry in &self.recent_commits {
|
||||
lines.push(format!(" {} {}", entry.hash, entry.subject));
|
||||
}
|
||||
}
|
||||
|
||||
if !self.staged_files.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push("Staged files:".to_string());
|
||||
for file in &self.staged_files {
|
||||
lines.push(format!(" {file}"));
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
fn read_branch(cwd: &Path) -> Option<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", "--abbrev-ref", "HEAD"])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let branch = String::from_utf8(output.stdout).ok()?;
|
||||
let trimmed = branch.trim();
|
||||
if trimmed.is_empty() || trimmed == "HEAD" {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn read_recent_commits(cwd: &Path) -> Vec<GitCommitEntry> {
|
||||
let output = Command::new("git")
|
||||
.args([
|
||||
"--no-optional-locks",
|
||||
"log",
|
||||
"--oneline",
|
||||
"-n",
|
||||
&MAX_RECENT_COMMITS.to_string(),
|
||||
"--no-decorate",
|
||||
])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok();
|
||||
let Some(output) = output else {
|
||||
return Vec::new();
|
||||
};
|
||||
if !output.status.success() {
|
||||
return Vec::new();
|
||||
}
|
||||
let stdout = String::from_utf8(output.stdout).unwrap_or_default();
|
||||
stdout
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let (hash, subject) = line.split_once(' ')?;
|
||||
Some(GitCommitEntry {
|
||||
hash: hash.to_string(),
|
||||
subject: subject.to_string(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn read_staged_files(cwd: &Path) -> Vec<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["--no-optional-locks", "diff", "--cached", "--name-only"])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok();
|
||||
let Some(output) = output else {
|
||||
return Vec::new();
|
||||
};
|
||||
if !output.status.success() {
|
||||
return Vec::new();
|
||||
}
|
||||
let stdout = String::from_utf8(output.stdout).unwrap_or_default();
|
||||
stdout
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.map(|line| line.trim().to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{GitCommitEntry, GitContext};
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir(label: &str) -> std::path::PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-git-context-{label}-{nanos}"))
|
||||
}
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
crate::test_env_lock()
|
||||
}
|
||||
|
||||
fn ensure_valid_cwd() {
|
||||
if std::env::current_dir().is_err() {
|
||||
std::env::set_current_dir(env!("CARGO_MANIFEST_DIR"))
|
||||
.expect("test cwd should be recoverable");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_non_git_directory() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("non-git");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root);
|
||||
|
||||
// then
|
||||
assert!(context.is_none());
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_branch_name_and_commits() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("branch-commits");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
git(&root, &["init", "--quiet", "--initial-branch=main"]);
|
||||
git(&root, &["config", "user.email", "tests@example.com"]);
|
||||
git(&root, &["config", "user.name", "Git Context Tests"]);
|
||||
fs::write(root.join("a.txt"), "a\n").expect("write a");
|
||||
git(&root, &["add", "a.txt"]);
|
||||
git(&root, &["commit", "-m", "first commit", "--quiet"]);
|
||||
fs::write(root.join("b.txt"), "b\n").expect("write b");
|
||||
git(&root, &["add", "b.txt"]);
|
||||
git(&root, &["commit", "-m", "second commit", "--quiet"]);
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root).expect("should detect git repo");
|
||||
|
||||
// then
|
||||
assert_eq!(context.branch.as_deref(), Some("main"));
|
||||
assert_eq!(context.recent_commits.len(), 2);
|
||||
assert_eq!(context.recent_commits[0].subject, "second commit");
|
||||
assert_eq!(context.recent_commits[1].subject, "first commit");
|
||||
assert!(context.staged_files.is_empty());
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_staged_files() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("staged");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
git(&root, &["init", "--quiet", "--initial-branch=main"]);
|
||||
git(&root, &["config", "user.email", "tests@example.com"]);
|
||||
git(&root, &["config", "user.name", "Git Context Tests"]);
|
||||
fs::write(root.join("init.txt"), "init\n").expect("write init");
|
||||
git(&root, &["add", "init.txt"]);
|
||||
git(&root, &["commit", "-m", "initial", "--quiet"]);
|
||||
fs::write(root.join("staged.txt"), "staged\n").expect("write staged");
|
||||
git(&root, &["add", "staged.txt"]);
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root).expect("should detect git repo");
|
||||
|
||||
// then
|
||||
assert_eq!(context.staged_files, vec!["staged.txt"]);
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_formats_all_sections() {
|
||||
// given
|
||||
let context = GitContext {
|
||||
branch: Some("feat/test".to_string()),
|
||||
recent_commits: vec![
|
||||
GitCommitEntry {
|
||||
hash: "abc1234".to_string(),
|
||||
subject: "add feature".to_string(),
|
||||
},
|
||||
GitCommitEntry {
|
||||
hash: "def5678".to_string(),
|
||||
subject: "fix bug".to_string(),
|
||||
},
|
||||
],
|
||||
staged_files: vec!["src/main.rs".to_string()],
|
||||
};
|
||||
|
||||
// when
|
||||
let rendered = context.render();
|
||||
|
||||
// then
|
||||
assert!(rendered.contains("Git branch: feat/test"));
|
||||
assert!(rendered.contains("abc1234 add feature"));
|
||||
assert!(rendered.contains("def5678 fix bug"));
|
||||
assert!(rendered.contains("src/main.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_omits_empty_sections() {
|
||||
// given
|
||||
let context = GitContext {
|
||||
branch: Some("main".to_string()),
|
||||
recent_commits: Vec::new(),
|
||||
staged_files: Vec::new(),
|
||||
};
|
||||
|
||||
// when
|
||||
let rendered = context.render();
|
||||
|
||||
// then
|
||||
assert!(rendered.contains("Git branch: main"));
|
||||
assert!(!rendered.contains("Recent commits:"));
|
||||
assert!(!rendered.contains("Staged files:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn limits_to_five_recent_commits() {
|
||||
// given
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir("five-commits");
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
git(&root, &["init", "--quiet", "--initial-branch=main"]);
|
||||
git(&root, &["config", "user.email", "tests@example.com"]);
|
||||
git(&root, &["config", "user.name", "Git Context Tests"]);
|
||||
for i in 1..=8 {
|
||||
let name = format!("file{i}.txt");
|
||||
fs::write(root.join(&name), format!("{i}\n")).expect("write file");
|
||||
git(&root, &["add", &name]);
|
||||
git(&root, &["commit", "-m", &format!("commit {i}"), "--quiet"]);
|
||||
}
|
||||
|
||||
// when
|
||||
let context = GitContext::detect(&root).expect("should detect git repo");
|
||||
|
||||
// then
|
||||
assert_eq!(context.recent_commits.len(), 5);
|
||||
assert_eq!(context.recent_commits[0].subject, "commit 8");
|
||||
assert_eq!(context.recent_commits[4].subject, "commit 4");
|
||||
fs::remove_dir_all(root).expect("cleanup");
|
||||
}
|
||||
|
||||
fn git(cwd: &std::path::Path, args: &[&str]) {
|
||||
let status = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.unwrap_or_else(|_| panic!("git {args:?} should run"))
|
||||
.status;
|
||||
assert!(status.success(), "git {args:?} failed");
|
||||
}
|
||||
}
|
||||
152
crates/runtime/src/green_contract.rs
Normal file
152
crates/runtime/src/green_contract.rs
Normal file
@ -0,0 +1,152 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum GreenLevel {
|
||||
TargetedTests,
|
||||
Package,
|
||||
Workspace,
|
||||
MergeReady,
|
||||
}
|
||||
|
||||
impl GreenLevel {
|
||||
#[must_use]
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::TargetedTests => "targeted_tests",
|
||||
Self::Package => "package",
|
||||
Self::Workspace => "workspace",
|
||||
Self::MergeReady => "merge_ready",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for GreenLevel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct GreenContract {
|
||||
pub required_level: GreenLevel,
|
||||
}
|
||||
|
||||
impl GreenContract {
|
||||
#[must_use]
|
||||
pub fn new(required_level: GreenLevel) -> Self {
|
||||
Self { required_level }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn evaluate(self, observed_level: Option<GreenLevel>) -> GreenContractOutcome {
|
||||
match observed_level {
|
||||
Some(level) if level >= self.required_level => GreenContractOutcome::Satisfied {
|
||||
required_level: self.required_level,
|
||||
observed_level: level,
|
||||
},
|
||||
_ => GreenContractOutcome::Unsatisfied {
|
||||
required_level: self.required_level,
|
||||
observed_level,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_satisfied_by(self, observed_level: GreenLevel) -> bool {
|
||||
observed_level >= self.required_level
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "outcome", rename_all = "snake_case")]
|
||||
pub enum GreenContractOutcome {
|
||||
Satisfied {
|
||||
required_level: GreenLevel,
|
||||
observed_level: GreenLevel,
|
||||
},
|
||||
Unsatisfied {
|
||||
required_level: GreenLevel,
|
||||
observed_level: Option<GreenLevel>,
|
||||
},
|
||||
}
|
||||
|
||||
impl GreenContractOutcome {
|
||||
#[must_use]
|
||||
pub fn is_satisfied(&self) -> bool {
|
||||
matches!(self, Self::Satisfied { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn given_matching_level_when_evaluating_contract_then_it_is_satisfied() {
|
||||
// given
|
||||
let contract = GreenContract::new(GreenLevel::Package);
|
||||
|
||||
// when
|
||||
let outcome = contract.evaluate(Some(GreenLevel::Package));
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
outcome,
|
||||
GreenContractOutcome::Satisfied {
|
||||
required_level: GreenLevel::Package,
|
||||
observed_level: GreenLevel::Package,
|
||||
}
|
||||
);
|
||||
assert!(outcome.is_satisfied());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_higher_level_when_checking_requirement_then_it_still_satisfies_contract() {
|
||||
// given
|
||||
let contract = GreenContract::new(GreenLevel::TargetedTests);
|
||||
|
||||
// when
|
||||
let is_satisfied = contract.is_satisfied_by(GreenLevel::Workspace);
|
||||
|
||||
// then
|
||||
assert!(is_satisfied);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_lower_level_when_evaluating_contract_then_it_is_unsatisfied() {
|
||||
// given
|
||||
let contract = GreenContract::new(GreenLevel::Workspace);
|
||||
|
||||
// when
|
||||
let outcome = contract.evaluate(Some(GreenLevel::Package));
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
outcome,
|
||||
GreenContractOutcome::Unsatisfied {
|
||||
required_level: GreenLevel::Workspace,
|
||||
observed_level: Some(GreenLevel::Package),
|
||||
}
|
||||
);
|
||||
assert!(!outcome.is_satisfied());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_no_green_level_when_evaluating_contract_then_contract_is_unsatisfied() {
|
||||
// given
|
||||
let contract = GreenContract::new(GreenLevel::MergeReady);
|
||||
|
||||
// when
|
||||
let outcome = contract.evaluate(None);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
outcome,
|
||||
GreenContractOutcome::Unsatisfied {
|
||||
required_level: GreenLevel::MergeReady,
|
||||
observed_level: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum JsonValue {
|
||||
Null,
|
||||
Bool(bool),
|
||||
|
||||
383
crates/runtime/src/lane_events.rs
Normal file
383
crates/runtime/src/lane_events.rs
Normal file
@ -0,0 +1,383 @@
|
||||
#![allow(clippy::similar_names)]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum LaneEventName {
|
||||
#[serde(rename = "lane.started")]
|
||||
Started,
|
||||
#[serde(rename = "lane.ready")]
|
||||
Ready,
|
||||
#[serde(rename = "lane.prompt_misdelivery")]
|
||||
PromptMisdelivery,
|
||||
#[serde(rename = "lane.blocked")]
|
||||
Blocked,
|
||||
#[serde(rename = "lane.red")]
|
||||
Red,
|
||||
#[serde(rename = "lane.green")]
|
||||
Green,
|
||||
#[serde(rename = "lane.commit.created")]
|
||||
CommitCreated,
|
||||
#[serde(rename = "lane.pr.opened")]
|
||||
PrOpened,
|
||||
#[serde(rename = "lane.merge.ready")]
|
||||
MergeReady,
|
||||
#[serde(rename = "lane.finished")]
|
||||
Finished,
|
||||
#[serde(rename = "lane.failed")]
|
||||
Failed,
|
||||
#[serde(rename = "lane.reconciled")]
|
||||
Reconciled,
|
||||
#[serde(rename = "lane.merged")]
|
||||
Merged,
|
||||
#[serde(rename = "lane.superseded")]
|
||||
Superseded,
|
||||
#[serde(rename = "lane.closed")]
|
||||
Closed,
|
||||
#[serde(rename = "branch.stale_against_main")]
|
||||
BranchStaleAgainstMain,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LaneEventStatus {
|
||||
Running,
|
||||
Ready,
|
||||
Blocked,
|
||||
Red,
|
||||
Green,
|
||||
Completed,
|
||||
Failed,
|
||||
Reconciled,
|
||||
Merged,
|
||||
Superseded,
|
||||
Closed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LaneFailureClass {
|
||||
PromptDelivery,
|
||||
TrustGate,
|
||||
BranchDivergence,
|
||||
Compile,
|
||||
Test,
|
||||
PluginStartup,
|
||||
McpStartup,
|
||||
McpHandshake,
|
||||
GatewayRouting,
|
||||
ToolRuntime,
|
||||
Infra,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct LaneEventBlocker {
|
||||
#[serde(rename = "failureClass")]
|
||||
pub failure_class: LaneFailureClass,
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct LaneCommitProvenance {
|
||||
pub commit: String,
|
||||
pub branch: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub worktree: Option<String>,
|
||||
#[serde(rename = "canonicalCommit", skip_serializing_if = "Option::is_none")]
|
||||
pub canonical_commit: Option<String>,
|
||||
#[serde(rename = "supersededBy", skip_serializing_if = "Option::is_none")]
|
||||
pub superseded_by: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub lineage: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct LaneEvent {
|
||||
pub event: LaneEventName,
|
||||
pub status: LaneEventStatus,
|
||||
#[serde(rename = "emittedAt")]
|
||||
pub emitted_at: String,
|
||||
#[serde(rename = "failureClass", skip_serializing_if = "Option::is_none")]
|
||||
pub failure_class: Option<LaneFailureClass>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
}
|
||||
|
||||
impl LaneEvent {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
event: LaneEventName,
|
||||
status: LaneEventStatus,
|
||||
emitted_at: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
event,
|
||||
status,
|
||||
emitted_at: emitted_at.into(),
|
||||
failure_class: None,
|
||||
detail: None,
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn started(emitted_at: impl Into<String>) -> Self {
|
||||
Self::new(LaneEventName::Started, LaneEventStatus::Running, emitted_at)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn finished(emitted_at: impl Into<String>, detail: Option<String>) -> Self {
|
||||
Self::new(
|
||||
LaneEventName::Finished,
|
||||
LaneEventStatus::Completed,
|
||||
emitted_at,
|
||||
)
|
||||
.with_optional_detail(detail)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn commit_created(
|
||||
emitted_at: impl Into<String>,
|
||||
detail: Option<String>,
|
||||
provenance: LaneCommitProvenance,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
LaneEventName::CommitCreated,
|
||||
LaneEventStatus::Completed,
|
||||
emitted_at,
|
||||
)
|
||||
.with_optional_detail(detail)
|
||||
.with_data(serde_json::to_value(provenance).expect("commit provenance should serialize"))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn superseded(
|
||||
emitted_at: impl Into<String>,
|
||||
detail: Option<String>,
|
||||
provenance: LaneCommitProvenance,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
LaneEventName::Superseded,
|
||||
LaneEventStatus::Superseded,
|
||||
emitted_at,
|
||||
)
|
||||
.with_optional_detail(detail)
|
||||
.with_data(serde_json::to_value(provenance).expect("commit provenance should serialize"))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn blocked(emitted_at: impl Into<String>, blocker: &LaneEventBlocker) -> Self {
|
||||
Self::new(LaneEventName::Blocked, LaneEventStatus::Blocked, emitted_at)
|
||||
.with_failure_class(blocker.failure_class)
|
||||
.with_detail(blocker.detail.clone())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn failed(emitted_at: impl Into<String>, blocker: &LaneEventBlocker) -> Self {
|
||||
Self::new(LaneEventName::Failed, LaneEventStatus::Failed, emitted_at)
|
||||
.with_failure_class(blocker.failure_class)
|
||||
.with_detail(blocker.detail.clone())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_failure_class(mut self, failure_class: LaneFailureClass) -> Self {
|
||||
self.failure_class = Some(failure_class);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
|
||||
self.detail = Some(detail.into());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_optional_detail(mut self, detail: Option<String>) -> Self {
|
||||
self.detail = detail;
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_data(mut self, data: Value) -> Self {
|
||||
self.data = Some(data);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn dedupe_superseded_commit_events(events: &[LaneEvent]) -> Vec<LaneEvent> {
|
||||
let mut keep = vec![true; events.len()];
|
||||
let mut latest_by_key = std::collections::BTreeMap::<String, usize>::new();
|
||||
|
||||
for (index, event) in events.iter().enumerate() {
|
||||
if event.event != LaneEventName::CommitCreated {
|
||||
continue;
|
||||
}
|
||||
let Some(data) = event.data.as_ref() else {
|
||||
continue;
|
||||
};
|
||||
let key = data
|
||||
.get("canonicalCommit")
|
||||
.or_else(|| data.get("commit"))
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(str::to_string);
|
||||
let superseded = data
|
||||
.get("supersededBy")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.is_some();
|
||||
if superseded {
|
||||
keep[index] = false;
|
||||
continue;
|
||||
}
|
||||
if let Some(key) = key {
|
||||
if let Some(previous) = latest_by_key.insert(key, index) {
|
||||
keep[previous] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(keep)
|
||||
.filter_map(|(event, retain)| retain.then_some(event))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::{
|
||||
dedupe_superseded_commit_events, LaneCommitProvenance, LaneEvent, LaneEventBlocker,
|
||||
LaneEventName, LaneEventStatus, LaneFailureClass,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn canonical_lane_event_names_serialize_to_expected_wire_values() {
|
||||
let cases = [
|
||||
(LaneEventName::Started, "lane.started"),
|
||||
(LaneEventName::Ready, "lane.ready"),
|
||||
(LaneEventName::PromptMisdelivery, "lane.prompt_misdelivery"),
|
||||
(LaneEventName::Blocked, "lane.blocked"),
|
||||
(LaneEventName::Red, "lane.red"),
|
||||
(LaneEventName::Green, "lane.green"),
|
||||
(LaneEventName::CommitCreated, "lane.commit.created"),
|
||||
(LaneEventName::PrOpened, "lane.pr.opened"),
|
||||
(LaneEventName::MergeReady, "lane.merge.ready"),
|
||||
(LaneEventName::Finished, "lane.finished"),
|
||||
(LaneEventName::Failed, "lane.failed"),
|
||||
(LaneEventName::Reconciled, "lane.reconciled"),
|
||||
(LaneEventName::Merged, "lane.merged"),
|
||||
(LaneEventName::Superseded, "lane.superseded"),
|
||||
(LaneEventName::Closed, "lane.closed"),
|
||||
(
|
||||
LaneEventName::BranchStaleAgainstMain,
|
||||
"branch.stale_against_main",
|
||||
),
|
||||
];
|
||||
|
||||
for (event, expected) in cases {
|
||||
assert_eq!(
|
||||
serde_json::to_value(event).expect("serialize event"),
|
||||
json!(expected)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn failure_classes_cover_canonical_taxonomy_wire_values() {
|
||||
let cases = [
|
||||
(LaneFailureClass::PromptDelivery, "prompt_delivery"),
|
||||
(LaneFailureClass::TrustGate, "trust_gate"),
|
||||
(LaneFailureClass::BranchDivergence, "branch_divergence"),
|
||||
(LaneFailureClass::Compile, "compile"),
|
||||
(LaneFailureClass::Test, "test"),
|
||||
(LaneFailureClass::PluginStartup, "plugin_startup"),
|
||||
(LaneFailureClass::McpStartup, "mcp_startup"),
|
||||
(LaneFailureClass::McpHandshake, "mcp_handshake"),
|
||||
(LaneFailureClass::GatewayRouting, "gateway_routing"),
|
||||
(LaneFailureClass::ToolRuntime, "tool_runtime"),
|
||||
(LaneFailureClass::Infra, "infra"),
|
||||
];
|
||||
|
||||
for (failure_class, expected) in cases {
|
||||
assert_eq!(
|
||||
serde_json::to_value(failure_class).expect("serialize failure class"),
|
||||
json!(expected)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blocked_and_failed_events_reuse_blocker_details() {
|
||||
let blocker = LaneEventBlocker {
|
||||
failure_class: LaneFailureClass::McpStartup,
|
||||
detail: "broken server".to_string(),
|
||||
};
|
||||
|
||||
let blocked = LaneEvent::blocked("2026-04-04T00:00:00Z", &blocker);
|
||||
let failed = LaneEvent::failed("2026-04-04T00:00:01Z", &blocker);
|
||||
|
||||
assert_eq!(blocked.event, LaneEventName::Blocked);
|
||||
assert_eq!(blocked.status, LaneEventStatus::Blocked);
|
||||
assert_eq!(blocked.failure_class, Some(LaneFailureClass::McpStartup));
|
||||
assert_eq!(failed.event, LaneEventName::Failed);
|
||||
assert_eq!(failed.status, LaneEventStatus::Failed);
|
||||
assert_eq!(failed.detail.as_deref(), Some("broken server"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn commit_events_can_carry_worktree_and_supersession_metadata() {
|
||||
let event = LaneEvent::commit_created(
|
||||
"2026-04-04T00:00:00Z",
|
||||
Some("commit created".to_string()),
|
||||
LaneCommitProvenance {
|
||||
commit: "abc123".to_string(),
|
||||
branch: "feature/provenance".to_string(),
|
||||
worktree: Some("wt-a".to_string()),
|
||||
canonical_commit: Some("abc123".to_string()),
|
||||
superseded_by: None,
|
||||
lineage: vec!["abc123".to_string()],
|
||||
},
|
||||
);
|
||||
let event_json = serde_json::to_value(&event).expect("lane event should serialize");
|
||||
assert_eq!(event_json["event"], "lane.commit.created");
|
||||
assert_eq!(event_json["data"]["branch"], "feature/provenance");
|
||||
assert_eq!(event_json["data"]["worktree"], "wt-a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dedupes_superseded_commit_events_by_canonical_commit() {
|
||||
let retained = dedupe_superseded_commit_events(&[
|
||||
LaneEvent::commit_created(
|
||||
"2026-04-04T00:00:00Z",
|
||||
Some("old".to_string()),
|
||||
LaneCommitProvenance {
|
||||
commit: "old123".to_string(),
|
||||
branch: "feature/provenance".to_string(),
|
||||
worktree: Some("wt-a".to_string()),
|
||||
canonical_commit: Some("canon123".to_string()),
|
||||
superseded_by: Some("new123".to_string()),
|
||||
lineage: vec!["old123".to_string(), "new123".to_string()],
|
||||
},
|
||||
),
|
||||
LaneEvent::commit_created(
|
||||
"2026-04-04T00:00:01Z",
|
||||
Some("new".to_string()),
|
||||
LaneCommitProvenance {
|
||||
commit: "new123".to_string(),
|
||||
branch: "feature/provenance".to_string(),
|
||||
worktree: Some("wt-b".to_string()),
|
||||
canonical_commit: Some("canon123".to_string()),
|
||||
superseded_by: None,
|
||||
lineage: vec!["old123".to_string(), "new123".to_string()],
|
||||
},
|
||||
),
|
||||
]);
|
||||
assert_eq!(retained.len(), 1);
|
||||
assert_eq!(retained[0].detail.as_deref(), Some("new"));
|
||||
}
|
||||
}
|
||||
@ -1,64 +1,112 @@
|
||||
//! Core runtime primitives for the `claw` CLI and supporting crates.
|
||||
//!
|
||||
//! This crate owns session persistence, permission evaluation, prompt assembly,
|
||||
//! MCP plumbing, tool-facing file operations, and the core conversation loop
|
||||
//! that drives interactive and one-shot turns.
|
||||
|
||||
mod bash;
|
||||
pub mod bash_validation;
|
||||
mod bootstrap;
|
||||
pub mod branch_lock;
|
||||
mod compact;
|
||||
mod config;
|
||||
pub mod config_validate;
|
||||
mod conversation;
|
||||
mod file_ops;
|
||||
mod git_context;
|
||||
pub mod green_contract;
|
||||
mod hooks;
|
||||
mod json;
|
||||
mod lane_events;
|
||||
pub mod lsp_client;
|
||||
mod mcp;
|
||||
mod mcp_client;
|
||||
pub mod mcp_lifecycle_hardened;
|
||||
pub mod mcp_server;
|
||||
mod mcp_stdio;
|
||||
pub mod mcp_tool_bridge;
|
||||
mod oauth;
|
||||
pub mod permission_enforcer;
|
||||
mod permissions;
|
||||
pub mod plugin_lifecycle;
|
||||
mod policy_engine;
|
||||
mod prompt;
|
||||
pub mod recovery_recipes;
|
||||
mod remote;
|
||||
pub mod sandbox;
|
||||
mod session;
|
||||
pub mod session_control;
|
||||
pub use session_control::SessionStore;
|
||||
mod sse;
|
||||
pub mod stale_base;
|
||||
pub mod stale_branch;
|
||||
pub mod summary_compression;
|
||||
pub mod task_packet;
|
||||
pub mod task_registry;
|
||||
pub mod team_cron_registry;
|
||||
#[cfg(test)]
|
||||
mod trust_resolver;
|
||||
mod usage;
|
||||
pub mod worker_boot;
|
||||
|
||||
pub use lsp::{
|
||||
FileDiagnostics, LspContextEnrichment, LspError, LspManager, LspServerConfig,
|
||||
SymbolLocation, WorkspaceDiagnostics,
|
||||
};
|
||||
pub use bash::{execute_bash, BashCommandInput, BashCommandOutput};
|
||||
pub use bootstrap::{BootstrapPhase, BootstrapPlan};
|
||||
pub use branch_lock::{detect_branch_lock_collisions, BranchLockCollision, BranchLockIntent};
|
||||
pub use compact::{
|
||||
compact_session, estimate_session_tokens, format_compact_summary,
|
||||
get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult,
|
||||
};
|
||||
pub use config::{
|
||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpManagedProxyServerConfig,
|
||||
McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||
ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpConfigCollection,
|
||||
McpManagedProxyServerConfig, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig,
|
||||
McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig,
|
||||
ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig,
|
||||
RuntimePluginConfig, ScopedMcpServerConfig, CLAW_SETTINGS_SCHEMA_NAME,
|
||||
ProviderFallbackConfig, ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig,
|
||||
RuntimeHookConfig, RuntimePermissionRuleConfig, RuntimePluginConfig, ScopedMcpServerConfig,
|
||||
CLAW_SETTINGS_SCHEMA_NAME,
|
||||
};
|
||||
pub use config_validate::{
|
||||
check_unsupported_format, format_diagnostics, validate_config_file, ConfigDiagnostic,
|
||||
DiagnosticKind, ValidationResult,
|
||||
};
|
||||
pub use conversation::{
|
||||
ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor,
|
||||
ToolError, ToolExecutor, TurnSummary,
|
||||
auto_compaction_threshold_from_env, ApiClient, ApiRequest, AssistantEvent, AutoCompactionEvent,
|
||||
ConversationRuntime, PromptCacheEvent, RuntimeError, StaticToolExecutor, ToolError,
|
||||
ToolExecutor, TurnSummary,
|
||||
};
|
||||
pub use file_ops::{
|
||||
edit_file, glob_search, grep_search, read_file, write_file, EditFileOutput, GlobSearchOutput,
|
||||
GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload,
|
||||
WriteFileOutput,
|
||||
};
|
||||
pub use hooks::{HookEvent, HookRunResult, HookRunner};
|
||||
pub use git_context::{GitCommitEntry, GitContext};
|
||||
pub use hooks::{
|
||||
HookAbortSignal, HookEvent, HookProgressEvent, HookProgressReporter, HookRunResult, HookRunner,
|
||||
};
|
||||
pub use lane_events::{
|
||||
dedupe_superseded_commit_events, LaneCommitProvenance, LaneEvent, LaneEventBlocker,
|
||||
LaneEventName, LaneEventStatus, LaneFailureClass,
|
||||
};
|
||||
pub use mcp::{
|
||||
mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp,
|
||||
scoped_mcp_config_hash, unwrap_ccr_proxy_url,
|
||||
};
|
||||
pub use mcp_client::{
|
||||
McpManagedProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport,
|
||||
McpClientAuth, McpClientBootstrap, McpClientTransport, McpManagedProxyTransport,
|
||||
McpRemoteTransport, McpSdkTransport, McpStdioTransport,
|
||||
};
|
||||
pub use mcp_lifecycle_hardened::{
|
||||
McpDegradedReport, McpErrorSurface, McpFailedServer, McpLifecyclePhase, McpLifecycleState,
|
||||
McpLifecycleValidator, McpPhaseResult,
|
||||
};
|
||||
pub use mcp_server::{McpServer, McpServerSpec, ToolCallHandler, MCP_SERVER_PROTOCOL_VERSION};
|
||||
pub use mcp_stdio::{
|
||||
spawn_mcp_stdio_process, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse,
|
||||
ManagedMcpTool, McpInitializeClientInfo, McpInitializeParams, McpInitializeResult,
|
||||
McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
|
||||
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpResource,
|
||||
McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess, McpTool,
|
||||
McpToolCallContent, McpToolCallParams, McpToolCallResult, UnsupportedMcpServer,
|
||||
ManagedMcpTool, McpDiscoveryFailure, McpInitializeClientInfo, McpInitializeParams,
|
||||
McpInitializeResult, McpInitializeServerInfo, McpListResourcesParams, McpListResourcesResult,
|
||||
McpListToolsParams, McpListToolsResult, McpReadResourceParams, McpReadResourceResult,
|
||||
McpResource, McpResourceContents, McpServerManager, McpServerManagerError, McpStdioProcess,
|
||||
McpTool, McpToolCallContent, McpToolCallParams, McpToolCallResult, McpToolDiscoveryReport,
|
||||
UnsupportedMcpServer,
|
||||
};
|
||||
pub use oauth::{
|
||||
clear_oauth_credentials, code_challenge_s256, credentials_path, generate_pkce_pair,
|
||||
@ -68,22 +116,59 @@ pub use oauth::{
|
||||
PkceChallengeMethod, PkceCodePair,
|
||||
};
|
||||
pub use permissions::{
|
||||
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
||||
PermissionPrompter, PermissionRequest,
|
||||
PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
|
||||
PermissionPromptDecision, PermissionPrompter, PermissionRequest,
|
||||
};
|
||||
pub use plugin_lifecycle::{
|
||||
DegradedMode, DiscoveryResult, PluginHealthcheck, PluginLifecycle, PluginLifecycleEvent,
|
||||
PluginState, ResourceInfo, ServerHealth, ServerStatus, ToolInfo,
|
||||
};
|
||||
pub use policy_engine::{
|
||||
evaluate, DiffScope, GreenLevel, LaneBlocker, LaneContext, PolicyAction, PolicyCondition,
|
||||
PolicyEngine, PolicyRule, ReconcileReason, ReviewStatus,
|
||||
};
|
||||
pub use prompt::{
|
||||
load_system_prompt, prepend_bullets, ContextFile, ProjectContext, PromptBuildError,
|
||||
SystemPromptBuilder, FRONTIER_MODEL_NAME, SYSTEM_PROMPT_DYNAMIC_BOUNDARY,
|
||||
};
|
||||
pub use recovery_recipes::{
|
||||
attempt_recovery, recipe_for, EscalationPolicy, FailureScenario, RecoveryContext,
|
||||
RecoveryEvent, RecoveryRecipe, RecoveryResult, RecoveryStep,
|
||||
};
|
||||
pub use remote::{
|
||||
inherited_upstream_proxy_env, no_proxy_list, read_token, upstream_proxy_ws_url,
|
||||
RemoteSessionContext, UpstreamProxyBootstrap, UpstreamProxyState, DEFAULT_REMOTE_BASE_URL,
|
||||
DEFAULT_SESSION_TOKEN_PATH, DEFAULT_SYSTEM_CA_BUNDLE, NO_PROXY_HOSTS, UPSTREAM_PROXY_ENV_KEYS,
|
||||
};
|
||||
pub use session::{ContentBlock, ConversationMessage, MessageRole, Session, SessionError};
|
||||
pub use sandbox::{
|
||||
build_linux_sandbox_command, detect_container_environment, detect_container_environment_from,
|
||||
resolve_sandbox_status, resolve_sandbox_status_for_request, ContainerEnvironment,
|
||||
FilesystemIsolationMode, LinuxSandboxCommand, SandboxConfig, SandboxDetectionInputs,
|
||||
SandboxRequest, SandboxStatus,
|
||||
};
|
||||
pub use session::{
|
||||
ContentBlock, ConversationMessage, MessageRole, Session, SessionCompaction, SessionError,
|
||||
SessionFork, SessionPromptEntry,
|
||||
};
|
||||
pub use sse::{IncrementalSseParser, SseEvent};
|
||||
pub use stale_base::{
|
||||
check_base_commit, format_stale_base_warning, read_claw_base_file, resolve_expected_base,
|
||||
BaseCommitSource, BaseCommitState,
|
||||
};
|
||||
pub use stale_branch::{
|
||||
apply_policy, check_freshness, BranchFreshness, StaleBranchAction, StaleBranchEvent,
|
||||
StaleBranchPolicy,
|
||||
};
|
||||
pub use task_packet::{validate_packet, TaskPacket, TaskPacketValidationError, ValidatedPacket};
|
||||
#[cfg(test)]
|
||||
pub use trust_resolver::{TrustConfig, TrustDecision, TrustEvent, TrustPolicy, TrustResolver};
|
||||
pub use usage::{
|
||||
format_usd, pricing_for_model, ModelPricing, TokenUsage, UsageCostEstimate, UsageTracker,
|
||||
};
|
||||
pub use worker_boot::{
|
||||
Worker, WorkerEvent, WorkerEventKind, WorkerEventPayload, WorkerFailure, WorkerFailureKind,
|
||||
WorkerPromptTarget, WorkerReadySnapshot, WorkerRegistry, WorkerStatus, WorkerTrustResolution,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn test_env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
|
||||
747
crates/runtime/src/lsp_client.rs
Normal file
747
crates/runtime/src/lsp_client.rs
Normal file
@ -0,0 +1,747 @@
|
||||
#![allow(clippy::should_implement_trait, clippy::must_use_candidate)]
|
||||
//! LSP (Language Server Protocol) client registry for tool dispatch.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Supported LSP actions.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LspAction {
|
||||
Diagnostics,
|
||||
Hover,
|
||||
Definition,
|
||||
References,
|
||||
Completion,
|
||||
Symbols,
|
||||
Format,
|
||||
}
|
||||
|
||||
impl LspAction {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s {
|
||||
"diagnostics" => Some(Self::Diagnostics),
|
||||
"hover" => Some(Self::Hover),
|
||||
"definition" | "goto_definition" => Some(Self::Definition),
|
||||
"references" | "find_references" => Some(Self::References),
|
||||
"completion" | "completions" => Some(Self::Completion),
|
||||
"symbols" | "document_symbols" => Some(Self::Symbols),
|
||||
"format" | "formatting" => Some(Self::Format),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspDiagnostic {
|
||||
pub path: String,
|
||||
pub line: u32,
|
||||
pub character: u32,
|
||||
pub severity: String,
|
||||
pub message: String,
|
||||
pub source: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspLocation {
|
||||
pub path: String,
|
||||
pub line: u32,
|
||||
pub character: u32,
|
||||
pub end_line: Option<u32>,
|
||||
pub end_character: Option<u32>,
|
||||
pub preview: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspHoverResult {
|
||||
pub content: String,
|
||||
pub language: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspCompletionItem {
|
||||
pub label: String,
|
||||
pub kind: Option<String>,
|
||||
pub detail: Option<String>,
|
||||
pub insert_text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspSymbol {
|
||||
pub name: String,
|
||||
pub kind: String,
|
||||
pub path: String,
|
||||
pub line: u32,
|
||||
pub character: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LspServerStatus {
|
||||
Connected,
|
||||
Disconnected,
|
||||
Starting,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for LspServerStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Connected => write!(f, "connected"),
|
||||
Self::Disconnected => write!(f, "disconnected"),
|
||||
Self::Starting => write!(f, "starting"),
|
||||
Self::Error => write!(f, "error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LspServerState {
|
||||
pub language: String,
|
||||
pub status: LspServerStatus,
|
||||
pub root_path: Option<String>,
|
||||
pub capabilities: Vec<String>,
|
||||
pub diagnostics: Vec<LspDiagnostic>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct LspRegistry {
|
||||
inner: Arc<Mutex<RegistryInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct RegistryInner {
|
||||
servers: HashMap<String, LspServerState>,
|
||||
}
|
||||
|
||||
impl LspRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn register(
|
||||
&self,
|
||||
language: &str,
|
||||
status: LspServerStatus,
|
||||
root_path: Option<&str>,
|
||||
capabilities: Vec<String>,
|
||||
) {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.insert(
|
||||
language.to_owned(),
|
||||
LspServerState {
|
||||
language: language.to_owned(),
|
||||
status,
|
||||
root_path: root_path.map(str::to_owned),
|
||||
capabilities,
|
||||
diagnostics: Vec::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get(&self, language: &str) -> Option<LspServerState> {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.get(language).cloned()
|
||||
}
|
||||
|
||||
/// Find the appropriate server for a file path based on extension.
|
||||
pub fn find_server_for_path(&self, path: &str) -> Option<LspServerState> {
|
||||
let ext = std::path::Path::new(path)
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let language = match ext {
|
||||
"rs" => "rust",
|
||||
"ts" | "tsx" => "typescript",
|
||||
"js" | "jsx" => "javascript",
|
||||
"py" => "python",
|
||||
"go" => "go",
|
||||
"java" => "java",
|
||||
"c" | "h" => "c",
|
||||
"cpp" | "hpp" | "cc" => "cpp",
|
||||
"rb" => "ruby",
|
||||
"lua" => "lua",
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
self.get(language)
|
||||
}
|
||||
|
||||
/// List all registered servers.
|
||||
pub fn list_servers(&self) -> Vec<LspServerState> {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.values().cloned().collect()
|
||||
}
|
||||
|
||||
/// Add diagnostics to a server.
|
||||
pub fn add_diagnostics(
|
||||
&self,
|
||||
language: &str,
|
||||
diagnostics: Vec<LspDiagnostic>,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
let server = inner
|
||||
.servers
|
||||
.get_mut(language)
|
||||
.ok_or_else(|| format!("LSP server not found for language: {language}"))?;
|
||||
server.diagnostics.extend(diagnostics);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get diagnostics for a specific file path.
|
||||
pub fn get_diagnostics(&self, path: &str) -> Vec<LspDiagnostic> {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner
|
||||
.servers
|
||||
.values()
|
||||
.flat_map(|s| &s.diagnostics)
|
||||
.filter(|d| d.path == path)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Clear diagnostics for a language server.
|
||||
pub fn clear_diagnostics(&self, language: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
let server = inner
|
||||
.servers
|
||||
.get_mut(language)
|
||||
.ok_or_else(|| format!("LSP server not found for language: {language}"))?;
|
||||
server.diagnostics.clear();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect a server.
|
||||
pub fn disconnect(&self, language: &str) -> Option<LspServerState> {
|
||||
let mut inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.remove(language)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
inner.servers.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Dispatch an LSP action and return a structured result.
|
||||
pub fn dispatch(
|
||||
&self,
|
||||
action: &str,
|
||||
path: Option<&str>,
|
||||
line: Option<u32>,
|
||||
character: Option<u32>,
|
||||
_query: Option<&str>,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let lsp_action =
|
||||
LspAction::from_str(action).ok_or_else(|| format!("unknown LSP action: {action}"))?;
|
||||
|
||||
// For diagnostics, we can check existing cached diagnostics
|
||||
if lsp_action == LspAction::Diagnostics {
|
||||
if let Some(path) = path {
|
||||
let diags = self.get_diagnostics(path);
|
||||
return Ok(serde_json::json!({
|
||||
"action": "diagnostics",
|
||||
"path": path,
|
||||
"diagnostics": diags,
|
||||
"count": diags.len()
|
||||
}));
|
||||
}
|
||||
// All diagnostics across all servers
|
||||
let inner = self.inner.lock().expect("lsp registry lock poisoned");
|
||||
let all_diags: Vec<_> = inner
|
||||
.servers
|
||||
.values()
|
||||
.flat_map(|s| &s.diagnostics)
|
||||
.collect();
|
||||
return Ok(serde_json::json!({
|
||||
"action": "diagnostics",
|
||||
"diagnostics": all_diags,
|
||||
"count": all_diags.len()
|
||||
}));
|
||||
}
|
||||
|
||||
// For other actions, we need a connected server for the given file
|
||||
let path = path.ok_or("path is required for this LSP action")?;
|
||||
let server = self
|
||||
.find_server_for_path(path)
|
||||
.ok_or_else(|| format!("no LSP server available for path: {path}"))?;
|
||||
|
||||
if server.status != LspServerStatus::Connected {
|
||||
return Err(format!(
|
||||
"LSP server for '{}' is not connected (status: {})",
|
||||
server.language, server.status
|
||||
));
|
||||
}
|
||||
|
||||
// Return structured placeholder — actual LSP JSON-RPC calls would
|
||||
// go through the real LSP process here.
|
||||
Ok(serde_json::json!({
|
||||
"action": action,
|
||||
"path": path,
|
||||
"line": line,
|
||||
"character": character,
|
||||
"language": server.language,
|
||||
"status": "dispatched",
|
||||
"message": format!("LSP {} dispatched to {} server", action, server.language)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn registers_and_retrieves_server() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register(
|
||||
"rust",
|
||||
LspServerStatus::Connected,
|
||||
Some("/workspace"),
|
||||
vec!["hover".into(), "completion".into()],
|
||||
);
|
||||
|
||||
let server = registry.get("rust").expect("should exist");
|
||||
assert_eq!(server.language, "rust");
|
||||
assert_eq!(server.status, LspServerStatus::Connected);
|
||||
assert_eq!(server.capabilities.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finds_server_by_file_extension() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("typescript", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
let rs_server = registry.find_server_for_path("src/main.rs").unwrap();
|
||||
assert_eq!(rs_server.language, "rust");
|
||||
|
||||
let ts_server = registry.find_server_for_path("src/index.ts").unwrap();
|
||||
assert_eq!(ts_server.language, "typescript");
|
||||
|
||||
assert!(registry.find_server_for_path("data.csv").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manages_diagnostics() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: "src/main.rs".into(),
|
||||
line: 10,
|
||||
character: 5,
|
||||
severity: "error".into(),
|
||||
message: "mismatched types".into(),
|
||||
source: Some("rust-analyzer".into()),
|
||||
}],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let diags = registry.get_diagnostics("src/main.rs");
|
||||
assert_eq!(diags.len(), 1);
|
||||
assert_eq!(diags[0].message, "mismatched types");
|
||||
|
||||
registry.clear_diagnostics("rust").unwrap();
|
||||
assert!(registry.get_diagnostics("src/main.rs").is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatches_diagnostics_action() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: "src/lib.rs".into(),
|
||||
line: 1,
|
||||
character: 0,
|
||||
severity: "warning".into(),
|
||||
message: "unused import".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = registry
|
||||
.dispatch("diagnostics", Some("src/lib.rs"), None, None, None)
|
||||
.unwrap();
|
||||
assert_eq!(result["count"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatches_hover_action() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
let result = registry
|
||||
.dispatch("hover", Some("src/main.rs"), Some(10), Some(5), None)
|
||||
.unwrap();
|
||||
assert_eq!(result["action"], "hover");
|
||||
assert_eq!(result["language"], "rust");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_action_on_disconnected_server() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Disconnected, None, vec![]);
|
||||
|
||||
assert!(registry
|
||||
.dispatch("hover", Some("src/main.rs"), Some(1), Some(0), None)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_unknown_action() {
|
||||
let registry = LspRegistry::new();
|
||||
assert!(registry
|
||||
.dispatch("unknown_action", Some("file.rs"), None, None, None)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disconnects_server() {
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
assert_eq!(registry.len(), 1);
|
||||
|
||||
let removed = registry.disconnect("rust");
|
||||
assert!(removed.is_some());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lsp_action_from_str_all_aliases() {
|
||||
// given
|
||||
let cases = [
|
||||
("diagnostics", Some(LspAction::Diagnostics)),
|
||||
("hover", Some(LspAction::Hover)),
|
||||
("definition", Some(LspAction::Definition)),
|
||||
("goto_definition", Some(LspAction::Definition)),
|
||||
("references", Some(LspAction::References)),
|
||||
("find_references", Some(LspAction::References)),
|
||||
("completion", Some(LspAction::Completion)),
|
||||
("completions", Some(LspAction::Completion)),
|
||||
("symbols", Some(LspAction::Symbols)),
|
||||
("document_symbols", Some(LspAction::Symbols)),
|
||||
("format", Some(LspAction::Format)),
|
||||
("formatting", Some(LspAction::Format)),
|
||||
("unknown", None),
|
||||
];
|
||||
|
||||
// when
|
||||
let resolved: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(input, expected)| (input, LspAction::from_str(input), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
for (input, actual, expected) in resolved {
|
||||
assert_eq!(actual, expected, "unexpected action resolution for {input}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lsp_server_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(LspServerStatus::Connected, "connected"),
|
||||
(LspServerStatus::Disconnected, "disconnected"),
|
||||
(LspServerStatus::Starting, "starting"),
|
||||
(LspServerStatus::Error, "error"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("connected".to_string(), "connected"),
|
||||
("disconnected".to_string(), "disconnected"),
|
||||
("starting".to_string(), "starting"),
|
||||
("error".to_string(), "error"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_diagnostics_without_path_aggregates() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: "src/lib.rs".into(),
|
||||
line: 1,
|
||||
character: 0,
|
||||
severity: "warning".into(),
|
||||
message: "unused import".into(),
|
||||
source: Some("rust-analyzer".into()),
|
||||
}],
|
||||
)
|
||||
.expect("rust diagnostics should add");
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"python",
|
||||
vec![LspDiagnostic {
|
||||
path: "script.py".into(),
|
||||
line: 2,
|
||||
character: 4,
|
||||
severity: "error".into(),
|
||||
message: "undefined name".into(),
|
||||
source: Some("pyright".into()),
|
||||
}],
|
||||
)
|
||||
.expect("python diagnostics should add");
|
||||
|
||||
// when
|
||||
let result = registry
|
||||
.dispatch("diagnostics", None, None, None, None)
|
||||
.expect("aggregate diagnostics should work");
|
||||
|
||||
// then
|
||||
assert_eq!(result["action"], "diagnostics");
|
||||
assert_eq!(result["count"], 2);
|
||||
assert_eq!(result["diagnostics"].as_array().map(Vec::len), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_non_diagnostics_requires_path() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", None, Some(1), Some(0), None);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
result.expect_err("path should be required"),
|
||||
"path is required for this LSP action"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_no_server_for_path_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", Some("notes.md"), Some(1), Some(0), None);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing server should fail");
|
||||
assert!(error.contains("no LSP server available for path: notes.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_disconnected_server_error_payload() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("typescript", LspServerStatus::Disconnected, None, vec![]);
|
||||
|
||||
// when
|
||||
let result = registry.dispatch("hover", Some("src/index.ts"), Some(3), Some(2), None);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("disconnected server should fail");
|
||||
assert!(error.contains("typescript"));
|
||||
assert!(error.contains("disconnected"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_server_for_all_extensions() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
for language in [
|
||||
"rust",
|
||||
"typescript",
|
||||
"javascript",
|
||||
"python",
|
||||
"go",
|
||||
"java",
|
||||
"c",
|
||||
"cpp",
|
||||
"ruby",
|
||||
"lua",
|
||||
] {
|
||||
registry.register(language, LspServerStatus::Connected, None, vec![]);
|
||||
}
|
||||
let cases = [
|
||||
("src/main.rs", "rust"),
|
||||
("src/index.ts", "typescript"),
|
||||
("src/view.tsx", "typescript"),
|
||||
("src/app.js", "javascript"),
|
||||
("src/app.jsx", "javascript"),
|
||||
("script.py", "python"),
|
||||
("main.go", "go"),
|
||||
("Main.java", "java"),
|
||||
("native.c", "c"),
|
||||
("native.h", "c"),
|
||||
("native.cpp", "cpp"),
|
||||
("native.hpp", "cpp"),
|
||||
("native.cc", "cpp"),
|
||||
("script.rb", "ruby"),
|
||||
("script.lua", "lua"),
|
||||
];
|
||||
|
||||
// when
|
||||
let resolved: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(path, expected)| {
|
||||
(
|
||||
path,
|
||||
registry
|
||||
.find_server_for_path(path)
|
||||
.map(|server| server.language),
|
||||
expected,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// then
|
||||
for (path, actual, expected) in resolved {
|
||||
assert_eq!(
|
||||
actual.as_deref(),
|
||||
Some(expected),
|
||||
"unexpected mapping for {path}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_server_for_path_no_extension() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
|
||||
// when
|
||||
let result = registry.find_server_for_path("Makefile");
|
||||
|
||||
// then
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_servers_with_multiple() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("typescript", LspServerStatus::Starting, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Error, None, vec![]);
|
||||
|
||||
// when
|
||||
let servers = registry.list_servers();
|
||||
|
||||
// then
|
||||
assert_eq!(servers.len(), 3);
|
||||
assert!(servers.iter().any(|server| server.language == "rust"));
|
||||
assert!(servers.iter().any(|server| server.language == "typescript"));
|
||||
assert!(servers.iter().any(|server| server.language == "python"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_missing_server_returns_none() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let server = registry.get("missing");
|
||||
|
||||
// then
|
||||
assert!(server.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_diagnostics_missing_language_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.add_diagnostics("missing", vec![]);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing language should fail");
|
||||
assert!(error.contains("LSP server not found for language: missing"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_diagnostics_across_servers() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
let shared_path = "shared/file.txt";
|
||||
registry.register("rust", LspServerStatus::Connected, None, vec![]);
|
||||
registry.register("python", LspServerStatus::Connected, None, vec![]);
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"rust",
|
||||
vec![LspDiagnostic {
|
||||
path: shared_path.into(),
|
||||
line: 4,
|
||||
character: 1,
|
||||
severity: "warning".into(),
|
||||
message: "warn".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.expect("rust diagnostics should add");
|
||||
registry
|
||||
.add_diagnostics(
|
||||
"python",
|
||||
vec![LspDiagnostic {
|
||||
path: shared_path.into(),
|
||||
line: 8,
|
||||
character: 3,
|
||||
severity: "error".into(),
|
||||
message: "err".into(),
|
||||
source: None,
|
||||
}],
|
||||
)
|
||||
.expect("python diagnostics should add");
|
||||
|
||||
// when
|
||||
let diagnostics = registry.get_diagnostics(shared_path);
|
||||
|
||||
// then
|
||||
assert_eq!(diagnostics.len(), 2);
|
||||
assert!(diagnostics
|
||||
.iter()
|
||||
.any(|diagnostic| diagnostic.message == "warn"));
|
||||
assert!(diagnostics
|
||||
.iter()
|
||||
.any(|diagnostic| diagnostic.message == "err"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_diagnostics_missing_language_errors() {
|
||||
// given
|
||||
let registry = LspRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.clear_diagnostics("missing");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing language should fail");
|
||||
assert!(error.contains("LSP server not found for language: missing"));
|
||||
}
|
||||
}
|
||||
@ -84,10 +84,13 @@ pub fn mcp_server_signature(config: &McpServerConfig) -> Option<String> {
|
||||
pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String {
|
||||
let rendered = match &config.config {
|
||||
McpServerConfig::Stdio(stdio) => format!(
|
||||
"stdio|{}|{}|{}",
|
||||
"stdio|{}|{}|{}|{}",
|
||||
stdio.command,
|
||||
render_command_signature(&stdio.args),
|
||||
render_env_signature(&stdio.env)
|
||||
render_env_signature(&stdio.env),
|
||||
stdio
|
||||
.tool_call_timeout_ms
|
||||
.map_or_else(String::new, |timeout_ms| timeout_ms.to_string())
|
||||
),
|
||||
McpServerConfig::Sse(remote) => format!(
|
||||
"sse|{}|{}|{}|{}",
|
||||
@ -245,6 +248,7 @@ mod tests {
|
||||
command: "uvx".to_string(),
|
||||
args: vec!["mcp-server".to_string()],
|
||||
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
|
||||
tool_call_timeout_ms: None,
|
||||
});
|
||||
assert_eq!(
|
||||
mcp_server_signature(&stdio),
|
||||
|
||||
@ -3,6 +3,8 @@ use std::collections::BTreeMap;
|
||||
use crate::config::{McpOAuthConfig, McpServerConfig, ScopedMcpServerConfig};
|
||||
use crate::mcp::{mcp_server_signature, mcp_tool_prefix, normalize_name_for_mcp};
|
||||
|
||||
pub const DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS: u64 = 60_000;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum McpClientTransport {
|
||||
Stdio(McpStdioTransport),
|
||||
@ -18,6 +20,7 @@ pub struct McpStdioTransport {
|
||||
pub command: String,
|
||||
pub args: Vec<String>,
|
||||
pub env: BTreeMap<String, String>,
|
||||
pub tool_call_timeout_ms: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@ -75,6 +78,7 @@ impl McpClientTransport {
|
||||
command: config.command.clone(),
|
||||
args: config.args.clone(),
|
||||
env: config.env.clone(),
|
||||
tool_call_timeout_ms: config.tool_call_timeout_ms,
|
||||
}),
|
||||
McpServerConfig::Sse(config) => Self::Sse(McpRemoteTransport {
|
||||
url: config.url.clone(),
|
||||
@ -105,6 +109,14 @@ impl McpClientTransport {
|
||||
}
|
||||
}
|
||||
|
||||
impl McpStdioTransport {
|
||||
#[must_use]
|
||||
pub fn resolved_tool_call_timeout_ms(&self) -> u64 {
|
||||
self.tool_call_timeout_ms
|
||||
.unwrap_or(DEFAULT_MCP_TOOL_CALL_TIMEOUT_MS)
|
||||
}
|
||||
}
|
||||
|
||||
impl McpClientAuth {
|
||||
#[must_use]
|
||||
pub fn from_oauth(oauth: Option<McpOAuthConfig>) -> Self {
|
||||
@ -136,6 +148,7 @@ mod tests {
|
||||
command: "uvx".to_string(),
|
||||
args: vec!["mcp-server".to_string()],
|
||||
env: BTreeMap::from([("TOKEN".to_string(), "secret".to_string())]),
|
||||
tool_call_timeout_ms: Some(15_000),
|
||||
}),
|
||||
};
|
||||
|
||||
@ -154,6 +167,7 @@ mod tests {
|
||||
transport.env.get("TOKEN").map(String::as_str),
|
||||
Some("secret")
|
||||
);
|
||||
assert_eq!(transport.tool_call_timeout_ms, Some(15_000));
|
||||
}
|
||||
other => panic!("expected stdio transport, got {other:?}"),
|
||||
}
|
||||
|
||||
843
crates/runtime/src/mcp_lifecycle_hardened.rs
Normal file
843
crates/runtime/src/mcp_lifecycle_hardened.rs
Normal file
@ -0,0 +1,843 @@
|
||||
#![allow(clippy::unnested_or_patterns, clippy::map_unwrap_or)]
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum McpLifecyclePhase {
|
||||
ConfigLoad,
|
||||
ServerRegistration,
|
||||
SpawnConnect,
|
||||
InitializeHandshake,
|
||||
ToolDiscovery,
|
||||
ResourceDiscovery,
|
||||
Ready,
|
||||
Invocation,
|
||||
ErrorSurfacing,
|
||||
Shutdown,
|
||||
Cleanup,
|
||||
}
|
||||
|
||||
impl McpLifecyclePhase {
|
||||
#[must_use]
|
||||
pub fn all() -> [Self; 11] {
|
||||
[
|
||||
Self::ConfigLoad,
|
||||
Self::ServerRegistration,
|
||||
Self::SpawnConnect,
|
||||
Self::InitializeHandshake,
|
||||
Self::ToolDiscovery,
|
||||
Self::ResourceDiscovery,
|
||||
Self::Ready,
|
||||
Self::Invocation,
|
||||
Self::ErrorSurfacing,
|
||||
Self::Shutdown,
|
||||
Self::Cleanup,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for McpLifecyclePhase {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConfigLoad => write!(f, "config_load"),
|
||||
Self::ServerRegistration => write!(f, "server_registration"),
|
||||
Self::SpawnConnect => write!(f, "spawn_connect"),
|
||||
Self::InitializeHandshake => write!(f, "initialize_handshake"),
|
||||
Self::ToolDiscovery => write!(f, "tool_discovery"),
|
||||
Self::ResourceDiscovery => write!(f, "resource_discovery"),
|
||||
Self::Ready => write!(f, "ready"),
|
||||
Self::Invocation => write!(f, "invocation"),
|
||||
Self::ErrorSurfacing => write!(f, "error_surfacing"),
|
||||
Self::Shutdown => write!(f, "shutdown"),
|
||||
Self::Cleanup => write!(f, "cleanup"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct McpErrorSurface {
|
||||
pub phase: McpLifecyclePhase,
|
||||
pub server_name: Option<String>,
|
||||
pub message: String,
|
||||
pub context: BTreeMap<String, String>,
|
||||
pub recoverable: bool,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
impl McpErrorSurface {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
phase: McpLifecyclePhase,
|
||||
server_name: Option<String>,
|
||||
message: impl Into<String>,
|
||||
context: BTreeMap<String, String>,
|
||||
recoverable: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
phase,
|
||||
server_name,
|
||||
message: message.into(),
|
||||
context,
|
||||
recoverable,
|
||||
timestamp: now_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for McpErrorSurface {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"MCP lifecycle error during {}: {}",
|
||||
self.phase, self.message
|
||||
)?;
|
||||
if let Some(server_name) = &self.server_name {
|
||||
write!(f, " (server: {server_name})")?;
|
||||
}
|
||||
if !self.context.is_empty() {
|
||||
write!(f, " with context {:?}", self.context)?;
|
||||
}
|
||||
if self.recoverable {
|
||||
write!(f, " [recoverable]")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for McpErrorSurface {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum McpPhaseResult {
|
||||
Success {
|
||||
phase: McpLifecyclePhase,
|
||||
duration: Duration,
|
||||
},
|
||||
Failure {
|
||||
phase: McpLifecyclePhase,
|
||||
error: McpErrorSurface,
|
||||
},
|
||||
Timeout {
|
||||
phase: McpLifecyclePhase,
|
||||
waited: Duration,
|
||||
error: McpErrorSurface,
|
||||
},
|
||||
}
|
||||
|
||||
impl McpPhaseResult {
|
||||
#[must_use]
|
||||
pub fn phase(&self) -> McpLifecyclePhase {
|
||||
match self {
|
||||
Self::Success { phase, .. }
|
||||
| Self::Failure { phase, .. }
|
||||
| Self::Timeout { phase, .. } => *phase,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct McpLifecycleState {
|
||||
current_phase: Option<McpLifecyclePhase>,
|
||||
phase_errors: BTreeMap<McpLifecyclePhase, Vec<McpErrorSurface>>,
|
||||
phase_timestamps: BTreeMap<McpLifecyclePhase, u64>,
|
||||
phase_results: Vec<McpPhaseResult>,
|
||||
}
|
||||
|
||||
impl McpLifecycleState {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn current_phase(&self) -> Option<McpLifecyclePhase> {
|
||||
self.current_phase
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn errors_for_phase(&self, phase: McpLifecyclePhase) -> &[McpErrorSurface] {
|
||||
self.phase_errors
|
||||
.get(&phase)
|
||||
.map(Vec::as_slice)
|
||||
.unwrap_or(&[])
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn results(&self) -> &[McpPhaseResult] {
|
||||
&self.phase_results
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn phase_timestamps(&self) -> &BTreeMap<McpLifecyclePhase, u64> {
|
||||
&self.phase_timestamps
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn phase_timestamp(&self, phase: McpLifecyclePhase) -> Option<u64> {
|
||||
self.phase_timestamps.get(&phase).copied()
|
||||
}
|
||||
|
||||
fn record_phase(&mut self, phase: McpLifecyclePhase) {
|
||||
self.current_phase = Some(phase);
|
||||
self.phase_timestamps.insert(phase, now_secs());
|
||||
}
|
||||
|
||||
fn record_error(&mut self, error: McpErrorSurface) {
|
||||
self.phase_errors
|
||||
.entry(error.phase)
|
||||
.or_default()
|
||||
.push(error);
|
||||
}
|
||||
|
||||
fn record_result(&mut self, result: McpPhaseResult) {
|
||||
self.phase_results.push(result);
|
||||
}
|
||||
|
||||
fn can_resume_after_error(&self) -> bool {
|
||||
match self.phase_results.last() {
|
||||
Some(McpPhaseResult::Failure { error, .. } | McpPhaseResult::Timeout { error, .. }) => {
|
||||
error.recoverable
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct McpFailedServer {
|
||||
pub server_name: String,
|
||||
pub phase: McpLifecyclePhase,
|
||||
pub error: McpErrorSurface,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct McpDegradedReport {
|
||||
pub working_servers: Vec<String>,
|
||||
pub failed_servers: Vec<McpFailedServer>,
|
||||
pub available_tools: Vec<String>,
|
||||
pub missing_tools: Vec<String>,
|
||||
}
|
||||
|
||||
impl McpDegradedReport {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
working_servers: Vec<String>,
|
||||
failed_servers: Vec<McpFailedServer>,
|
||||
available_tools: Vec<String>,
|
||||
expected_tools: Vec<String>,
|
||||
) -> Self {
|
||||
let working_servers = dedupe_sorted(working_servers);
|
||||
let available_tools = dedupe_sorted(available_tools);
|
||||
let available_tool_set: BTreeSet<_> = available_tools.iter().cloned().collect();
|
||||
let expected_tools = dedupe_sorted(expected_tools);
|
||||
let missing_tools = expected_tools
|
||||
.into_iter()
|
||||
.filter(|tool| !available_tool_set.contains(tool))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
working_servers,
|
||||
failed_servers,
|
||||
available_tools,
|
||||
missing_tools,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct McpLifecycleValidator {
|
||||
state: McpLifecycleState,
|
||||
}
|
||||
|
||||
impl McpLifecycleValidator {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn state(&self) -> &McpLifecycleState {
|
||||
&self.state
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn validate_phase_transition(from: McpLifecyclePhase, to: McpLifecyclePhase) -> bool {
|
||||
match (from, to) {
|
||||
(McpLifecyclePhase::ConfigLoad, McpLifecyclePhase::ServerRegistration)
|
||||
| (McpLifecyclePhase::ServerRegistration, McpLifecyclePhase::SpawnConnect)
|
||||
| (McpLifecyclePhase::SpawnConnect, McpLifecyclePhase::InitializeHandshake)
|
||||
| (McpLifecyclePhase::InitializeHandshake, McpLifecyclePhase::ToolDiscovery)
|
||||
| (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::ResourceDiscovery)
|
||||
| (McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready)
|
||||
| (McpLifecyclePhase::ResourceDiscovery, McpLifecyclePhase::Ready)
|
||||
| (McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation)
|
||||
| (McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready)
|
||||
| (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Ready)
|
||||
| (McpLifecyclePhase::ErrorSurfacing, McpLifecyclePhase::Shutdown)
|
||||
| (McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup) => true,
|
||||
(_, McpLifecyclePhase::Shutdown) => from != McpLifecyclePhase::Cleanup,
|
||||
(_, McpLifecyclePhase::ErrorSurfacing) => {
|
||||
from != McpLifecyclePhase::Cleanup && from != McpLifecyclePhase::Shutdown
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_phase(&mut self, phase: McpLifecyclePhase) -> McpPhaseResult {
|
||||
let started = Instant::now();
|
||||
|
||||
if let Some(current_phase) = self.state.current_phase() {
|
||||
if current_phase == McpLifecyclePhase::ErrorSurfacing
|
||||
&& phase == McpLifecyclePhase::Ready
|
||||
&& !self.state.can_resume_after_error()
|
||||
{
|
||||
return self.record_failure(McpErrorSurface::new(
|
||||
phase,
|
||||
None,
|
||||
"cannot return to ready after a non-recoverable MCP lifecycle failure",
|
||||
BTreeMap::from([
|
||||
("from".to_string(), current_phase.to_string()),
|
||||
("to".to_string(), phase.to_string()),
|
||||
]),
|
||||
false,
|
||||
));
|
||||
}
|
||||
|
||||
if !Self::validate_phase_transition(current_phase, phase) {
|
||||
return self.record_failure(McpErrorSurface::new(
|
||||
phase,
|
||||
None,
|
||||
format!("invalid MCP lifecycle transition from {current_phase} to {phase}"),
|
||||
BTreeMap::from([
|
||||
("from".to_string(), current_phase.to_string()),
|
||||
("to".to_string(), phase.to_string()),
|
||||
]),
|
||||
false,
|
||||
));
|
||||
}
|
||||
} else if phase != McpLifecyclePhase::ConfigLoad {
|
||||
return self.record_failure(McpErrorSurface::new(
|
||||
phase,
|
||||
None,
|
||||
format!("invalid initial MCP lifecycle phase {phase}"),
|
||||
BTreeMap::from([("phase".to_string(), phase.to_string())]),
|
||||
false,
|
||||
));
|
||||
}
|
||||
|
||||
self.state.record_phase(phase);
|
||||
let result = McpPhaseResult::Success {
|
||||
phase,
|
||||
duration: started.elapsed(),
|
||||
};
|
||||
self.state.record_result(result.clone());
|
||||
result
|
||||
}
|
||||
|
||||
pub fn record_failure(&mut self, error: McpErrorSurface) -> McpPhaseResult {
|
||||
let phase = error.phase;
|
||||
self.state.record_error(error.clone());
|
||||
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
|
||||
let result = McpPhaseResult::Failure { phase, error };
|
||||
self.state.record_result(result.clone());
|
||||
result
|
||||
}
|
||||
|
||||
pub fn record_timeout(
|
||||
&mut self,
|
||||
phase: McpLifecyclePhase,
|
||||
waited: Duration,
|
||||
server_name: Option<String>,
|
||||
mut context: BTreeMap<String, String>,
|
||||
) -> McpPhaseResult {
|
||||
context.insert("waited_ms".to_string(), waited.as_millis().to_string());
|
||||
let error = McpErrorSurface::new(
|
||||
phase,
|
||||
server_name,
|
||||
format!(
|
||||
"MCP lifecycle phase {phase} timed out after {} ms",
|
||||
waited.as_millis()
|
||||
),
|
||||
context,
|
||||
true,
|
||||
);
|
||||
self.state.record_error(error.clone());
|
||||
self.state.record_phase(McpLifecyclePhase::ErrorSurfacing);
|
||||
let result = McpPhaseResult::Timeout {
|
||||
phase,
|
||||
waited,
|
||||
error,
|
||||
};
|
||||
self.state.record_result(result.clone());
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
fn dedupe_sorted(mut values: Vec<String>) -> Vec<String> {
|
||||
values.sort();
|
||||
values.dedup();
|
||||
values
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn phase_display_matches_serde_name() {
|
||||
// given
|
||||
let phases = McpLifecyclePhase::all();
|
||||
|
||||
// when
|
||||
let serialized = phases
|
||||
.into_iter()
|
||||
.map(|phase| {
|
||||
(
|
||||
phase.to_string(),
|
||||
serde_json::to_value(phase).expect("serialize phase"),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// then
|
||||
for (display, json_value) in serialized {
|
||||
assert_eq!(json_value, json!(display));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_startup_path_when_running_to_cleanup_then_each_control_transition_succeeds() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
let phases = [
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
McpLifecyclePhase::ResourceDiscovery,
|
||||
McpLifecyclePhase::Ready,
|
||||
McpLifecyclePhase::Invocation,
|
||||
McpLifecyclePhase::Ready,
|
||||
McpLifecyclePhase::Shutdown,
|
||||
McpLifecyclePhase::Cleanup,
|
||||
];
|
||||
|
||||
// when
|
||||
let results = phases
|
||||
.into_iter()
|
||||
.map(|phase| validator.run_phase(phase))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// then
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|result| matches!(result, McpPhaseResult::Success { .. })));
|
||||
assert_eq!(
|
||||
validator.state().current_phase(),
|
||||
Some(McpLifecyclePhase::Cleanup)
|
||||
);
|
||||
for phase in [
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
McpLifecyclePhase::ResourceDiscovery,
|
||||
McpLifecyclePhase::Ready,
|
||||
McpLifecyclePhase::Invocation,
|
||||
McpLifecyclePhase::Shutdown,
|
||||
McpLifecyclePhase::Cleanup,
|
||||
] {
|
||||
assert!(validator.state().phase_timestamp(phase).is_some());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_tool_discovery_when_resource_discovery_is_skipped_then_ready_is_still_allowed() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
for phase in [
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
] {
|
||||
let result = validator.run_phase(phase);
|
||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
||||
}
|
||||
|
||||
// when
|
||||
let result = validator.run_phase(McpLifecyclePhase::Ready);
|
||||
|
||||
// then
|
||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
||||
assert_eq!(
|
||||
validator.state().current_phase(),
|
||||
Some(McpLifecyclePhase::Ready)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validates_expected_phase_transitions() {
|
||||
// given
|
||||
let valid_transitions = [
|
||||
(
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
),
|
||||
(
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
),
|
||||
(
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
),
|
||||
(
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
),
|
||||
(
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
McpLifecyclePhase::ResourceDiscovery,
|
||||
),
|
||||
(McpLifecyclePhase::ToolDiscovery, McpLifecyclePhase::Ready),
|
||||
(
|
||||
McpLifecyclePhase::ResourceDiscovery,
|
||||
McpLifecyclePhase::Ready,
|
||||
),
|
||||
(McpLifecyclePhase::Ready, McpLifecyclePhase::Invocation),
|
||||
(McpLifecyclePhase::Invocation, McpLifecyclePhase::Ready),
|
||||
(McpLifecyclePhase::Ready, McpLifecyclePhase::Shutdown),
|
||||
(
|
||||
McpLifecyclePhase::Invocation,
|
||||
McpLifecyclePhase::ErrorSurfacing,
|
||||
),
|
||||
(
|
||||
McpLifecyclePhase::ErrorSurfacing,
|
||||
McpLifecyclePhase::Shutdown,
|
||||
),
|
||||
(McpLifecyclePhase::Shutdown, McpLifecyclePhase::Cleanup),
|
||||
];
|
||||
|
||||
// when / then
|
||||
for (from, to) in valid_transitions {
|
||||
assert!(McpLifecycleValidator::validate_phase_transition(from, to));
|
||||
}
|
||||
assert!(!McpLifecycleValidator::validate_phase_transition(
|
||||
McpLifecyclePhase::Ready,
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
));
|
||||
assert!(!McpLifecycleValidator::validate_phase_transition(
|
||||
McpLifecyclePhase::Cleanup,
|
||||
McpLifecyclePhase::Ready,
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_invalid_transition_when_running_phase_then_structured_failure_is_recorded() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
let _ = validator.run_phase(McpLifecyclePhase::ConfigLoad);
|
||||
let _ = validator.run_phase(McpLifecyclePhase::ServerRegistration);
|
||||
|
||||
// when
|
||||
let result = validator.run_phase(McpLifecyclePhase::Ready);
|
||||
|
||||
// then
|
||||
match result {
|
||||
McpPhaseResult::Failure { phase, error } => {
|
||||
assert_eq!(phase, McpLifecyclePhase::Ready);
|
||||
assert!(!error.recoverable);
|
||||
assert_eq!(error.phase, McpLifecyclePhase::Ready);
|
||||
assert_eq!(
|
||||
error.context.get("from").map(String::as_str),
|
||||
Some("server_registration")
|
||||
);
|
||||
assert_eq!(error.context.get("to").map(String::as_str), Some("ready"));
|
||||
}
|
||||
other => panic!("expected failure result, got {other:?}"),
|
||||
}
|
||||
assert_eq!(
|
||||
validator.state().current_phase(),
|
||||
Some(McpLifecyclePhase::ErrorSurfacing)
|
||||
);
|
||||
assert_eq!(
|
||||
validator
|
||||
.state()
|
||||
.errors_for_phase(McpLifecyclePhase::Ready)
|
||||
.len(),
|
||||
1
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_each_phase_when_failure_is_recorded_then_error_is_tracked_per_phase() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
|
||||
// when / then
|
||||
for phase in McpLifecyclePhase::all() {
|
||||
let result = validator.record_failure(McpErrorSurface::new(
|
||||
phase,
|
||||
Some("alpha".to_string()),
|
||||
format!("failure at {phase}"),
|
||||
BTreeMap::from([("server".to_string(), "alpha".to_string())]),
|
||||
phase == McpLifecyclePhase::ResourceDiscovery,
|
||||
));
|
||||
|
||||
match result {
|
||||
McpPhaseResult::Failure {
|
||||
phase: failed_phase,
|
||||
error,
|
||||
} => {
|
||||
assert_eq!(failed_phase, phase);
|
||||
assert_eq!(error.phase, phase);
|
||||
assert_eq!(
|
||||
error.recoverable,
|
||||
phase == McpLifecyclePhase::ResourceDiscovery
|
||||
);
|
||||
}
|
||||
other => panic!("expected failure result, got {other:?}"),
|
||||
}
|
||||
assert_eq!(validator.state().errors_for_phase(phase).len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_spawn_connect_timeout_when_recorded_then_waited_duration_is_preserved() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
let waited = Duration::from_millis(250);
|
||||
|
||||
// when
|
||||
let result = validator.record_timeout(
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
waited,
|
||||
Some("alpha".to_string()),
|
||||
BTreeMap::from([("attempt".to_string(), "1".to_string())]),
|
||||
);
|
||||
|
||||
// then
|
||||
match result {
|
||||
McpPhaseResult::Timeout {
|
||||
phase,
|
||||
waited: actual,
|
||||
error,
|
||||
} => {
|
||||
assert_eq!(phase, McpLifecyclePhase::SpawnConnect);
|
||||
assert_eq!(actual, waited);
|
||||
assert!(error.recoverable);
|
||||
assert_eq!(error.server_name.as_deref(), Some("alpha"));
|
||||
}
|
||||
other => panic!("expected timeout result, got {other:?}"),
|
||||
}
|
||||
let errors = validator
|
||||
.state()
|
||||
.errors_for_phase(McpLifecyclePhase::SpawnConnect);
|
||||
assert_eq!(errors.len(), 1);
|
||||
assert_eq!(
|
||||
errors[0].context.get("waited_ms").map(String::as_str),
|
||||
Some("250")
|
||||
);
|
||||
assert_eq!(
|
||||
validator.state().current_phase(),
|
||||
Some(McpLifecyclePhase::ErrorSurfacing)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_partial_server_health_when_building_degraded_report_then_missing_tools_are_reported() {
|
||||
// given
|
||||
let failed = vec![McpFailedServer {
|
||||
server_name: "broken".to_string(),
|
||||
phase: McpLifecyclePhase::InitializeHandshake,
|
||||
error: McpErrorSurface::new(
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
Some("broken".to_string()),
|
||||
"initialize failed",
|
||||
BTreeMap::from([("reason".to_string(), "broken pipe".to_string())]),
|
||||
false,
|
||||
),
|
||||
}];
|
||||
|
||||
// when
|
||||
let report = McpDegradedReport::new(
|
||||
vec!["alpha".to_string(), "beta".to_string(), "alpha".to_string()],
|
||||
failed,
|
||||
vec![
|
||||
"alpha.echo".to_string(),
|
||||
"beta.search".to_string(),
|
||||
"alpha.echo".to_string(),
|
||||
],
|
||||
vec![
|
||||
"alpha.echo".to_string(),
|
||||
"beta.search".to_string(),
|
||||
"broken.fetch".to_string(),
|
||||
],
|
||||
);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
report.working_servers,
|
||||
vec!["alpha".to_string(), "beta".to_string()]
|
||||
);
|
||||
assert_eq!(report.failed_servers.len(), 1);
|
||||
assert_eq!(report.failed_servers[0].server_name, "broken");
|
||||
assert_eq!(
|
||||
report.available_tools,
|
||||
vec!["alpha.echo".to_string(), "beta.search".to_string()]
|
||||
);
|
||||
assert_eq!(report.missing_tools, vec!["broken.fetch".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_failure_during_resource_discovery_when_shutting_down_then_cleanup_still_succeeds() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
for phase in [
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
] {
|
||||
let result = validator.run_phase(phase);
|
||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
||||
}
|
||||
let _ = validator.record_failure(McpErrorSurface::new(
|
||||
McpLifecyclePhase::ResourceDiscovery,
|
||||
Some("alpha".to_string()),
|
||||
"resource listing failed",
|
||||
BTreeMap::from([("reason".to_string(), "timeout".to_string())]),
|
||||
true,
|
||||
));
|
||||
|
||||
// when
|
||||
let shutdown = validator.run_phase(McpLifecyclePhase::Shutdown);
|
||||
let cleanup = validator.run_phase(McpLifecyclePhase::Cleanup);
|
||||
|
||||
// then
|
||||
assert!(matches!(shutdown, McpPhaseResult::Success { .. }));
|
||||
assert!(matches!(cleanup, McpPhaseResult::Success { .. }));
|
||||
assert_eq!(
|
||||
validator.state().current_phase(),
|
||||
Some(McpLifecyclePhase::Cleanup)
|
||||
);
|
||||
assert!(validator
|
||||
.state()
|
||||
.phase_timestamp(McpLifecyclePhase::ErrorSurfacing)
|
||||
.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_surface_display_includes_phase_server_and_recoverable_flag() {
|
||||
// given
|
||||
let error = McpErrorSurface::new(
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
Some("alpha".to_string()),
|
||||
"process exited early",
|
||||
BTreeMap::from([("exit_code".to_string(), "1".to_string())]),
|
||||
true,
|
||||
);
|
||||
|
||||
// when
|
||||
let rendered = error.to_string();
|
||||
|
||||
// then
|
||||
assert!(rendered.contains("spawn_connect"));
|
||||
assert!(rendered.contains("process exited early"));
|
||||
assert!(rendered.contains("server: alpha"));
|
||||
assert!(rendered.contains("recoverable"));
|
||||
let trait_object: &dyn std::error::Error = &error;
|
||||
assert_eq!(trait_object.to_string(), rendered);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_nonrecoverable_failure_when_returning_to_ready_then_validator_rejects_resume() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
for phase in [
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
McpLifecyclePhase::Ready,
|
||||
] {
|
||||
let result = validator.run_phase(phase);
|
||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
||||
}
|
||||
let _ = validator.record_failure(McpErrorSurface::new(
|
||||
McpLifecyclePhase::Invocation,
|
||||
Some("alpha".to_string()),
|
||||
"tool call corrupted the session",
|
||||
BTreeMap::from([("reason".to_string(), "invalid frame".to_string())]),
|
||||
false,
|
||||
));
|
||||
|
||||
// when
|
||||
let result = validator.run_phase(McpLifecyclePhase::Ready);
|
||||
|
||||
// then
|
||||
match result {
|
||||
McpPhaseResult::Failure { phase, error } => {
|
||||
assert_eq!(phase, McpLifecyclePhase::Ready);
|
||||
assert!(!error.recoverable);
|
||||
assert!(error.message.contains("non-recoverable"));
|
||||
}
|
||||
other => panic!("expected failure result, got {other:?}"),
|
||||
}
|
||||
assert_eq!(
|
||||
validator.state().current_phase(),
|
||||
Some(McpLifecyclePhase::ErrorSurfacing)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_recoverable_failure_when_returning_to_ready_then_validator_allows_resume() {
|
||||
// given
|
||||
let mut validator = McpLifecycleValidator::new();
|
||||
for phase in [
|
||||
McpLifecyclePhase::ConfigLoad,
|
||||
McpLifecyclePhase::ServerRegistration,
|
||||
McpLifecyclePhase::SpawnConnect,
|
||||
McpLifecyclePhase::InitializeHandshake,
|
||||
McpLifecyclePhase::ToolDiscovery,
|
||||
McpLifecyclePhase::Ready,
|
||||
] {
|
||||
let result = validator.run_phase(phase);
|
||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
||||
}
|
||||
let _ = validator.record_failure(McpErrorSurface::new(
|
||||
McpLifecyclePhase::Invocation,
|
||||
Some("alpha".to_string()),
|
||||
"tool call failed but can be retried",
|
||||
BTreeMap::from([("reason".to_string(), "upstream timeout".to_string())]),
|
||||
true,
|
||||
));
|
||||
|
||||
// when
|
||||
let result = validator.run_phase(McpLifecyclePhase::Ready);
|
||||
|
||||
// then
|
||||
assert!(matches!(result, McpPhaseResult::Success { .. }));
|
||||
assert_eq!(
|
||||
validator.state().current_phase(),
|
||||
Some(McpLifecyclePhase::Ready)
|
||||
);
|
||||
}
|
||||
}
|
||||
440
crates/runtime/src/mcp_server.rs
Normal file
440
crates/runtime/src/mcp_server.rs
Normal file
@ -0,0 +1,440 @@
|
||||
//! Minimal Model Context Protocol (MCP) server.
|
||||
//!
|
||||
//! Implements a newline-safe, LSP-framed JSON-RPC server over stdio that
|
||||
//! answers `initialize`, `tools/list`, and `tools/call` requests. The framing
|
||||
//! matches the client transport implemented in [`crate::mcp_stdio`] so this
|
||||
//! server can be driven by either an external MCP client (e.g. Claude
|
||||
//! Desktop) or `claw`'s own [`McpServerManager`](crate::McpServerManager).
|
||||
//!
|
||||
//! The server is intentionally small: it exposes a list of pre-built
|
||||
//! [`McpTool`] descriptors and delegates `tools/call` to a caller-supplied
|
||||
//! handler. Tool execution itself lives in the `tools` crate; this module is
|
||||
//! purely the transport + dispatch loop.
|
||||
//!
|
||||
//! [`McpTool`]: crate::mcp_stdio::McpTool
|
||||
|
||||
use std::io;
|
||||
|
||||
use serde_json::{json, Value as JsonValue};
|
||||
use tokio::io::{
|
||||
stdin, stdout, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, Stdin, Stdout,
|
||||
};
|
||||
|
||||
use crate::mcp_stdio::{
|
||||
JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, McpInitializeResult,
|
||||
McpInitializeServerInfo, McpListToolsResult, McpTool, McpToolCallContent, McpToolCallParams,
|
||||
McpToolCallResult,
|
||||
};
|
||||
|
||||
/// Protocol version the server advertises during `initialize`.
|
||||
///
|
||||
/// Matches the version used by the built-in client in
|
||||
/// [`crate::mcp_stdio`], so the two stay in lockstep.
|
||||
pub const MCP_SERVER_PROTOCOL_VERSION: &str = "2025-03-26";
|
||||
|
||||
/// Synchronous handler invoked for every `tools/call` request.
|
||||
///
|
||||
/// Returning `Ok(text)` yields a single `text` content block and
|
||||
/// `isError: false`. Returning `Err(message)` yields a `text` block with the
|
||||
/// error and `isError: true`, mirroring the error-surfacing convention used
|
||||
/// elsewhere in claw.
|
||||
pub type ToolCallHandler =
|
||||
Box<dyn Fn(&str, &JsonValue) -> Result<String, String> + Send + Sync + 'static>;
|
||||
|
||||
/// Configuration for an [`McpServer`] instance.
|
||||
///
|
||||
/// Named `McpServerSpec` rather than `McpServerConfig` to avoid colliding
|
||||
/// with the existing client-side [`crate::config::McpServerConfig`] that
|
||||
/// describes *remote* MCP servers the runtime connects to.
|
||||
pub struct McpServerSpec {
|
||||
/// Name advertised in the `serverInfo` field of the `initialize` response.
|
||||
pub server_name: String,
|
||||
/// Version advertised in the `serverInfo` field of the `initialize`
|
||||
/// response.
|
||||
pub server_version: String,
|
||||
/// Tool descriptors returned for `tools/list`.
|
||||
pub tools: Vec<McpTool>,
|
||||
/// Handler invoked for `tools/call`.
|
||||
pub tool_handler: ToolCallHandler,
|
||||
}
|
||||
|
||||
/// Minimal MCP stdio server.
|
||||
///
|
||||
/// The server runs a blocking read/dispatch/write loop over the current
|
||||
/// process's stdin/stdout, terminating cleanly when the peer closes the
|
||||
/// stream.
|
||||
pub struct McpServer {
|
||||
spec: McpServerSpec,
|
||||
stdin: BufReader<Stdin>,
|
||||
stdout: Stdout,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
#[must_use]
|
||||
pub fn new(spec: McpServerSpec) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the server until the client closes stdin.
|
||||
///
|
||||
/// Returns `Ok(())` on clean EOF; any other I/O error is propagated so
|
||||
/// callers can log and exit non-zero.
|
||||
pub async fn run(&mut self) -> io::Result<()> {
|
||||
loop {
|
||||
let Some(payload) = read_frame(&mut self.stdin).await? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Requests and notifications share a wire format; the absence of
|
||||
// `id` distinguishes notifications, which must never receive a
|
||||
// response.
|
||||
let message: JsonValue = match serde_json::from_slice(&payload) {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
// Parse error with null id per JSON-RPC 2.0 §4.2.
|
||||
let response = JsonRpcResponse::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Null,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32700,
|
||||
message: format!("parse error: {error}"),
|
||||
data: None,
|
||||
}),
|
||||
};
|
||||
write_response(&mut self.stdout, &response).await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if message.get("id").is_none() {
|
||||
// Notification: dispatch for side effects only (e.g. log),
|
||||
// but send no reply.
|
||||
continue;
|
||||
}
|
||||
|
||||
let request: JsonRpcRequest<JsonValue> = match serde_json::from_value(message) {
|
||||
Ok(request) => request,
|
||||
Err(error) => {
|
||||
let response = JsonRpcResponse::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Null,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32600,
|
||||
message: format!("invalid request: {error}"),
|
||||
data: None,
|
||||
}),
|
||||
};
|
||||
write_response(&mut self.stdout, &response).await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = self.dispatch(request);
|
||||
write_response(&mut self.stdout, &response).await?;
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch(&self, request: JsonRpcRequest<JsonValue>) -> JsonRpcResponse<JsonValue> {
|
||||
let id = request.id.clone();
|
||||
match request.method.as_str() {
|
||||
"initialize" => self.handle_initialize(id),
|
||||
"tools/list" => self.handle_tools_list(id),
|
||||
"tools/call" => self.handle_tools_call(id, request.params),
|
||||
other => JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32601,
|
||||
message: format!("method not found: {other}"),
|
||||
data: None,
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_initialize(&self, id: JsonRpcId) -> JsonRpcResponse<JsonValue> {
|
||||
let result = McpInitializeResult {
|
||||
protocol_version: MCP_SERVER_PROTOCOL_VERSION.to_string(),
|
||||
capabilities: json!({ "tools": {} }),
|
||||
server_info: McpInitializeServerInfo {
|
||||
name: self.spec.server_name.clone(),
|
||||
version: self.spec.server_version.clone(),
|
||||
},
|
||||
};
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: serde_json::to_value(result).ok(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tools_list(&self, id: JsonRpcId) -> JsonRpcResponse<JsonValue> {
|
||||
let result = McpListToolsResult {
|
||||
tools: self.spec.tools.clone(),
|
||||
next_cursor: None,
|
||||
};
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: serde_json::to_value(result).ok(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tools_call(
|
||||
&self,
|
||||
id: JsonRpcId,
|
||||
params: Option<JsonValue>,
|
||||
) -> JsonRpcResponse<JsonValue> {
|
||||
let Some(params) = params else {
|
||||
return invalid_params_response(id, "missing params for tools/call");
|
||||
};
|
||||
let call: McpToolCallParams = match serde_json::from_value(params) {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
return invalid_params_response(id, &format!("invalid tools/call params: {error}"));
|
||||
}
|
||||
};
|
||||
let arguments = call.arguments.unwrap_or_else(|| json!({}));
|
||||
let tool_result = (self.spec.tool_handler)(&call.name, &arguments);
|
||||
let (text, is_error) = match tool_result {
|
||||
Ok(text) => (text, false),
|
||||
Err(message) => (message, true),
|
||||
};
|
||||
let mut data = std::collections::BTreeMap::new();
|
||||
data.insert("text".to_string(), JsonValue::String(text));
|
||||
let call_result = McpToolCallResult {
|
||||
content: vec![McpToolCallContent {
|
||||
kind: "text".to_string(),
|
||||
data,
|
||||
}],
|
||||
structured_content: None,
|
||||
is_error: Some(is_error),
|
||||
meta: None,
|
||||
};
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: serde_json::to_value(call_result).ok(),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn invalid_params_response(id: JsonRpcId, message: &str) -> JsonRpcResponse<JsonValue> {
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code: -32602,
|
||||
message: message.to_string(),
|
||||
data: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads a single LSP-framed JSON-RPC payload from `reader`.
|
||||
///
|
||||
/// Returns `Ok(None)` on clean EOF before any header bytes have been read,
|
||||
/// matching how [`crate::mcp_stdio::McpStdioProcess`] treats stream closure.
|
||||
async fn read_frame(reader: &mut BufReader<Stdin>) -> io::Result<Option<Vec<u8>>> {
|
||||
let mut content_length: Option<usize> = None;
|
||||
let mut first_header = true;
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
let bytes_read = reader.read_line(&mut line).await?;
|
||||
if bytes_read == 0 {
|
||||
if first_header {
|
||||
return Ok(None);
|
||||
}
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"MCP stdio stream closed while reading headers",
|
||||
));
|
||||
}
|
||||
first_header = false;
|
||||
if line == "\r\n" || line == "\n" {
|
||||
break;
|
||||
}
|
||||
let header = line.trim_end_matches(['\r', '\n']);
|
||||
if let Some((name, value)) = header.split_once(':') {
|
||||
if name.trim().eq_ignore_ascii_case("Content-Length") {
|
||||
let parsed = value
|
||||
.trim()
|
||||
.parse::<usize>()
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
content_length = Some(parsed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let content_length = content_length.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
|
||||
})?;
|
||||
let mut payload = vec![0_u8; content_length];
|
||||
reader.read_exact(&mut payload).await?;
|
||||
Ok(Some(payload))
|
||||
}
|
||||
|
||||
async fn write_response(
|
||||
stdout: &mut Stdout,
|
||||
response: &JsonRpcResponse<JsonValue>,
|
||||
) -> io::Result<()> {
|
||||
let body = serde_json::to_vec(response)
|
||||
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
|
||||
let header = format!("Content-Length: {}\r\n\r\n", body.len());
|
||||
stdout.write_all(header.as_bytes()).await?;
|
||||
stdout.write_all(&body).await?;
|
||||
stdout.flush().await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn dispatch_initialize_returns_server_info() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "9.9.9".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|_, _| Ok(String::new())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(1),
|
||||
method: "initialize".to_string(),
|
||||
params: None,
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
assert_eq!(response.id, JsonRpcId::Number(1));
|
||||
assert!(response.error.is_none());
|
||||
let result = response.result.expect("initialize result");
|
||||
assert_eq!(result["protocolVersion"], MCP_SERVER_PROTOCOL_VERSION);
|
||||
assert_eq!(result["serverInfo"]["name"], "test");
|
||||
assert_eq!(result["serverInfo"]["version"], "9.9.9");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_tools_list_returns_registered_tools() {
|
||||
let tool = McpTool {
|
||||
name: "echo".to_string(),
|
||||
description: Some("Echo".to_string()),
|
||||
input_schema: Some(json!({"type": "object"})),
|
||||
annotations: None,
|
||||
meta: None,
|
||||
};
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: vec![tool.clone()],
|
||||
tool_handler: Box::new(|_, _| Ok(String::new())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(2),
|
||||
method: "tools/list".to_string(),
|
||||
params: None,
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
assert!(response.error.is_none());
|
||||
let result = response.result.expect("tools/list result");
|
||||
assert_eq!(result["tools"][0]["name"], "echo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_tools_call_wraps_handler_output() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|name, args| Ok(format!("called {name} with {args}"))),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(3),
|
||||
method: "tools/call".to_string(),
|
||||
params: Some(json!({
|
||||
"name": "echo",
|
||||
"arguments": {"text": "hi"}
|
||||
})),
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
assert!(response.error.is_none());
|
||||
let result = response.result.expect("tools/call result");
|
||||
assert_eq!(result["isError"], false);
|
||||
assert_eq!(result["content"][0]["type"], "text");
|
||||
assert!(result["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.starts_with("called echo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_tools_call_surfaces_handler_error() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|_, _| Err("boom".to_string())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(4),
|
||||
method: "tools/call".to_string(),
|
||||
params: Some(json!({"name": "broken"})),
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
let result = response.result.expect("tools/call result");
|
||||
assert_eq!(result["isError"], true);
|
||||
assert_eq!(result["content"][0]["text"], "boom");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dispatch_unknown_method_returns_method_not_found() {
|
||||
let server = McpServer {
|
||||
spec: McpServerSpec {
|
||||
server_name: "test".to_string(),
|
||||
server_version: "0.0.0".to_string(),
|
||||
tools: Vec::new(),
|
||||
tool_handler: Box::new(|_, _| Ok(String::new())),
|
||||
},
|
||||
stdin: BufReader::new(stdin()),
|
||||
stdout: stdout(),
|
||||
};
|
||||
let request = JsonRpcRequest::<JsonValue> {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: JsonRpcId::Number(5),
|
||||
method: "nonsense".to_string(),
|
||||
params: None,
|
||||
};
|
||||
let response = server.dispatch(request);
|
||||
let error = response.error.expect("error payload");
|
||||
assert_eq!(error.code, -32601);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
920
crates/runtime/src/mcp_tool_bridge.rs
Normal file
920
crates/runtime/src/mcp_tool_bridge.rs
Normal file
@ -0,0 +1,920 @@
|
||||
#![allow(
|
||||
clippy::await_holding_lock,
|
||||
clippy::doc_markdown,
|
||||
clippy::match_same_arms,
|
||||
clippy::must_use_candidate,
|
||||
clippy::uninlined_format_args,
|
||||
clippy::unnested_or_patterns
|
||||
)]
|
||||
//! Bridge between MCP tool surface (ListMcpResources, ReadMcpResource, McpAuth, MCP)
|
||||
//! and the existing McpServerManager runtime.
|
||||
//!
|
||||
//! Provides a stateful client registry that tool handlers can use to
|
||||
//! connect to MCP servers and invoke their capabilities.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
use crate::mcp::mcp_tool_name;
|
||||
use crate::mcp_stdio::McpServerManager;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Status of a managed MCP server connection.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum McpConnectionStatus {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Connected,
|
||||
AuthRequired,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for McpConnectionStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Disconnected => write!(f, "disconnected"),
|
||||
Self::Connecting => write!(f, "connecting"),
|
||||
Self::Connected => write!(f, "connected"),
|
||||
Self::AuthRequired => write!(f, "auth_required"),
|
||||
Self::Error => write!(f, "error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata about an MCP resource.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpResourceInfo {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
/// Metadata about an MCP tool exposed by a server.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolInfo {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub input_schema: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Tracked state of an MCP server connection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpServerState {
|
||||
pub server_name: String,
|
||||
pub status: McpConnectionStatus,
|
||||
pub tools: Vec<McpToolInfo>,
|
||||
pub resources: Vec<McpResourceInfo>,
|
||||
pub server_info: Option<String>,
|
||||
pub error_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct McpToolRegistry {
|
||||
inner: Arc<Mutex<HashMap<String, McpServerState>>>,
|
||||
manager: Arc<OnceLock<Arc<Mutex<McpServerManager>>>>,
|
||||
}
|
||||
|
||||
impl McpToolRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn set_manager(
|
||||
&self,
|
||||
manager: Arc<Mutex<McpServerManager>>,
|
||||
) -> Result<(), Arc<Mutex<McpServerManager>>> {
|
||||
self.manager.set(manager)
|
||||
}
|
||||
|
||||
pub fn register_server(
|
||||
&self,
|
||||
server_name: &str,
|
||||
status: McpConnectionStatus,
|
||||
tools: Vec<McpToolInfo>,
|
||||
resources: Vec<McpResourceInfo>,
|
||||
server_info: Option<String>,
|
||||
) {
|
||||
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.insert(
|
||||
server_name.to_owned(),
|
||||
McpServerState {
|
||||
server_name: server_name.to_owned(),
|
||||
status,
|
||||
tools,
|
||||
resources,
|
||||
server_info,
|
||||
error_message: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get_server(&self, server_name: &str) -> Option<McpServerState> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.get(server_name).cloned()
|
||||
}
|
||||
|
||||
pub fn list_servers(&self) -> Vec<McpServerState> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn list_resources(&self, server_name: &str) -> Result<Vec<McpResourceInfo>, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
match inner.get(server_name) {
|
||||
Some(state) => {
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
Ok(state.resources.clone())
|
||||
}
|
||||
None => Err(format!("server '{}' not found", server_name)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_resource(&self, server_name: &str, uri: &str) -> Result<McpResourceInfo, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
let state = inner
|
||||
.get(server_name)
|
||||
.ok_or_else(|| format!("server '{}' not found", server_name))?;
|
||||
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
|
||||
state
|
||||
.resources
|
||||
.iter()
|
||||
.find(|r| r.uri == uri)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("resource '{}' not found on server '{}'", uri, server_name))
|
||||
}
|
||||
|
||||
pub fn list_tools(&self, server_name: &str) -> Result<Vec<McpToolInfo>, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
match inner.get(server_name) {
|
||||
Some(state) => {
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
Ok(state.tools.clone())
|
||||
}
|
||||
None => Err(format!("server '{}' not found", server_name)),
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_tool_call(
|
||||
manager: Arc<Mutex<McpServerManager>>,
|
||||
qualified_tool_name: String,
|
||||
arguments: Option<serde_json::Value>,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let join_handle = std::thread::Builder::new()
|
||||
.name(format!("mcp-tool-call-{qualified_tool_name}"))
|
||||
.spawn(move || {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|error| format!("failed to create MCP tool runtime: {error}"))?;
|
||||
|
||||
runtime.block_on(async move {
|
||||
let response = {
|
||||
let mut manager = manager
|
||||
.lock()
|
||||
.map_err(|_| "mcp server manager lock poisoned".to_string())?;
|
||||
manager
|
||||
.discover_tools()
|
||||
.await
|
||||
.map_err(|error| error.to_string())?;
|
||||
let response = manager
|
||||
.call_tool(&qualified_tool_name, arguments)
|
||||
.await
|
||||
.map_err(|error| error.to_string());
|
||||
let shutdown = manager.shutdown().await.map_err(|error| error.to_string());
|
||||
|
||||
match (response, shutdown) {
|
||||
(Ok(response), Ok(())) => Ok(response),
|
||||
(Err(error), Ok(())) | (Err(error), Err(_)) => Err(error),
|
||||
(Ok(_), Err(error)) => Err(error),
|
||||
}
|
||||
}?;
|
||||
|
||||
if let Some(error) = response.error {
|
||||
return Err(format!(
|
||||
"MCP server returned JSON-RPC error for tools/call: {} ({})",
|
||||
error.message, error.code
|
||||
));
|
||||
}
|
||||
|
||||
let result = response.result.ok_or_else(|| {
|
||||
"MCP server returned no result for tools/call".to_string()
|
||||
})?;
|
||||
|
||||
serde_json::to_value(result)
|
||||
.map_err(|error| format!("failed to serialize MCP tool result: {error}"))
|
||||
})
|
||||
})
|
||||
.map_err(|error| format!("failed to spawn MCP tool call thread: {error}"))?;
|
||||
|
||||
join_handle.join().map_err(|panic_payload| {
|
||||
if let Some(message) = panic_payload.downcast_ref::<&str>() {
|
||||
format!("MCP tool call thread panicked: {message}")
|
||||
} else if let Some(message) = panic_payload.downcast_ref::<String>() {
|
||||
format!("MCP tool call thread panicked: {message}")
|
||||
} else {
|
||||
"MCP tool call thread panicked".to_string()
|
||||
}
|
||||
})?
|
||||
}
|
||||
|
||||
pub fn call_tool(
|
||||
&self,
|
||||
server_name: &str,
|
||||
tool_name: &str,
|
||||
arguments: &serde_json::Value,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
let state = inner
|
||||
.get(server_name)
|
||||
.ok_or_else(|| format!("server '{}' not found", server_name))?;
|
||||
|
||||
if state.status != McpConnectionStatus::Connected {
|
||||
return Err(format!(
|
||||
"server '{}' is not connected (status: {})",
|
||||
server_name, state.status
|
||||
));
|
||||
}
|
||||
|
||||
if !state.tools.iter().any(|t| t.name == tool_name) {
|
||||
return Err(format!(
|
||||
"tool '{}' not found on server '{}'",
|
||||
tool_name, server_name
|
||||
));
|
||||
}
|
||||
|
||||
drop(inner);
|
||||
|
||||
let manager = self
|
||||
.manager
|
||||
.get()
|
||||
.cloned()
|
||||
.ok_or_else(|| "MCP server manager is not configured".to_string())?;
|
||||
|
||||
Self::spawn_tool_call(
|
||||
manager,
|
||||
mcp_tool_name(server_name, tool_name),
|
||||
(!arguments.is_null()).then(|| arguments.clone()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Set auth status for a server.
|
||||
pub fn set_auth_status(
|
||||
&self,
|
||||
server_name: &str,
|
||||
status: McpConnectionStatus,
|
||||
) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
let state = inner
|
||||
.get_mut(server_name)
|
||||
.ok_or_else(|| format!("server '{}' not found", server_name))?;
|
||||
state.status = status;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disconnect / remove a server.
|
||||
pub fn disconnect(&self, server_name: &str) -> Option<McpServerState> {
|
||||
let mut inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.remove(server_name)
|
||||
}
|
||||
|
||||
/// Number of registered servers.
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("mcp registry lock poisoned");
|
||||
inner.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, unix))]
|
||||
mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use super::*;
|
||||
use crate::config::{
|
||||
ConfigSource, McpServerConfig, McpStdioServerConfig, ScopedMcpServerConfig,
|
||||
};
|
||||
|
||||
fn temp_dir() -> PathBuf {
|
||||
static NEXT_TEMP_DIR_ID: AtomicU64 = AtomicU64::new(0);
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
let unique_id = NEXT_TEMP_DIR_ID.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!("runtime-mcp-tool-bridge-{nanos}-{unique_id}"))
|
||||
}
|
||||
|
||||
fn cleanup_script(script_path: &Path) {
|
||||
if let Some(root) = script_path.parent() {
|
||||
let _ = fs::remove_dir_all(root);
|
||||
}
|
||||
}
|
||||
|
||||
fn write_bridge_mcp_server_script() -> PathBuf {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("temp dir");
|
||||
let script_path = root.join("bridge-mcp-server.py");
|
||||
let script = [
|
||||
"#!/usr/bin/env python3",
|
||||
"import json, os, sys",
|
||||
"LABEL = os.environ.get('MCP_SERVER_LABEL', 'server')",
|
||||
"LOG_PATH = os.environ.get('MCP_LOG_PATH')",
|
||||
"",
|
||||
"def log(method):",
|
||||
" if LOG_PATH:",
|
||||
" with open(LOG_PATH, 'a', encoding='utf-8') as handle:",
|
||||
" handle.write(f'{method}\\n')",
|
||||
"",
|
||||
"def read_message():",
|
||||
" header = b''",
|
||||
r" while not header.endswith(b'\r\n\r\n'):",
|
||||
" chunk = sys.stdin.buffer.read(1)",
|
||||
" if not chunk:",
|
||||
" return None",
|
||||
" header += chunk",
|
||||
" length = 0",
|
||||
r" for line in header.decode().split('\r\n'):",
|
||||
r" if line.lower().startswith('content-length:'):",
|
||||
r" length = int(line.split(':', 1)[1].strip())",
|
||||
" payload = sys.stdin.buffer.read(length)",
|
||||
" return json.loads(payload.decode())",
|
||||
"",
|
||||
"def send_message(message):",
|
||||
" payload = json.dumps(message).encode()",
|
||||
r" sys.stdout.buffer.write(f'Content-Length: {len(payload)}\r\n\r\n'.encode() + payload)",
|
||||
" sys.stdout.buffer.flush()",
|
||||
"",
|
||||
"while True:",
|
||||
" request = read_message()",
|
||||
" if request is None:",
|
||||
" break",
|
||||
" method = request['method']",
|
||||
" log(method)",
|
||||
" if method == 'initialize':",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'protocolVersion': request['params']['protocolVersion'],",
|
||||
" 'capabilities': {'tools': {}},",
|
||||
" 'serverInfo': {'name': LABEL, 'version': '1.0.0'}",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/list':",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'tools': [",
|
||||
" {",
|
||||
" 'name': 'echo',",
|
||||
" 'description': f'Echo tool for {LABEL}',",
|
||||
" 'inputSchema': {",
|
||||
" 'type': 'object',",
|
||||
" 'properties': {'text': {'type': 'string'}},",
|
||||
" 'required': ['text']",
|
||||
" }",
|
||||
" }",
|
||||
" ]",
|
||||
" }",
|
||||
" })",
|
||||
" elif method == 'tools/call':",
|
||||
" args = request['params'].get('arguments') or {}",
|
||||
" text = args.get('text', '')",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'result': {",
|
||||
" 'content': [{'type': 'text', 'text': f'{LABEL}:{text}'}],",
|
||||
" 'structuredContent': {'server': LABEL, 'echoed': text},",
|
||||
" 'isError': False",
|
||||
" }",
|
||||
" })",
|
||||
" else:",
|
||||
" send_message({",
|
||||
" 'jsonrpc': '2.0',",
|
||||
" 'id': request['id'],",
|
||||
" 'error': {'code': -32601, 'message': f'unknown method: {method}'},",
|
||||
" })",
|
||||
"",
|
||||
]
|
||||
.join("\n");
|
||||
fs::write(&script_path, script).expect("write script");
|
||||
let mut permissions = fs::metadata(&script_path).expect("metadata").permissions();
|
||||
permissions.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, permissions).expect("chmod");
|
||||
script_path
|
||||
}
|
||||
|
||||
fn manager_server_config(
|
||||
script_path: &Path,
|
||||
server_name: &str,
|
||||
log_path: &Path,
|
||||
) -> ScopedMcpServerConfig {
|
||||
ScopedMcpServerConfig {
|
||||
scope: ConfigSource::Local,
|
||||
config: McpServerConfig::Stdio(McpStdioServerConfig {
|
||||
command: "python3".to_string(),
|
||||
args: vec![script_path.to_string_lossy().into_owned()],
|
||||
env: BTreeMap::from([
|
||||
("MCP_SERVER_LABEL".to_string(), server_name.to_string()),
|
||||
(
|
||||
"MCP_LOG_PATH".to_string(),
|
||||
log_path.to_string_lossy().into_owned(),
|
||||
),
|
||||
]),
|
||||
tool_call_timeout_ms: Some(1_000),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registers_and_retrieves_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"test-server",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "greet".into(),
|
||||
description: Some("Greet someone".into()),
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![McpResourceInfo {
|
||||
uri: "res://data".into(),
|
||||
name: "Data".into(),
|
||||
description: None,
|
||||
mime_type: Some("application/json".into()),
|
||||
}],
|
||||
Some("TestServer v1.0".into()),
|
||||
);
|
||||
|
||||
let server = registry.get_server("test-server").expect("should exist");
|
||||
assert_eq!(server.status, McpConnectionStatus::Connected);
|
||||
assert_eq!(server.tools.len(), 1);
|
||||
assert_eq!(server.resources.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_resources_from_connected_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![McpResourceInfo {
|
||||
uri: "res://alpha".into(),
|
||||
name: "Alpha".into(),
|
||||
description: None,
|
||||
mime_type: None,
|
||||
}],
|
||||
None,
|
||||
);
|
||||
|
||||
let resources = registry.list_resources("srv").expect("should succeed");
|
||||
assert_eq!(resources.len(), 1);
|
||||
assert_eq!(resources[0].uri, "res://alpha");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_resource_listing_for_disconnected_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Disconnected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
assert!(registry.list_resources("srv").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reads_specific_resource() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![McpResourceInfo {
|
||||
uri: "res://data".into(),
|
||||
name: "Data".into(),
|
||||
description: Some("Test data".into()),
|
||||
mime_type: Some("text/plain".into()),
|
||||
}],
|
||||
None,
|
||||
);
|
||||
|
||||
let resource = registry
|
||||
.read_resource("srv", "res://data")
|
||||
.expect("should find");
|
||||
assert_eq!(resource.name, "Data");
|
||||
|
||||
assert!(registry.read_resource("srv", "res://missing").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_connected_server_without_manager_when_calling_tool_then_it_errors() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "greet".into(),
|
||||
description: None,
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
let error = registry
|
||||
.call_tool("srv", "greet", &serde_json::json!({"name": "world"}))
|
||||
.expect_err("should require a configured manager");
|
||||
assert!(error.contains("MCP server manager is not configured"));
|
||||
|
||||
// Unknown tool should fail
|
||||
assert!(registry
|
||||
.call_tool("srv", "missing", &serde_json::json!({}))
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn given_connected_server_with_manager_when_calling_tool_then_it_returns_live_result() {
|
||||
let script_path = write_bridge_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("bridge.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"alpha".to_string(),
|
||||
manager_server_config(&script_path, "alpha", &log_path),
|
||||
)]);
|
||||
let manager = Arc::new(Mutex::new(McpServerManager::from_servers(&servers)));
|
||||
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool for alpha".into()),
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"]
|
||||
})),
|
||||
}],
|
||||
vec![],
|
||||
Some("bridge test server".into()),
|
||||
);
|
||||
registry
|
||||
.set_manager(Arc::clone(&manager))
|
||||
.expect("manager should only be set once");
|
||||
|
||||
let result = registry
|
||||
.call_tool("alpha", "echo", &serde_json::json!({"text": "hello"}))
|
||||
.expect("should return live MCP result");
|
||||
|
||||
assert_eq!(
|
||||
result["structuredContent"]["server"],
|
||||
serde_json::json!("alpha")
|
||||
);
|
||||
assert_eq!(
|
||||
result["structuredContent"]["echoed"],
|
||||
serde_json::json!("hello")
|
||||
);
|
||||
assert_eq!(
|
||||
result["content"][0]["text"],
|
||||
serde_json::json!("alpha:hello")
|
||||
);
|
||||
|
||||
let log = fs::read_to_string(&log_path).expect("read log");
|
||||
assert_eq!(
|
||||
log.lines().collect::<Vec<_>>(),
|
||||
vec!["initialize", "tools/list", "tools/call"]
|
||||
);
|
||||
|
||||
cleanup_script(&script_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_tool_call_on_disconnected_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::AuthRequired,
|
||||
vec![McpToolInfo {
|
||||
name: "greet".into(),
|
||||
description: None,
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(registry
|
||||
.call_tool("srv", "greet", &serde_json::json!({}))
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sets_auth_and_disconnects() {
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::AuthRequired,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
registry
|
||||
.set_auth_status("srv", McpConnectionStatus::Connected)
|
||||
.expect("should succeed");
|
||||
let state = registry.get_server("srv").unwrap();
|
||||
assert_eq!(state.status, McpConnectionStatus::Connected);
|
||||
|
||||
let removed = registry.disconnect("srv");
|
||||
assert!(removed.is_some());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_operations_on_missing_server() {
|
||||
let registry = McpToolRegistry::new();
|
||||
assert!(registry.list_resources("missing").is_err());
|
||||
assert!(registry.read_resource("missing", "uri").is_err());
|
||||
assert!(registry.list_tools("missing").is_err());
|
||||
assert!(registry
|
||||
.call_tool("missing", "tool", &serde_json::json!({}))
|
||||
.is_err());
|
||||
assert!(registry
|
||||
.set_auth_status("missing", McpConnectionStatus::Connected)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_connection_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(McpConnectionStatus::Disconnected, "disconnected"),
|
||||
(McpConnectionStatus::Connecting, "connecting"),
|
||||
(McpConnectionStatus::Connected, "connected"),
|
||||
(McpConnectionStatus::AuthRequired, "auth_required"),
|
||||
(McpConnectionStatus::Error, "error"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("disconnected".to_string(), "disconnected"),
|
||||
("connecting".to_string(), "connecting"),
|
||||
("connected".to_string(), "connected"),
|
||||
("auth_required".to_string(), "auth_required"),
|
||||
("error".to_string(), "error"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_servers_returns_all_registered() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry.register_server(
|
||||
"beta",
|
||||
McpConnectionStatus::Connecting,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let servers = registry.list_servers();
|
||||
|
||||
// then
|
||||
assert_eq!(servers.len(), 2);
|
||||
assert!(servers.iter().any(|server| server.server_name == "alpha"));
|
||||
assert!(servers.iter().any(|server| server.server_name == "beta"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_from_connected_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "inspect".into(),
|
||||
description: Some("Inspect data".into()),
|
||||
input_schema: Some(serde_json::json!({"type": "object"})),
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let tools = registry.list_tools("srv").expect("tools should list");
|
||||
|
||||
// then
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].name, "inspect");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_rejects_disconnected_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::AuthRequired,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let result = registry.list_tools("srv");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("non-connected server should fail");
|
||||
assert!(error.contains("not connected"));
|
||||
assert!(error.contains("auth_required"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_rejects_missing_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.list_tools("missing");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
result.expect_err("missing server should fail"),
|
||||
"server 'missing' not found"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_server_returns_none_for_missing() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let server = registry.get_server("missing");
|
||||
|
||||
// then
|
||||
assert!(server.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_tool_payload_structure() {
|
||||
let script_path = write_bridge_mcp_server_script();
|
||||
let root = script_path.parent().expect("script parent");
|
||||
let log_path = root.join("payload.log");
|
||||
let servers = BTreeMap::from([(
|
||||
"srv".to_string(),
|
||||
manager_server_config(&script_path, "srv", &log_path),
|
||||
)]);
|
||||
let registry = McpToolRegistry::new();
|
||||
let arguments = serde_json::json!({"text": "world"});
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "echo".into(),
|
||||
description: Some("Echo tool for srv".into()),
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {"text": {"type": "string"}},
|
||||
"required": ["text"]
|
||||
})),
|
||||
}],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry
|
||||
.set_manager(Arc::new(Mutex::new(McpServerManager::from_servers(
|
||||
&servers,
|
||||
))))
|
||||
.expect("manager should only be set once");
|
||||
|
||||
let result = registry
|
||||
.call_tool("srv", "echo", &arguments)
|
||||
.expect("tool should return live payload");
|
||||
|
||||
assert_eq!(result["structuredContent"]["server"], "srv");
|
||||
assert_eq!(result["structuredContent"]["echoed"], "world");
|
||||
assert_eq!(result["content"][0]["text"], "srv:world");
|
||||
|
||||
cleanup_script(&script_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn upsert_overwrites_existing_server() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
registry.register_server("srv", McpConnectionStatus::Connecting, vec![], vec![], None);
|
||||
|
||||
// when
|
||||
registry.register_server(
|
||||
"srv",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![McpToolInfo {
|
||||
name: "inspect".into(),
|
||||
description: None,
|
||||
input_schema: None,
|
||||
}],
|
||||
vec![],
|
||||
Some("Inspector".into()),
|
||||
);
|
||||
let state = registry.get_server("srv").expect("server should exist");
|
||||
|
||||
// then
|
||||
assert_eq!(state.status, McpConnectionStatus::Connected);
|
||||
assert_eq!(state.tools.len(), 1);
|
||||
assert_eq!(state.server_info.as_deref(), Some("Inspector"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disconnect_missing_returns_none() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.disconnect("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn len_and_is_empty_transitions() {
|
||||
// given
|
||||
let registry = McpToolRegistry::new();
|
||||
|
||||
// when
|
||||
registry.register_server(
|
||||
"alpha",
|
||||
McpConnectionStatus::Connected,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
registry.register_server("beta", McpConnectionStatus::Connected, vec![], vec![], None);
|
||||
let after_create = registry.len();
|
||||
registry.disconnect("alpha");
|
||||
let after_first_remove = registry.len();
|
||||
registry.disconnect("beta");
|
||||
|
||||
// then
|
||||
assert_eq!(after_create, 2);
|
||||
assert_eq!(after_first_remove, 1);
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
}
|
||||
@ -9,6 +9,7 @@ use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::config::OAuthConfig;
|
||||
|
||||
/// Persisted OAuth access token bundle used by the CLI.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct OAuthTokenSet {
|
||||
pub access_token: String,
|
||||
@ -17,6 +18,7 @@ pub struct OAuthTokenSet {
|
||||
pub scopes: Vec<String>,
|
||||
}
|
||||
|
||||
/// PKCE verifier/challenge pair generated for an OAuth authorization flow.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PkceCodePair {
|
||||
pub verifier: String,
|
||||
@ -24,6 +26,7 @@ pub struct PkceCodePair {
|
||||
pub challenge_method: PkceChallengeMethod,
|
||||
}
|
||||
|
||||
/// Challenge algorithms supported by the local PKCE helpers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PkceChallengeMethod {
|
||||
S256,
|
||||
@ -38,6 +41,7 @@ impl PkceChallengeMethod {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters needed to build an authorization URL for browser-based login.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OAuthAuthorizationRequest {
|
||||
pub authorize_url: String,
|
||||
@ -50,6 +54,7 @@ pub struct OAuthAuthorizationRequest {
|
||||
pub extra_params: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
/// Request body for exchanging an OAuth authorization code for tokens.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OAuthTokenExchangeRequest {
|
||||
pub grant_type: &'static str,
|
||||
@ -60,6 +65,7 @@ pub struct OAuthTokenExchangeRequest {
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
/// Request body for refreshing an existing OAuth token set.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OAuthRefreshRequest {
|
||||
pub grant_type: &'static str,
|
||||
@ -68,6 +74,7 @@ pub struct OAuthRefreshRequest {
|
||||
pub scopes: Vec<String>,
|
||||
}
|
||||
|
||||
/// Parsed query parameters returned to the local OAuth callback endpoint.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct OAuthCallbackParams {
|
||||
pub code: Option<String>,
|
||||
@ -327,15 +334,16 @@ fn credentials_home_dir() -> io::Result<PathBuf> {
|
||||
if let Some(path) = std::env::var_os("CLAW_CONFIG_HOME") {
|
||||
return Ok(PathBuf::from(path));
|
||||
}
|
||||
if let Some(path) = std::env::var_os("HOME") {
|
||||
return Ok(PathBuf::from(path).join(".claw"));
|
||||
}
|
||||
if cfg!(target_os = "windows") {
|
||||
if let Some(path) = std::env::var_os("USERPROFILE") {
|
||||
return Ok(PathBuf::from(path).join(".claw"));
|
||||
}
|
||||
}
|
||||
Err(io::Error::new(io::ErrorKind::NotFound, "HOME or USERPROFILE is not set"))
|
||||
let home = std::env::var_os("HOME")
|
||||
.or_else(|| std::env::var_os("USERPROFILE"))
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"HOME is not set (on Windows, set USERPROFILE or HOME, \
|
||||
or use CLAW_CONFIG_HOME to point directly at the config directory)",
|
||||
)
|
||||
})?;
|
||||
Ok(PathBuf::from(home).join(".claw"))
|
||||
}
|
||||
|
||||
fn read_credentials_root(path: &PathBuf) -> io::Result<Map<String, Value>> {
|
||||
@ -448,7 +456,7 @@ fn decode_hex(byte: u8) -> Result<u8, String> {
|
||||
b'0'..=b'9' => Ok(byte - b'0'),
|
||||
b'a'..=b'f' => Ok(byte - b'a' + 10),
|
||||
b'A'..=b'F' => Ok(byte - b'A' + 10),
|
||||
_ => Err(format!("invalid percent-encoding byte: {byte}")),
|
||||
_ => Err(format!("invalid percent byte: {byte}")),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
551
crates/runtime/src/permission_enforcer.rs
Normal file
551
crates/runtime/src/permission_enforcer.rs
Normal file
@ -0,0 +1,551 @@
|
||||
#![allow(
|
||||
clippy::match_wildcard_for_single_variants,
|
||||
clippy::must_use_candidate,
|
||||
clippy::uninlined_format_args
|
||||
)]
|
||||
//! Permission enforcement layer that gates tool execution based on the
|
||||
//! active `PermissionPolicy`.
|
||||
|
||||
use crate::permissions::{PermissionMode, PermissionOutcome, PermissionPolicy};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "outcome")]
|
||||
pub enum EnforcementResult {
|
||||
/// Tool execution is allowed.
|
||||
Allowed,
|
||||
/// Tool execution was denied due to insufficient permissions.
|
||||
Denied {
|
||||
tool: String,
|
||||
active_mode: String,
|
||||
required_mode: String,
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PermissionEnforcer {
|
||||
policy: PermissionPolicy,
|
||||
}
|
||||
|
||||
impl PermissionEnforcer {
|
||||
#[must_use]
|
||||
pub fn new(policy: PermissionPolicy) -> Self {
|
||||
Self { policy }
|
||||
}
|
||||
|
||||
/// Check whether a tool can be executed under the current permission policy.
|
||||
/// Auto-denies when prompting is required but no prompter is provided.
|
||||
pub fn check(&self, tool_name: &str, input: &str) -> EnforcementResult {
|
||||
// When the active mode is Prompt, defer to the caller's interactive
|
||||
// prompt flow rather than hard-denying (the enforcer has no prompter).
|
||||
if self.policy.active_mode() == PermissionMode::Prompt {
|
||||
return EnforcementResult::Allowed;
|
||||
}
|
||||
|
||||
let outcome = self.policy.authorize(tool_name, input, None);
|
||||
|
||||
match outcome {
|
||||
PermissionOutcome::Allow => EnforcementResult::Allowed,
|
||||
PermissionOutcome::Deny { reason } => {
|
||||
let active_mode = self.policy.active_mode();
|
||||
let required_mode = self.policy.required_mode_for(tool_name);
|
||||
EnforcementResult::Denied {
|
||||
tool: tool_name.to_owned(),
|
||||
active_mode: active_mode.as_str().to_owned(),
|
||||
required_mode: required_mode.as_str().to_owned(),
|
||||
reason,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_allowed(&self, tool_name: &str, input: &str) -> bool {
|
||||
matches!(self.check(tool_name, input), EnforcementResult::Allowed)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn active_mode(&self) -> PermissionMode {
|
||||
self.policy.active_mode()
|
||||
}
|
||||
|
||||
/// Classify a file operation against workspace boundaries.
|
||||
pub fn check_file_write(&self, path: &str, workspace_root: &str) -> EnforcementResult {
|
||||
let mode = self.policy.active_mode();
|
||||
|
||||
match mode {
|
||||
PermissionMode::ReadOnly => EnforcementResult::Denied {
|
||||
tool: "write_file".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
|
||||
reason: format!("file writes are not allowed in '{}' mode", mode.as_str()),
|
||||
},
|
||||
PermissionMode::WorkspaceWrite => {
|
||||
if is_within_workspace(path, workspace_root) {
|
||||
EnforcementResult::Allowed
|
||||
} else {
|
||||
EnforcementResult::Denied {
|
||||
tool: "write_file".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(),
|
||||
reason: format!(
|
||||
"path '{}' is outside workspace root '{}'",
|
||||
path, workspace_root
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Allow and DangerFullAccess permit all writes
|
||||
PermissionMode::Allow | PermissionMode::DangerFullAccess => EnforcementResult::Allowed,
|
||||
PermissionMode::Prompt => EnforcementResult::Denied {
|
||||
tool: "write_file".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
|
||||
reason: "file write requires confirmation in prompt mode".to_owned(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a bash command should be allowed based on current mode.
|
||||
pub fn check_bash(&self, command: &str) -> EnforcementResult {
|
||||
let mode = self.policy.active_mode();
|
||||
|
||||
match mode {
|
||||
PermissionMode::ReadOnly => {
|
||||
if is_read_only_command(command) {
|
||||
EnforcementResult::Allowed
|
||||
} else {
|
||||
EnforcementResult::Denied {
|
||||
tool: "bash".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::WorkspaceWrite.as_str().to_owned(),
|
||||
reason: format!(
|
||||
"command may modify state; not allowed in '{}' mode",
|
||||
mode.as_str()
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
PermissionMode::Prompt => EnforcementResult::Denied {
|
||||
tool: "bash".to_owned(),
|
||||
active_mode: mode.as_str().to_owned(),
|
||||
required_mode: PermissionMode::DangerFullAccess.as_str().to_owned(),
|
||||
reason: "bash requires confirmation in prompt mode".to_owned(),
|
||||
},
|
||||
// WorkspaceWrite, Allow, DangerFullAccess: permit bash
|
||||
_ => EnforcementResult::Allowed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple workspace boundary check via string prefix.
|
||||
fn is_within_workspace(path: &str, workspace_root: &str) -> bool {
|
||||
let normalized = if path.starts_with('/') {
|
||||
path.to_owned()
|
||||
} else {
|
||||
format!("{workspace_root}/{path}")
|
||||
};
|
||||
|
||||
let root = if workspace_root.ends_with('/') {
|
||||
workspace_root.to_owned()
|
||||
} else {
|
||||
format!("{workspace_root}/")
|
||||
};
|
||||
|
||||
normalized.starts_with(&root) || normalized == workspace_root.trim_end_matches('/')
|
||||
}
|
||||
|
||||
/// Conservative heuristic: is this bash command read-only?
|
||||
fn is_read_only_command(command: &str) -> bool {
|
||||
let first_token = command
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
|
||||
matches!(
|
||||
first_token,
|
||||
"cat"
|
||||
| "head"
|
||||
| "tail"
|
||||
| "less"
|
||||
| "more"
|
||||
| "wc"
|
||||
| "ls"
|
||||
| "find"
|
||||
| "grep"
|
||||
| "rg"
|
||||
| "awk"
|
||||
| "sed"
|
||||
| "echo"
|
||||
| "printf"
|
||||
| "which"
|
||||
| "where"
|
||||
| "whoami"
|
||||
| "pwd"
|
||||
| "env"
|
||||
| "printenv"
|
||||
| "date"
|
||||
| "cal"
|
||||
| "df"
|
||||
| "du"
|
||||
| "free"
|
||||
| "uptime"
|
||||
| "uname"
|
||||
| "file"
|
||||
| "stat"
|
||||
| "diff"
|
||||
| "sort"
|
||||
| "uniq"
|
||||
| "tr"
|
||||
| "cut"
|
||||
| "paste"
|
||||
| "tee"
|
||||
| "xargs"
|
||||
| "test"
|
||||
| "true"
|
||||
| "false"
|
||||
| "type"
|
||||
| "readlink"
|
||||
| "realpath"
|
||||
| "basename"
|
||||
| "dirname"
|
||||
| "sha256sum"
|
||||
| "md5sum"
|
||||
| "b3sum"
|
||||
| "xxd"
|
||||
| "hexdump"
|
||||
| "od"
|
||||
| "strings"
|
||||
| "tree"
|
||||
| "jq"
|
||||
| "yq"
|
||||
| "python3"
|
||||
| "python"
|
||||
| "node"
|
||||
| "ruby"
|
||||
| "cargo"
|
||||
| "rustc"
|
||||
| "git"
|
||||
| "gh"
|
||||
) && !command.contains("-i ")
|
||||
&& !command.contains("--in-place")
|
||||
&& !command.contains(" > ")
|
||||
&& !command.contains(" >> ")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_enforcer(mode: PermissionMode) -> PermissionEnforcer {
|
||||
let policy = PermissionPolicy::new(mode);
|
||||
PermissionEnforcer::new(policy)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allow_mode_permits_everything() {
|
||||
let enforcer = make_enforcer(PermissionMode::Allow);
|
||||
assert!(enforcer.is_allowed("bash", ""));
|
||||
assert!(enforcer.is_allowed("write_file", ""));
|
||||
assert!(enforcer.is_allowed("edit_file", ""));
|
||||
assert_eq!(
|
||||
enforcer.check_file_write("/outside/path", "/workspace"),
|
||||
EnforcementResult::Allowed
|
||||
);
|
||||
assert_eq!(enforcer.check_bash("rm -rf /"), EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_denies_writes() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("read_file", PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("grep_search", PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||
|
||||
let enforcer = PermissionEnforcer::new(policy);
|
||||
assert!(enforcer.is_allowed("read_file", ""));
|
||||
assert!(enforcer.is_allowed("grep_search", ""));
|
||||
|
||||
// write_file requires WorkspaceWrite but we're in ReadOnly
|
||||
let result = enforcer.check("write_file", "");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
|
||||
let result = enforcer.check_file_write("/workspace/file.rs", "/workspace");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_allows_read_commands() {
|
||||
let enforcer = make_enforcer(PermissionMode::ReadOnly);
|
||||
assert_eq!(
|
||||
enforcer.check_bash("cat src/main.rs"),
|
||||
EnforcementResult::Allowed
|
||||
);
|
||||
assert_eq!(
|
||||
enforcer.check_bash("grep -r 'pattern' ."),
|
||||
EnforcementResult::Allowed
|
||||
);
|
||||
assert_eq!(enforcer.check_bash("ls -la"), EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_denies_write_commands() {
|
||||
let enforcer = make_enforcer(PermissionMode::ReadOnly);
|
||||
let result = enforcer.check_bash("rm file.txt");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_write_allows_within_workspace() {
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace");
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_write_denies_outside_workspace() {
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
let result = enforcer.check_file_write("/etc/passwd", "/workspace");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_mode_denies_without_prompter() {
|
||||
let enforcer = make_enforcer(PermissionMode::Prompt);
|
||||
let result = enforcer.check_bash("echo test");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
|
||||
let result = enforcer.check_file_write("/workspace/file.rs", "/workspace");
|
||||
assert!(matches!(result, EnforcementResult::Denied { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_boundary_check() {
|
||||
assert!(is_within_workspace("/workspace/src/main.rs", "/workspace"));
|
||||
assert!(is_within_workspace("/workspace", "/workspace"));
|
||||
assert!(!is_within_workspace("/etc/passwd", "/workspace"));
|
||||
assert!(!is_within_workspace("/workspacex/hack", "/workspace"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_command_heuristic() {
|
||||
assert!(is_read_only_command("cat file.txt"));
|
||||
assert!(is_read_only_command("grep pattern file"));
|
||||
assert!(is_read_only_command("git log --oneline"));
|
||||
assert!(!is_read_only_command("rm file.txt"));
|
||||
assert!(!is_read_only_command("echo test > file.txt"));
|
||||
assert!(!is_read_only_command("sed -i 's/a/b/' file"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn active_mode_returns_policy_mode() {
|
||||
// given
|
||||
let modes = [
|
||||
PermissionMode::ReadOnly,
|
||||
PermissionMode::WorkspaceWrite,
|
||||
PermissionMode::DangerFullAccess,
|
||||
PermissionMode::Prompt,
|
||||
PermissionMode::Allow,
|
||||
];
|
||||
|
||||
// when
|
||||
let active_modes: Vec<_> = modes
|
||||
.into_iter()
|
||||
.map(|mode| make_enforcer(mode).active_mode())
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(active_modes, modes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn danger_full_access_permits_file_writes_and_bash() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::DangerFullAccess);
|
||||
|
||||
// when
|
||||
let file_result = enforcer.check_file_write("/outside/workspace/file.txt", "/workspace");
|
||||
let bash_result = enforcer.check_bash("rm -rf /tmp/scratch");
|
||||
|
||||
// then
|
||||
assert_eq!(file_result, EnforcementResult::Allowed);
|
||||
assert_eq!(bash_result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn check_denied_payload_contains_tool_and_modes() {
|
||||
// given
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("write_file", PermissionMode::WorkspaceWrite);
|
||||
let enforcer = PermissionEnforcer::new(policy);
|
||||
|
||||
// when
|
||||
let result = enforcer.check("write_file", "{}");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "write_file");
|
||||
assert_eq!(active_mode, "read-only");
|
||||
assert_eq!(required_mode, "workspace-write");
|
||||
assert!(reason.contains("requires workspace-write permission"));
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_write_relative_path_resolved() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("src/main.rs", "/workspace");
|
||||
|
||||
// then
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_root_with_trailing_slash() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::WorkspaceWrite);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("/workspace/src/main.rs", "/workspace/");
|
||||
|
||||
// then
|
||||
assert_eq!(result, EnforcementResult::Allowed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_root_equality() {
|
||||
// given
|
||||
let root = "/workspace/";
|
||||
|
||||
// when
|
||||
let equal_to_root = is_within_workspace("/workspace", root);
|
||||
|
||||
// then
|
||||
assert!(equal_to_root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_full_path_prefix() {
|
||||
// given
|
||||
let full_path_command = "/usr/bin/cat Cargo.toml";
|
||||
let git_path_command = "/usr/local/bin/git status";
|
||||
|
||||
// when
|
||||
let cat_result = is_read_only_command(full_path_command);
|
||||
let git_result = is_read_only_command(git_path_command);
|
||||
|
||||
// then
|
||||
assert!(cat_result);
|
||||
assert!(git_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_redirects_block_read_only_commands() {
|
||||
// given
|
||||
let overwrite = "cat Cargo.toml > out.txt";
|
||||
let append = "echo test >> out.txt";
|
||||
|
||||
// when
|
||||
let overwrite_result = is_read_only_command(overwrite);
|
||||
let append_result = is_read_only_command(append);
|
||||
|
||||
// then
|
||||
assert!(!overwrite_result);
|
||||
assert!(!append_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_in_place_flag_blocks() {
|
||||
// given
|
||||
let interactive_python = "python -i script.py";
|
||||
let in_place_sed = "sed --in-place 's/a/b/' file.txt";
|
||||
|
||||
// when
|
||||
let interactive_result = is_read_only_command(interactive_python);
|
||||
let in_place_result = is_read_only_command(in_place_sed);
|
||||
|
||||
// then
|
||||
assert!(!interactive_result);
|
||||
assert!(!in_place_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_heuristic_empty_command() {
|
||||
// given
|
||||
let empty = "";
|
||||
let whitespace = " ";
|
||||
|
||||
// when
|
||||
let empty_result = is_read_only_command(empty);
|
||||
let whitespace_result = is_read_only_command(whitespace);
|
||||
|
||||
// then
|
||||
assert!(!empty_result);
|
||||
assert!(!whitespace_result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_mode_check_bash_denied_payload_fields() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::Prompt);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_bash("git status");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "bash");
|
||||
assert_eq!(active_mode, "prompt");
|
||||
assert_eq!(required_mode, "danger-full-access");
|
||||
assert_eq!(reason, "bash requires confirmation in prompt mode");
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_only_check_file_write_denied_payload() {
|
||||
// given
|
||||
let enforcer = make_enforcer(PermissionMode::ReadOnly);
|
||||
|
||||
// when
|
||||
let result = enforcer.check_file_write("/workspace/file.txt", "/workspace");
|
||||
|
||||
// then
|
||||
match result {
|
||||
EnforcementResult::Denied {
|
||||
tool,
|
||||
active_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
} => {
|
||||
assert_eq!(tool, "write_file");
|
||||
assert_eq!(active_mode, "read-only");
|
||||
assert_eq!(required_mode, "workspace-write");
|
||||
assert!(reason.contains("file writes are not allowed"));
|
||||
}
|
||||
other => panic!("expected denied result, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,10 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::config::RuntimePermissionRuleConfig;
|
||||
|
||||
/// Permission level assigned to a tool invocation or runtime session.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum PermissionMode {
|
||||
ReadOnly,
|
||||
@ -22,34 +27,81 @@ impl PermissionMode {
|
||||
}
|
||||
}
|
||||
|
||||
/// Hook-provided override applied before standard permission evaluation.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PermissionOverride {
|
||||
Allow,
|
||||
Deny,
|
||||
Ask,
|
||||
}
|
||||
|
||||
/// Additional permission context supplied by hooks or higher-level orchestration.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct PermissionContext {
|
||||
override_decision: Option<PermissionOverride>,
|
||||
override_reason: Option<String>,
|
||||
}
|
||||
|
||||
impl PermissionContext {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
override_decision: Option<PermissionOverride>,
|
||||
override_reason: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
override_decision,
|
||||
override_reason,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn override_decision(&self) -> Option<PermissionOverride> {
|
||||
self.override_decision
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn override_reason(&self) -> Option<&str> {
|
||||
self.override_reason.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Full authorization request presented to a permission prompt.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PermissionRequest {
|
||||
pub tool_name: String,
|
||||
pub input: String,
|
||||
pub current_mode: PermissionMode,
|
||||
pub required_mode: PermissionMode,
|
||||
pub reason: Option<String>,
|
||||
}
|
||||
|
||||
/// User-facing decision returned by a [`PermissionPrompter`].
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PermissionPromptDecision {
|
||||
Allow,
|
||||
Deny { reason: String },
|
||||
}
|
||||
|
||||
/// Prompting interface used when policy requires interactive approval.
|
||||
pub trait PermissionPrompter {
|
||||
fn decide(&mut self, request: &PermissionRequest) -> PermissionPromptDecision;
|
||||
}
|
||||
|
||||
/// Final authorization result after evaluating static rules and prompts.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PermissionOutcome {
|
||||
Allow,
|
||||
Deny { reason: String },
|
||||
}
|
||||
|
||||
/// Evaluates permission mode requirements plus allow/deny/ask rules.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PermissionPolicy {
|
||||
active_mode: PermissionMode,
|
||||
tool_requirements: BTreeMap<String, PermissionMode>,
|
||||
allow_rules: Vec<PermissionRule>,
|
||||
deny_rules: Vec<PermissionRule>,
|
||||
ask_rules: Vec<PermissionRule>,
|
||||
}
|
||||
|
||||
impl PermissionPolicy {
|
||||
@ -58,6 +110,9 @@ impl PermissionPolicy {
|
||||
Self {
|
||||
active_mode,
|
||||
tool_requirements: BTreeMap::new(),
|
||||
allow_rules: Vec::new(),
|
||||
deny_rules: Vec::new(),
|
||||
ask_rules: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -72,6 +127,26 @@ impl PermissionPolicy {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_permission_rules(mut self, config: &RuntimePermissionRuleConfig) -> Self {
|
||||
self.allow_rules = config
|
||||
.allow()
|
||||
.iter()
|
||||
.map(|rule| PermissionRule::parse(rule))
|
||||
.collect();
|
||||
self.deny_rules = config
|
||||
.deny()
|
||||
.iter()
|
||||
.map(|rule| PermissionRule::parse(rule))
|
||||
.collect();
|
||||
self.ask_rules = config
|
||||
.ask()
|
||||
.iter()
|
||||
.map(|rule| PermissionRule::parse(rule))
|
||||
.collect();
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn active_mode(&self) -> PermissionMode {
|
||||
self.active_mode
|
||||
@ -90,38 +165,121 @@ impl PermissionPolicy {
|
||||
&self,
|
||||
tool_name: &str,
|
||||
input: &str,
|
||||
mut prompter: Option<&mut dyn PermissionPrompter>,
|
||||
prompter: Option<&mut dyn PermissionPrompter>,
|
||||
) -> PermissionOutcome {
|
||||
let current_mode = self.active_mode();
|
||||
let required_mode = self.required_mode_for(tool_name);
|
||||
if current_mode == PermissionMode::Allow || current_mode >= required_mode {
|
||||
return PermissionOutcome::Allow;
|
||||
self.authorize_with_context(tool_name, input, &PermissionContext::default(), prompter)
|
||||
}
|
||||
|
||||
let request = PermissionRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
input: input.to_string(),
|
||||
#[must_use]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn authorize_with_context(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
input: &str,
|
||||
context: &PermissionContext,
|
||||
prompter: Option<&mut dyn PermissionPrompter>,
|
||||
) -> PermissionOutcome {
|
||||
if let Some(rule) = Self::find_matching_rule(&self.deny_rules, tool_name, input) {
|
||||
return PermissionOutcome::Deny {
|
||||
reason: format!(
|
||||
"Permission to use {tool_name} has been denied by rule '{}'",
|
||||
rule.raw
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
let current_mode = self.active_mode();
|
||||
let required_mode = self.required_mode_for(tool_name);
|
||||
let ask_rule = Self::find_matching_rule(&self.ask_rules, tool_name, input);
|
||||
let allow_rule = Self::find_matching_rule(&self.allow_rules, tool_name, input);
|
||||
|
||||
match context.override_decision() {
|
||||
Some(PermissionOverride::Deny) => {
|
||||
return PermissionOutcome::Deny {
|
||||
reason: context.override_reason().map_or_else(
|
||||
|| format!("tool '{tool_name}' denied by hook"),
|
||||
ToOwned::to_owned,
|
||||
),
|
||||
};
|
||||
}
|
||||
Some(PermissionOverride::Ask) => {
|
||||
let reason = context.override_reason().map_or_else(
|
||||
|| format!("tool '{tool_name}' requires approval due to hook guidance"),
|
||||
ToOwned::to_owned,
|
||||
);
|
||||
return Self::prompt_or_deny(
|
||||
tool_name,
|
||||
input,
|
||||
current_mode,
|
||||
required_mode,
|
||||
};
|
||||
Some(reason),
|
||||
prompter,
|
||||
);
|
||||
}
|
||||
Some(PermissionOverride::Allow) => {
|
||||
if let Some(rule) = ask_rule {
|
||||
let reason = format!(
|
||||
"tool '{tool_name}' requires approval due to ask rule '{}'",
|
||||
rule.raw
|
||||
);
|
||||
return Self::prompt_or_deny(
|
||||
tool_name,
|
||||
input,
|
||||
current_mode,
|
||||
required_mode,
|
||||
Some(reason),
|
||||
prompter,
|
||||
);
|
||||
}
|
||||
if allow_rule.is_some()
|
||||
|| current_mode == PermissionMode::Allow
|
||||
|| current_mode >= required_mode
|
||||
{
|
||||
return PermissionOutcome::Allow;
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
if let Some(rule) = ask_rule {
|
||||
let reason = format!(
|
||||
"tool '{tool_name}' requires approval due to ask rule '{}'",
|
||||
rule.raw
|
||||
);
|
||||
return Self::prompt_or_deny(
|
||||
tool_name,
|
||||
input,
|
||||
current_mode,
|
||||
required_mode,
|
||||
Some(reason),
|
||||
prompter,
|
||||
);
|
||||
}
|
||||
|
||||
if allow_rule.is_some()
|
||||
|| current_mode == PermissionMode::Allow
|
||||
|| current_mode >= required_mode
|
||||
{
|
||||
return PermissionOutcome::Allow;
|
||||
}
|
||||
|
||||
if current_mode == PermissionMode::Prompt
|
||||
|| (current_mode == PermissionMode::WorkspaceWrite
|
||||
&& required_mode == PermissionMode::DangerFullAccess)
|
||||
{
|
||||
return match prompter.as_mut() {
|
||||
Some(prompter) => match prompter.decide(&request) {
|
||||
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
|
||||
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
|
||||
},
|
||||
None => PermissionOutcome::Deny {
|
||||
reason: format!(
|
||||
let reason = Some(format!(
|
||||
"tool '{tool_name}' requires approval to escalate from {} to {}",
|
||||
current_mode.as_str(),
|
||||
required_mode.as_str()
|
||||
),
|
||||
},
|
||||
};
|
||||
));
|
||||
return Self::prompt_or_deny(
|
||||
tool_name,
|
||||
input,
|
||||
current_mode,
|
||||
required_mode,
|
||||
reason,
|
||||
prompter,
|
||||
);
|
||||
}
|
||||
|
||||
PermissionOutcome::Deny {
|
||||
@ -132,14 +290,191 @@ impl PermissionPolicy {
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn prompt_or_deny(
|
||||
tool_name: &str,
|
||||
input: &str,
|
||||
current_mode: PermissionMode,
|
||||
required_mode: PermissionMode,
|
||||
reason: Option<String>,
|
||||
mut prompter: Option<&mut dyn PermissionPrompter>,
|
||||
) -> PermissionOutcome {
|
||||
let request = PermissionRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
input: input.to_string(),
|
||||
current_mode,
|
||||
required_mode,
|
||||
reason: reason.clone(),
|
||||
};
|
||||
|
||||
match prompter.as_mut() {
|
||||
Some(prompter) => match prompter.decide(&request) {
|
||||
PermissionPromptDecision::Allow => PermissionOutcome::Allow,
|
||||
PermissionPromptDecision::Deny { reason } => PermissionOutcome::Deny { reason },
|
||||
},
|
||||
None => PermissionOutcome::Deny {
|
||||
reason: reason.unwrap_or_else(|| {
|
||||
format!(
|
||||
"tool '{tool_name}' requires approval to run while mode is {}",
|
||||
current_mode.as_str()
|
||||
)
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn find_matching_rule<'a>(
|
||||
rules: &'a [PermissionRule],
|
||||
tool_name: &str,
|
||||
input: &str,
|
||||
) -> Option<&'a PermissionRule> {
|
||||
rules.iter().find(|rule| rule.matches(tool_name, input))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct PermissionRule {
|
||||
raw: String,
|
||||
tool_name: String,
|
||||
matcher: PermissionRuleMatcher,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum PermissionRuleMatcher {
|
||||
Any,
|
||||
Exact(String),
|
||||
Prefix(String),
|
||||
}
|
||||
|
||||
impl PermissionRule {
|
||||
fn parse(raw: &str) -> Self {
|
||||
let trimmed = raw.trim();
|
||||
let open = find_first_unescaped(trimmed, '(');
|
||||
let close = find_last_unescaped(trimmed, ')');
|
||||
|
||||
if let (Some(open), Some(close)) = (open, close) {
|
||||
if close == trimmed.len() - 1 && open < close {
|
||||
let tool_name = trimmed[..open].trim();
|
||||
let content = &trimmed[open + 1..close];
|
||||
if !tool_name.is_empty() {
|
||||
let matcher = parse_rule_matcher(content);
|
||||
return Self {
|
||||
raw: trimmed.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
matcher,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
raw: trimmed.to_string(),
|
||||
tool_name: trimmed.to_string(),
|
||||
matcher: PermissionRuleMatcher::Any,
|
||||
}
|
||||
}
|
||||
|
||||
fn matches(&self, tool_name: &str, input: &str) -> bool {
|
||||
if self.tool_name != tool_name {
|
||||
return false;
|
||||
}
|
||||
|
||||
match &self.matcher {
|
||||
PermissionRuleMatcher::Any => true,
|
||||
PermissionRuleMatcher::Exact(expected) => {
|
||||
extract_permission_subject(input).is_some_and(|candidate| candidate == *expected)
|
||||
}
|
||||
PermissionRuleMatcher::Prefix(prefix) => extract_permission_subject(input)
|
||||
.is_some_and(|candidate| candidate.starts_with(prefix)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_rule_matcher(content: &str) -> PermissionRuleMatcher {
|
||||
let unescaped = unescape_rule_content(content.trim());
|
||||
if unescaped.is_empty() || unescaped == "*" {
|
||||
PermissionRuleMatcher::Any
|
||||
} else if let Some(prefix) = unescaped.strip_suffix(":*") {
|
||||
PermissionRuleMatcher::Prefix(prefix.to_string())
|
||||
} else {
|
||||
PermissionRuleMatcher::Exact(unescaped)
|
||||
}
|
||||
}
|
||||
|
||||
fn unescape_rule_content(content: &str) -> String {
|
||||
content
|
||||
.replace(r"\(", "(")
|
||||
.replace(r"\)", ")")
|
||||
.replace(r"\\", r"\")
|
||||
}
|
||||
|
||||
fn find_first_unescaped(value: &str, needle: char) -> Option<usize> {
|
||||
let mut escaped = false;
|
||||
for (idx, ch) in value.char_indices() {
|
||||
if ch == '\\' {
|
||||
escaped = !escaped;
|
||||
continue;
|
||||
}
|
||||
if ch == needle && !escaped {
|
||||
return Some(idx);
|
||||
}
|
||||
escaped = false;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn find_last_unescaped(value: &str, needle: char) -> Option<usize> {
|
||||
let chars = value.char_indices().collect::<Vec<_>>();
|
||||
for (pos, (idx, ch)) in chars.iter().enumerate().rev() {
|
||||
if *ch != needle {
|
||||
continue;
|
||||
}
|
||||
let mut backslashes = 0;
|
||||
for (_, prev) in chars[..pos].iter().rev() {
|
||||
if *prev == '\\' {
|
||||
backslashes += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if backslashes % 2 == 0 {
|
||||
return Some(*idx);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_permission_subject(input: &str) -> Option<String> {
|
||||
let parsed = serde_json::from_str::<Value>(input).ok();
|
||||
if let Some(Value::Object(object)) = parsed {
|
||||
for key in [
|
||||
"command",
|
||||
"path",
|
||||
"file_path",
|
||||
"filePath",
|
||||
"notebook_path",
|
||||
"notebookPath",
|
||||
"url",
|
||||
"pattern",
|
||||
"code",
|
||||
"message",
|
||||
] {
|
||||
if let Some(value) = object.get(key).and_then(Value::as_str) {
|
||||
return Some(value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(!input.trim().is_empty()).then(|| input.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
PermissionMode, PermissionOutcome, PermissionPolicy, PermissionPromptDecision,
|
||||
PermissionPrompter, PermissionRequest,
|
||||
PermissionContext, PermissionMode, PermissionOutcome, PermissionOverride, PermissionPolicy,
|
||||
PermissionPromptDecision, PermissionPrompter, PermissionRequest,
|
||||
};
|
||||
use crate::config::RuntimePermissionRuleConfig;
|
||||
|
||||
struct RecordingPrompter {
|
||||
seen: Vec<PermissionRequest>,
|
||||
@ -229,4 +564,120 @@ mod tests {
|
||||
PermissionOutcome::Deny { reason } if reason == "not now"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn applies_rule_based_denials_and_allows() {
|
||||
let rules = RuntimePermissionRuleConfig::new(
|
||||
vec!["bash(git:*)".to_string()],
|
||||
vec!["bash(rm -rf:*)".to_string()],
|
||||
Vec::new(),
|
||||
);
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
|
||||
.with_permission_rules(&rules);
|
||||
|
||||
assert_eq!(
|
||||
policy.authorize("bash", r#"{"command":"git status"}"#, None),
|
||||
PermissionOutcome::Allow
|
||||
);
|
||||
assert!(matches!(
|
||||
policy.authorize("bash", r#"{"command":"rm -rf /tmp/x"}"#, None),
|
||||
PermissionOutcome::Deny { reason } if reason.contains("denied by rule")
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ask_rules_force_prompt_even_when_mode_allows() {
|
||||
let rules = RuntimePermissionRuleConfig::new(
|
||||
Vec::new(),
|
||||
Vec::new(),
|
||||
vec!["bash(git:*)".to_string()],
|
||||
);
|
||||
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
|
||||
.with_permission_rules(&rules);
|
||||
let mut prompter = RecordingPrompter {
|
||||
seen: Vec::new(),
|
||||
allow: true,
|
||||
};
|
||||
|
||||
let outcome = policy.authorize("bash", r#"{"command":"git status"}"#, Some(&mut prompter));
|
||||
|
||||
assert_eq!(outcome, PermissionOutcome::Allow);
|
||||
assert_eq!(prompter.seen.len(), 1);
|
||||
assert!(prompter.seen[0]
|
||||
.reason
|
||||
.as_deref()
|
||||
.is_some_and(|reason| reason.contains("ask rule")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_allow_still_respects_ask_rules() {
|
||||
let rules = RuntimePermissionRuleConfig::new(
|
||||
Vec::new(),
|
||||
Vec::new(),
|
||||
vec!["bash(git:*)".to_string()],
|
||||
);
|
||||
let policy = PermissionPolicy::new(PermissionMode::ReadOnly)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess)
|
||||
.with_permission_rules(&rules);
|
||||
let context = PermissionContext::new(
|
||||
Some(PermissionOverride::Allow),
|
||||
Some("hook approved".to_string()),
|
||||
);
|
||||
let mut prompter = RecordingPrompter {
|
||||
seen: Vec::new(),
|
||||
allow: true,
|
||||
};
|
||||
|
||||
let outcome = policy.authorize_with_context(
|
||||
"bash",
|
||||
r#"{"command":"git status"}"#,
|
||||
&context,
|
||||
Some(&mut prompter),
|
||||
);
|
||||
|
||||
assert_eq!(outcome, PermissionOutcome::Allow);
|
||||
assert_eq!(prompter.seen.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_deny_short_circuits_permission_flow() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||
let context = PermissionContext::new(
|
||||
Some(PermissionOverride::Deny),
|
||||
Some("blocked by hook".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
policy.authorize_with_context("bash", "{}", &context, None),
|
||||
PermissionOutcome::Deny {
|
||||
reason: "blocked by hook".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hook_ask_forces_prompt() {
|
||||
let policy = PermissionPolicy::new(PermissionMode::DangerFullAccess)
|
||||
.with_tool_requirement("bash", PermissionMode::DangerFullAccess);
|
||||
let context = PermissionContext::new(
|
||||
Some(PermissionOverride::Ask),
|
||||
Some("hook requested confirmation".to_string()),
|
||||
);
|
||||
let mut prompter = RecordingPrompter {
|
||||
seen: Vec::new(),
|
||||
allow: true,
|
||||
};
|
||||
|
||||
let outcome = policy.authorize_with_context("bash", "{}", &context, Some(&mut prompter));
|
||||
|
||||
assert_eq!(outcome, PermissionOutcome::Allow);
|
||||
assert_eq!(prompter.seen.len(), 1);
|
||||
assert_eq!(
|
||||
prompter.seen[0].reason.as_deref(),
|
||||
Some("hook requested confirmation")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
533
crates/runtime/src/plugin_lifecycle.rs
Normal file
533
crates/runtime/src/plugin_lifecycle.rs
Normal file
@ -0,0 +1,533 @@
|
||||
#![allow(clippy::redundant_closure_for_method_calls)]
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::config::RuntimePluginConfig;
|
||||
use crate::mcp_tool_bridge::{McpResourceInfo, McpToolInfo};
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
pub type ToolInfo = McpToolInfo;
|
||||
pub type ResourceInfo = McpResourceInfo;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ServerStatus {
|
||||
Healthy,
|
||||
Degraded,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ServerStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Healthy => write!(f, "healthy"),
|
||||
Self::Degraded => write!(f, "degraded"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ServerHealth {
|
||||
pub server_name: String,
|
||||
pub status: ServerStatus,
|
||||
pub capabilities: Vec<String>,
|
||||
pub last_error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "state")]
|
||||
pub enum PluginState {
|
||||
Unconfigured,
|
||||
Validated,
|
||||
Starting,
|
||||
Healthy,
|
||||
Degraded {
|
||||
healthy_servers: Vec<String>,
|
||||
failed_servers: Vec<ServerHealth>,
|
||||
},
|
||||
Failed {
|
||||
reason: String,
|
||||
},
|
||||
ShuttingDown,
|
||||
Stopped,
|
||||
}
|
||||
|
||||
impl PluginState {
|
||||
#[must_use]
|
||||
pub fn from_servers(servers: &[ServerHealth]) -> Self {
|
||||
if servers.is_empty() {
|
||||
return Self::Failed {
|
||||
reason: "no servers available".to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
let healthy_servers = servers
|
||||
.iter()
|
||||
.filter(|server| server.status != ServerStatus::Failed)
|
||||
.map(|server| server.server_name.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let failed_servers = servers
|
||||
.iter()
|
||||
.filter(|server| server.status == ServerStatus::Failed)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
let has_degraded_server = servers
|
||||
.iter()
|
||||
.any(|server| server.status == ServerStatus::Degraded);
|
||||
|
||||
if failed_servers.is_empty() && !has_degraded_server {
|
||||
Self::Healthy
|
||||
} else if healthy_servers.is_empty() {
|
||||
Self::Failed {
|
||||
reason: format!("all {} servers failed", failed_servers.len()),
|
||||
}
|
||||
} else {
|
||||
Self::Degraded {
|
||||
healthy_servers,
|
||||
failed_servers,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PluginState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Unconfigured => write!(f, "unconfigured"),
|
||||
Self::Validated => write!(f, "validated"),
|
||||
Self::Starting => write!(f, "starting"),
|
||||
Self::Healthy => write!(f, "healthy"),
|
||||
Self::Degraded { .. } => write!(f, "degraded"),
|
||||
Self::Failed { .. } => write!(f, "failed"),
|
||||
Self::ShuttingDown => write!(f, "shutting_down"),
|
||||
Self::Stopped => write!(f, "stopped"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct PluginHealthcheck {
|
||||
pub plugin_name: String,
|
||||
pub state: PluginState,
|
||||
pub servers: Vec<ServerHealth>,
|
||||
pub last_check: u64,
|
||||
}
|
||||
|
||||
impl PluginHealthcheck {
|
||||
#[must_use]
|
||||
pub fn new(plugin_name: impl Into<String>, servers: Vec<ServerHealth>) -> Self {
|
||||
let state = PluginState::from_servers(&servers);
|
||||
Self {
|
||||
plugin_name: plugin_name.into(),
|
||||
state,
|
||||
servers,
|
||||
last_check: now_secs(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn degraded_mode(&self, discovery: &DiscoveryResult) -> Option<DegradedMode> {
|
||||
match &self.state {
|
||||
PluginState::Degraded {
|
||||
healthy_servers,
|
||||
failed_servers,
|
||||
} => Some(DegradedMode {
|
||||
available_tools: discovery
|
||||
.tools
|
||||
.iter()
|
||||
.map(|tool| tool.name.clone())
|
||||
.collect(),
|
||||
unavailable_tools: failed_servers
|
||||
.iter()
|
||||
.flat_map(|server| server.capabilities.iter().cloned())
|
||||
.collect(),
|
||||
reason: format!(
|
||||
"{} servers healthy, {} servers failed",
|
||||
healthy_servers.len(),
|
||||
failed_servers.len()
|
||||
),
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiscoveryResult {
|
||||
pub tools: Vec<ToolInfo>,
|
||||
pub resources: Vec<ResourceInfo>,
|
||||
pub partial: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct DegradedMode {
|
||||
pub available_tools: Vec<String>,
|
||||
pub unavailable_tools: Vec<String>,
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
impl DegradedMode {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
available_tools: Vec<String>,
|
||||
unavailable_tools: Vec<String>,
|
||||
reason: impl Into<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
available_tools,
|
||||
unavailable_tools,
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PluginLifecycleEvent {
|
||||
ConfigValidated,
|
||||
StartupHealthy,
|
||||
StartupDegraded,
|
||||
StartupFailed,
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PluginLifecycleEvent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ConfigValidated => write!(f, "config_validated"),
|
||||
Self::StartupHealthy => write!(f, "startup_healthy"),
|
||||
Self::StartupDegraded => write!(f, "startup_degraded"),
|
||||
Self::StartupFailed => write!(f, "startup_failed"),
|
||||
Self::Shutdown => write!(f, "shutdown"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait PluginLifecycle {
|
||||
fn validate_config(&self, config: &RuntimePluginConfig) -> Result<(), String>;
|
||||
fn healthcheck(&self) -> PluginHealthcheck;
|
||||
fn discover(&self) -> DiscoveryResult;
|
||||
fn shutdown(&mut self) -> Result<(), String>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MockPluginLifecycle {
|
||||
plugin_name: String,
|
||||
valid_config: bool,
|
||||
healthcheck: PluginHealthcheck,
|
||||
discovery: DiscoveryResult,
|
||||
shutdown_error: Option<String>,
|
||||
shutdown_called: bool,
|
||||
}
|
||||
|
||||
impl MockPluginLifecycle {
|
||||
fn new(
|
||||
plugin_name: &str,
|
||||
valid_config: bool,
|
||||
servers: Vec<ServerHealth>,
|
||||
discovery: DiscoveryResult,
|
||||
shutdown_error: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
plugin_name: plugin_name.to_string(),
|
||||
valid_config,
|
||||
healthcheck: PluginHealthcheck::new(plugin_name, servers),
|
||||
discovery,
|
||||
shutdown_error,
|
||||
shutdown_called: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PluginLifecycle for MockPluginLifecycle {
|
||||
fn validate_config(&self, _config: &RuntimePluginConfig) -> Result<(), String> {
|
||||
if self.valid_config {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!(
|
||||
"plugin `{}` failed configuration validation",
|
||||
self.plugin_name
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn healthcheck(&self) -> PluginHealthcheck {
|
||||
if self.shutdown_called {
|
||||
PluginHealthcheck {
|
||||
plugin_name: self.plugin_name.clone(),
|
||||
state: PluginState::Stopped,
|
||||
servers: self.healthcheck.servers.clone(),
|
||||
last_check: now_secs(),
|
||||
}
|
||||
} else {
|
||||
self.healthcheck.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn discover(&self) -> DiscoveryResult {
|
||||
self.discovery.clone()
|
||||
}
|
||||
|
||||
fn shutdown(&mut self) -> Result<(), String> {
|
||||
if let Some(error) = &self.shutdown_error {
|
||||
return Err(error.clone());
|
||||
}
|
||||
|
||||
self.shutdown_called = true;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn healthy_server(name: &str, capabilities: &[&str]) -> ServerHealth {
|
||||
ServerHealth {
|
||||
server_name: name.to_string(),
|
||||
status: ServerStatus::Healthy,
|
||||
capabilities: capabilities
|
||||
.iter()
|
||||
.map(|capability| capability.to_string())
|
||||
.collect(),
|
||||
last_error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn failed_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth {
|
||||
ServerHealth {
|
||||
server_name: name.to_string(),
|
||||
status: ServerStatus::Failed,
|
||||
capabilities: capabilities
|
||||
.iter()
|
||||
.map(|capability| capability.to_string())
|
||||
.collect(),
|
||||
last_error: Some(error.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn degraded_server(name: &str, capabilities: &[&str], error: &str) -> ServerHealth {
|
||||
ServerHealth {
|
||||
server_name: name.to_string(),
|
||||
status: ServerStatus::Degraded,
|
||||
capabilities: capabilities
|
||||
.iter()
|
||||
.map(|capability| capability.to_string())
|
||||
.collect(),
|
||||
last_error: Some(error.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn tool(name: &str) -> ToolInfo {
|
||||
ToolInfo {
|
||||
name: name.to_string(),
|
||||
description: Some(format!("{name} tool")),
|
||||
input_schema: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn resource(name: &str, uri: &str) -> ResourceInfo {
|
||||
ResourceInfo {
|
||||
uri: uri.to_string(),
|
||||
name: name.to_string(),
|
||||
description: Some(format!("{name} resource")),
|
||||
mime_type: Some("application/json".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_lifecycle_happy_path() {
|
||||
// given
|
||||
let mut lifecycle = MockPluginLifecycle::new(
|
||||
"healthy-plugin",
|
||||
true,
|
||||
vec![
|
||||
healthy_server("alpha", &["search", "read"]),
|
||||
healthy_server("beta", &["write"]),
|
||||
],
|
||||
DiscoveryResult {
|
||||
tools: vec![tool("search"), tool("read"), tool("write")],
|
||||
resources: vec![resource("docs", "file:///docs")],
|
||||
partial: false,
|
||||
},
|
||||
None,
|
||||
);
|
||||
let config = RuntimePluginConfig::default();
|
||||
|
||||
// when
|
||||
let validation = lifecycle.validate_config(&config);
|
||||
let healthcheck = lifecycle.healthcheck();
|
||||
let discovery = lifecycle.discover();
|
||||
let shutdown = lifecycle.shutdown();
|
||||
let post_shutdown = lifecycle.healthcheck();
|
||||
|
||||
// then
|
||||
assert_eq!(validation, Ok(()));
|
||||
assert_eq!(healthcheck.state, PluginState::Healthy);
|
||||
assert_eq!(healthcheck.plugin_name, "healthy-plugin");
|
||||
assert_eq!(discovery.tools.len(), 3);
|
||||
assert_eq!(discovery.resources.len(), 1);
|
||||
assert!(!discovery.partial);
|
||||
assert_eq!(shutdown, Ok(()));
|
||||
assert_eq!(post_shutdown.state, PluginState::Stopped);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_startup_when_one_of_three_servers_fails() {
|
||||
// given
|
||||
let lifecycle = MockPluginLifecycle::new(
|
||||
"degraded-plugin",
|
||||
true,
|
||||
vec![
|
||||
healthy_server("alpha", &["search"]),
|
||||
failed_server("beta", &["write"], "connection refused"),
|
||||
healthy_server("gamma", &["read"]),
|
||||
],
|
||||
DiscoveryResult {
|
||||
tools: vec![tool("search"), tool("read")],
|
||||
resources: vec![resource("alpha-docs", "file:///alpha")],
|
||||
partial: true,
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let healthcheck = lifecycle.healthcheck();
|
||||
let discovery = lifecycle.discover();
|
||||
let degraded_mode = healthcheck
|
||||
.degraded_mode(&discovery)
|
||||
.expect("degraded startup should expose degraded mode");
|
||||
|
||||
// then
|
||||
match healthcheck.state {
|
||||
PluginState::Degraded {
|
||||
healthy_servers,
|
||||
failed_servers,
|
||||
} => {
|
||||
assert_eq!(
|
||||
healthy_servers,
|
||||
vec!["alpha".to_string(), "gamma".to_string()]
|
||||
);
|
||||
assert_eq!(failed_servers.len(), 1);
|
||||
assert_eq!(failed_servers[0].server_name, "beta");
|
||||
assert_eq!(
|
||||
failed_servers[0].last_error.as_deref(),
|
||||
Some("connection refused")
|
||||
);
|
||||
}
|
||||
other => panic!("expected degraded state, got {other:?}"),
|
||||
}
|
||||
assert!(discovery.partial);
|
||||
assert_eq!(
|
||||
degraded_mode.available_tools,
|
||||
vec!["search".to_string(), "read".to_string()]
|
||||
);
|
||||
assert_eq!(degraded_mode.unavailable_tools, vec!["write".to_string()]);
|
||||
assert_eq!(degraded_mode.reason, "2 servers healthy, 1 servers failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_server_status_keeps_server_usable() {
|
||||
// given
|
||||
let lifecycle = MockPluginLifecycle::new(
|
||||
"soft-degraded-plugin",
|
||||
true,
|
||||
vec![
|
||||
healthy_server("alpha", &["search"]),
|
||||
degraded_server("beta", &["write"], "high latency"),
|
||||
],
|
||||
DiscoveryResult {
|
||||
tools: vec![tool("search"), tool("write")],
|
||||
resources: Vec::new(),
|
||||
partial: true,
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let healthcheck = lifecycle.healthcheck();
|
||||
|
||||
// then
|
||||
match healthcheck.state {
|
||||
PluginState::Degraded {
|
||||
healthy_servers,
|
||||
failed_servers,
|
||||
} => {
|
||||
assert_eq!(
|
||||
healthy_servers,
|
||||
vec!["alpha".to_string(), "beta".to_string()]
|
||||
);
|
||||
assert!(failed_servers.is_empty());
|
||||
}
|
||||
other => panic!("expected degraded state, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn complete_failure_when_all_servers_fail() {
|
||||
// given
|
||||
let lifecycle = MockPluginLifecycle::new(
|
||||
"failed-plugin",
|
||||
true,
|
||||
vec![
|
||||
failed_server("alpha", &["search"], "timeout"),
|
||||
failed_server("beta", &["read"], "handshake failed"),
|
||||
],
|
||||
DiscoveryResult {
|
||||
tools: Vec::new(),
|
||||
resources: Vec::new(),
|
||||
partial: false,
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let healthcheck = lifecycle.healthcheck();
|
||||
let discovery = lifecycle.discover();
|
||||
|
||||
// then
|
||||
match &healthcheck.state {
|
||||
PluginState::Failed { reason } => {
|
||||
assert_eq!(reason, "all 2 servers failed");
|
||||
}
|
||||
other => panic!("expected failed state, got {other:?}"),
|
||||
}
|
||||
assert!(!discovery.partial);
|
||||
assert!(discovery.tools.is_empty());
|
||||
assert!(discovery.resources.is_empty());
|
||||
assert!(healthcheck.degraded_mode(&discovery).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn graceful_shutdown() {
|
||||
// given
|
||||
let mut lifecycle = MockPluginLifecycle::new(
|
||||
"shutdown-plugin",
|
||||
true,
|
||||
vec![healthy_server("alpha", &["search"])],
|
||||
DiscoveryResult {
|
||||
tools: vec![tool("search")],
|
||||
resources: Vec::new(),
|
||||
partial: false,
|
||||
},
|
||||
None,
|
||||
);
|
||||
|
||||
// when
|
||||
let shutdown = lifecycle.shutdown();
|
||||
let post_shutdown = lifecycle.healthcheck();
|
||||
|
||||
// then
|
||||
assert_eq!(shutdown, Ok(()));
|
||||
assert_eq!(PluginLifecycleEvent::Shutdown.to_string(), "shutdown");
|
||||
assert_eq!(post_shutdown.state, PluginState::Stopped);
|
||||
}
|
||||
}
|
||||
581
crates/runtime/src/policy_engine.rs
Normal file
581
crates/runtime/src/policy_engine.rs
Normal file
@ -0,0 +1,581 @@
|
||||
use std::time::Duration;
|
||||
|
||||
pub type GreenLevel = u8;
|
||||
|
||||
const STALE_BRANCH_THRESHOLD: Duration = Duration::from_secs(60 * 60);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PolicyRule {
|
||||
pub name: String,
|
||||
pub condition: PolicyCondition,
|
||||
pub action: PolicyAction,
|
||||
pub priority: u32,
|
||||
}
|
||||
|
||||
impl PolicyRule {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
name: impl Into<String>,
|
||||
condition: PolicyCondition,
|
||||
action: PolicyAction,
|
||||
priority: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
condition,
|
||||
action,
|
||||
priority,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn matches(&self, context: &LaneContext) -> bool {
|
||||
self.condition.matches(context)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PolicyCondition {
|
||||
And(Vec<PolicyCondition>),
|
||||
Or(Vec<PolicyCondition>),
|
||||
GreenAt { level: GreenLevel },
|
||||
StaleBranch,
|
||||
StartupBlocked,
|
||||
LaneCompleted,
|
||||
LaneReconciled,
|
||||
ReviewPassed,
|
||||
ScopedDiff,
|
||||
TimedOut { duration: Duration },
|
||||
}
|
||||
|
||||
impl PolicyCondition {
|
||||
#[must_use]
|
||||
pub fn matches(&self, context: &LaneContext) -> bool {
|
||||
match self {
|
||||
Self::And(conditions) => conditions
|
||||
.iter()
|
||||
.all(|condition| condition.matches(context)),
|
||||
Self::Or(conditions) => conditions
|
||||
.iter()
|
||||
.any(|condition| condition.matches(context)),
|
||||
Self::GreenAt { level } => context.green_level >= *level,
|
||||
Self::StaleBranch => context.branch_freshness >= STALE_BRANCH_THRESHOLD,
|
||||
Self::StartupBlocked => context.blocker == LaneBlocker::Startup,
|
||||
Self::LaneCompleted => context.completed,
|
||||
Self::LaneReconciled => context.reconciled,
|
||||
Self::ReviewPassed => context.review_status == ReviewStatus::Approved,
|
||||
Self::ScopedDiff => context.diff_scope == DiffScope::Scoped,
|
||||
Self::TimedOut { duration } => context.branch_freshness >= *duration,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PolicyAction {
|
||||
MergeToDev,
|
||||
MergeForward,
|
||||
RecoverOnce,
|
||||
Escalate { reason: String },
|
||||
CloseoutLane,
|
||||
CleanupSession,
|
||||
Reconcile { reason: ReconcileReason },
|
||||
Notify { channel: String },
|
||||
Block { reason: String },
|
||||
Chain(Vec<PolicyAction>),
|
||||
}
|
||||
|
||||
/// Why a lane was reconciled without further action.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ReconcileReason {
|
||||
/// Branch already merged into main — no PR needed.
|
||||
AlreadyMerged,
|
||||
/// Work superseded by another lane or direct commit.
|
||||
Superseded,
|
||||
/// PR would be empty — all changes already landed.
|
||||
EmptyDiff,
|
||||
/// Lane manually closed by operator.
|
||||
ManualClose,
|
||||
}
|
||||
|
||||
impl PolicyAction {
|
||||
fn flatten_into(&self, actions: &mut Vec<PolicyAction>) {
|
||||
match self {
|
||||
Self::Chain(chained) => {
|
||||
for action in chained {
|
||||
action.flatten_into(actions);
|
||||
}
|
||||
}
|
||||
_ => actions.push(self.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LaneBlocker {
|
||||
None,
|
||||
Startup,
|
||||
External,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReviewStatus {
|
||||
Pending,
|
||||
Approved,
|
||||
Rejected,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum DiffScope {
|
||||
Full,
|
||||
Scoped,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LaneContext {
|
||||
pub lane_id: String,
|
||||
pub green_level: GreenLevel,
|
||||
pub branch_freshness: Duration,
|
||||
pub blocker: LaneBlocker,
|
||||
pub review_status: ReviewStatus,
|
||||
pub diff_scope: DiffScope,
|
||||
pub completed: bool,
|
||||
pub reconciled: bool,
|
||||
}
|
||||
|
||||
impl LaneContext {
|
||||
#[must_use]
|
||||
pub fn new(
|
||||
lane_id: impl Into<String>,
|
||||
green_level: GreenLevel,
|
||||
branch_freshness: Duration,
|
||||
blocker: LaneBlocker,
|
||||
review_status: ReviewStatus,
|
||||
diff_scope: DiffScope,
|
||||
completed: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
lane_id: lane_id.into(),
|
||||
green_level,
|
||||
branch_freshness,
|
||||
blocker,
|
||||
review_status,
|
||||
diff_scope,
|
||||
completed,
|
||||
reconciled: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a lane context that is already reconciled (no further action needed).
|
||||
#[must_use]
|
||||
pub fn reconciled(lane_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
lane_id: lane_id.into(),
|
||||
green_level: 0,
|
||||
branch_freshness: Duration::from_secs(0),
|
||||
blocker: LaneBlocker::None,
|
||||
review_status: ReviewStatus::Pending,
|
||||
diff_scope: DiffScope::Full,
|
||||
completed: true,
|
||||
reconciled: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PolicyEngine {
|
||||
rules: Vec<PolicyRule>,
|
||||
}
|
||||
|
||||
impl PolicyEngine {
|
||||
#[must_use]
|
||||
pub fn new(mut rules: Vec<PolicyRule>) -> Self {
|
||||
rules.sort_by_key(|rule| rule.priority);
|
||||
Self { rules }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn rules(&self) -> &[PolicyRule] {
|
||||
&self.rules
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn evaluate(&self, context: &LaneContext) -> Vec<PolicyAction> {
|
||||
evaluate(self, context)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn evaluate(engine: &PolicyEngine, context: &LaneContext) -> Vec<PolicyAction> {
|
||||
let mut actions = Vec::new();
|
||||
for rule in &engine.rules {
|
||||
if rule.matches(context) {
|
||||
rule.action.flatten_into(&mut actions);
|
||||
}
|
||||
}
|
||||
actions
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::{
|
||||
evaluate, DiffScope, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, PolicyEngine,
|
||||
PolicyRule, ReconcileReason, ReviewStatus, STALE_BRANCH_THRESHOLD,
|
||||
};
|
||||
|
||||
fn default_context() -> LaneContext {
|
||||
LaneContext::new(
|
||||
"lane-7",
|
||||
0,
|
||||
Duration::from_secs(0),
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_to_dev_rule_fires_for_green_scoped_reviewed_lane() {
|
||||
// given
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"merge-to-dev",
|
||||
PolicyCondition::And(vec![
|
||||
PolicyCondition::GreenAt { level: 2 },
|
||||
PolicyCondition::ScopedDiff,
|
||||
PolicyCondition::ReviewPassed,
|
||||
]),
|
||||
PolicyAction::MergeToDev,
|
||||
20,
|
||||
)]);
|
||||
let context = LaneContext::new(
|
||||
"lane-7",
|
||||
3,
|
||||
Duration::from_secs(5),
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Approved,
|
||||
DiffScope::Scoped,
|
||||
false,
|
||||
);
|
||||
|
||||
// when
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// then
|
||||
assert_eq!(actions, vec![PolicyAction::MergeToDev]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_branch_rule_fires_at_threshold() {
|
||||
// given
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"merge-forward",
|
||||
PolicyCondition::StaleBranch,
|
||||
PolicyAction::MergeForward,
|
||||
10,
|
||||
)]);
|
||||
let context = LaneContext::new(
|
||||
"lane-7",
|
||||
1,
|
||||
STALE_BRANCH_THRESHOLD,
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
false,
|
||||
);
|
||||
|
||||
// when
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// then
|
||||
assert_eq!(actions, vec![PolicyAction::MergeForward]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn startup_blocked_rule_recovers_then_escalates() {
|
||||
// given
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"startup-recovery",
|
||||
PolicyCondition::StartupBlocked,
|
||||
PolicyAction::Chain(vec![
|
||||
PolicyAction::RecoverOnce,
|
||||
PolicyAction::Escalate {
|
||||
reason: "startup remained blocked".to_string(),
|
||||
},
|
||||
]),
|
||||
15,
|
||||
)]);
|
||||
let context = LaneContext::new(
|
||||
"lane-7",
|
||||
0,
|
||||
Duration::from_secs(0),
|
||||
LaneBlocker::Startup,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
false,
|
||||
);
|
||||
|
||||
// when
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![
|
||||
PolicyAction::RecoverOnce,
|
||||
PolicyAction::Escalate {
|
||||
reason: "startup remained blocked".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn completed_lane_rule_closes_out_and_cleans_up() {
|
||||
// given
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"lane-closeout",
|
||||
PolicyCondition::LaneCompleted,
|
||||
PolicyAction::Chain(vec![
|
||||
PolicyAction::CloseoutLane,
|
||||
PolicyAction::CleanupSession,
|
||||
]),
|
||||
30,
|
||||
)]);
|
||||
let context = LaneContext::new(
|
||||
"lane-7",
|
||||
0,
|
||||
Duration::from_secs(0),
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
true,
|
||||
);
|
||||
|
||||
// when
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![PolicyAction::CloseoutLane, PolicyAction::CleanupSession]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matching_rules_are_returned_in_priority_order_with_stable_ties() {
|
||||
// given
|
||||
let engine = PolicyEngine::new(vec![
|
||||
PolicyRule::new(
|
||||
"late-cleanup",
|
||||
PolicyCondition::And(vec![]),
|
||||
PolicyAction::CleanupSession,
|
||||
30,
|
||||
),
|
||||
PolicyRule::new(
|
||||
"first-notify",
|
||||
PolicyCondition::And(vec![]),
|
||||
PolicyAction::Notify {
|
||||
channel: "ops".to_string(),
|
||||
},
|
||||
10,
|
||||
),
|
||||
PolicyRule::new(
|
||||
"second-notify",
|
||||
PolicyCondition::And(vec![]),
|
||||
PolicyAction::Notify {
|
||||
channel: "review".to_string(),
|
||||
},
|
||||
10,
|
||||
),
|
||||
PolicyRule::new(
|
||||
"merge",
|
||||
PolicyCondition::And(vec![]),
|
||||
PolicyAction::MergeToDev,
|
||||
20,
|
||||
),
|
||||
]);
|
||||
let context = default_context();
|
||||
|
||||
// when
|
||||
let actions = evaluate(&engine, &context);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![
|
||||
PolicyAction::Notify {
|
||||
channel: "ops".to_string(),
|
||||
},
|
||||
PolicyAction::Notify {
|
||||
channel: "review".to_string(),
|
||||
},
|
||||
PolicyAction::MergeToDev,
|
||||
PolicyAction::CleanupSession,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combinators_handle_empty_cases_and_nested_chains() {
|
||||
// given
|
||||
let engine = PolicyEngine::new(vec![
|
||||
PolicyRule::new(
|
||||
"empty-and",
|
||||
PolicyCondition::And(vec![]),
|
||||
PolicyAction::Notify {
|
||||
channel: "orchestrator".to_string(),
|
||||
},
|
||||
5,
|
||||
),
|
||||
PolicyRule::new(
|
||||
"empty-or",
|
||||
PolicyCondition::Or(vec![]),
|
||||
PolicyAction::Block {
|
||||
reason: "should not fire".to_string(),
|
||||
},
|
||||
10,
|
||||
),
|
||||
PolicyRule::new(
|
||||
"nested",
|
||||
PolicyCondition::Or(vec![
|
||||
PolicyCondition::StartupBlocked,
|
||||
PolicyCondition::And(vec![
|
||||
PolicyCondition::GreenAt { level: 2 },
|
||||
PolicyCondition::TimedOut {
|
||||
duration: Duration::from_secs(5),
|
||||
},
|
||||
]),
|
||||
]),
|
||||
PolicyAction::Chain(vec![
|
||||
PolicyAction::Notify {
|
||||
channel: "alerts".to_string(),
|
||||
},
|
||||
PolicyAction::Chain(vec![
|
||||
PolicyAction::MergeForward,
|
||||
PolicyAction::CleanupSession,
|
||||
]),
|
||||
]),
|
||||
15,
|
||||
),
|
||||
]);
|
||||
let context = LaneContext::new(
|
||||
"lane-7",
|
||||
2,
|
||||
Duration::from_secs(10),
|
||||
LaneBlocker::External,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
false,
|
||||
);
|
||||
|
||||
// when
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![
|
||||
PolicyAction::Notify {
|
||||
channel: "orchestrator".to_string(),
|
||||
},
|
||||
PolicyAction::Notify {
|
||||
channel: "alerts".to_string(),
|
||||
},
|
||||
PolicyAction::MergeForward,
|
||||
PolicyAction::CleanupSession,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reconciled_lane_emits_reconcile_and_cleanup() {
|
||||
// given — a lane where branch is already merged, no PR needed, session stale
|
||||
let engine = PolicyEngine::new(vec![
|
||||
PolicyRule::new(
|
||||
"reconcile-closeout",
|
||||
PolicyCondition::LaneReconciled,
|
||||
PolicyAction::Chain(vec![
|
||||
PolicyAction::Reconcile {
|
||||
reason: ReconcileReason::AlreadyMerged,
|
||||
},
|
||||
PolicyAction::CloseoutLane,
|
||||
PolicyAction::CleanupSession,
|
||||
]),
|
||||
5,
|
||||
),
|
||||
// This rule should NOT fire — reconciled lanes are completed but we want
|
||||
// the more specific reconcile rule to handle them
|
||||
PolicyRule::new(
|
||||
"generic-closeout",
|
||||
PolicyCondition::And(vec![
|
||||
PolicyCondition::LaneCompleted,
|
||||
// Only fire if NOT reconciled
|
||||
PolicyCondition::And(vec![]),
|
||||
]),
|
||||
PolicyAction::CloseoutLane,
|
||||
30,
|
||||
),
|
||||
]);
|
||||
let context = LaneContext::reconciled("lane-9411");
|
||||
|
||||
// when
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// then — reconcile rule fires first (priority 5), then generic closeout also fires
|
||||
// because reconciled context has completed=true
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![
|
||||
PolicyAction::Reconcile {
|
||||
reason: ReconcileReason::AlreadyMerged,
|
||||
},
|
||||
PolicyAction::CloseoutLane,
|
||||
PolicyAction::CleanupSession,
|
||||
PolicyAction::CloseoutLane,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reconciled_context_has_correct_defaults() {
|
||||
let ctx = LaneContext::reconciled("test-lane");
|
||||
assert_eq!(ctx.lane_id, "test-lane");
|
||||
assert!(ctx.completed);
|
||||
assert!(ctx.reconciled);
|
||||
assert_eq!(ctx.blocker, LaneBlocker::None);
|
||||
assert_eq!(ctx.green_level, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_reconciled_lane_does_not_trigger_reconcile_rule() {
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"reconcile-closeout",
|
||||
PolicyCondition::LaneReconciled,
|
||||
PolicyAction::Reconcile {
|
||||
reason: ReconcileReason::EmptyDiff,
|
||||
},
|
||||
5,
|
||||
)]);
|
||||
// Normal completed lane — not reconciled
|
||||
let context = LaneContext::new(
|
||||
"lane-7",
|
||||
0,
|
||||
Duration::from_secs(0),
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
true,
|
||||
);
|
||||
|
||||
let actions = engine.evaluate(&context);
|
||||
assert!(actions.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reconcile_reason_variants_are_distinct() {
|
||||
assert_ne!(ReconcileReason::AlreadyMerged, ReconcileReason::Superseded);
|
||||
assert_ne!(ReconcileReason::EmptyDiff, ReconcileReason::ManualClose);
|
||||
}
|
||||
}
|
||||
@ -4,8 +4,9 @@ use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
use crate::config::{ConfigError, ConfigLoader, RuntimeConfig};
|
||||
use lsp::LspContextEnrichment;
|
||||
use crate::git_context::GitContext;
|
||||
|
||||
/// Errors raised while assembling the final system prompt.
|
||||
#[derive(Debug)]
|
||||
pub enum PromptBuildError {
|
||||
Io(std::io::Error),
|
||||
@ -35,23 +36,28 @@ impl From<ConfigError> for PromptBuildError {
|
||||
}
|
||||
}
|
||||
|
||||
/// Marker separating static prompt scaffolding from dynamic runtime context.
|
||||
pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__";
|
||||
pub const FRONTIER_MODEL_NAME: &str = "Opus 4.6";
|
||||
/// Human-readable default frontier model name embedded into generated prompts.
|
||||
pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6";
|
||||
const MAX_INSTRUCTION_FILE_CHARS: usize = 4_000;
|
||||
const MAX_TOTAL_INSTRUCTION_CHARS: usize = 12_000;
|
||||
|
||||
/// Contents of an instruction file included in prompt construction.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ContextFile {
|
||||
pub path: PathBuf,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Project-local context injected into the rendered system prompt.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProjectContext {
|
||||
pub cwd: PathBuf,
|
||||
pub current_date: String,
|
||||
pub git_status: Option<String>,
|
||||
pub git_diff: Option<String>,
|
||||
pub git_context: Option<GitContext>,
|
||||
pub instruction_files: Vec<ContextFile>,
|
||||
}
|
||||
|
||||
@ -67,6 +73,7 @@ impl ProjectContext {
|
||||
current_date: current_date.into(),
|
||||
git_status: None,
|
||||
git_diff: None,
|
||||
git_context: None,
|
||||
instruction_files,
|
||||
})
|
||||
}
|
||||
@ -78,10 +85,12 @@ impl ProjectContext {
|
||||
let mut context = Self::discover(cwd, current_date)?;
|
||||
context.git_status = read_git_status(&context.cwd);
|
||||
context.git_diff = read_git_diff(&context.cwd);
|
||||
context.git_context = GitContext::detect(&context.cwd);
|
||||
Ok(context)
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for the runtime system prompt and dynamic environment sections.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct SystemPromptBuilder {
|
||||
output_style_name: Option<String>,
|
||||
@ -131,15 +140,6 @@ impl SystemPromptBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_lsp_context(mut self, enrichment: &LspContextEnrichment) -> Self {
|
||||
if !enrichment.is_empty() {
|
||||
self.append_sections
|
||||
.push(enrichment.render_prompt_section());
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn build(&self) -> Vec<String> {
|
||||
let mut sections = Vec::new();
|
||||
@ -194,6 +194,7 @@ impl SystemPromptBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Formats each item as an indented bullet for prompt sections.
|
||||
#[must_use]
|
||||
pub fn prepend_bullets(items: Vec<String>) -> Vec<String> {
|
||||
items.into_iter().map(|item| format!(" - {item}")).collect()
|
||||
@ -211,9 +212,9 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> {
|
||||
let mut files = Vec::new();
|
||||
for dir in directories {
|
||||
for candidate in [
|
||||
dir.join("CLAW.md"),
|
||||
dir.join("CLAW.local.md"),
|
||||
dir.join(".claw").join("CLAW.md"),
|
||||
dir.join("CLAUDE.md"),
|
||||
dir.join("CLAUDE.local.md"),
|
||||
dir.join(".claw").join("CLAUDE.md"),
|
||||
dir.join(".claw").join("instructions.md"),
|
||||
] {
|
||||
push_context_file(&mut files, candidate)?;
|
||||
@ -292,7 +293,7 @@ fn render_project_context(project_context: &ProjectContext) -> String {
|
||||
];
|
||||
if !project_context.instruction_files.is_empty() {
|
||||
bullets.push(format!(
|
||||
"Claw instruction files discovered: {}.",
|
||||
"Claude instruction files discovered: {}.",
|
||||
project_context.instruction_files.len()
|
||||
));
|
||||
}
|
||||
@ -302,16 +303,32 @@ fn render_project_context(project_context: &ProjectContext) -> String {
|
||||
lines.push("Git status snapshot:".to_string());
|
||||
lines.push(status.clone());
|
||||
}
|
||||
if let Some(ref gc) = project_context.git_context {
|
||||
if !gc.recent_commits.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push("Recent commits (last 5):".to_string());
|
||||
for c in &gc.recent_commits {
|
||||
lines.push(format!(" {} {}", c.hash, c.subject));
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(diff) = &project_context.git_diff {
|
||||
lines.push(String::new());
|
||||
lines.push("Git diff snapshot:".to_string());
|
||||
lines.push(diff.clone());
|
||||
}
|
||||
if let Some(git_context) = &project_context.git_context {
|
||||
let rendered = git_context.render();
|
||||
if !rendered.is_empty() {
|
||||
lines.push(String::new());
|
||||
lines.push(rendered);
|
||||
}
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn render_instruction_files(files: &[ContextFile]) -> String {
|
||||
let mut sections = vec!["# Claw instructions".to_string()];
|
||||
let mut sections = vec!["# Claude instructions".to_string()];
|
||||
let mut remaining_chars = MAX_TOTAL_INSTRUCTION_CHARS;
|
||||
for file in files {
|
||||
if remaining_chars == 0 {
|
||||
@ -411,6 +428,7 @@ fn collapse_blank_lines(content: &str) -> String {
|
||||
result
|
||||
}
|
||||
|
||||
/// Loads config and project context, then renders the system prompt text.
|
||||
pub fn load_system_prompt(
|
||||
cwd: impl Into<PathBuf>,
|
||||
current_date: impl Into<String>,
|
||||
@ -523,24 +541,31 @@ mod tests {
|
||||
crate::test_env_lock()
|
||||
}
|
||||
|
||||
fn ensure_valid_cwd() {
|
||||
if std::env::current_dir().is_err() {
|
||||
std::env::set_current_dir(env!("CARGO_MANIFEST_DIR"))
|
||||
.expect("test cwd should be recoverable");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discovers_instruction_files_from_ancestor_chain() {
|
||||
let root = temp_dir();
|
||||
let nested = root.join("apps").join("api");
|
||||
fs::create_dir_all(nested.join(".claw")).expect("nested claw dir");
|
||||
fs::write(root.join("CLAW.md"), "root instructions").expect("write root instructions");
|
||||
fs::write(root.join("CLAW.local.md"), "local instructions")
|
||||
fs::write(root.join("CLAUDE.md"), "root instructions").expect("write root instructions");
|
||||
fs::write(root.join("CLAUDE.local.md"), "local instructions")
|
||||
.expect("write local instructions");
|
||||
fs::create_dir_all(root.join("apps")).expect("apps dir");
|
||||
fs::create_dir_all(root.join("apps").join(".claw")).expect("apps claw dir");
|
||||
fs::write(root.join("apps").join("CLAW.md"), "apps instructions")
|
||||
fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions")
|
||||
.expect("write apps instructions");
|
||||
fs::write(
|
||||
root.join("apps").join(".claw").join("instructions.md"),
|
||||
"apps dot claw instructions",
|
||||
"apps dot claude instructions",
|
||||
)
|
||||
.expect("write apps dot claw instructions");
|
||||
fs::write(nested.join(".claw").join("CLAW.md"), "nested rules")
|
||||
.expect("write apps dot claude instructions");
|
||||
fs::write(nested.join(".claw").join("CLAUDE.md"), "nested rules")
|
||||
.expect("write nested rules");
|
||||
fs::write(
|
||||
nested.join(".claw").join("instructions.md"),
|
||||
@ -561,7 +586,7 @@ mod tests {
|
||||
"root instructions",
|
||||
"local instructions",
|
||||
"apps instructions",
|
||||
"apps dot claw instructions",
|
||||
"apps dot claude instructions",
|
||||
"nested rules",
|
||||
"nested instructions"
|
||||
]
|
||||
@ -574,8 +599,8 @@ mod tests {
|
||||
let root = temp_dir();
|
||||
let nested = root.join("apps").join("api");
|
||||
fs::create_dir_all(&nested).expect("nested dir");
|
||||
fs::write(root.join("CLAW.md"), "same rules\n\n").expect("write root");
|
||||
fs::write(nested.join("CLAW.md"), "same rules\n").expect("write nested");
|
||||
fs::write(root.join("CLAUDE.md"), "same rules\n\n").expect("write root");
|
||||
fs::write(nested.join("CLAUDE.md"), "same rules\n").expect("write nested");
|
||||
|
||||
let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load");
|
||||
assert_eq!(context.instruction_files.len(), 1);
|
||||
@ -603,14 +628,15 @@ mod tests {
|
||||
#[test]
|
||||
fn displays_context_paths_compactly() {
|
||||
assert_eq!(
|
||||
display_context_path(Path::new("/tmp/project/.claw/CLAW.md")),
|
||||
"CLAW.md"
|
||||
display_context_path(Path::new("/tmp/project/.claw/CLAUDE.md")),
|
||||
"CLAUDE.md"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discover_with_git_includes_status_snapshot() {
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir");
|
||||
std::process::Command::new("git")
|
||||
@ -618,7 +644,7 @@ mod tests {
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git init should run");
|
||||
fs::write(root.join("CLAW.md"), "rules").expect("write instructions");
|
||||
fs::write(root.join("CLAUDE.md"), "rules").expect("write instructions");
|
||||
fs::write(root.join("tracked.txt"), "hello").expect("write tracked file");
|
||||
|
||||
let context =
|
||||
@ -626,16 +652,99 @@ mod tests {
|
||||
|
||||
let status = context.git_status.expect("git status should be present");
|
||||
assert!(status.contains("## No commits yet on") || status.contains("## "));
|
||||
assert!(status.contains("?? CLAW.md"));
|
||||
assert!(status.contains("?? CLAUDE.md"));
|
||||
assert!(status.contains("?? tracked.txt"));
|
||||
assert!(context.git_diff.is_none());
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discover_with_git_includes_recent_commits_and_renders_them() {
|
||||
// given: a git repo with three commits and a current branch
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir");
|
||||
std::process::Command::new("git")
|
||||
.args(["init", "--quiet", "-b", "main"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git init should run");
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.email", "tests@example.com"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git config email should run");
|
||||
std::process::Command::new("git")
|
||||
.args(["config", "user.name", "Runtime Prompt Tests"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git config name should run");
|
||||
for (file, message) in [
|
||||
("a.txt", "first commit"),
|
||||
("b.txt", "second commit"),
|
||||
("c.txt", "third commit"),
|
||||
] {
|
||||
fs::write(root.join(file), "x\n").expect("write commit file");
|
||||
std::process::Command::new("git")
|
||||
.args(["add", file])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git add should run");
|
||||
std::process::Command::new("git")
|
||||
.args(["commit", "-m", message, "--quiet"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git commit should run");
|
||||
}
|
||||
fs::write(root.join("d.txt"), "staged\n").expect("write staged file");
|
||||
std::process::Command::new("git")
|
||||
.args(["add", "d.txt"])
|
||||
.current_dir(&root)
|
||||
.status()
|
||||
.expect("git add staged should run");
|
||||
|
||||
// when: discovering project context with git auto-include
|
||||
let context =
|
||||
ProjectContext::discover_with_git(&root, "2026-03-31").expect("context should load");
|
||||
let rendered = SystemPromptBuilder::new()
|
||||
.with_os("linux", "6.8")
|
||||
.with_project_context(context.clone())
|
||||
.render();
|
||||
|
||||
// then: branch, recent commits and staged files are present in context
|
||||
let gc = context
|
||||
.git_context
|
||||
.as_ref()
|
||||
.expect("git context should be present");
|
||||
let commits: String = gc
|
||||
.recent_commits
|
||||
.iter()
|
||||
.map(|c| c.subject.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
assert!(commits.contains("first commit"));
|
||||
assert!(commits.contains("second commit"));
|
||||
assert!(commits.contains("third commit"));
|
||||
assert_eq!(gc.recent_commits.len(), 3);
|
||||
|
||||
let status = context.git_status.as_deref().expect("status snapshot");
|
||||
assert!(status.contains("## main"));
|
||||
assert!(status.contains("A d.txt"));
|
||||
|
||||
assert!(rendered.contains("Recent commits (last 5):"));
|
||||
assert!(rendered.contains("first commit"));
|
||||
assert!(rendered.contains("Git status snapshot:"));
|
||||
assert!(rendered.contains("## main"));
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discover_with_git_includes_diff_snapshot_for_tracked_changes() {
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir");
|
||||
std::process::Command::new("git")
|
||||
@ -677,10 +786,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_system_prompt_reads_claw_files_and_config() {
|
||||
fn load_system_prompt_reads_claude_files_and_config() {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(root.join(".claw")).expect("claw dir");
|
||||
fs::write(root.join("CLAW.md"), "Project rules").expect("write instructions");
|
||||
fs::write(root.join("CLAUDE.md"), "Project rules").expect("write instructions");
|
||||
fs::write(
|
||||
root.join(".claw").join("settings.json"),
|
||||
r#"{"permissionMode":"acceptEdits"}"#,
|
||||
@ -688,6 +797,7 @@ mod tests {
|
||||
.expect("write settings");
|
||||
|
||||
let _guard = env_lock();
|
||||
ensure_valid_cwd();
|
||||
let previous = std::env::current_dir().expect("cwd");
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
let original_claw_home = std::env::var("CLAW_CONFIG_HOME").ok();
|
||||
@ -719,10 +829,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_claw_code_style_sections_with_project_context() {
|
||||
fn renders_claude_code_style_sections_with_project_context() {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(root.join(".claw")).expect("claw dir");
|
||||
fs::write(root.join("CLAW.md"), "Project rules").expect("write CLAW.md");
|
||||
fs::write(root.join("CLAUDE.md"), "Project rules").expect("write CLAUDE.md");
|
||||
fs::write(
|
||||
root.join(".claw").join("settings.json"),
|
||||
r#"{"permissionMode":"acceptEdits"}"#,
|
||||
@ -743,7 +853,7 @@ mod tests {
|
||||
|
||||
assert!(prompt.contains("# System"));
|
||||
assert!(prompt.contains("# Project context"));
|
||||
assert!(prompt.contains("# Claw instructions"));
|
||||
assert!(prompt.contains("# Claude instructions"));
|
||||
assert!(prompt.contains("Project rules"));
|
||||
assert!(prompt.contains("permissionMode"));
|
||||
assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY));
|
||||
@ -760,7 +870,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn discovers_dot_claw_instructions_markdown() {
|
||||
fn discovers_dot_claude_instructions_markdown() {
|
||||
let root = temp_dir();
|
||||
let nested = root.join("apps").join("api");
|
||||
fs::create_dir_all(nested.join(".claw")).expect("nested claw dir");
|
||||
@ -785,10 +895,10 @@ mod tests {
|
||||
#[test]
|
||||
fn renders_instruction_file_metadata() {
|
||||
let rendered = render_instruction_files(&[ContextFile {
|
||||
path: PathBuf::from("/tmp/project/CLAW.md"),
|
||||
path: PathBuf::from("/tmp/project/CLAUDE.md"),
|
||||
content: "Project rules".to_string(),
|
||||
}]);
|
||||
assert!(rendered.contains("# Claw instructions"));
|
||||
assert!(rendered.contains("# Claude instructions"));
|
||||
assert!(rendered.contains("scope: /tmp/project"));
|
||||
assert!(rendered.contains("Project rules"));
|
||||
}
|
||||
|
||||
631
crates/runtime/src/recovery_recipes.rs
Normal file
631
crates/runtime/src/recovery_recipes.rs
Normal file
@ -0,0 +1,631 @@
|
||||
#![allow(clippy::cast_possible_truncation, clippy::uninlined_format_args)]
|
||||
//! Recovery recipes for common failure scenarios.
|
||||
//!
|
||||
//! Encodes known automatic recoveries for the six failure scenarios
|
||||
//! listed in ROADMAP item 8, and enforces one automatic recovery
|
||||
//! attempt before escalation. Each attempt is emitted as a structured
|
||||
//! recovery event.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::worker_boot::WorkerFailureKind;
|
||||
|
||||
/// The six failure scenarios that have known recovery recipes.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum FailureScenario {
|
||||
TrustPromptUnresolved,
|
||||
PromptMisdelivery,
|
||||
StaleBranch,
|
||||
CompileRedCrossCrate,
|
||||
McpHandshakeFailure,
|
||||
PartialPluginStartup,
|
||||
ProviderFailure,
|
||||
}
|
||||
|
||||
impl FailureScenario {
|
||||
/// Returns all known failure scenarios.
|
||||
#[must_use]
|
||||
pub fn all() -> &'static [FailureScenario] {
|
||||
&[
|
||||
Self::TrustPromptUnresolved,
|
||||
Self::PromptMisdelivery,
|
||||
Self::StaleBranch,
|
||||
Self::CompileRedCrossCrate,
|
||||
Self::McpHandshakeFailure,
|
||||
Self::PartialPluginStartup,
|
||||
Self::ProviderFailure,
|
||||
]
|
||||
}
|
||||
|
||||
/// Map a `WorkerFailureKind` to the corresponding `FailureScenario`.
|
||||
/// This is the bridge that lets recovery policy consume worker boot events.
|
||||
#[must_use]
|
||||
pub fn from_worker_failure_kind(kind: WorkerFailureKind) -> Self {
|
||||
match kind {
|
||||
WorkerFailureKind::TrustGate => Self::TrustPromptUnresolved,
|
||||
WorkerFailureKind::PromptDelivery => Self::PromptMisdelivery,
|
||||
WorkerFailureKind::Protocol => Self::McpHandshakeFailure,
|
||||
WorkerFailureKind::Provider => Self::ProviderFailure,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FailureScenario {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::TrustPromptUnresolved => write!(f, "trust_prompt_unresolved"),
|
||||
Self::PromptMisdelivery => write!(f, "prompt_misdelivery"),
|
||||
Self::StaleBranch => write!(f, "stale_branch"),
|
||||
Self::CompileRedCrossCrate => write!(f, "compile_red_cross_crate"),
|
||||
Self::McpHandshakeFailure => write!(f, "mcp_handshake_failure"),
|
||||
Self::PartialPluginStartup => write!(f, "partial_plugin_startup"),
|
||||
Self::ProviderFailure => write!(f, "provider_failure"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Individual step that can be executed as part of a recovery recipe.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RecoveryStep {
|
||||
AcceptTrustPrompt,
|
||||
RedirectPromptToAgent,
|
||||
RebaseBranch,
|
||||
CleanBuild,
|
||||
RetryMcpHandshake { timeout: u64 },
|
||||
RestartPlugin { name: String },
|
||||
RestartWorker,
|
||||
EscalateToHuman { reason: String },
|
||||
}
|
||||
|
||||
/// Policy governing what happens when automatic recovery is exhausted.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum EscalationPolicy {
|
||||
AlertHuman,
|
||||
LogAndContinue,
|
||||
Abort,
|
||||
}
|
||||
|
||||
/// A recovery recipe encodes the sequence of steps to attempt for a
|
||||
/// given failure scenario, along with the maximum number of automatic
|
||||
/// attempts and the escalation policy.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct RecoveryRecipe {
|
||||
pub scenario: FailureScenario,
|
||||
pub steps: Vec<RecoveryStep>,
|
||||
pub max_attempts: u32,
|
||||
pub escalation_policy: EscalationPolicy,
|
||||
}
|
||||
|
||||
/// Outcome of a recovery attempt.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RecoveryResult {
|
||||
Recovered {
|
||||
steps_taken: u32,
|
||||
},
|
||||
PartialRecovery {
|
||||
recovered: Vec<RecoveryStep>,
|
||||
remaining: Vec<RecoveryStep>,
|
||||
},
|
||||
EscalationRequired {
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Structured event emitted during recovery.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RecoveryEvent {
|
||||
RecoveryAttempted {
|
||||
scenario: FailureScenario,
|
||||
recipe: RecoveryRecipe,
|
||||
result: RecoveryResult,
|
||||
},
|
||||
RecoverySucceeded,
|
||||
RecoveryFailed,
|
||||
Escalated,
|
||||
}
|
||||
|
||||
/// Minimal context for tracking recovery state and emitting events.
|
||||
///
|
||||
/// Holds per-scenario attempt counts, a structured event log, and an
|
||||
/// optional simulation knob for controlling step outcomes during tests.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RecoveryContext {
|
||||
attempts: HashMap<FailureScenario, u32>,
|
||||
events: Vec<RecoveryEvent>,
|
||||
/// Optional step index at which simulated execution fails.
|
||||
/// `None` means all steps succeed.
|
||||
fail_at_step: Option<usize>,
|
||||
}
|
||||
|
||||
impl RecoveryContext {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Configure a step index at which simulated execution will fail.
|
||||
#[must_use]
|
||||
pub fn with_fail_at_step(mut self, index: usize) -> Self {
|
||||
self.fail_at_step = Some(index);
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the structured event log populated during recovery.
|
||||
#[must_use]
|
||||
pub fn events(&self) -> &[RecoveryEvent] {
|
||||
&self.events
|
||||
}
|
||||
|
||||
/// Returns the number of recovery attempts made for a scenario.
|
||||
#[must_use]
|
||||
pub fn attempt_count(&self, scenario: &FailureScenario) -> u32 {
|
||||
self.attempts.get(scenario).copied().unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the known recovery recipe for the given failure scenario.
|
||||
#[must_use]
|
||||
pub fn recipe_for(scenario: &FailureScenario) -> RecoveryRecipe {
|
||||
match scenario {
|
||||
FailureScenario::TrustPromptUnresolved => RecoveryRecipe {
|
||||
scenario: *scenario,
|
||||
steps: vec![RecoveryStep::AcceptTrustPrompt],
|
||||
max_attempts: 1,
|
||||
escalation_policy: EscalationPolicy::AlertHuman,
|
||||
},
|
||||
FailureScenario::PromptMisdelivery => RecoveryRecipe {
|
||||
scenario: *scenario,
|
||||
steps: vec![RecoveryStep::RedirectPromptToAgent],
|
||||
max_attempts: 1,
|
||||
escalation_policy: EscalationPolicy::AlertHuman,
|
||||
},
|
||||
FailureScenario::StaleBranch => RecoveryRecipe {
|
||||
scenario: *scenario,
|
||||
steps: vec![RecoveryStep::RebaseBranch, RecoveryStep::CleanBuild],
|
||||
max_attempts: 1,
|
||||
escalation_policy: EscalationPolicy::AlertHuman,
|
||||
},
|
||||
FailureScenario::CompileRedCrossCrate => RecoveryRecipe {
|
||||
scenario: *scenario,
|
||||
steps: vec![RecoveryStep::CleanBuild],
|
||||
max_attempts: 1,
|
||||
escalation_policy: EscalationPolicy::AlertHuman,
|
||||
},
|
||||
FailureScenario::McpHandshakeFailure => RecoveryRecipe {
|
||||
scenario: *scenario,
|
||||
steps: vec![RecoveryStep::RetryMcpHandshake { timeout: 5000 }],
|
||||
max_attempts: 1,
|
||||
escalation_policy: EscalationPolicy::Abort,
|
||||
},
|
||||
FailureScenario::PartialPluginStartup => RecoveryRecipe {
|
||||
scenario: *scenario,
|
||||
steps: vec![
|
||||
RecoveryStep::RestartPlugin {
|
||||
name: "stalled".to_string(),
|
||||
},
|
||||
RecoveryStep::RetryMcpHandshake { timeout: 3000 },
|
||||
],
|
||||
max_attempts: 1,
|
||||
escalation_policy: EscalationPolicy::LogAndContinue,
|
||||
},
|
||||
FailureScenario::ProviderFailure => RecoveryRecipe {
|
||||
scenario: *scenario,
|
||||
steps: vec![RecoveryStep::RestartWorker],
|
||||
max_attempts: 1,
|
||||
escalation_policy: EscalationPolicy::AlertHuman,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts automatic recovery for the given failure scenario.
|
||||
///
|
||||
/// Looks up the recipe, enforces the one-attempt-before-escalation
|
||||
/// policy, simulates step execution (controlled by the context), and
|
||||
/// emits structured [`RecoveryEvent`]s for every attempt.
|
||||
pub fn attempt_recovery(scenario: &FailureScenario, ctx: &mut RecoveryContext) -> RecoveryResult {
|
||||
let recipe = recipe_for(scenario);
|
||||
let attempt_count = ctx.attempts.entry(*scenario).or_insert(0);
|
||||
|
||||
// Enforce one automatic recovery attempt before escalation.
|
||||
if *attempt_count >= recipe.max_attempts {
|
||||
let result = RecoveryResult::EscalationRequired {
|
||||
reason: format!(
|
||||
"max recovery attempts ({}) exceeded for {}",
|
||||
recipe.max_attempts, scenario
|
||||
),
|
||||
};
|
||||
ctx.events.push(RecoveryEvent::RecoveryAttempted {
|
||||
scenario: *scenario,
|
||||
recipe,
|
||||
result: result.clone(),
|
||||
});
|
||||
ctx.events.push(RecoveryEvent::Escalated);
|
||||
return result;
|
||||
}
|
||||
|
||||
*attempt_count += 1;
|
||||
|
||||
// Execute steps, honoring the optional fail_at_step simulation.
|
||||
let fail_index = ctx.fail_at_step;
|
||||
let mut executed = Vec::new();
|
||||
let mut failed = false;
|
||||
|
||||
for (i, step) in recipe.steps.iter().enumerate() {
|
||||
if fail_index == Some(i) {
|
||||
failed = true;
|
||||
break;
|
||||
}
|
||||
executed.push(step.clone());
|
||||
}
|
||||
|
||||
let result = if failed {
|
||||
let remaining: Vec<RecoveryStep> = recipe.steps[executed.len()..].to_vec();
|
||||
if executed.is_empty() {
|
||||
RecoveryResult::EscalationRequired {
|
||||
reason: format!("recovery failed at first step for {}", scenario),
|
||||
}
|
||||
} else {
|
||||
RecoveryResult::PartialRecovery {
|
||||
recovered: executed,
|
||||
remaining,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
RecoveryResult::Recovered {
|
||||
steps_taken: recipe.steps.len() as u32,
|
||||
}
|
||||
};
|
||||
|
||||
// Emit the attempt as structured event data.
|
||||
ctx.events.push(RecoveryEvent::RecoveryAttempted {
|
||||
scenario: *scenario,
|
||||
recipe,
|
||||
result: result.clone(),
|
||||
});
|
||||
|
||||
match &result {
|
||||
RecoveryResult::Recovered { .. } => {
|
||||
ctx.events.push(RecoveryEvent::RecoverySucceeded);
|
||||
}
|
||||
RecoveryResult::PartialRecovery { .. } => {
|
||||
ctx.events.push(RecoveryEvent::RecoveryFailed);
|
||||
}
|
||||
RecoveryResult::EscalationRequired { .. } => {
|
||||
ctx.events.push(RecoveryEvent::Escalated);
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn each_scenario_has_a_matching_recipe() {
|
||||
// given
|
||||
let scenarios = FailureScenario::all();
|
||||
|
||||
// when / then
|
||||
for scenario in scenarios {
|
||||
let recipe = recipe_for(scenario);
|
||||
assert_eq!(
|
||||
recipe.scenario, *scenario,
|
||||
"recipe scenario should match requested scenario"
|
||||
);
|
||||
assert!(
|
||||
!recipe.steps.is_empty(),
|
||||
"recipe for {} should have at least one step",
|
||||
scenario
|
||||
);
|
||||
assert!(
|
||||
recipe.max_attempts >= 1,
|
||||
"recipe for {} should allow at least one attempt",
|
||||
scenario
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn successful_recovery_returns_recovered_and_emits_events() {
|
||||
// given
|
||||
let mut ctx = RecoveryContext::new();
|
||||
let scenario = FailureScenario::TrustPromptUnresolved;
|
||||
|
||||
// when
|
||||
let result = attempt_recovery(&scenario, &mut ctx);
|
||||
|
||||
// then
|
||||
assert_eq!(result, RecoveryResult::Recovered { steps_taken: 1 });
|
||||
assert_eq!(ctx.events().len(), 2);
|
||||
assert!(matches!(
|
||||
&ctx.events()[0],
|
||||
RecoveryEvent::RecoveryAttempted {
|
||||
scenario: s,
|
||||
result: r,
|
||||
..
|
||||
} if *s == FailureScenario::TrustPromptUnresolved
|
||||
&& matches!(r, RecoveryResult::Recovered { steps_taken: 1 })
|
||||
));
|
||||
assert_eq!(ctx.events()[1], RecoveryEvent::RecoverySucceeded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn escalation_after_max_attempts_exceeded() {
|
||||
// given
|
||||
let mut ctx = RecoveryContext::new();
|
||||
let scenario = FailureScenario::PromptMisdelivery;
|
||||
|
||||
// when — first attempt succeeds
|
||||
let first = attempt_recovery(&scenario, &mut ctx);
|
||||
assert!(matches!(first, RecoveryResult::Recovered { .. }));
|
||||
|
||||
// when — second attempt should escalate
|
||||
let second = attempt_recovery(&scenario, &mut ctx);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
matches!(
|
||||
&second,
|
||||
RecoveryResult::EscalationRequired { reason }
|
||||
if reason.contains("max recovery attempts")
|
||||
),
|
||||
"second attempt should require escalation, got: {second:?}"
|
||||
);
|
||||
assert_eq!(ctx.attempt_count(&scenario), 1);
|
||||
assert!(ctx
|
||||
.events()
|
||||
.iter()
|
||||
.any(|e| matches!(e, RecoveryEvent::Escalated)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_recovery_when_step_fails_midway() {
|
||||
// given — PartialPluginStartup has two steps; fail at step index 1
|
||||
let mut ctx = RecoveryContext::new().with_fail_at_step(1);
|
||||
let scenario = FailureScenario::PartialPluginStartup;
|
||||
|
||||
// when
|
||||
let result = attempt_recovery(&scenario, &mut ctx);
|
||||
|
||||
// then
|
||||
match &result {
|
||||
RecoveryResult::PartialRecovery {
|
||||
recovered,
|
||||
remaining,
|
||||
} => {
|
||||
assert_eq!(recovered.len(), 1, "one step should have succeeded");
|
||||
assert_eq!(remaining.len(), 1, "one step should remain");
|
||||
assert!(matches!(recovered[0], RecoveryStep::RestartPlugin { .. }));
|
||||
assert!(matches!(
|
||||
remaining[0],
|
||||
RecoveryStep::RetryMcpHandshake { .. }
|
||||
));
|
||||
}
|
||||
other => panic!("expected PartialRecovery, got {other:?}"),
|
||||
}
|
||||
assert!(ctx
|
||||
.events()
|
||||
.iter()
|
||||
.any(|e| matches!(e, RecoveryEvent::RecoveryFailed)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_step_failure_escalates_immediately() {
|
||||
// given — fail at step index 0
|
||||
let mut ctx = RecoveryContext::new().with_fail_at_step(0);
|
||||
let scenario = FailureScenario::CompileRedCrossCrate;
|
||||
|
||||
// when
|
||||
let result = attempt_recovery(&scenario, &mut ctx);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
matches!(
|
||||
&result,
|
||||
RecoveryResult::EscalationRequired { reason }
|
||||
if reason.contains("failed at first step")
|
||||
),
|
||||
"zero-step failure should escalate, got: {result:?}"
|
||||
);
|
||||
assert!(ctx
|
||||
.events()
|
||||
.iter()
|
||||
.any(|e| matches!(e, RecoveryEvent::Escalated)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn emitted_events_include_structured_attempt_data() {
|
||||
// given
|
||||
let mut ctx = RecoveryContext::new();
|
||||
let scenario = FailureScenario::McpHandshakeFailure;
|
||||
|
||||
// when
|
||||
let _ = attempt_recovery(&scenario, &mut ctx);
|
||||
|
||||
// then — verify the RecoveryAttempted event carries full context
|
||||
let attempted = ctx
|
||||
.events()
|
||||
.iter()
|
||||
.find(|e| matches!(e, RecoveryEvent::RecoveryAttempted { .. }))
|
||||
.expect("should have emitted RecoveryAttempted event");
|
||||
|
||||
match attempted {
|
||||
RecoveryEvent::RecoveryAttempted {
|
||||
scenario: s,
|
||||
recipe,
|
||||
result,
|
||||
} => {
|
||||
assert_eq!(*s, scenario);
|
||||
assert_eq!(recipe.scenario, scenario);
|
||||
assert!(!recipe.steps.is_empty());
|
||||
assert!(matches!(result, RecoveryResult::Recovered { .. }));
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// Verify the event is serializable as structured JSON
|
||||
let json = serde_json::to_string(&ctx.events()[0])
|
||||
.expect("recovery event should be serializable to JSON");
|
||||
assert!(
|
||||
json.contains("mcp_handshake_failure"),
|
||||
"serialized event should contain scenario name"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recovery_context_tracks_attempts_per_scenario() {
|
||||
// given
|
||||
let mut ctx = RecoveryContext::new();
|
||||
|
||||
// when
|
||||
assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 0);
|
||||
attempt_recovery(&FailureScenario::StaleBranch, &mut ctx);
|
||||
|
||||
// then
|
||||
assert_eq!(ctx.attempt_count(&FailureScenario::StaleBranch), 1);
|
||||
assert_eq!(ctx.attempt_count(&FailureScenario::PromptMisdelivery), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_branch_recipe_has_rebase_then_clean_build() {
|
||||
// given
|
||||
let recipe = recipe_for(&FailureScenario::StaleBranch);
|
||||
|
||||
// then
|
||||
assert_eq!(recipe.steps.len(), 2);
|
||||
assert_eq!(recipe.steps[0], RecoveryStep::RebaseBranch);
|
||||
assert_eq!(recipe.steps[1], RecoveryStep::CleanBuild);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn partial_plugin_startup_recipe_has_restart_then_handshake() {
|
||||
// given
|
||||
let recipe = recipe_for(&FailureScenario::PartialPluginStartup);
|
||||
|
||||
// then
|
||||
assert_eq!(recipe.steps.len(), 2);
|
||||
assert!(matches!(
|
||||
recipe.steps[0],
|
||||
RecoveryStep::RestartPlugin { .. }
|
||||
));
|
||||
assert!(matches!(
|
||||
recipe.steps[1],
|
||||
RecoveryStep::RetryMcpHandshake { timeout: 3000 }
|
||||
));
|
||||
assert_eq!(recipe.escalation_policy, EscalationPolicy::LogAndContinue);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn failure_scenario_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(
|
||||
FailureScenario::TrustPromptUnresolved,
|
||||
"trust_prompt_unresolved",
|
||||
),
|
||||
(FailureScenario::PromptMisdelivery, "prompt_misdelivery"),
|
||||
(FailureScenario::StaleBranch, "stale_branch"),
|
||||
(
|
||||
FailureScenario::CompileRedCrossCrate,
|
||||
"compile_red_cross_crate",
|
||||
),
|
||||
(
|
||||
FailureScenario::McpHandshakeFailure,
|
||||
"mcp_handshake_failure",
|
||||
),
|
||||
(
|
||||
FailureScenario::PartialPluginStartup,
|
||||
"partial_plugin_startup",
|
||||
),
|
||||
];
|
||||
|
||||
// when / then
|
||||
for (scenario, expected) in &cases {
|
||||
assert_eq!(scenario.to_string(), *expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multi_step_success_reports_correct_steps_taken() {
|
||||
// given — StaleBranch has 2 steps, no simulated failure
|
||||
let mut ctx = RecoveryContext::new();
|
||||
let scenario = FailureScenario::StaleBranch;
|
||||
|
||||
// when
|
||||
let result = attempt_recovery(&scenario, &mut ctx);
|
||||
|
||||
// then
|
||||
assert_eq!(result, RecoveryResult::Recovered { steps_taken: 2 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_handshake_recipe_uses_abort_escalation_policy() {
|
||||
// given
|
||||
let recipe = recipe_for(&FailureScenario::McpHandshakeFailure);
|
||||
|
||||
// then
|
||||
assert_eq!(recipe.escalation_policy, EscalationPolicy::Abort);
|
||||
assert_eq!(recipe.max_attempts, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn worker_failure_kind_maps_to_failure_scenario() {
|
||||
// given / when / then — verify the bridge is correct
|
||||
assert_eq!(
|
||||
FailureScenario::from_worker_failure_kind(WorkerFailureKind::TrustGate),
|
||||
FailureScenario::TrustPromptUnresolved,
|
||||
);
|
||||
assert_eq!(
|
||||
FailureScenario::from_worker_failure_kind(WorkerFailureKind::PromptDelivery),
|
||||
FailureScenario::PromptMisdelivery,
|
||||
);
|
||||
assert_eq!(
|
||||
FailureScenario::from_worker_failure_kind(WorkerFailureKind::Protocol),
|
||||
FailureScenario::McpHandshakeFailure,
|
||||
);
|
||||
assert_eq!(
|
||||
FailureScenario::from_worker_failure_kind(WorkerFailureKind::Provider),
|
||||
FailureScenario::ProviderFailure,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_failure_recipe_uses_restart_worker_step() {
|
||||
// given
|
||||
let recipe = recipe_for(&FailureScenario::ProviderFailure);
|
||||
|
||||
// then
|
||||
assert_eq!(recipe.scenario, FailureScenario::ProviderFailure);
|
||||
assert!(recipe.steps.contains(&RecoveryStep::RestartWorker));
|
||||
assert_eq!(recipe.escalation_policy, EscalationPolicy::AlertHuman);
|
||||
assert_eq!(recipe.max_attempts, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_failure_recovery_attempt_succeeds_then_escalates() {
|
||||
// given
|
||||
let mut ctx = RecoveryContext::new();
|
||||
let scenario = FailureScenario::ProviderFailure;
|
||||
|
||||
// when — first attempt
|
||||
let first = attempt_recovery(&scenario, &mut ctx);
|
||||
assert!(matches!(first, RecoveryResult::Recovered { .. }));
|
||||
|
||||
// when — second attempt should escalate (max_attempts=1)
|
||||
let second = attempt_recovery(&scenario, &mut ctx);
|
||||
assert!(matches!(second, RecoveryResult::EscalationRequired { .. }));
|
||||
assert!(ctx
|
||||
.events()
|
||||
.iter()
|
||||
.any(|e| matches!(e, RecoveryEvent::Escalated)));
|
||||
}
|
||||
}
|
||||
@ -72,9 +72,9 @@ impl RemoteSessionContext {
|
||||
#[must_use]
|
||||
pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self {
|
||||
Self {
|
||||
enabled: env_truthy(env_map.get("CLAW_CODE_REMOTE")),
|
||||
enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")),
|
||||
session_id: env_map
|
||||
.get("CLAW_CODE_REMOTE_SESSION_ID")
|
||||
.get("CLAUDE_CODE_REMOTE_SESSION_ID")
|
||||
.filter(|value| !value.is_empty())
|
||||
.cloned(),
|
||||
base_url: env_map
|
||||
@ -272,9 +272,9 @@ mod tests {
|
||||
#[test]
|
||||
fn remote_context_reads_env_state() {
|
||||
let env = BTreeMap::from([
|
||||
("CLAW_CODE_REMOTE".to_string(), "true".to_string()),
|
||||
("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()),
|
||||
(
|
||||
"CLAW_CODE_REMOTE_SESSION_ID".to_string(),
|
||||
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
|
||||
"session-123".to_string(),
|
||||
),
|
||||
(
|
||||
@ -291,7 +291,7 @@ mod tests {
|
||||
#[test]
|
||||
fn bootstrap_fails_open_when_token_or_session_is_missing() {
|
||||
let env = BTreeMap::from([
|
||||
("CLAW_CODE_REMOTE".to_string(), "1".to_string()),
|
||||
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
|
||||
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
|
||||
]);
|
||||
let bootstrap = UpstreamProxyBootstrap::from_env_map(&env);
|
||||
@ -307,10 +307,10 @@ mod tests {
|
||||
fs::write(&token_path, "secret-token\n").expect("write token");
|
||||
|
||||
let env = BTreeMap::from([
|
||||
("CLAW_CODE_REMOTE".to_string(), "1".to_string()),
|
||||
("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()),
|
||||
("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()),
|
||||
(
|
||||
"CLAW_CODE_REMOTE_SESSION_ID".to_string(),
|
||||
"CLAUDE_CODE_REMOTE_SESSION_ID".to_string(),
|
||||
"session-123".to_string(),
|
||||
),
|
||||
(
|
||||
|
||||
@ -107,23 +107,11 @@ impl SandboxConfig {
|
||||
|
||||
#[must_use]
|
||||
pub fn detect_container_environment() -> ContainerEnvironment {
|
||||
let proc_1_cgroup = if cfg!(target_os = "linux") {
|
||||
fs::read_to_string("/proc/1/cgroup").ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let proc_1_cgroup = fs::read_to_string("/proc/1/cgroup").ok();
|
||||
detect_container_environment_from(SandboxDetectionInputs {
|
||||
env_pairs: env::vars().collect(),
|
||||
dockerenv_exists: if cfg!(target_os = "linux") {
|
||||
Path::new("/.dockerenv").exists()
|
||||
} else {
|
||||
false
|
||||
},
|
||||
containerenv_exists: if cfg!(target_os = "linux") {
|
||||
Path::new("/run/.containerenv").exists()
|
||||
} else {
|
||||
false
|
||||
},
|
||||
dockerenv_exists: Path::new("/.dockerenv").exists(),
|
||||
containerenv_exists: Path::new("/run/.containerenv").exists(),
|
||||
proc_1_cgroup: proc_1_cgroup.as_deref(),
|
||||
})
|
||||
}
|
||||
@ -173,7 +161,7 @@ pub fn resolve_sandbox_status(config: &SandboxConfig, cwd: &Path) -> SandboxStat
|
||||
#[must_use]
|
||||
pub fn resolve_sandbox_status_for_request(request: &SandboxRequest, cwd: &Path) -> SandboxStatus {
|
||||
let container = detect_container_environment();
|
||||
let namespace_supported = cfg!(target_os = "linux") && command_exists("unshare");
|
||||
let namespace_supported = cfg!(target_os = "linux") && unshare_user_namespace_works();
|
||||
let network_supported = namespace_supported;
|
||||
let filesystem_active =
|
||||
request.enabled && request.filesystem_mode != FilesystemIsolationMode::Off;
|
||||
@ -294,6 +282,27 @@ fn command_exists(command: &str) -> bool {
|
||||
.is_some_and(|paths| env::split_paths(&paths).any(|path| path.join(command).exists()))
|
||||
}
|
||||
|
||||
/// Check whether `unshare --user` actually works on this system.
|
||||
/// On some CI environments (e.g. GitHub Actions), the binary exists but
|
||||
/// user namespaces are restricted, causing silent failures.
|
||||
fn unshare_user_namespace_works() -> bool {
|
||||
use std::sync::OnceLock;
|
||||
static RESULT: OnceLock<bool> = OnceLock::new();
|
||||
*RESULT.get_or_init(|| {
|
||||
if !command_exists("unshare") {
|
||||
return false;
|
||||
}
|
||||
std::process::Command::new("unshare")
|
||||
.args(["--user", "--map-root-user", "true"])
|
||||
.stdin(std::process::Stdio::null())
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.status()
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(false)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
873
crates/runtime/src/session_control.rs
Normal file
873
crates/runtime/src/session_control.rs
Normal file
@ -0,0 +1,873 @@
|
||||
#![allow(dead_code)]
|
||||
use std::env;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::UNIX_EPOCH;
|
||||
|
||||
use crate::session::{Session, SessionError};
|
||||
|
||||
/// Per-worktree session store that namespaces on-disk session files by
|
||||
/// workspace fingerprint so that parallel `opencode serve` instances never
|
||||
/// collide.
|
||||
///
|
||||
/// Create via [`SessionStore::from_cwd`] (derives the store path from the
|
||||
/// server's working directory) or [`SessionStore::from_data_dir`] (honours an
|
||||
/// explicit `--data-dir` flag). Both constructors produce a directory layout
|
||||
/// of `<data_dir>/sessions/<workspace_hash>/` where `<workspace_hash>` is a
|
||||
/// stable hex digest of the canonical workspace root.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionStore {
|
||||
/// Resolved root of the session namespace, e.g.
|
||||
/// `/home/user/project/.claw/sessions/a1b2c3d4e5f60718/`.
|
||||
sessions_root: PathBuf,
|
||||
/// The canonical workspace path that was fingerprinted.
|
||||
workspace_root: PathBuf,
|
||||
}
|
||||
|
||||
impl SessionStore {
|
||||
/// Build a store from the server's current working directory.
|
||||
///
|
||||
/// The on-disk layout becomes `<cwd>/.claw/sessions/<workspace_hash>/`.
|
||||
pub fn from_cwd(cwd: impl AsRef<Path>) -> Result<Self, SessionControlError> {
|
||||
let cwd = cwd.as_ref();
|
||||
let sessions_root = cwd
|
||||
.join(".claw")
|
||||
.join("sessions")
|
||||
.join(workspace_fingerprint(cwd));
|
||||
fs::create_dir_all(&sessions_root)?;
|
||||
Ok(Self {
|
||||
sessions_root,
|
||||
workspace_root: cwd.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Build a store from an explicit `--data-dir` flag.
|
||||
///
|
||||
/// The on-disk layout becomes `<data_dir>/sessions/<workspace_hash>/`
|
||||
/// where `<workspace_hash>` is derived from `workspace_root`.
|
||||
pub fn from_data_dir(
|
||||
data_dir: impl AsRef<Path>,
|
||||
workspace_root: impl AsRef<Path>,
|
||||
) -> Result<Self, SessionControlError> {
|
||||
let workspace_root = workspace_root.as_ref();
|
||||
let sessions_root = data_dir
|
||||
.as_ref()
|
||||
.join("sessions")
|
||||
.join(workspace_fingerprint(workspace_root));
|
||||
fs::create_dir_all(&sessions_root)?;
|
||||
Ok(Self {
|
||||
sessions_root,
|
||||
workspace_root: workspace_root.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
/// The fully resolved sessions directory for this namespace.
|
||||
#[must_use]
|
||||
pub fn sessions_dir(&self) -> &Path {
|
||||
&self.sessions_root
|
||||
}
|
||||
|
||||
/// The workspace root this store is bound to.
|
||||
#[must_use]
|
||||
pub fn workspace_root(&self) -> &Path {
|
||||
&self.workspace_root
|
||||
}
|
||||
|
||||
pub fn create_handle(&self, session_id: &str) -> SessionHandle {
|
||||
let id = session_id.to_string();
|
||||
let path = self
|
||||
.sessions_root
|
||||
.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}"));
|
||||
SessionHandle { id, path }
|
||||
}
|
||||
|
||||
pub fn resolve_reference(&self, reference: &str) -> Result<SessionHandle, SessionControlError> {
|
||||
if is_session_reference_alias(reference) {
|
||||
let latest = self.latest_session()?;
|
||||
return Ok(SessionHandle {
|
||||
id: latest.id,
|
||||
path: latest.path,
|
||||
});
|
||||
}
|
||||
|
||||
let direct = PathBuf::from(reference);
|
||||
let candidate = if direct.is_absolute() {
|
||||
direct.clone()
|
||||
} else {
|
||||
self.workspace_root.join(&direct)
|
||||
};
|
||||
let looks_like_path = direct.extension().is_some() || direct.components().count() > 1;
|
||||
let path = if candidate.exists() {
|
||||
candidate
|
||||
} else if looks_like_path {
|
||||
return Err(SessionControlError::Format(
|
||||
format_missing_session_reference(reference),
|
||||
));
|
||||
} else {
|
||||
self.resolve_managed_path(reference)?
|
||||
};
|
||||
|
||||
Ok(SessionHandle {
|
||||
id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()),
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resolve_managed_path(&self, session_id: &str) -> Result<PathBuf, SessionControlError> {
|
||||
for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] {
|
||||
let path = self.sessions_root.join(format!("{session_id}.{extension}"));
|
||||
if path.exists() {
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
Err(SessionControlError::Format(
|
||||
format_missing_session_reference(session_id),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn list_sessions(&self) -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
|
||||
let mut sessions = Vec::new();
|
||||
let read_result = fs::read_dir(&self.sessions_root);
|
||||
let entries = match read_result {
|
||||
Ok(entries) => entries,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(sessions),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if !is_managed_session_file(&path) {
|
||||
continue;
|
||||
}
|
||||
let metadata = entry.metadata()?;
|
||||
let modified_epoch_millis = metadata
|
||||
.modified()
|
||||
.ok()
|
||||
.and_then(|time| time.duration_since(UNIX_EPOCH).ok())
|
||||
.map(|duration| duration.as_millis())
|
||||
.unwrap_or_default();
|
||||
let (id, message_count, parent_session_id, branch_name) =
|
||||
match Session::load_from_path(&path) {
|
||||
Ok(session) => {
|
||||
let parent_session_id = session
|
||||
.fork
|
||||
.as_ref()
|
||||
.map(|fork| fork.parent_session_id.clone());
|
||||
let branch_name = session
|
||||
.fork
|
||||
.as_ref()
|
||||
.and_then(|fork| fork.branch_name.clone());
|
||||
(
|
||||
session.session_id,
|
||||
session.messages.len(),
|
||||
parent_session_id,
|
||||
branch_name,
|
||||
)
|
||||
}
|
||||
Err(_) => (
|
||||
path.file_stem()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string(),
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
};
|
||||
sessions.push(ManagedSessionSummary {
|
||||
id,
|
||||
path,
|
||||
modified_epoch_millis,
|
||||
message_count,
|
||||
parent_session_id,
|
||||
branch_name,
|
||||
});
|
||||
}
|
||||
sessions.sort_by(|left, right| {
|
||||
right
|
||||
.modified_epoch_millis
|
||||
.cmp(&left.modified_epoch_millis)
|
||||
.then_with(|| right.id.cmp(&left.id))
|
||||
});
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
pub fn latest_session(&self) -> Result<ManagedSessionSummary, SessionControlError> {
|
||||
self.list_sessions()?
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| SessionControlError::Format(format_no_managed_sessions()))
|
||||
}
|
||||
|
||||
pub fn load_session(
|
||||
&self,
|
||||
reference: &str,
|
||||
) -> Result<LoadedManagedSession, SessionControlError> {
|
||||
let handle = self.resolve_reference(reference)?;
|
||||
let session = Session::load_from_path(&handle.path)?;
|
||||
Ok(LoadedManagedSession {
|
||||
handle: SessionHandle {
|
||||
id: session.session_id.clone(),
|
||||
path: handle.path,
|
||||
},
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fork_session(
|
||||
&self,
|
||||
session: &Session,
|
||||
branch_name: Option<String>,
|
||||
) -> Result<ForkedManagedSession, SessionControlError> {
|
||||
let parent_session_id = session.session_id.clone();
|
||||
let forked = session.fork(branch_name);
|
||||
let handle = self.create_handle(&forked.session_id);
|
||||
let branch_name = forked
|
||||
.fork
|
||||
.as_ref()
|
||||
.and_then(|fork| fork.branch_name.clone());
|
||||
let forked = forked.with_persistence_path(handle.path.clone());
|
||||
forked.save_to_path(&handle.path)?;
|
||||
Ok(ForkedManagedSession {
|
||||
parent_session_id,
|
||||
handle,
|
||||
session: forked,
|
||||
branch_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Stable hex fingerprint of a workspace path.
|
||||
///
|
||||
/// Uses FNV-1a (64-bit) to produce a 16-char hex string that partitions the
|
||||
/// on-disk session directory per workspace root.
|
||||
#[must_use]
|
||||
pub fn workspace_fingerprint(workspace_root: &Path) -> String {
|
||||
let input = workspace_root.to_string_lossy();
|
||||
let mut hash = 0xcbf2_9ce4_8422_2325_u64;
|
||||
for byte in input.as_bytes() {
|
||||
hash ^= u64::from(*byte);
|
||||
hash = hash.wrapping_mul(0x0100_0000_01b3);
|
||||
}
|
||||
format!("{hash:016x}")
|
||||
}
|
||||
|
||||
pub const PRIMARY_SESSION_EXTENSION: &str = "jsonl";
|
||||
pub const LEGACY_SESSION_EXTENSION: &str = "json";
|
||||
pub const LATEST_SESSION_REFERENCE: &str = "latest";
|
||||
|
||||
const SESSION_REFERENCE_ALIASES: &[&str] = &[LATEST_SESSION_REFERENCE, "last", "recent"];
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionHandle {
|
||||
pub id: String,
|
||||
pub path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ManagedSessionSummary {
|
||||
pub id: String,
|
||||
pub path: PathBuf,
|
||||
pub modified_epoch_millis: u128,
|
||||
pub message_count: usize,
|
||||
pub parent_session_id: Option<String>,
|
||||
pub branch_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LoadedManagedSession {
|
||||
pub handle: SessionHandle,
|
||||
pub session: Session,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ForkedManagedSession {
|
||||
pub parent_session_id: String,
|
||||
pub handle: SessionHandle,
|
||||
pub session: Session,
|
||||
pub branch_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SessionControlError {
|
||||
Io(std::io::Error),
|
||||
Session(SessionError),
|
||||
Format(String),
|
||||
}
|
||||
|
||||
impl Display for SessionControlError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Io(error) => write!(f, "{error}"),
|
||||
Self::Session(error) => write!(f, "{error}"),
|
||||
Self::Format(error) => write!(f, "{error}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SessionControlError {}
|
||||
|
||||
impl From<std::io::Error> for SessionControlError {
|
||||
fn from(value: std::io::Error) -> Self {
|
||||
Self::Io(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SessionError> for SessionControlError {
|
||||
fn from(value: SessionError) -> Self {
|
||||
Self::Session(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sessions_dir() -> Result<PathBuf, SessionControlError> {
|
||||
managed_sessions_dir_for(env::current_dir()?)
|
||||
}
|
||||
|
||||
pub fn managed_sessions_dir_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
) -> Result<PathBuf, SessionControlError> {
|
||||
let path = base_dir.as_ref().join(".claw").join("sessions");
|
||||
fs::create_dir_all(&path)?;
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
pub fn create_managed_session_handle(
|
||||
session_id: &str,
|
||||
) -> Result<SessionHandle, SessionControlError> {
|
||||
create_managed_session_handle_for(env::current_dir()?, session_id)
|
||||
}
|
||||
|
||||
pub fn create_managed_session_handle_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
session_id: &str,
|
||||
) -> Result<SessionHandle, SessionControlError> {
|
||||
let id = session_id.to_string();
|
||||
let path =
|
||||
managed_sessions_dir_for(base_dir)?.join(format!("{id}.{PRIMARY_SESSION_EXTENSION}"));
|
||||
Ok(SessionHandle { id, path })
|
||||
}
|
||||
|
||||
pub fn resolve_session_reference(reference: &str) -> Result<SessionHandle, SessionControlError> {
|
||||
resolve_session_reference_for(env::current_dir()?, reference)
|
||||
}
|
||||
|
||||
pub fn resolve_session_reference_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
reference: &str,
|
||||
) -> Result<SessionHandle, SessionControlError> {
|
||||
let base_dir = base_dir.as_ref();
|
||||
if is_session_reference_alias(reference) {
|
||||
let latest = latest_managed_session_for(base_dir)?;
|
||||
return Ok(SessionHandle {
|
||||
id: latest.id,
|
||||
path: latest.path,
|
||||
});
|
||||
}
|
||||
|
||||
let direct = PathBuf::from(reference);
|
||||
let candidate = if direct.is_absolute() {
|
||||
direct.clone()
|
||||
} else {
|
||||
base_dir.join(&direct)
|
||||
};
|
||||
let looks_like_path = direct.extension().is_some() || direct.components().count() > 1;
|
||||
let path = if candidate.exists() {
|
||||
candidate
|
||||
} else if looks_like_path {
|
||||
return Err(SessionControlError::Format(
|
||||
format_missing_session_reference(reference),
|
||||
));
|
||||
} else {
|
||||
resolve_managed_session_path_for(base_dir, reference)?
|
||||
};
|
||||
|
||||
Ok(SessionHandle {
|
||||
id: session_id_from_path(&path).unwrap_or_else(|| reference.to_string()),
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resolve_managed_session_path(session_id: &str) -> Result<PathBuf, SessionControlError> {
|
||||
resolve_managed_session_path_for(env::current_dir()?, session_id)
|
||||
}
|
||||
|
||||
pub fn resolve_managed_session_path_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
session_id: &str,
|
||||
) -> Result<PathBuf, SessionControlError> {
|
||||
let directory = managed_sessions_dir_for(base_dir)?;
|
||||
for extension in [PRIMARY_SESSION_EXTENSION, LEGACY_SESSION_EXTENSION] {
|
||||
let path = directory.join(format!("{session_id}.{extension}"));
|
||||
if path.exists() {
|
||||
return Ok(path);
|
||||
}
|
||||
}
|
||||
Err(SessionControlError::Format(
|
||||
format_missing_session_reference(session_id),
|
||||
))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_managed_session_file(path: &Path) -> bool {
|
||||
path.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.is_some_and(|extension| {
|
||||
extension == PRIMARY_SESSION_EXTENSION || extension == LEGACY_SESSION_EXTENSION
|
||||
})
|
||||
}
|
||||
|
||||
pub fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
|
||||
list_managed_sessions_for(env::current_dir()?)
|
||||
}
|
||||
|
||||
pub fn list_managed_sessions_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
) -> Result<Vec<ManagedSessionSummary>, SessionControlError> {
|
||||
let mut sessions = Vec::new();
|
||||
for entry in fs::read_dir(managed_sessions_dir_for(base_dir)?)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if !is_managed_session_file(&path) {
|
||||
continue;
|
||||
}
|
||||
let metadata = entry.metadata()?;
|
||||
let modified_epoch_millis = metadata
|
||||
.modified()
|
||||
.ok()
|
||||
.and_then(|time| time.duration_since(UNIX_EPOCH).ok())
|
||||
.map(|duration| duration.as_millis())
|
||||
.unwrap_or_default();
|
||||
let (id, message_count, parent_session_id, branch_name) =
|
||||
match Session::load_from_path(&path) {
|
||||
Ok(session) => {
|
||||
let parent_session_id = session
|
||||
.fork
|
||||
.as_ref()
|
||||
.map(|fork| fork.parent_session_id.clone());
|
||||
let branch_name = session
|
||||
.fork
|
||||
.as_ref()
|
||||
.and_then(|fork| fork.branch_name.clone());
|
||||
(
|
||||
session.session_id,
|
||||
session.messages.len(),
|
||||
parent_session_id,
|
||||
branch_name,
|
||||
)
|
||||
}
|
||||
Err(_) => (
|
||||
path.file_stem()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string(),
|
||||
0,
|
||||
None,
|
||||
None,
|
||||
),
|
||||
};
|
||||
sessions.push(ManagedSessionSummary {
|
||||
id,
|
||||
path,
|
||||
modified_epoch_millis,
|
||||
message_count,
|
||||
parent_session_id,
|
||||
branch_name,
|
||||
});
|
||||
}
|
||||
sessions.sort_by(|left, right| {
|
||||
right
|
||||
.modified_epoch_millis
|
||||
.cmp(&left.modified_epoch_millis)
|
||||
.then_with(|| right.id.cmp(&left.id))
|
||||
});
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
pub fn latest_managed_session() -> Result<ManagedSessionSummary, SessionControlError> {
|
||||
latest_managed_session_for(env::current_dir()?)
|
||||
}
|
||||
|
||||
pub fn latest_managed_session_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
) -> Result<ManagedSessionSummary, SessionControlError> {
|
||||
list_managed_sessions_for(base_dir)?
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| SessionControlError::Format(format_no_managed_sessions()))
|
||||
}
|
||||
|
||||
pub fn load_managed_session(reference: &str) -> Result<LoadedManagedSession, SessionControlError> {
|
||||
load_managed_session_for(env::current_dir()?, reference)
|
||||
}
|
||||
|
||||
pub fn load_managed_session_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
reference: &str,
|
||||
) -> Result<LoadedManagedSession, SessionControlError> {
|
||||
let handle = resolve_session_reference_for(base_dir, reference)?;
|
||||
let session = Session::load_from_path(&handle.path)?;
|
||||
Ok(LoadedManagedSession {
|
||||
handle: SessionHandle {
|
||||
id: session.session_id.clone(),
|
||||
path: handle.path,
|
||||
},
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn fork_managed_session(
|
||||
session: &Session,
|
||||
branch_name: Option<String>,
|
||||
) -> Result<ForkedManagedSession, SessionControlError> {
|
||||
fork_managed_session_for(env::current_dir()?, session, branch_name)
|
||||
}
|
||||
|
||||
pub fn fork_managed_session_for(
|
||||
base_dir: impl AsRef<Path>,
|
||||
session: &Session,
|
||||
branch_name: Option<String>,
|
||||
) -> Result<ForkedManagedSession, SessionControlError> {
|
||||
let parent_session_id = session.session_id.clone();
|
||||
let forked = session.fork(branch_name);
|
||||
let handle = create_managed_session_handle_for(base_dir, &forked.session_id)?;
|
||||
let branch_name = forked
|
||||
.fork
|
||||
.as_ref()
|
||||
.and_then(|fork| fork.branch_name.clone());
|
||||
let forked = forked.with_persistence_path(handle.path.clone());
|
||||
forked.save_to_path(&handle.path)?;
|
||||
Ok(ForkedManagedSession {
|
||||
parent_session_id,
|
||||
handle,
|
||||
session: forked,
|
||||
branch_name,
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_session_reference_alias(reference: &str) -> bool {
|
||||
SESSION_REFERENCE_ALIASES
|
||||
.iter()
|
||||
.any(|alias| reference.eq_ignore_ascii_case(alias))
|
||||
}
|
||||
|
||||
fn session_id_from_path(path: &Path) -> Option<String> {
|
||||
path.file_name()
|
||||
.and_then(|value| value.to_str())
|
||||
.and_then(|name| {
|
||||
name.strip_suffix(&format!(".{PRIMARY_SESSION_EXTENSION}"))
|
||||
.or_else(|| name.strip_suffix(&format!(".{LEGACY_SESSION_EXTENSION}")))
|
||||
})
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
fn format_missing_session_reference(reference: &str) -> String {
|
||||
format!(
|
||||
"session not found: {reference}\nHint: managed sessions live in .claw/sessions/. Try `{LATEST_SESSION_REFERENCE}` for the most recent session or `/session list` in the REPL."
|
||||
)
|
||||
}
|
||||
|
||||
fn format_no_managed_sessions() -> String {
|
||||
format!(
|
||||
"no managed sessions found in .claw/sessions/\nStart `claw` to create a session, then rerun with `--resume {LATEST_SESSION_REFERENCE}`."
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
create_managed_session_handle_for, fork_managed_session_for, is_session_reference_alias,
|
||||
list_managed_sessions_for, load_managed_session_for, resolve_session_reference_for,
|
||||
workspace_fingerprint, ManagedSessionSummary, SessionStore, LATEST_SESSION_REFERENCE,
|
||||
};
|
||||
use crate::session::Session;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir() -> PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-session-control-{nanos}"))
|
||||
}
|
||||
|
||||
fn persist_session(root: &Path, text: &str) -> Session {
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text(text)
|
||||
.expect("session message should save");
|
||||
let handle = create_managed_session_handle_for(root, &session.session_id)
|
||||
.expect("managed session handle should build");
|
||||
let session = session.with_persistence_path(handle.path.clone());
|
||||
session
|
||||
.save_to_path(&handle.path)
|
||||
.expect("session should persist");
|
||||
session
|
||||
}
|
||||
|
||||
fn wait_for_next_millisecond() {
|
||||
let start = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_millis();
|
||||
while SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_millis()
|
||||
<= start
|
||||
{}
|
||||
}
|
||||
|
||||
fn summary_by_id<'a>(
|
||||
summaries: &'a [ManagedSessionSummary],
|
||||
id: &str,
|
||||
) -> &'a ManagedSessionSummary {
|
||||
summaries
|
||||
.iter()
|
||||
.find(|summary| summary.id == id)
|
||||
.expect("session summary should exist")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_and_lists_managed_sessions() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir should exist");
|
||||
let older = persist_session(&root, "older session");
|
||||
wait_for_next_millisecond();
|
||||
let newer = persist_session(&root, "newer session");
|
||||
|
||||
// when
|
||||
let sessions = list_managed_sessions_for(&root).expect("managed sessions should list");
|
||||
|
||||
// then
|
||||
assert_eq!(sessions.len(), 2);
|
||||
assert_eq!(sessions[0].id, newer.session_id);
|
||||
assert_eq!(summary_by_id(&sessions, &older.session_id).message_count, 1);
|
||||
assert_eq!(summary_by_id(&sessions, &newer.session_id).message_count, 1);
|
||||
fs::remove_dir_all(root).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolves_latest_alias_and_loads_session_from_workspace_root() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir should exist");
|
||||
let older = persist_session(&root, "older session");
|
||||
wait_for_next_millisecond();
|
||||
let newer = persist_session(&root, "newer session");
|
||||
|
||||
// when
|
||||
let handle = resolve_session_reference_for(&root, LATEST_SESSION_REFERENCE)
|
||||
.expect("latest alias should resolve");
|
||||
let loaded = load_managed_session_for(&root, "recent")
|
||||
.expect("recent alias should load the latest session");
|
||||
|
||||
// then
|
||||
assert_eq!(handle.id, newer.session_id);
|
||||
assert_eq!(loaded.handle.id, newer.session_id);
|
||||
assert_eq!(loaded.session.messages.len(), 1);
|
||||
assert_ne!(loaded.handle.id, older.session_id);
|
||||
assert!(is_session_reference_alias("last"));
|
||||
fs::remove_dir_all(root).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forks_session_into_managed_storage_with_lineage() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("root dir should exist");
|
||||
let source = persist_session(&root, "parent session");
|
||||
|
||||
// when
|
||||
let forked = fork_managed_session_for(&root, &source, Some("incident-review".to_string()))
|
||||
.expect("session should fork");
|
||||
let sessions = list_managed_sessions_for(&root).expect("managed sessions should list");
|
||||
let summary = summary_by_id(&sessions, &forked.handle.id);
|
||||
|
||||
// then
|
||||
assert_eq!(forked.parent_session_id, source.session_id);
|
||||
assert_eq!(forked.branch_name.as_deref(), Some("incident-review"));
|
||||
assert_eq!(
|
||||
summary.parent_session_id.as_deref(),
|
||||
Some(source.session_id.as_str())
|
||||
);
|
||||
assert_eq!(summary.branch_name.as_deref(), Some("incident-review"));
|
||||
assert_eq!(
|
||||
forked.session.persistence_path(),
|
||||
Some(forked.handle.path.as_path())
|
||||
);
|
||||
fs::remove_dir_all(root).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Per-worktree session isolation (SessionStore) tests
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
fn persist_session_via_store(store: &SessionStore, text: &str) -> Session {
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text(text)
|
||||
.expect("session message should save");
|
||||
let handle = store.create_handle(&session.session_id);
|
||||
let session = session.with_persistence_path(handle.path.clone());
|
||||
session
|
||||
.save_to_path(&handle.path)
|
||||
.expect("session should persist");
|
||||
session
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn workspace_fingerprint_is_deterministic_and_differs_per_path() {
|
||||
// given
|
||||
let path_a = Path::new("/tmp/worktree-alpha");
|
||||
let path_b = Path::new("/tmp/worktree-beta");
|
||||
|
||||
// when
|
||||
let fp_a1 = workspace_fingerprint(path_a);
|
||||
let fp_a2 = workspace_fingerprint(path_a);
|
||||
let fp_b = workspace_fingerprint(path_b);
|
||||
|
||||
// then
|
||||
assert_eq!(fp_a1, fp_a2, "same path must produce the same fingerprint");
|
||||
assert_ne!(
|
||||
fp_a1, fp_b,
|
||||
"different paths must produce different fingerprints"
|
||||
);
|
||||
assert_eq!(fp_a1.len(), 16, "fingerprint must be a 16-char hex string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_from_cwd_isolates_sessions_by_workspace() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
let workspace_a = base.join("repo-alpha");
|
||||
let workspace_b = base.join("repo-beta");
|
||||
fs::create_dir_all(&workspace_a).expect("workspace a should exist");
|
||||
fs::create_dir_all(&workspace_b).expect("workspace b should exist");
|
||||
|
||||
let store_a = SessionStore::from_cwd(&workspace_a).expect("store a should build");
|
||||
let store_b = SessionStore::from_cwd(&workspace_b).expect("store b should build");
|
||||
|
||||
// when
|
||||
let session_a = persist_session_via_store(&store_a, "alpha work");
|
||||
let _session_b = persist_session_via_store(&store_b, "beta work");
|
||||
|
||||
// then — each store only sees its own sessions
|
||||
let list_a = store_a.list_sessions().expect("list a");
|
||||
let list_b = store_b.list_sessions().expect("list b");
|
||||
assert_eq!(list_a.len(), 1, "store a should see exactly one session");
|
||||
assert_eq!(list_b.len(), 1, "store b should see exactly one session");
|
||||
assert_eq!(list_a[0].id, session_a.session_id);
|
||||
assert_ne!(
|
||||
store_a.sessions_dir(),
|
||||
store_b.sessions_dir(),
|
||||
"session directories must differ across workspaces"
|
||||
);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_from_data_dir_namespaces_by_workspace() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
let data_dir = base.join("global-data");
|
||||
let workspace_a = PathBuf::from("/tmp/project-one");
|
||||
let workspace_b = PathBuf::from("/tmp/project-two");
|
||||
fs::create_dir_all(&data_dir).expect("data dir should exist");
|
||||
|
||||
let store_a =
|
||||
SessionStore::from_data_dir(&data_dir, &workspace_a).expect("store a should build");
|
||||
let store_b =
|
||||
SessionStore::from_data_dir(&data_dir, &workspace_b).expect("store b should build");
|
||||
|
||||
// when
|
||||
persist_session_via_store(&store_a, "work in project-one");
|
||||
persist_session_via_store(&store_b, "work in project-two");
|
||||
|
||||
// then
|
||||
assert_ne!(
|
||||
store_a.sessions_dir(),
|
||||
store_b.sessions_dir(),
|
||||
"data-dir stores must namespace by workspace"
|
||||
);
|
||||
assert_eq!(store_a.list_sessions().expect("list a").len(), 1);
|
||||
assert_eq!(store_b.list_sessions().expect("list b").len(), 1);
|
||||
assert_eq!(store_a.workspace_root(), workspace_a.as_path());
|
||||
assert_eq!(store_b.workspace_root(), workspace_b.as_path());
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_create_and_load_round_trip() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
fs::create_dir_all(&base).expect("base dir should exist");
|
||||
let store = SessionStore::from_cwd(&base).expect("store should build");
|
||||
let session = persist_session_via_store(&store, "round-trip message");
|
||||
|
||||
// when
|
||||
let loaded = store
|
||||
.load_session(&session.session_id)
|
||||
.expect("session should load via store");
|
||||
|
||||
// then
|
||||
assert_eq!(loaded.handle.id, session.session_id);
|
||||
assert_eq!(loaded.session.messages.len(), 1);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_latest_and_resolve_reference() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
fs::create_dir_all(&base).expect("base dir should exist");
|
||||
let store = SessionStore::from_cwd(&base).expect("store should build");
|
||||
let _older = persist_session_via_store(&store, "older");
|
||||
wait_for_next_millisecond();
|
||||
let newer = persist_session_via_store(&store, "newer");
|
||||
|
||||
// when
|
||||
let latest = store.latest_session().expect("latest should resolve");
|
||||
let handle = store
|
||||
.resolve_reference("latest")
|
||||
.expect("latest alias should resolve");
|
||||
|
||||
// then
|
||||
assert_eq!(latest.id, newer.session_id);
|
||||
assert_eq!(handle.id, newer.session_id);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_store_fork_stays_in_same_namespace() {
|
||||
// given
|
||||
let base = temp_dir();
|
||||
fs::create_dir_all(&base).expect("base dir should exist");
|
||||
let store = SessionStore::from_cwd(&base).expect("store should build");
|
||||
let source = persist_session_via_store(&store, "parent work");
|
||||
|
||||
// when
|
||||
let forked = store
|
||||
.fork_session(&source, Some("bugfix".to_string()))
|
||||
.expect("fork should succeed");
|
||||
let sessions = store.list_sessions().expect("list sessions");
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
sessions.len(),
|
||||
2,
|
||||
"forked session must land in the same namespace"
|
||||
);
|
||||
assert_eq!(forked.parent_session_id, source.session_id);
|
||||
assert_eq!(forked.branch_name.as_deref(), Some("bugfix"));
|
||||
assert!(
|
||||
forked.handle.path.starts_with(store.sessions_dir()),
|
||||
"forked session path must be inside the store namespace"
|
||||
);
|
||||
fs::remove_dir_all(base).expect("temp dir should clean up");
|
||||
}
|
||||
}
|
||||
@ -80,7 +80,11 @@ impl IncrementalSseParser {
|
||||
}
|
||||
|
||||
fn take_event(&mut self) -> Option<SseEvent> {
|
||||
if self.data_lines.is_empty() && self.event_name.is_none() && self.id.is_none() && self.retry.is_none() {
|
||||
if self.data_lines.is_empty()
|
||||
&& self.event_name.is_none()
|
||||
&& self.id.is_none()
|
||||
&& self.retry.is_none()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
@ -102,8 +106,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parses_streaming_events() {
|
||||
// given
|
||||
let mut parser = IncrementalSseParser::new();
|
||||
|
||||
// when
|
||||
let first = parser.push_chunk("event: message\ndata: hel");
|
||||
|
||||
// then
|
||||
assert!(first.is_empty());
|
||||
|
||||
let second = parser.push_chunk("lo\n\nid: 1\ndata: world\n\n");
|
||||
@ -125,4 +134,25 @@ mod tests {
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn finish_flushes_a_trailing_event_without_separator() {
|
||||
// given
|
||||
let mut parser = IncrementalSseParser::new();
|
||||
parser.push_chunk("event: message\ndata: trailing");
|
||||
|
||||
// when
|
||||
let events = parser.finish();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
events,
|
||||
vec![SseEvent {
|
||||
event: Some("message".to_string()),
|
||||
data: "trailing".to_string(),
|
||||
id: None,
|
||||
retry: None,
|
||||
}]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
429
crates/runtime/src/stale_base.rs
Normal file
429
crates/runtime/src/stale_base.rs
Normal file
@ -0,0 +1,429 @@
|
||||
#![allow(clippy::must_use_candidate)]
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
/// Outcome of comparing the worktree HEAD against the expected base commit.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BaseCommitState {
|
||||
/// HEAD matches the expected base commit.
|
||||
Matches,
|
||||
/// HEAD has diverged from the expected base.
|
||||
Diverged { expected: String, actual: String },
|
||||
/// No expected base was supplied (neither flag nor file).
|
||||
NoExpectedBase,
|
||||
/// The working directory is not inside a git repository.
|
||||
NotAGitRepo,
|
||||
}
|
||||
|
||||
/// Where the expected base commit originated from.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BaseCommitSource {
|
||||
Flag(String),
|
||||
File(String),
|
||||
}
|
||||
|
||||
/// Read the `.claw-base` file from the given directory and return the trimmed
|
||||
/// commit hash, or `None` when the file is absent or empty.
|
||||
pub fn read_claw_base_file(cwd: &Path) -> Option<String> {
|
||||
let path = cwd.join(".claw-base");
|
||||
let content = std::fs::read_to_string(path).ok()?;
|
||||
let trimmed = content.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the expected base commit: prefer the `--base-commit` flag value,
|
||||
/// fall back to reading `.claw-base` from `cwd`.
|
||||
pub fn resolve_expected_base(flag_value: Option<&str>, cwd: &Path) -> Option<BaseCommitSource> {
|
||||
if let Some(value) = flag_value {
|
||||
let trimmed = value.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(BaseCommitSource::Flag(trimmed.to_string()));
|
||||
}
|
||||
}
|
||||
read_claw_base_file(cwd).map(BaseCommitSource::File)
|
||||
}
|
||||
|
||||
/// Verify that the worktree HEAD matches `expected_base`.
|
||||
///
|
||||
/// Returns [`BaseCommitState::NoExpectedBase`] when no expected commit is
|
||||
/// provided (the check is effectively a no-op in that case).
|
||||
pub fn check_base_commit(cwd: &Path, expected_base: Option<&BaseCommitSource>) -> BaseCommitState {
|
||||
let Some(source) = expected_base else {
|
||||
return BaseCommitState::NoExpectedBase;
|
||||
};
|
||||
let expected_raw = match source {
|
||||
BaseCommitSource::Flag(value) | BaseCommitSource::File(value) => value.as_str(),
|
||||
};
|
||||
|
||||
let Some(head_sha) = resolve_head_sha(cwd) else {
|
||||
return BaseCommitState::NotAGitRepo;
|
||||
};
|
||||
|
||||
let Some(expected_sha) = resolve_rev(cwd, expected_raw) else {
|
||||
// If the expected ref cannot be resolved, compare raw strings as a
|
||||
// best-effort fallback (e.g. partial SHA provided by the caller).
|
||||
return if head_sha.starts_with(expected_raw) || expected_raw.starts_with(&head_sha) {
|
||||
BaseCommitState::Matches
|
||||
} else {
|
||||
BaseCommitState::Diverged {
|
||||
expected: expected_raw.to_string(),
|
||||
actual: head_sha,
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
if head_sha == expected_sha {
|
||||
BaseCommitState::Matches
|
||||
} else {
|
||||
BaseCommitState::Diverged {
|
||||
expected: expected_sha,
|
||||
actual: head_sha,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a human-readable warning when the base commit has diverged.
|
||||
///
|
||||
/// Returns `None` for non-warning states (`Matches`, `NoExpectedBase`).
|
||||
pub fn format_stale_base_warning(state: &BaseCommitState) -> Option<String> {
|
||||
match state {
|
||||
BaseCommitState::Diverged { expected, actual } => Some(format!(
|
||||
"warning: worktree HEAD ({actual}) does not match expected base commit ({expected}). \
|
||||
Session may run against a stale codebase."
|
||||
)),
|
||||
BaseCommitState::NotAGitRepo => {
|
||||
Some("warning: stale-base check skipped — not inside a git repository.".to_string())
|
||||
}
|
||||
BaseCommitState::Matches | BaseCommitState::NoExpectedBase => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_head_sha(cwd: &Path) -> Option<String> {
|
||||
resolve_rev(cwd, "HEAD")
|
||||
}
|
||||
|
||||
fn resolve_rev(cwd: &Path, rev: &str) -> Option<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", rev])
|
||||
.current_dir(cwd)
|
||||
.output()
|
||||
.ok()?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
let sha = String::from_utf8(output.stdout).ok()?;
|
||||
let trimmed = sha.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use std::process::Command;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir() -> std::path::PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-stale-base-{nanos}"))
|
||||
}
|
||||
|
||||
fn init_repo(path: &std::path::Path) {
|
||||
fs::create_dir_all(path).expect("create repo dir");
|
||||
run(path, &["init", "--quiet", "-b", "main"]);
|
||||
run(path, &["config", "user.email", "tests@example.com"]);
|
||||
run(path, &["config", "user.name", "Stale Base Tests"]);
|
||||
fs::write(path.join("init.txt"), "initial\n").expect("write init file");
|
||||
run(path, &["add", "."]);
|
||||
run(path, &["commit", "-m", "initial commit", "--quiet"]);
|
||||
}
|
||||
|
||||
fn run(cwd: &std::path::Path, args: &[&str]) {
|
||||
let status = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.status()
|
||||
.unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" ")));
|
||||
assert!(
|
||||
status.success(),
|
||||
"git {} exited with {status}",
|
||||
args.join(" ")
|
||||
);
|
||||
}
|
||||
|
||||
fn commit_file(repo: &std::path::Path, name: &str, msg: &str) {
|
||||
fs::write(repo.join(name), format!("{msg}\n")).expect("write file");
|
||||
run(repo, &["add", name]);
|
||||
run(repo, &["commit", "-m", msg, "--quiet"]);
|
||||
}
|
||||
|
||||
fn head_sha(repo: &std::path::Path) -> String {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-parse", "HEAD"])
|
||||
.current_dir(repo)
|
||||
.output()
|
||||
.expect("git rev-parse HEAD");
|
||||
String::from_utf8(output.stdout)
|
||||
.expect("valid utf8")
|
||||
.trim()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_when_head_equals_expected_base() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let sha = head_sha(&root);
|
||||
let source = BaseCommitSource::Flag(sha);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, Some(&source));
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::Matches);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn diverged_when_head_moved_past_expected_base() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let old_sha = head_sha(&root);
|
||||
commit_file(&root, "extra.txt", "move head forward");
|
||||
let new_sha = head_sha(&root);
|
||||
let source = BaseCommitSource::Flag(old_sha.clone());
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, Some(&source));
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
state,
|
||||
BaseCommitState::Diverged {
|
||||
expected: old_sha,
|
||||
actual: new_sha,
|
||||
}
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_expected_base_when_source_is_none() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, None);
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::NoExpectedBase);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_a_git_repo_when_outside_repo() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
let source = BaseCommitSource::Flag("abc1234".to_string());
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, Some(&source));
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::NotAGitRepo);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reads_claw_base_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), "abc1234def5678\n").expect("write .claw-base");
|
||||
|
||||
// when
|
||||
let value = read_claw_base_file(&root);
|
||||
|
||||
// then
|
||||
assert_eq!(value, Some("abc1234def5678".to_string()));
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_missing_claw_base_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
|
||||
// when
|
||||
let value = read_claw_base_file(&root);
|
||||
|
||||
// then
|
||||
assert!(value.is_none());
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_for_empty_claw_base_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), " \n").expect("write empty .claw-base");
|
||||
|
||||
// when
|
||||
let value = read_claw_base_file(&root);
|
||||
|
||||
// then
|
||||
assert!(value.is_none());
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_expected_base_prefers_flag_over_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base");
|
||||
|
||||
// when
|
||||
let source = resolve_expected_base(Some("from_flag"), &root);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
source,
|
||||
Some(BaseCommitSource::Flag("from_flag".to_string()))
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_expected_base_falls_back_to_file() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
fs::write(root.join(".claw-base"), "from_file\n").expect("write .claw-base");
|
||||
|
||||
// when
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
source,
|
||||
Some(BaseCommitSource::File("from_file".to_string()))
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_expected_base_returns_none_when_nothing_available() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create dir");
|
||||
|
||||
// when
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// then
|
||||
assert!(source.is_none());
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_warning_returns_message_for_diverged() {
|
||||
// given
|
||||
let state = BaseCommitState::Diverged {
|
||||
expected: "abc1234".to_string(),
|
||||
actual: "def5678".to_string(),
|
||||
};
|
||||
|
||||
// when
|
||||
let warning = format_stale_base_warning(&state);
|
||||
|
||||
// then
|
||||
let message = warning.expect("should produce warning");
|
||||
assert!(message.contains("abc1234"));
|
||||
assert!(message.contains("def5678"));
|
||||
assert!(message.contains("stale codebase"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_warning_returns_none_for_matches() {
|
||||
// given
|
||||
let state = BaseCommitState::Matches;
|
||||
|
||||
// when
|
||||
let warning = format_stale_base_warning(&state);
|
||||
|
||||
// then
|
||||
assert!(warning.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_warning_returns_none_for_no_expected_base() {
|
||||
// given
|
||||
let state = BaseCommitState::NoExpectedBase;
|
||||
|
||||
// when
|
||||
let warning = format_stale_base_warning(&state);
|
||||
|
||||
// then
|
||||
assert!(warning.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matches_with_claw_base_file_in_real_repo() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let sha = head_sha(&root);
|
||||
fs::write(root.join(".claw-base"), format!("{sha}\n")).expect("write .claw-base");
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, source.as_ref());
|
||||
|
||||
// then
|
||||
assert_eq!(state, BaseCommitState::Matches);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn diverged_with_claw_base_file_after_new_commit() {
|
||||
// given
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
let old_sha = head_sha(&root);
|
||||
fs::write(root.join(".claw-base"), format!("{old_sha}\n")).expect("write .claw-base");
|
||||
commit_file(&root, "new.txt", "advance head");
|
||||
let new_sha = head_sha(&root);
|
||||
let source = resolve_expected_base(None, &root);
|
||||
|
||||
// when
|
||||
let state = check_base_commit(&root, source.as_ref());
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
state,
|
||||
BaseCommitState::Diverged {
|
||||
expected: old_sha,
|
||||
actual: new_sha,
|
||||
}
|
||||
);
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
}
|
||||
417
crates/runtime/src/stale_branch.rs
Normal file
417
crates/runtime/src/stale_branch.rs
Normal file
@ -0,0 +1,417 @@
|
||||
#![allow(clippy::must_use_candidate)]
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BranchFreshness {
|
||||
Fresh,
|
||||
Stale {
|
||||
commits_behind: usize,
|
||||
missing_fixes: Vec<String>,
|
||||
},
|
||||
Diverged {
|
||||
ahead: usize,
|
||||
behind: usize,
|
||||
missing_fixes: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StaleBranchPolicy {
|
||||
AutoRebase,
|
||||
AutoMergeForward,
|
||||
WarnOnly,
|
||||
Block,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum StaleBranchEvent {
|
||||
BranchStaleAgainstMain {
|
||||
branch: String,
|
||||
commits_behind: usize,
|
||||
missing_fixes: Vec<String>,
|
||||
},
|
||||
RebaseAttempted {
|
||||
branch: String,
|
||||
result: String,
|
||||
},
|
||||
MergeForwardAttempted {
|
||||
branch: String,
|
||||
result: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum StaleBranchAction {
|
||||
Noop,
|
||||
Warn { message: String },
|
||||
Block { message: String },
|
||||
Rebase,
|
||||
MergeForward,
|
||||
}
|
||||
|
||||
pub fn check_freshness(branch: &str, main_ref: &str) -> BranchFreshness {
|
||||
check_freshness_in(branch, main_ref, Path::new("."))
|
||||
}
|
||||
|
||||
pub fn apply_policy(freshness: &BranchFreshness, policy: StaleBranchPolicy) -> StaleBranchAction {
|
||||
match freshness {
|
||||
BranchFreshness::Fresh => StaleBranchAction::Noop,
|
||||
BranchFreshness::Stale {
|
||||
commits_behind,
|
||||
missing_fixes,
|
||||
} => match policy {
|
||||
StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn {
|
||||
message: format!(
|
||||
"Branch is {commits_behind} commit(s) behind main. Missing fixes: {}",
|
||||
if missing_fixes.is_empty() {
|
||||
"(none)".to_string()
|
||||
} else {
|
||||
missing_fixes.join("; ")
|
||||
}
|
||||
),
|
||||
},
|
||||
StaleBranchPolicy::Block => StaleBranchAction::Block {
|
||||
message: format!(
|
||||
"Branch is {commits_behind} commit(s) behind main and must be updated before proceeding."
|
||||
),
|
||||
},
|
||||
StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase,
|
||||
StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward,
|
||||
},
|
||||
BranchFreshness::Diverged {
|
||||
ahead,
|
||||
behind,
|
||||
missing_fixes,
|
||||
} => match policy {
|
||||
StaleBranchPolicy::WarnOnly => StaleBranchAction::Warn {
|
||||
message: format!(
|
||||
"Branch has diverged: {ahead} commit(s) ahead, {behind} commit(s) behind main. Missing fixes: {}",
|
||||
format_missing_fixes(missing_fixes)
|
||||
),
|
||||
},
|
||||
StaleBranchPolicy::Block => StaleBranchAction::Block {
|
||||
message: format!(
|
||||
"Branch has diverged ({ahead} ahead, {behind} behind) and must be reconciled before proceeding. Missing fixes: {}",
|
||||
format_missing_fixes(missing_fixes)
|
||||
),
|
||||
},
|
||||
StaleBranchPolicy::AutoRebase => StaleBranchAction::Rebase,
|
||||
StaleBranchPolicy::AutoMergeForward => StaleBranchAction::MergeForward,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn check_freshness_in(
|
||||
branch: &str,
|
||||
main_ref: &str,
|
||||
repo_path: &Path,
|
||||
) -> BranchFreshness {
|
||||
let behind = rev_list_count(main_ref, branch, repo_path);
|
||||
let ahead = rev_list_count(branch, main_ref, repo_path);
|
||||
|
||||
if behind == 0 {
|
||||
return BranchFreshness::Fresh;
|
||||
}
|
||||
|
||||
if ahead > 0 {
|
||||
return BranchFreshness::Diverged {
|
||||
ahead,
|
||||
behind,
|
||||
missing_fixes: missing_fix_subjects(main_ref, branch, repo_path),
|
||||
};
|
||||
}
|
||||
|
||||
let missing_fixes = missing_fix_subjects(main_ref, branch, repo_path);
|
||||
BranchFreshness::Stale {
|
||||
commits_behind: behind,
|
||||
missing_fixes,
|
||||
}
|
||||
}
|
||||
|
||||
fn format_missing_fixes(missing_fixes: &[String]) -> String {
|
||||
if missing_fixes.is_empty() {
|
||||
"(none)".to_string()
|
||||
} else {
|
||||
missing_fixes.join("; ")
|
||||
}
|
||||
}
|
||||
|
||||
fn rev_list_count(a: &str, b: &str, repo_path: &Path) -> usize {
|
||||
let output = Command::new("git")
|
||||
.args(["rev-list", "--count", &format!("{b}..{a}")])
|
||||
.current_dir(repo_path)
|
||||
.output();
|
||||
match output {
|
||||
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout)
|
||||
.trim()
|
||||
.parse::<usize>()
|
||||
.unwrap_or(0),
|
||||
_ => 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn missing_fix_subjects(a: &str, b: &str, repo_path: &Path) -> Vec<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["log", "--format=%s", &format!("{b}..{a}")])
|
||||
.current_dir(repo_path)
|
||||
.output();
|
||||
match output {
|
||||
Ok(o) if o.status.success() => String::from_utf8_lossy(&o.stdout)
|
||||
.lines()
|
||||
.filter(|l| !l.is_empty())
|
||||
.map(String::from)
|
||||
.collect(),
|
||||
_ => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn temp_dir() -> std::path::PathBuf {
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("time should be after epoch")
|
||||
.as_nanos();
|
||||
std::env::temp_dir().join(format!("runtime-stale-branch-{nanos}"))
|
||||
}
|
||||
|
||||
fn init_repo(path: &Path) {
|
||||
fs::create_dir_all(path).expect("create repo dir");
|
||||
run(path, &["init", "--quiet", "-b", "main"]);
|
||||
run(path, &["config", "user.email", "tests@example.com"]);
|
||||
run(path, &["config", "user.name", "Stale Branch Tests"]);
|
||||
fs::write(path.join("init.txt"), "initial\n").expect("write init file");
|
||||
run(path, &["add", "."]);
|
||||
run(path, &["commit", "-m", "initial commit", "--quiet"]);
|
||||
}
|
||||
|
||||
fn run(cwd: &Path, args: &[&str]) {
|
||||
let status = Command::new("git")
|
||||
.args(args)
|
||||
.current_dir(cwd)
|
||||
.status()
|
||||
.unwrap_or_else(|e| panic!("git {} failed to execute: {e}", args.join(" ")));
|
||||
assert!(
|
||||
status.success(),
|
||||
"git {} exited with {status}",
|
||||
args.join(" ")
|
||||
);
|
||||
}
|
||||
|
||||
fn commit_file(repo: &Path, name: &str, msg: &str) {
|
||||
fs::write(repo.join(name), format!("{msg}\n")).expect("write file");
|
||||
run(repo, &["add", name]);
|
||||
run(repo, &["commit", "-m", msg, "--quiet"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fresh_branch_passes() {
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
|
||||
// given
|
||||
run(&root, &["checkout", "-b", "topic"]);
|
||||
|
||||
// when
|
||||
let freshness = check_freshness_in("topic", "main", &root);
|
||||
|
||||
// then
|
||||
assert_eq!(freshness, BranchFreshness::Fresh);
|
||||
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fresh_branch_ahead_of_main_still_fresh() {
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
|
||||
// given
|
||||
run(&root, &["checkout", "-b", "topic"]);
|
||||
commit_file(&root, "feature.txt", "add feature");
|
||||
|
||||
// when
|
||||
let freshness = check_freshness_in("topic", "main", &root);
|
||||
|
||||
// then
|
||||
assert_eq!(freshness, BranchFreshness::Fresh);
|
||||
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_branch_detected_with_correct_behind_count_and_missing_fixes() {
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
|
||||
// given
|
||||
run(&root, &["checkout", "-b", "topic"]);
|
||||
run(&root, &["checkout", "main"]);
|
||||
commit_file(&root, "fix1.txt", "fix: resolve timeout");
|
||||
commit_file(&root, "fix2.txt", "fix: handle null pointer");
|
||||
|
||||
// when
|
||||
let freshness = check_freshness_in("topic", "main", &root);
|
||||
|
||||
// then
|
||||
match freshness {
|
||||
BranchFreshness::Stale {
|
||||
commits_behind,
|
||||
missing_fixes,
|
||||
} => {
|
||||
assert_eq!(commits_behind, 2);
|
||||
assert_eq!(missing_fixes.len(), 2);
|
||||
assert_eq!(missing_fixes[0], "fix: handle null pointer");
|
||||
assert_eq!(missing_fixes[1], "fix: resolve timeout");
|
||||
}
|
||||
other => panic!("expected Stale, got {other:?}"),
|
||||
}
|
||||
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn diverged_branch_detection() {
|
||||
let root = temp_dir();
|
||||
init_repo(&root);
|
||||
|
||||
// given
|
||||
run(&root, &["checkout", "-b", "topic"]);
|
||||
commit_file(&root, "topic_work.txt", "topic work");
|
||||
run(&root, &["checkout", "main"]);
|
||||
commit_file(&root, "main_fix.txt", "main fix");
|
||||
|
||||
// when
|
||||
let freshness = check_freshness_in("topic", "main", &root);
|
||||
|
||||
// then
|
||||
match freshness {
|
||||
BranchFreshness::Diverged {
|
||||
ahead,
|
||||
behind,
|
||||
missing_fixes,
|
||||
} => {
|
||||
assert_eq!(ahead, 1);
|
||||
assert_eq!(behind, 1);
|
||||
assert_eq!(missing_fixes, vec!["main fix".to_string()]);
|
||||
}
|
||||
other => panic!("expected Diverged, got {other:?}"),
|
||||
}
|
||||
|
||||
fs::remove_dir_all(&root).expect("cleanup");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_noop_for_fresh_branch() {
|
||||
// given
|
||||
let freshness = BranchFreshness::Fresh;
|
||||
|
||||
// when
|
||||
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
|
||||
|
||||
// then
|
||||
assert_eq!(action, StaleBranchAction::Noop);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_warn_for_stale_branch() {
|
||||
// given
|
||||
let freshness = BranchFreshness::Stale {
|
||||
commits_behind: 3,
|
||||
missing_fixes: vec!["fix: timeout".into(), "fix: null ptr".into()],
|
||||
};
|
||||
|
||||
// when
|
||||
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
|
||||
|
||||
// then
|
||||
match action {
|
||||
StaleBranchAction::Warn { message } => {
|
||||
assert!(message.contains("3 commit(s) behind"));
|
||||
assert!(message.contains("fix: timeout"));
|
||||
assert!(message.contains("fix: null ptr"));
|
||||
}
|
||||
other => panic!("expected Warn, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_block_for_stale_branch() {
|
||||
// given
|
||||
let freshness = BranchFreshness::Stale {
|
||||
commits_behind: 1,
|
||||
missing_fixes: vec!["hotfix".into()],
|
||||
};
|
||||
|
||||
// when
|
||||
let action = apply_policy(&freshness, StaleBranchPolicy::Block);
|
||||
|
||||
// then
|
||||
match action {
|
||||
StaleBranchAction::Block { message } => {
|
||||
assert!(message.contains("1 commit(s) behind"));
|
||||
}
|
||||
other => panic!("expected Block, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_auto_rebase_for_stale_branch() {
|
||||
// given
|
||||
let freshness = BranchFreshness::Stale {
|
||||
commits_behind: 2,
|
||||
missing_fixes: vec![],
|
||||
};
|
||||
|
||||
// when
|
||||
let action = apply_policy(&freshness, StaleBranchPolicy::AutoRebase);
|
||||
|
||||
// then
|
||||
assert_eq!(action, StaleBranchAction::Rebase);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_auto_merge_forward_for_diverged_branch() {
|
||||
// given
|
||||
let freshness = BranchFreshness::Diverged {
|
||||
ahead: 5,
|
||||
behind: 2,
|
||||
missing_fixes: vec!["fix: merge main".into()],
|
||||
};
|
||||
|
||||
// when
|
||||
let action = apply_policy(&freshness, StaleBranchPolicy::AutoMergeForward);
|
||||
|
||||
// then
|
||||
assert_eq!(action, StaleBranchAction::MergeForward);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_warn_for_diverged_branch() {
|
||||
// given
|
||||
let freshness = BranchFreshness::Diverged {
|
||||
ahead: 3,
|
||||
behind: 1,
|
||||
missing_fixes: vec!["main hotfix".into()],
|
||||
};
|
||||
|
||||
// when
|
||||
let action = apply_policy(&freshness, StaleBranchPolicy::WarnOnly);
|
||||
|
||||
// then
|
||||
match action {
|
||||
StaleBranchAction::Warn { message } => {
|
||||
assert!(message.contains("diverged"));
|
||||
assert!(message.contains("3 commit(s) ahead"));
|
||||
assert!(message.contains("1 commit(s) behind"));
|
||||
assert!(message.contains("main hotfix"));
|
||||
}
|
||||
other => panic!("expected Warn, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
300
crates/runtime/src/summary_compression.rs
Normal file
300
crates/runtime/src/summary_compression.rs
Normal file
@ -0,0 +1,300 @@
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
const DEFAULT_MAX_CHARS: usize = 1_200;
|
||||
const DEFAULT_MAX_LINES: usize = 24;
|
||||
const DEFAULT_MAX_LINE_CHARS: usize = 160;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SummaryCompressionBudget {
|
||||
pub max_chars: usize,
|
||||
pub max_lines: usize,
|
||||
pub max_line_chars: usize,
|
||||
}
|
||||
|
||||
impl Default for SummaryCompressionBudget {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_chars: DEFAULT_MAX_CHARS,
|
||||
max_lines: DEFAULT_MAX_LINES,
|
||||
max_line_chars: DEFAULT_MAX_LINE_CHARS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SummaryCompressionResult {
|
||||
pub summary: String,
|
||||
pub original_chars: usize,
|
||||
pub compressed_chars: usize,
|
||||
pub original_lines: usize,
|
||||
pub compressed_lines: usize,
|
||||
pub removed_duplicate_lines: usize,
|
||||
pub omitted_lines: usize,
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn compress_summary(
|
||||
summary: &str,
|
||||
budget: SummaryCompressionBudget,
|
||||
) -> SummaryCompressionResult {
|
||||
let original_chars = summary.chars().count();
|
||||
let original_lines = summary.lines().count();
|
||||
|
||||
let normalized = normalize_lines(summary, budget.max_line_chars);
|
||||
if normalized.lines.is_empty() || budget.max_chars == 0 || budget.max_lines == 0 {
|
||||
return SummaryCompressionResult {
|
||||
summary: String::new(),
|
||||
original_chars,
|
||||
compressed_chars: 0,
|
||||
original_lines,
|
||||
compressed_lines: 0,
|
||||
removed_duplicate_lines: normalized.removed_duplicate_lines,
|
||||
omitted_lines: normalized.lines.len(),
|
||||
truncated: original_chars > 0,
|
||||
};
|
||||
}
|
||||
|
||||
let selected = select_line_indexes(&normalized.lines, budget);
|
||||
let mut compressed_lines = selected
|
||||
.iter()
|
||||
.map(|index| normalized.lines[*index].clone())
|
||||
.collect::<Vec<_>>();
|
||||
if compressed_lines.is_empty() {
|
||||
compressed_lines.push(truncate_line(&normalized.lines[0], budget.max_chars));
|
||||
}
|
||||
let omitted_lines = normalized
|
||||
.lines
|
||||
.len()
|
||||
.saturating_sub(compressed_lines.len());
|
||||
|
||||
if omitted_lines > 0 {
|
||||
let omission_notice = omission_notice(omitted_lines);
|
||||
push_line_with_budget(&mut compressed_lines, omission_notice, budget);
|
||||
}
|
||||
|
||||
let compressed_summary = compressed_lines.join("\n");
|
||||
|
||||
SummaryCompressionResult {
|
||||
compressed_chars: compressed_summary.chars().count(),
|
||||
compressed_lines: compressed_lines.len(),
|
||||
removed_duplicate_lines: normalized.removed_duplicate_lines,
|
||||
omitted_lines,
|
||||
truncated: compressed_summary != summary.trim(),
|
||||
summary: compressed_summary,
|
||||
original_chars,
|
||||
original_lines,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn compress_summary_text(summary: &str) -> String {
|
||||
compress_summary(summary, SummaryCompressionBudget::default()).summary
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct NormalizedSummary {
|
||||
lines: Vec<String>,
|
||||
removed_duplicate_lines: usize,
|
||||
}
|
||||
|
||||
fn normalize_lines(summary: &str, max_line_chars: usize) -> NormalizedSummary {
|
||||
let mut seen = BTreeSet::new();
|
||||
let mut lines = Vec::new();
|
||||
let mut removed_duplicate_lines = 0;
|
||||
|
||||
for raw_line in summary.lines() {
|
||||
let normalized = collapse_inline_whitespace(raw_line);
|
||||
if normalized.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let truncated = truncate_line(&normalized, max_line_chars);
|
||||
let dedupe_key = dedupe_key(&truncated);
|
||||
if !seen.insert(dedupe_key) {
|
||||
removed_duplicate_lines += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
lines.push(truncated);
|
||||
}
|
||||
|
||||
NormalizedSummary {
|
||||
lines,
|
||||
removed_duplicate_lines,
|
||||
}
|
||||
}
|
||||
|
||||
fn select_line_indexes(lines: &[String], budget: SummaryCompressionBudget) -> Vec<usize> {
|
||||
let mut selected = BTreeSet::<usize>::new();
|
||||
|
||||
for priority in 0..=3 {
|
||||
for (index, line) in lines.iter().enumerate() {
|
||||
if selected.contains(&index) || line_priority(line) != priority {
|
||||
continue;
|
||||
}
|
||||
|
||||
let candidate = selected
|
||||
.iter()
|
||||
.map(|selected_index| lines[*selected_index].as_str())
|
||||
.chain(std::iter::once(line.as_str()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if candidate.len() > budget.max_lines {
|
||||
continue;
|
||||
}
|
||||
|
||||
if joined_char_count(&candidate) > budget.max_chars {
|
||||
continue;
|
||||
}
|
||||
|
||||
selected.insert(index);
|
||||
}
|
||||
}
|
||||
|
||||
selected.into_iter().collect()
|
||||
}
|
||||
|
||||
fn push_line_with_budget(lines: &mut Vec<String>, line: String, budget: SummaryCompressionBudget) {
|
||||
let candidate = lines
|
||||
.iter()
|
||||
.map(String::as_str)
|
||||
.chain(std::iter::once(line.as_str()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if candidate.len() <= budget.max_lines && joined_char_count(&candidate) <= budget.max_chars {
|
||||
lines.push(line);
|
||||
}
|
||||
}
|
||||
|
||||
fn joined_char_count(lines: &[&str]) -> usize {
|
||||
lines.iter().map(|line| line.chars().count()).sum::<usize>() + lines.len().saturating_sub(1)
|
||||
}
|
||||
|
||||
fn line_priority(line: &str) -> usize {
|
||||
if line == "Summary:" || line == "Conversation summary:" || is_core_detail(line) {
|
||||
0
|
||||
} else if is_section_header(line) {
|
||||
1
|
||||
} else if line.starts_with("- ") || line.starts_with(" - ") {
|
||||
2
|
||||
} else {
|
||||
3
|
||||
}
|
||||
}
|
||||
|
||||
fn is_core_detail(line: &str) -> bool {
|
||||
[
|
||||
"- Scope:",
|
||||
"- Current work:",
|
||||
"- Pending work:",
|
||||
"- Key files referenced:",
|
||||
"- Tools mentioned:",
|
||||
"- Recent user requests:",
|
||||
"- Previously compacted context:",
|
||||
"- Newly compacted context:",
|
||||
]
|
||||
.iter()
|
||||
.any(|prefix| line.starts_with(prefix))
|
||||
}
|
||||
|
||||
fn is_section_header(line: &str) -> bool {
|
||||
line.ends_with(':')
|
||||
}
|
||||
|
||||
fn omission_notice(omitted_lines: usize) -> String {
|
||||
format!("- … {omitted_lines} additional line(s) omitted.")
|
||||
}
|
||||
|
||||
fn collapse_inline_whitespace(line: &str) -> String {
|
||||
line.split_whitespace().collect::<Vec<_>>().join(" ")
|
||||
}
|
||||
|
||||
fn truncate_line(line: &str, max_chars: usize) -> String {
|
||||
if max_chars == 0 || line.chars().count() <= max_chars {
|
||||
return line.to_string();
|
||||
}
|
||||
|
||||
if max_chars == 1 {
|
||||
return "…".to_string();
|
||||
}
|
||||
|
||||
let mut truncated = line
|
||||
.chars()
|
||||
.take(max_chars.saturating_sub(1))
|
||||
.collect::<String>();
|
||||
truncated.push('…');
|
||||
truncated
|
||||
}
|
||||
|
||||
fn dedupe_key(line: &str) -> String {
|
||||
line.to_ascii_lowercase()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{compress_summary, compress_summary_text, SummaryCompressionBudget};
|
||||
|
||||
#[test]
|
||||
fn collapses_whitespace_and_duplicate_lines() {
|
||||
// given
|
||||
let summary = "Conversation summary:\n\n- Scope: compact earlier messages.\n- Scope: compact earlier messages.\n- Current work: update runtime module.\n";
|
||||
|
||||
// when
|
||||
let result = compress_summary(summary, SummaryCompressionBudget::default());
|
||||
|
||||
// then
|
||||
assert_eq!(result.removed_duplicate_lines, 1);
|
||||
assert!(result
|
||||
.summary
|
||||
.contains("- Scope: compact earlier messages."));
|
||||
assert!(!result.summary.contains(" compact earlier"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keeps_core_lines_when_budget_is_tight() {
|
||||
// given
|
||||
let summary = [
|
||||
"Conversation summary:",
|
||||
"- Scope: 18 earlier messages compacted.",
|
||||
"- Current work: finish summary compression.",
|
||||
"- Key timeline:",
|
||||
" - user: asked for a working implementation.",
|
||||
" - assistant: inspected runtime compaction flow.",
|
||||
" - tool: cargo check succeeded.",
|
||||
]
|
||||
.join("\n");
|
||||
|
||||
// when
|
||||
let result = compress_summary(
|
||||
&summary,
|
||||
SummaryCompressionBudget {
|
||||
max_chars: 120,
|
||||
max_lines: 3,
|
||||
max_line_chars: 80,
|
||||
},
|
||||
);
|
||||
|
||||
// then
|
||||
assert!(result.summary.contains("Conversation summary:"));
|
||||
assert!(result
|
||||
.summary
|
||||
.contains("- Scope: 18 earlier messages compacted."));
|
||||
assert!(result
|
||||
.summary
|
||||
.contains("- Current work: finish summary compression."));
|
||||
assert!(result.omitted_lines > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provides_a_default_text_only_helper() {
|
||||
// given
|
||||
let summary = "Summary:\n\nA short line.";
|
||||
|
||||
// when
|
||||
let compressed = compress_summary_text(summary);
|
||||
|
||||
// then
|
||||
assert_eq!(compressed, "Summary:\nA short line.");
|
||||
}
|
||||
}
|
||||
158
crates/runtime/src/task_packet.rs
Normal file
158
crates/runtime/src/task_packet.rs
Normal file
@ -0,0 +1,158 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct TaskPacket {
|
||||
pub objective: String,
|
||||
pub scope: String,
|
||||
pub repo: String,
|
||||
pub branch_policy: String,
|
||||
pub acceptance_tests: Vec<String>,
|
||||
pub commit_policy: String,
|
||||
pub reporting_contract: String,
|
||||
pub escalation_policy: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TaskPacketValidationError {
|
||||
errors: Vec<String>,
|
||||
}
|
||||
|
||||
impl TaskPacketValidationError {
|
||||
#[must_use]
|
||||
pub fn new(errors: Vec<String>) -> Self {
|
||||
Self { errors }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn errors(&self) -> &[String] {
|
||||
&self.errors
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TaskPacketValidationError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.errors.join("; "))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for TaskPacketValidationError {}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ValidatedPacket(TaskPacket);
|
||||
|
||||
impl ValidatedPacket {
|
||||
#[must_use]
|
||||
pub fn packet(&self) -> &TaskPacket {
|
||||
&self.0
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn into_inner(self) -> TaskPacket {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_packet(packet: TaskPacket) -> Result<ValidatedPacket, TaskPacketValidationError> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
validate_required("objective", &packet.objective, &mut errors);
|
||||
validate_required("scope", &packet.scope, &mut errors);
|
||||
validate_required("repo", &packet.repo, &mut errors);
|
||||
validate_required("branch_policy", &packet.branch_policy, &mut errors);
|
||||
validate_required("commit_policy", &packet.commit_policy, &mut errors);
|
||||
validate_required(
|
||||
"reporting_contract",
|
||||
&packet.reporting_contract,
|
||||
&mut errors,
|
||||
);
|
||||
validate_required("escalation_policy", &packet.escalation_policy, &mut errors);
|
||||
|
||||
for (index, test) in packet.acceptance_tests.iter().enumerate() {
|
||||
if test.trim().is_empty() {
|
||||
errors.push(format!(
|
||||
"acceptance_tests contains an empty value at index {index}"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(ValidatedPacket(packet))
|
||||
} else {
|
||||
Err(TaskPacketValidationError::new(errors))
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_required(field: &str, value: &str, errors: &mut Vec<String>) {
|
||||
if value.trim().is_empty() {
|
||||
errors.push(format!("{field} must not be empty"));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn sample_packet() -> TaskPacket {
|
||||
TaskPacket {
|
||||
objective: "Implement typed task packet format".to_string(),
|
||||
scope: "runtime/task system".to_string(),
|
||||
repo: "claw-code-parity".to_string(),
|
||||
branch_policy: "origin/main only".to_string(),
|
||||
acceptance_tests: vec![
|
||||
"cargo build --workspace".to_string(),
|
||||
"cargo test --workspace".to_string(),
|
||||
],
|
||||
commit_policy: "single verified commit".to_string(),
|
||||
reporting_contract: "print build result, test result, commit sha".to_string(),
|
||||
escalation_policy: "stop only on destructive ambiguity".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_packet_passes_validation() {
|
||||
let packet = sample_packet();
|
||||
let validated = validate_packet(packet.clone()).expect("packet should validate");
|
||||
assert_eq!(validated.packet(), &packet);
|
||||
assert_eq!(validated.into_inner(), packet);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_packet_accumulates_errors() {
|
||||
let packet = TaskPacket {
|
||||
objective: " ".to_string(),
|
||||
scope: String::new(),
|
||||
repo: String::new(),
|
||||
branch_policy: "\t".to_string(),
|
||||
acceptance_tests: vec!["ok".to_string(), " ".to_string()],
|
||||
commit_policy: String::new(),
|
||||
reporting_contract: String::new(),
|
||||
escalation_policy: String::new(),
|
||||
};
|
||||
|
||||
let error = validate_packet(packet).expect_err("packet should be rejected");
|
||||
|
||||
assert!(error.errors().len() >= 7);
|
||||
assert!(error
|
||||
.errors()
|
||||
.contains(&"objective must not be empty".to_string()));
|
||||
assert!(error
|
||||
.errors()
|
||||
.contains(&"scope must not be empty".to_string()));
|
||||
assert!(error
|
||||
.errors()
|
||||
.contains(&"repo must not be empty".to_string()));
|
||||
assert!(error
|
||||
.errors()
|
||||
.contains(&"acceptance_tests contains an empty value at index 1".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialization_roundtrip_preserves_packet() {
|
||||
let packet = sample_packet();
|
||||
let serialized = serde_json::to_string(&packet).expect("packet should serialize");
|
||||
let deserialized: TaskPacket =
|
||||
serde_json::from_str(&serialized).expect("packet should deserialize");
|
||||
assert_eq!(deserialized, packet);
|
||||
}
|
||||
}
|
||||
503
crates/runtime/src/task_registry.rs
Normal file
503
crates/runtime/src/task_registry.rs
Normal file
@ -0,0 +1,503 @@
|
||||
#![allow(clippy::must_use_candidate, clippy::unnecessary_map_or)]
|
||||
//! In-memory task registry for sub-agent task lifecycle management.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{validate_packet, TaskPacket, TaskPacketValidationError};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TaskStatus {
|
||||
Created,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
Stopped,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TaskStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Created => write!(f, "created"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Failed => write!(f, "failed"),
|
||||
Self::Stopped => write!(f, "stopped"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Task {
|
||||
pub task_id: String,
|
||||
pub prompt: String,
|
||||
pub description: Option<String>,
|
||||
pub task_packet: Option<TaskPacket>,
|
||||
pub status: TaskStatus,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
pub messages: Vec<TaskMessage>,
|
||||
pub output: String,
|
||||
pub team_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TaskMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
pub timestamp: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TaskRegistry {
|
||||
inner: Arc<Mutex<RegistryInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct RegistryInner {
|
||||
tasks: HashMap<String, Task>,
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
impl TaskRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn create(&self, prompt: &str, description: Option<&str>) -> Task {
|
||||
self.create_task(prompt.to_owned(), description.map(str::to_owned), None)
|
||||
}
|
||||
|
||||
pub fn create_from_packet(
|
||||
&self,
|
||||
packet: TaskPacket,
|
||||
) -> Result<Task, TaskPacketValidationError> {
|
||||
let packet = validate_packet(packet)?.into_inner();
|
||||
Ok(self.create_task(
|
||||
packet.objective.clone(),
|
||||
Some(packet.scope.clone()),
|
||||
Some(packet),
|
||||
))
|
||||
}
|
||||
|
||||
fn create_task(
|
||||
&self,
|
||||
prompt: String,
|
||||
description: Option<String>,
|
||||
task_packet: Option<TaskPacket>,
|
||||
) -> Task {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
let ts = now_secs();
|
||||
let task_id = format!("task_{:08x}_{}", ts, inner.counter);
|
||||
let task = Task {
|
||||
task_id: task_id.clone(),
|
||||
prompt,
|
||||
description,
|
||||
task_packet,
|
||||
status: TaskStatus::Created,
|
||||
created_at: ts,
|
||||
updated_at: ts,
|
||||
messages: Vec::new(),
|
||||
output: String::new(),
|
||||
team_id: None,
|
||||
};
|
||||
inner.tasks.insert(task_id, task.clone());
|
||||
task
|
||||
}
|
||||
|
||||
pub fn get(&self, task_id: &str) -> Option<Task> {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.tasks.get(task_id).cloned()
|
||||
}
|
||||
|
||||
pub fn list(&self, status_filter: Option<TaskStatus>) -> Vec<Task> {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner
|
||||
.tasks
|
||||
.values()
|
||||
.filter(|t| status_filter.map_or(true, |s| t.status == s))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn stop(&self, task_id: &str) -> Result<Task, String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
|
||||
match task.status {
|
||||
TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Stopped => {
|
||||
return Err(format!(
|
||||
"task {task_id} is already in terminal state: {}",
|
||||
task.status
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
task.status = TaskStatus::Stopped;
|
||||
task.updated_at = now_secs();
|
||||
Ok(task.clone())
|
||||
}
|
||||
|
||||
pub fn update(&self, task_id: &str, message: &str) -> Result<Task, String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
|
||||
task.messages.push(TaskMessage {
|
||||
role: String::from("user"),
|
||||
content: message.to_owned(),
|
||||
timestamp: now_secs(),
|
||||
});
|
||||
task.updated_at = now_secs();
|
||||
Ok(task.clone())
|
||||
}
|
||||
|
||||
pub fn output(&self, task_id: &str) -> Result<String, String> {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
Ok(task.output.clone())
|
||||
}
|
||||
|
||||
pub fn append_output(&self, task_id: &str, output: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
task.output.push_str(output);
|
||||
task.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_status(&self, task_id: &str, status: TaskStatus) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
task.status = status;
|
||||
task.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn assign_team(&self, task_id: &str, team_id: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
let task = inner
|
||||
.tasks
|
||||
.get_mut(task_id)
|
||||
.ok_or_else(|| format!("task not found: {task_id}"))?;
|
||||
task.team_id = Some(team_id.to_owned());
|
||||
task.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn remove(&self, task_id: &str) -> Option<Task> {
|
||||
let mut inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.tasks.remove(task_id)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("registry lock poisoned");
|
||||
inner.tasks.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn creates_and_retrieves_tasks() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Do something", Some("A test task"));
|
||||
assert_eq!(task.status, TaskStatus::Created);
|
||||
assert_eq!(task.prompt, "Do something");
|
||||
assert_eq!(task.description.as_deref(), Some("A test task"));
|
||||
assert_eq!(task.task_packet, None);
|
||||
|
||||
let fetched = registry.get(&task.task_id).expect("task should exist");
|
||||
assert_eq!(fetched.task_id, task.task_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_task_from_packet() {
|
||||
let registry = TaskRegistry::new();
|
||||
let packet = TaskPacket {
|
||||
objective: "Ship task packet support".to_string(),
|
||||
scope: "runtime/task system".to_string(),
|
||||
repo: "claw-code-parity".to_string(),
|
||||
branch_policy: "origin/main only".to_string(),
|
||||
acceptance_tests: vec!["cargo test --workspace".to_string()],
|
||||
commit_policy: "single commit".to_string(),
|
||||
reporting_contract: "print commit sha".to_string(),
|
||||
escalation_policy: "manual escalation".to_string(),
|
||||
};
|
||||
|
||||
let task = registry
|
||||
.create_from_packet(packet.clone())
|
||||
.expect("packet-backed task should be created");
|
||||
|
||||
assert_eq!(task.prompt, packet.objective);
|
||||
assert_eq!(task.description.as_deref(), Some("runtime/task system"));
|
||||
assert_eq!(task.task_packet, Some(packet.clone()));
|
||||
|
||||
let fetched = registry.get(&task.task_id).expect("task should exist");
|
||||
assert_eq!(fetched.task_packet, Some(packet));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_tasks_with_optional_filter() {
|
||||
let registry = TaskRegistry::new();
|
||||
registry.create("Task A", None);
|
||||
let task_b = registry.create("Task B", None);
|
||||
registry
|
||||
.set_status(&task_b.task_id, TaskStatus::Running)
|
||||
.expect("set status should succeed");
|
||||
|
||||
let all = registry.list(None);
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let running = registry.list(Some(TaskStatus::Running));
|
||||
assert_eq!(running.len(), 1);
|
||||
assert_eq!(running[0].task_id, task_b.task_id);
|
||||
|
||||
let created = registry.list(Some(TaskStatus::Created));
|
||||
assert_eq!(created.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stops_running_task() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Stoppable", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Running)
|
||||
.unwrap();
|
||||
|
||||
let stopped = registry.stop(&task.task_id).expect("stop should succeed");
|
||||
assert_eq!(stopped.status, TaskStatus::Stopped);
|
||||
|
||||
// Stopping again should fail
|
||||
let result = registry.stop(&task.task_id);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn updates_task_with_messages() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Messageable", None);
|
||||
let updated = registry
|
||||
.update(&task.task_id, "Here's more context")
|
||||
.expect("update should succeed");
|
||||
assert_eq!(updated.messages.len(), 1);
|
||||
assert_eq!(updated.messages[0].content, "Here's more context");
|
||||
assert_eq!(updated.messages[0].role, "user");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn appends_and_retrieves_output() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Output task", None);
|
||||
registry
|
||||
.append_output(&task.task_id, "line 1\n")
|
||||
.expect("append should succeed");
|
||||
registry
|
||||
.append_output(&task.task_id, "line 2\n")
|
||||
.expect("append should succeed");
|
||||
|
||||
let output = registry.output(&task.task_id).expect("output should exist");
|
||||
assert_eq!(output, "line 1\nline 2\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assigns_team_and_removes_task() {
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("Team task", None);
|
||||
registry
|
||||
.assign_team(&task.task_id, "team_abc")
|
||||
.expect("assign should succeed");
|
||||
|
||||
let fetched = registry.get(&task.task_id).unwrap();
|
||||
assert_eq!(fetched.team_id.as_deref(), Some("team_abc"));
|
||||
|
||||
let removed = registry.remove(&task.task_id);
|
||||
assert!(removed.is_some());
|
||||
assert!(registry.get(&task.task_id).is_none());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_operations_on_missing_task() {
|
||||
let registry = TaskRegistry::new();
|
||||
assert!(registry.stop("nonexistent").is_err());
|
||||
assert!(registry.update("nonexistent", "msg").is_err());
|
||||
assert!(registry.output("nonexistent").is_err());
|
||||
assert!(registry.append_output("nonexistent", "data").is_err());
|
||||
assert!(registry
|
||||
.set_status("nonexistent", TaskStatus::Running)
|
||||
.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn task_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(TaskStatus::Created, "created"),
|
||||
(TaskStatus::Running, "running"),
|
||||
(TaskStatus::Completed, "completed"),
|
||||
(TaskStatus::Failed, "failed"),
|
||||
(TaskStatus::Stopped, "stopped"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("created".to_string(), "created"),
|
||||
("running".to_string(), "running"),
|
||||
("completed".to_string(), "completed"),
|
||||
("failed".to_string(), "failed"),
|
||||
("stopped".to_string(), "stopped"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_rejects_completed_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("done", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Completed)
|
||||
.expect("set status should succeed");
|
||||
|
||||
// when
|
||||
let result = registry.stop(&task.task_id);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("completed task should be rejected");
|
||||
assert!(error.contains("already in terminal state"));
|
||||
assert!(error.contains("completed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_rejects_failed_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("failed", None);
|
||||
registry
|
||||
.set_status(&task.task_id, TaskStatus::Failed)
|
||||
.expect("set status should succeed");
|
||||
|
||||
// when
|
||||
let result = registry.stop(&task.task_id);
|
||||
|
||||
// then
|
||||
let error = result.expect_err("failed task should be rejected");
|
||||
assert!(error.contains("already in terminal state"));
|
||||
assert!(error.contains("failed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stop_succeeds_from_created_state() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
let task = registry.create("created task", None);
|
||||
|
||||
// when
|
||||
let stopped = registry.stop(&task.task_id).expect("stop should succeed");
|
||||
|
||||
// then
|
||||
assert_eq!(stopped.status, TaskStatus::Stopped);
|
||||
assert!(stopped.updated_at >= task.updated_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_registry_is_empty() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let all_tasks = registry.list(None);
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(all_tasks.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_without_description() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let task = registry.create("Do the thing", None);
|
||||
|
||||
// then
|
||||
assert!(task.task_id.starts_with("task_"));
|
||||
assert_eq!(task.description, None);
|
||||
assert_eq!(task.task_packet, None);
|
||||
assert!(task.messages.is_empty());
|
||||
assert!(task.output.is_empty());
|
||||
assert_eq!(task.team_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_nonexistent_returns_none() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.remove("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assign_team_rejects_missing_task() {
|
||||
// given
|
||||
let registry = TaskRegistry::new();
|
||||
|
||||
// when
|
||||
let result = registry.assign_team("missing", "team_123");
|
||||
|
||||
// then
|
||||
let error = result.expect_err("missing task should be rejected");
|
||||
assert_eq!(error, "task not found: missing");
|
||||
}
|
||||
}
|
||||
509
crates/runtime/src/team_cron_registry.rs
Normal file
509
crates/runtime/src/team_cron_registry.rs
Normal file
@ -0,0 +1,509 @@
|
||||
#![allow(clippy::must_use_candidate)]
|
||||
//! In-memory registries for Team and Cron lifecycle management.
|
||||
//!
|
||||
//! Provides TeamCreate/Delete and CronCreate/Delete/List runtime backing
|
||||
//! to replace the stub implementations in the tools crate.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
fn now_secs() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Team {
|
||||
pub team_id: String,
|
||||
pub name: String,
|
||||
pub task_ids: Vec<String>,
|
||||
pub status: TeamStatus,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TeamStatus {
|
||||
Created,
|
||||
Running,
|
||||
Completed,
|
||||
Deleted,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TeamStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Created => write!(f, "created"),
|
||||
Self::Running => write!(f, "running"),
|
||||
Self::Completed => write!(f, "completed"),
|
||||
Self::Deleted => write!(f, "deleted"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TeamRegistry {
|
||||
inner: Arc<Mutex<TeamInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct TeamInner {
|
||||
teams: HashMap<String, Team>,
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
impl TeamRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn create(&self, name: &str, task_ids: Vec<String>) -> Team {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
let ts = now_secs();
|
||||
let team_id = format!("team_{:08x}_{}", ts, inner.counter);
|
||||
let team = Team {
|
||||
team_id: team_id.clone(),
|
||||
name: name.to_owned(),
|
||||
task_ids,
|
||||
status: TeamStatus::Created,
|
||||
created_at: ts,
|
||||
updated_at: ts,
|
||||
};
|
||||
inner.teams.insert(team_id, team.clone());
|
||||
team
|
||||
}
|
||||
|
||||
pub fn get(&self, team_id: &str) -> Option<Team> {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.get(team_id).cloned()
|
||||
}
|
||||
|
||||
pub fn list(&self) -> Vec<Team> {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn delete(&self, team_id: &str) -> Result<Team, String> {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
let team = inner
|
||||
.teams
|
||||
.get_mut(team_id)
|
||||
.ok_or_else(|| format!("team not found: {team_id}"))?;
|
||||
team.status = TeamStatus::Deleted;
|
||||
team.updated_at = now_secs();
|
||||
Ok(team.clone())
|
||||
}
|
||||
|
||||
pub fn remove(&self, team_id: &str) -> Option<Team> {
|
||||
let mut inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.remove(team_id)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("team registry lock poisoned");
|
||||
inner.teams.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CronEntry {
|
||||
pub cron_id: String,
|
||||
pub schedule: String,
|
||||
pub prompt: String,
|
||||
pub description: Option<String>,
|
||||
pub enabled: bool,
|
||||
pub created_at: u64,
|
||||
pub updated_at: u64,
|
||||
pub last_run_at: Option<u64>,
|
||||
pub run_count: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CronRegistry {
|
||||
inner: Arc<Mutex<CronInner>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct CronInner {
|
||||
entries: HashMap<String, CronEntry>,
|
||||
counter: u64,
|
||||
}
|
||||
|
||||
impl CronRegistry {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn create(&self, schedule: &str, prompt: &str, description: Option<&str>) -> CronEntry {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.counter += 1;
|
||||
let ts = now_secs();
|
||||
let cron_id = format!("cron_{:08x}_{}", ts, inner.counter);
|
||||
let entry = CronEntry {
|
||||
cron_id: cron_id.clone(),
|
||||
schedule: schedule.to_owned(),
|
||||
prompt: prompt.to_owned(),
|
||||
description: description.map(str::to_owned),
|
||||
enabled: true,
|
||||
created_at: ts,
|
||||
updated_at: ts,
|
||||
last_run_at: None,
|
||||
run_count: 0,
|
||||
};
|
||||
inner.entries.insert(cron_id, entry.clone());
|
||||
entry
|
||||
}
|
||||
|
||||
pub fn get(&self, cron_id: &str) -> Option<CronEntry> {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.entries.get(cron_id).cloned()
|
||||
}
|
||||
|
||||
pub fn list(&self, enabled_only: bool) -> Vec<CronEntry> {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner
|
||||
.entries
|
||||
.values()
|
||||
.filter(|e| !enabled_only || e.enabled)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn delete(&self, cron_id: &str) -> Result<CronEntry, String> {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner
|
||||
.entries
|
||||
.remove(cron_id)
|
||||
.ok_or_else(|| format!("cron not found: {cron_id}"))
|
||||
}
|
||||
|
||||
/// Disable a cron entry without removing it.
|
||||
pub fn disable(&self, cron_id: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
let entry = inner
|
||||
.entries
|
||||
.get_mut(cron_id)
|
||||
.ok_or_else(|| format!("cron not found: {cron_id}"))?;
|
||||
entry.enabled = false;
|
||||
entry.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Record a cron run.
|
||||
pub fn record_run(&self, cron_id: &str) -> Result<(), String> {
|
||||
let mut inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
let entry = inner
|
||||
.entries
|
||||
.get_mut(cron_id)
|
||||
.ok_or_else(|| format!("cron not found: {cron_id}"))?;
|
||||
entry.last_run_at = Some(now_secs());
|
||||
entry.run_count += 1;
|
||||
entry.updated_at = now_secs();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn len(&self) -> usize {
|
||||
let inner = self.inner.lock().expect("cron registry lock poisoned");
|
||||
inner.entries.len()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── Team tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn creates_and_retrieves_team() {
|
||||
let registry = TeamRegistry::new();
|
||||
let team = registry.create("Alpha Squad", vec!["task_001".into(), "task_002".into()]);
|
||||
assert_eq!(team.name, "Alpha Squad");
|
||||
assert_eq!(team.task_ids.len(), 2);
|
||||
assert_eq!(team.status, TeamStatus::Created);
|
||||
|
||||
let fetched = registry.get(&team.team_id).expect("team should exist");
|
||||
assert_eq!(fetched.team_id, team.team_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_and_deletes_teams() {
|
||||
let registry = TeamRegistry::new();
|
||||
let t1 = registry.create("Team A", vec![]);
|
||||
let t2 = registry.create("Team B", vec![]);
|
||||
|
||||
let all = registry.list();
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let deleted = registry.delete(&t1.team_id).expect("delete should succeed");
|
||||
assert_eq!(deleted.status, TeamStatus::Deleted);
|
||||
|
||||
// Team is still listable (soft delete)
|
||||
let still_there = registry.get(&t1.team_id).unwrap();
|
||||
assert_eq!(still_there.status, TeamStatus::Deleted);
|
||||
|
||||
// Hard remove
|
||||
registry.remove(&t2.team_id);
|
||||
assert_eq!(registry.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_missing_team_operations() {
|
||||
let registry = TeamRegistry::new();
|
||||
assert!(registry.delete("nonexistent").is_err());
|
||||
assert!(registry.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
// ── Cron tests ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn creates_and_retrieves_cron() {
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("0 * * * *", "Check status", Some("hourly check"));
|
||||
assert_eq!(entry.schedule, "0 * * * *");
|
||||
assert_eq!(entry.prompt, "Check status");
|
||||
assert!(entry.enabled);
|
||||
assert_eq!(entry.run_count, 0);
|
||||
assert!(entry.last_run_at.is_none());
|
||||
|
||||
let fetched = registry.get(&entry.cron_id).expect("cron should exist");
|
||||
assert_eq!(fetched.cron_id, entry.cron_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lists_with_enabled_filter() {
|
||||
let registry = CronRegistry::new();
|
||||
let c1 = registry.create("* * * * *", "Task 1", None);
|
||||
let c2 = registry.create("0 * * * *", "Task 2", None);
|
||||
registry
|
||||
.disable(&c1.cron_id)
|
||||
.expect("disable should succeed");
|
||||
|
||||
let all = registry.list(false);
|
||||
assert_eq!(all.len(), 2);
|
||||
|
||||
let enabled_only = registry.list(true);
|
||||
assert_eq!(enabled_only.len(), 1);
|
||||
assert_eq!(enabled_only[0].cron_id, c2.cron_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deletes_cron_entry() {
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("* * * * *", "To delete", None);
|
||||
let deleted = registry
|
||||
.delete(&entry.cron_id)
|
||||
.expect("delete should succeed");
|
||||
assert_eq!(deleted.cron_id, entry.cron_id);
|
||||
assert!(registry.get(&entry.cron_id).is_none());
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn records_cron_runs() {
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("*/5 * * * *", "Recurring", None);
|
||||
registry.record_run(&entry.cron_id).unwrap();
|
||||
registry.record_run(&entry.cron_id).unwrap();
|
||||
|
||||
let fetched = registry.get(&entry.cron_id).unwrap();
|
||||
assert_eq!(fetched.run_count, 2);
|
||||
assert!(fetched.last_run_at.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_missing_cron_operations() {
|
||||
let registry = CronRegistry::new();
|
||||
assert!(registry.delete("nonexistent").is_err());
|
||||
assert!(registry.disable("nonexistent").is_err());
|
||||
assert!(registry.record_run("nonexistent").is_err());
|
||||
assert!(registry.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_status_display_all_variants() {
|
||||
// given
|
||||
let cases = [
|
||||
(TeamStatus::Created, "created"),
|
||||
(TeamStatus::Running, "running"),
|
||||
(TeamStatus::Completed, "completed"),
|
||||
(TeamStatus::Deleted, "deleted"),
|
||||
];
|
||||
|
||||
// when
|
||||
let rendered: Vec<_> = cases
|
||||
.into_iter()
|
||||
.map(|(status, expected)| (status.to_string(), expected))
|
||||
.collect();
|
||||
|
||||
// then
|
||||
assert_eq!(
|
||||
rendered,
|
||||
vec![
|
||||
("created".to_string(), "created"),
|
||||
("running".to_string(), "running"),
|
||||
("completed".to_string(), "completed"),
|
||||
("deleted".to_string(), "deleted"),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_team_registry_is_empty() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let teams = registry.list();
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(teams.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_remove_nonexistent_returns_none() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let removed = registry.remove("missing");
|
||||
|
||||
// then
|
||||
assert!(removed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn team_len_transitions() {
|
||||
// given
|
||||
let registry = TeamRegistry::new();
|
||||
|
||||
// when
|
||||
let alpha = registry.create("Alpha", vec![]);
|
||||
let beta = registry.create("Beta", vec![]);
|
||||
let after_create = registry.len();
|
||||
registry.remove(&alpha.team_id);
|
||||
let after_first_remove = registry.len();
|
||||
registry.remove(&beta.team_id);
|
||||
|
||||
// then
|
||||
assert_eq!(after_create, 2);
|
||||
assert_eq!(after_first_remove, 1);
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_list_all_disabled_returns_empty_for_enabled_only() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let first = registry.create("* * * * *", "Task 1", None);
|
||||
let second = registry.create("0 * * * *", "Task 2", None);
|
||||
registry
|
||||
.disable(&first.cron_id)
|
||||
.expect("disable should succeed");
|
||||
registry
|
||||
.disable(&second.cron_id)
|
||||
.expect("disable should succeed");
|
||||
|
||||
// when
|
||||
let enabled_only = registry.list(true);
|
||||
let all_entries = registry.list(false);
|
||||
|
||||
// then
|
||||
assert!(enabled_only.is_empty());
|
||||
assert_eq!(all_entries.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_create_without_description() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
|
||||
// when
|
||||
let entry = registry.create("*/15 * * * *", "Check health", None);
|
||||
|
||||
// then
|
||||
assert!(entry.cron_id.starts_with("cron_"));
|
||||
assert_eq!(entry.description, None);
|
||||
assert!(entry.enabled);
|
||||
assert_eq!(entry.run_count, 0);
|
||||
assert_eq!(entry.last_run_at, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_cron_registry_is_empty() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
|
||||
// when
|
||||
let enabled_only = registry.list(true);
|
||||
let all_entries = registry.list(false);
|
||||
|
||||
// then
|
||||
assert!(registry.is_empty());
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(enabled_only.is_empty());
|
||||
assert!(all_entries.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_record_run_updates_timestamp_and_counter() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("*/5 * * * *", "Recurring", None);
|
||||
|
||||
// when
|
||||
registry
|
||||
.record_run(&entry.cron_id)
|
||||
.expect("first run should succeed");
|
||||
registry
|
||||
.record_run(&entry.cron_id)
|
||||
.expect("second run should succeed");
|
||||
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
|
||||
|
||||
// then
|
||||
assert_eq!(fetched.run_count, 2);
|
||||
assert!(fetched.last_run_at.is_some());
|
||||
assert!(fetched.updated_at >= entry.updated_at);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cron_disable_updates_timestamp() {
|
||||
// given
|
||||
let registry = CronRegistry::new();
|
||||
let entry = registry.create("0 0 * * *", "Nightly", None);
|
||||
|
||||
// when
|
||||
registry
|
||||
.disable(&entry.cron_id)
|
||||
.expect("disable should succeed");
|
||||
let fetched = registry.get(&entry.cron_id).expect("entry should exist");
|
||||
|
||||
// then
|
||||
assert!(!fetched.enabled);
|
||||
assert!(fetched.updated_at >= entry.updated_at);
|
||||
}
|
||||
}
|
||||
299
crates/runtime/src/trust_resolver.rs
Normal file
299
crates/runtime/src/trust_resolver.rs
Normal file
@ -0,0 +1,299 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const TRUST_PROMPT_CUES: &[&str] = &[
|
||||
"do you trust the files in this folder",
|
||||
"trust the files in this folder",
|
||||
"trust this folder",
|
||||
"allow and continue",
|
||||
"yes, proceed",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TrustPolicy {
|
||||
AutoTrust,
|
||||
RequireApproval,
|
||||
Deny,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TrustEvent {
|
||||
TrustRequired { cwd: String },
|
||||
TrustResolved { cwd: String, policy: TrustPolicy },
|
||||
TrustDenied { cwd: String, reason: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct TrustConfig {
|
||||
allowlisted: Vec<PathBuf>,
|
||||
denied: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl TrustConfig {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_allowlisted(mut self, path: impl Into<PathBuf>) -> Self {
|
||||
self.allowlisted.push(path.into());
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_denied(mut self, path: impl Into<PathBuf>) -> Self {
|
||||
self.denied.push(path.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum TrustDecision {
|
||||
NotRequired,
|
||||
Required {
|
||||
policy: TrustPolicy,
|
||||
events: Vec<TrustEvent>,
|
||||
},
|
||||
}
|
||||
|
||||
impl TrustDecision {
|
||||
#[must_use]
|
||||
pub fn policy(&self) -> Option<TrustPolicy> {
|
||||
match self {
|
||||
Self::NotRequired => None,
|
||||
Self::Required { policy, .. } => Some(*policy),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn events(&self) -> &[TrustEvent] {
|
||||
match self {
|
||||
Self::NotRequired => &[],
|
||||
Self::Required { events, .. } => events,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TrustResolver {
|
||||
config: TrustConfig,
|
||||
}
|
||||
|
||||
impl TrustResolver {
|
||||
#[must_use]
|
||||
pub fn new(config: TrustConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn resolve(&self, cwd: &str, screen_text: &str) -> TrustDecision {
|
||||
if !detect_trust_prompt(screen_text) {
|
||||
return TrustDecision::NotRequired;
|
||||
}
|
||||
|
||||
let mut events = vec![TrustEvent::TrustRequired {
|
||||
cwd: cwd.to_owned(),
|
||||
}];
|
||||
|
||||
if let Some(matched_root) = self
|
||||
.config
|
||||
.denied
|
||||
.iter()
|
||||
.find(|root| path_matches(cwd, root))
|
||||
{
|
||||
let reason = format!("cwd matches denied trust root: {}", matched_root.display());
|
||||
events.push(TrustEvent::TrustDenied {
|
||||
cwd: cwd.to_owned(),
|
||||
reason,
|
||||
});
|
||||
return TrustDecision::Required {
|
||||
policy: TrustPolicy::Deny,
|
||||
events,
|
||||
};
|
||||
}
|
||||
|
||||
if self
|
||||
.config
|
||||
.allowlisted
|
||||
.iter()
|
||||
.any(|root| path_matches(cwd, root))
|
||||
{
|
||||
events.push(TrustEvent::TrustResolved {
|
||||
cwd: cwd.to_owned(),
|
||||
policy: TrustPolicy::AutoTrust,
|
||||
});
|
||||
return TrustDecision::Required {
|
||||
policy: TrustPolicy::AutoTrust,
|
||||
events,
|
||||
};
|
||||
}
|
||||
|
||||
TrustDecision::Required {
|
||||
policy: TrustPolicy::RequireApproval,
|
||||
events,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn trusts(&self, cwd: &str) -> bool {
|
||||
!self
|
||||
.config
|
||||
.denied
|
||||
.iter()
|
||||
.any(|root| path_matches(cwd, root))
|
||||
&& self
|
||||
.config
|
||||
.allowlisted
|
||||
.iter()
|
||||
.any(|root| path_matches(cwd, root))
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn detect_trust_prompt(screen_text: &str) -> bool {
|
||||
let lowered = screen_text.to_ascii_lowercase();
|
||||
TRUST_PROMPT_CUES
|
||||
.iter()
|
||||
.any(|needle| lowered.contains(needle))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn path_matches_trusted_root(cwd: &str, trusted_root: &str) -> bool {
|
||||
path_matches(cwd, &normalize_path(Path::new(trusted_root)))
|
||||
}
|
||||
|
||||
fn path_matches(candidate: &str, root: &Path) -> bool {
|
||||
let candidate = normalize_path(Path::new(candidate));
|
||||
let root = normalize_path(root);
|
||||
candidate == root || candidate.starts_with(&root)
|
||||
}
|
||||
|
||||
fn normalize_path(path: &Path) -> PathBuf {
|
||||
std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
detect_trust_prompt, path_matches_trusted_root, TrustConfig, TrustDecision, TrustEvent,
|
||||
TrustPolicy, TrustResolver,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn detects_known_trust_prompt_copy() {
|
||||
// given
|
||||
let screen_text = "Do you trust the files in this folder?\n1. Yes, proceed\n2. No";
|
||||
|
||||
// when
|
||||
let detected = detect_trust_prompt(screen_text);
|
||||
|
||||
// then
|
||||
assert!(detected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn does_not_emit_events_when_prompt_is_absent() {
|
||||
// given
|
||||
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
|
||||
|
||||
// when
|
||||
let decision = resolver.resolve("/tmp/worktrees/repo-a", "Ready for your input\n>");
|
||||
|
||||
// then
|
||||
assert_eq!(decision, TrustDecision::NotRequired);
|
||||
assert_eq!(decision.events(), &[]);
|
||||
assert_eq!(decision.policy(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_trusts_allowlisted_cwd_after_prompt_detection() {
|
||||
// given
|
||||
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
|
||||
|
||||
// when
|
||||
let decision = resolver.resolve(
|
||||
"/tmp/worktrees/repo-a",
|
||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
||||
);
|
||||
|
||||
// then
|
||||
assert_eq!(decision.policy(), Some(TrustPolicy::AutoTrust));
|
||||
assert_eq!(
|
||||
decision.events(),
|
||||
&[
|
||||
TrustEvent::TrustRequired {
|
||||
cwd: "/tmp/worktrees/repo-a".to_string(),
|
||||
},
|
||||
TrustEvent::TrustResolved {
|
||||
cwd: "/tmp/worktrees/repo-a".to_string(),
|
||||
policy: TrustPolicy::AutoTrust,
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn requires_approval_for_unknown_cwd_after_prompt_detection() {
|
||||
// given
|
||||
let resolver = TrustResolver::new(TrustConfig::new().with_allowlisted("/tmp/worktrees"));
|
||||
|
||||
// when
|
||||
let decision = resolver.resolve(
|
||||
"/tmp/other/repo-b",
|
||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
||||
);
|
||||
|
||||
// then
|
||||
assert_eq!(decision.policy(), Some(TrustPolicy::RequireApproval));
|
||||
assert_eq!(
|
||||
decision.events(),
|
||||
&[TrustEvent::TrustRequired {
|
||||
cwd: "/tmp/other/repo-b".to_string(),
|
||||
}]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn denied_root_takes_precedence_over_allowlist() {
|
||||
// given
|
||||
let resolver = TrustResolver::new(
|
||||
TrustConfig::new()
|
||||
.with_allowlisted("/tmp/worktrees")
|
||||
.with_denied("/tmp/worktrees/repo-c"),
|
||||
);
|
||||
|
||||
// when
|
||||
let decision = resolver.resolve(
|
||||
"/tmp/worktrees/repo-c",
|
||||
"Do you trust the files in this folder?\n1. Yes, proceed\n2. No",
|
||||
);
|
||||
|
||||
// then
|
||||
assert_eq!(decision.policy(), Some(TrustPolicy::Deny));
|
||||
assert_eq!(
|
||||
decision.events(),
|
||||
&[
|
||||
TrustEvent::TrustRequired {
|
||||
cwd: "/tmp/worktrees/repo-c".to_string(),
|
||||
},
|
||||
TrustEvent::TrustDenied {
|
||||
cwd: "/tmp/worktrees/repo-c".to_string(),
|
||||
reason: "cwd matches denied trust root: /tmp/worktrees/repo-c".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sibling_prefix_does_not_match_trusted_root() {
|
||||
// given
|
||||
let trusted_root = "/tmp/worktrees";
|
||||
let sibling_path = "/tmp/worktrees-other/repo-d";
|
||||
|
||||
// when
|
||||
let matched = path_matches_trusted_root(sibling_path, trusted_root);
|
||||
|
||||
// then
|
||||
assert!(!matched);
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,11 @@
|
||||
use crate::session::Session;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0;
|
||||
const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0;
|
||||
const DEFAULT_CACHE_CREATION_COST_PER_MILLION: f64 = 18.75;
|
||||
const DEFAULT_CACHE_READ_COST_PER_MILLION: f64 = 1.5;
|
||||
|
||||
/// Per-million-token pricing used for cost estimation.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct ModelPricing {
|
||||
pub input_cost_per_million: f64,
|
||||
@ -26,7 +26,8 @@ impl ModelPricing {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)]
|
||||
/// Token counters accumulated for a conversation turn or session.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub struct TokenUsage {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
@ -34,6 +35,7 @@ pub struct TokenUsage {
|
||||
pub cache_read_input_tokens: u32,
|
||||
}
|
||||
|
||||
/// Estimated dollar cost derived from a [`TokenUsage`] sample.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct UsageCostEstimate {
|
||||
pub input_cost_usd: f64,
|
||||
@ -52,6 +54,7 @@ impl UsageCostEstimate {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns pricing metadata for a known model alias or family.
|
||||
#[must_use]
|
||||
pub fn pricing_for_model(model: &str) -> Option<ModelPricing> {
|
||||
let normalized = model.to_ascii_lowercase();
|
||||
@ -156,10 +159,12 @@ fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
/// Formats a dollar-denominated value for CLI display.
|
||||
pub fn format_usd(amount: f64) -> String {
|
||||
format!("${amount:.4}")
|
||||
}
|
||||
|
||||
/// Aggregates token usage across a running session.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct UsageTracker {
|
||||
latest_turn: TokenUsage,
|
||||
@ -250,9 +255,9 @@ mod tests {
|
||||
let cost = usage.estimate_cost_usd();
|
||||
assert_eq!(format_usd(cost.input_cost_usd), "$15.0000");
|
||||
assert_eq!(format_usd(cost.output_cost_usd), "$37.5000");
|
||||
let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6"));
|
||||
let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514"));
|
||||
assert!(lines[0].contains("estimated_cost=$54.6750"));
|
||||
assert!(lines[0].contains("model=claude-sonnet-4-6"));
|
||||
assert!(lines[0].contains("model=claude-sonnet-4-20250514"));
|
||||
assert!(lines[1].contains("cache_read=$0.3000"));
|
||||
}
|
||||
|
||||
@ -265,7 +270,7 @@ mod tests {
|
||||
cache_read_input_tokens: 0,
|
||||
};
|
||||
|
||||
let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing");
|
||||
let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing");
|
||||
let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing");
|
||||
let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku);
|
||||
let opus_cost = usage.estimate_cost_usd_with_pricing(opus);
|
||||
@ -287,9 +292,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn reconstructs_usage_from_session_messages() {
|
||||
let session = Session {
|
||||
version: 1,
|
||||
messages: vec![ConversationMessage {
|
||||
let mut session = Session::new();
|
||||
session.messages = vec![ConversationMessage {
|
||||
role: MessageRole::Assistant,
|
||||
blocks: vec![ContentBlock::Text {
|
||||
text: "done".to_string(),
|
||||
@ -300,8 +304,7 @@ mod tests {
|
||||
cache_creation_input_tokens: 1,
|
||||
cache_read_input_tokens: 0,
|
||||
}),
|
||||
}],
|
||||
};
|
||||
}];
|
||||
|
||||
let tracker = UsageTracker::from_session(&session);
|
||||
assert_eq!(tracker.turns(), 1);
|
||||
|
||||
1180
crates/runtime/src/worker_boot.rs
Normal file
1180
crates/runtime/src/worker_boot.rs
Normal file
File diff suppressed because it is too large
Load Diff
386
crates/runtime/tests/integration_tests.rs
Normal file
386
crates/runtime/tests/integration_tests.rs
Normal file
@ -0,0 +1,386 @@
|
||||
#![allow(clippy::doc_markdown, clippy::uninlined_format_args, unused_imports)]
|
||||
//! Integration tests for cross-module wiring.
|
||||
//!
|
||||
//! These tests verify that adjacent modules in the runtime crate actually
|
||||
//! connect correctly — catching wiring gaps that unit tests miss.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use runtime::green_contract::{GreenContract, GreenContractOutcome, GreenLevel};
|
||||
use runtime::{
|
||||
apply_policy, BranchFreshness, DiffScope, LaneBlocker, LaneContext, PolicyAction,
|
||||
PolicyCondition, PolicyEngine, PolicyRule, ReconcileReason, ReviewStatus, StaleBranchAction,
|
||||
StaleBranchPolicy,
|
||||
};
|
||||
|
||||
/// stale_branch + policy_engine integration:
|
||||
/// When a branch is detected stale, does it correctly flow through
|
||||
/// PolicyCondition::StaleBranch to generate the expected action?
|
||||
#[test]
|
||||
fn stale_branch_detection_flows_into_policy_engine() {
|
||||
// given — a stale branch context (2 hours behind main, threshold is 1 hour)
|
||||
let stale_context = LaneContext::new(
|
||||
"stale-lane",
|
||||
0,
|
||||
Duration::from_secs(2 * 60 * 60), // 2 hours stale
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
false,
|
||||
);
|
||||
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"stale-merge-forward",
|
||||
PolicyCondition::StaleBranch,
|
||||
PolicyAction::MergeForward,
|
||||
10,
|
||||
)]);
|
||||
|
||||
// when
|
||||
let actions = engine.evaluate(&stale_context);
|
||||
|
||||
// then
|
||||
assert_eq!(actions, vec![PolicyAction::MergeForward]);
|
||||
}
|
||||
|
||||
/// stale_branch + policy_engine: Fresh branch does NOT trigger stale rules
|
||||
#[test]
|
||||
fn fresh_branch_does_not_trigger_stale_policy() {
|
||||
let fresh_context = LaneContext::new(
|
||||
"fresh-lane",
|
||||
0,
|
||||
Duration::from_secs(30 * 60), // 30 min stale — under 1 hour threshold
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
false,
|
||||
);
|
||||
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"stale-merge-forward",
|
||||
PolicyCondition::StaleBranch,
|
||||
PolicyAction::MergeForward,
|
||||
10,
|
||||
)]);
|
||||
|
||||
let actions = engine.evaluate(&fresh_context);
|
||||
assert!(actions.is_empty());
|
||||
}
|
||||
|
||||
/// green_contract + policy_engine integration:
|
||||
/// A lane that meets its green contract should be mergeable
|
||||
#[test]
|
||||
fn green_contract_satisfied_allows_merge() {
|
||||
let contract = GreenContract::new(GreenLevel::Workspace);
|
||||
let satisfied = contract.is_satisfied_by(GreenLevel::Workspace);
|
||||
assert!(satisfied);
|
||||
|
||||
let exceeded = contract.is_satisfied_by(GreenLevel::MergeReady);
|
||||
assert!(exceeded);
|
||||
|
||||
let insufficient = contract.is_satisfied_by(GreenLevel::Package);
|
||||
assert!(!insufficient);
|
||||
}
|
||||
|
||||
/// green_contract + policy_engine:
|
||||
/// Lane with green level below contract requirement gets blocked
|
||||
#[test]
|
||||
fn green_contract_unsatisfied_blocks_merge() {
|
||||
let context = LaneContext::new(
|
||||
"partial-green-lane",
|
||||
1, // GreenLevel::Package as u8
|
||||
Duration::from_secs(0),
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Pending,
|
||||
DiffScope::Full,
|
||||
false,
|
||||
);
|
||||
|
||||
// This is a conceptual test — we need a way to express "requires workspace green"
|
||||
// Currently LaneContext has raw green_level: u8, not a contract
|
||||
// For now we just verify the policy condition works
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"workspace-green-required",
|
||||
PolicyCondition::GreenAt { level: 3 }, // GreenLevel::Workspace
|
||||
PolicyAction::MergeToDev,
|
||||
10,
|
||||
)]);
|
||||
|
||||
let actions = engine.evaluate(&context);
|
||||
assert!(actions.is_empty()); // level 1 < 3, so no merge
|
||||
}
|
||||
|
||||
/// reconciliation + policy_engine integration:
|
||||
/// A reconciled lane should be handled by reconcile rules, not generic closeout
|
||||
#[test]
|
||||
fn reconciled_lane_matches_reconcile_condition() {
|
||||
let context = LaneContext::reconciled("reconciled-lane");
|
||||
|
||||
let engine = PolicyEngine::new(vec![
|
||||
PolicyRule::new(
|
||||
"reconcile-first",
|
||||
PolicyCondition::LaneReconciled,
|
||||
PolicyAction::Reconcile {
|
||||
reason: ReconcileReason::AlreadyMerged,
|
||||
},
|
||||
5,
|
||||
),
|
||||
PolicyRule::new(
|
||||
"generic-closeout",
|
||||
PolicyCondition::LaneCompleted,
|
||||
PolicyAction::CloseoutLane,
|
||||
30,
|
||||
),
|
||||
]);
|
||||
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// Both rules fire — reconcile (priority 5) first, then closeout (priority 30)
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![
|
||||
PolicyAction::Reconcile {
|
||||
reason: ReconcileReason::AlreadyMerged,
|
||||
},
|
||||
PolicyAction::CloseoutLane,
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
/// stale_branch module: apply_policy generates correct actions
|
||||
#[test]
|
||||
fn stale_branch_apply_policy_produces_rebase_action() {
|
||||
let stale = BranchFreshness::Stale {
|
||||
commits_behind: 5,
|
||||
missing_fixes: vec!["fix-123".to_string()],
|
||||
};
|
||||
|
||||
let action = apply_policy(&stale, StaleBranchPolicy::AutoRebase);
|
||||
assert_eq!(action, StaleBranchAction::Rebase);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_branch_apply_policy_produces_merge_forward_action() {
|
||||
let stale = BranchFreshness::Stale {
|
||||
commits_behind: 3,
|
||||
missing_fixes: vec![],
|
||||
};
|
||||
|
||||
let action = apply_policy(&stale, StaleBranchPolicy::AutoMergeForward);
|
||||
assert_eq!(action, StaleBranchAction::MergeForward);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_branch_apply_policy_warn_only() {
|
||||
let stale = BranchFreshness::Stale {
|
||||
commits_behind: 2,
|
||||
missing_fixes: vec!["fix-456".to_string()],
|
||||
};
|
||||
|
||||
let action = apply_policy(&stale, StaleBranchPolicy::WarnOnly);
|
||||
match action {
|
||||
StaleBranchAction::Warn { message } => {
|
||||
assert!(message.contains("2 commit(s) behind main"));
|
||||
assert!(message.contains("fix-456"));
|
||||
}
|
||||
_ => panic!("expected Warn action, got {:?}", action),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stale_branch_fresh_produces_noop() {
|
||||
let fresh = BranchFreshness::Fresh;
|
||||
let action = apply_policy(&fresh, StaleBranchPolicy::AutoRebase);
|
||||
assert_eq!(action, StaleBranchAction::Noop);
|
||||
}
|
||||
|
||||
/// Combined flow: stale detection + policy + action
|
||||
#[test]
|
||||
fn end_to_end_stale_lane_gets_merge_forward_action() {
|
||||
// Simulating what a harness would do:
|
||||
// 1. Detect branch freshness
|
||||
// 2. Build lane context from freshness + other signals
|
||||
// 3. Run policy engine
|
||||
// 4. Return actions
|
||||
|
||||
// given: detected stale state
|
||||
let _freshness = BranchFreshness::Stale {
|
||||
commits_behind: 5,
|
||||
missing_fixes: vec!["fix-123".to_string()],
|
||||
};
|
||||
|
||||
// when: build context and evaluate policy
|
||||
let context = LaneContext::new(
|
||||
"lane-9411",
|
||||
3, // Workspace green
|
||||
Duration::from_secs(5 * 60 * 60), // 5 hours stale, definitely over threshold
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Approved,
|
||||
DiffScope::Scoped,
|
||||
false,
|
||||
);
|
||||
|
||||
let engine = PolicyEngine::new(vec![
|
||||
// Priority 5: Check if stale first
|
||||
PolicyRule::new(
|
||||
"auto-merge-forward-if-stale-and-approved",
|
||||
PolicyCondition::And(vec![
|
||||
PolicyCondition::StaleBranch,
|
||||
PolicyCondition::ReviewPassed,
|
||||
]),
|
||||
PolicyAction::MergeForward,
|
||||
5,
|
||||
),
|
||||
// Priority 10: Normal stale handling
|
||||
PolicyRule::new(
|
||||
"stale-warning",
|
||||
PolicyCondition::StaleBranch,
|
||||
PolicyAction::Notify {
|
||||
channel: "#build-status".to_string(),
|
||||
},
|
||||
10,
|
||||
),
|
||||
]);
|
||||
|
||||
let actions = engine.evaluate(&context);
|
||||
|
||||
// then: both rules should fire (stale + approved matches both)
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![
|
||||
PolicyAction::MergeForward,
|
||||
PolicyAction::Notify {
|
||||
channel: "#build-status".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
/// Fresh branch with approved review should merge (not stale-blocked)
|
||||
#[test]
|
||||
fn fresh_approved_lane_gets_merge_action() {
|
||||
let context = LaneContext::new(
|
||||
"fresh-approved-lane",
|
||||
3, // Workspace green
|
||||
Duration::from_secs(30 * 60), // 30 min — under 1 hour threshold = fresh
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Approved,
|
||||
DiffScope::Scoped,
|
||||
false,
|
||||
);
|
||||
|
||||
let engine = PolicyEngine::new(vec![PolicyRule::new(
|
||||
"merge-if-green-approved-not-stale",
|
||||
PolicyCondition::And(vec![
|
||||
PolicyCondition::GreenAt { level: 3 },
|
||||
PolicyCondition::ReviewPassed,
|
||||
// NOT PolicyCondition::StaleBranch — fresh lanes bypass this
|
||||
]),
|
||||
PolicyAction::MergeToDev,
|
||||
5,
|
||||
)]);
|
||||
|
||||
let actions = engine.evaluate(&context);
|
||||
assert_eq!(actions, vec![PolicyAction::MergeToDev]);
|
||||
}
|
||||
|
||||
/// worker_boot + recovery_recipes + policy_engine integration:
|
||||
/// When a session completes with a provider failure, does the worker
|
||||
/// status transition trigger the correct recovery recipe, and does
|
||||
/// the resulting recovery state feed into policy decisions?
|
||||
#[test]
|
||||
fn worker_provider_failure_flows_through_recovery_to_policy() {
|
||||
use runtime::recovery_recipes::{
|
||||
attempt_recovery, FailureScenario, RecoveryContext, RecoveryResult, RecoveryStep,
|
||||
};
|
||||
use runtime::worker_boot::{WorkerFailureKind, WorkerRegistry, WorkerStatus};
|
||||
|
||||
// given — a worker that encounters a provider failure during session completion
|
||||
let registry = WorkerRegistry::new();
|
||||
let worker = registry.create("/tmp/repo-recovery-test", &[], true);
|
||||
|
||||
// Worker reaches ready state
|
||||
registry
|
||||
.observe(&worker.worker_id, "Ready for your input\n>")
|
||||
.expect("ready observe should succeed");
|
||||
registry
|
||||
.send_prompt(&worker.worker_id, Some("Run analysis"))
|
||||
.expect("prompt send should succeed");
|
||||
|
||||
// Session completes with provider failure (finish="unknown", tokens=0)
|
||||
let failed_worker = registry
|
||||
.observe_completion(&worker.worker_id, "unknown", 0)
|
||||
.expect("completion observe should succeed");
|
||||
assert_eq!(failed_worker.status, WorkerStatus::Failed);
|
||||
let failure = failed_worker
|
||||
.last_error
|
||||
.expect("worker should have recorded error");
|
||||
assert_eq!(failure.kind, WorkerFailureKind::Provider);
|
||||
|
||||
// Bridge: WorkerFailureKind -> FailureScenario
|
||||
let scenario = FailureScenario::from_worker_failure_kind(failure.kind);
|
||||
assert_eq!(scenario, FailureScenario::ProviderFailure);
|
||||
|
||||
// Recovery recipe lookup and execution
|
||||
let mut ctx = RecoveryContext::new();
|
||||
let result = attempt_recovery(&scenario, &mut ctx);
|
||||
|
||||
// then — recovery should recommend RestartWorker step
|
||||
assert!(
|
||||
matches!(result, RecoveryResult::Recovered { steps_taken: 1 }),
|
||||
"provider failure should recover via single RestartWorker step, got: {result:?}"
|
||||
);
|
||||
assert!(
|
||||
ctx.events().iter().any(|e| {
|
||||
matches!(
|
||||
e,
|
||||
runtime::recovery_recipes::RecoveryEvent::RecoveryAttempted {
|
||||
result: RecoveryResult::Recovered { steps_taken: 1 },
|
||||
..
|
||||
}
|
||||
)
|
||||
}),
|
||||
"recovery should emit structured attempt event"
|
||||
);
|
||||
|
||||
// Policy integration: recovery success + green status = merge-ready
|
||||
// (Simulating the policy check that would happen after successful recovery)
|
||||
let recovery_success = matches!(result, RecoveryResult::Recovered { .. });
|
||||
let green_level = 3; // Workspace green
|
||||
let not_stale = Duration::from_secs(30 * 60); // 30 min — fresh
|
||||
|
||||
let post_recovery_context = LaneContext::new(
|
||||
"recovered-lane",
|
||||
green_level,
|
||||
not_stale,
|
||||
LaneBlocker::None,
|
||||
ReviewStatus::Approved,
|
||||
DiffScope::Scoped,
|
||||
false,
|
||||
);
|
||||
|
||||
let policy_engine = PolicyEngine::new(vec![
|
||||
// Rule: if recovered from failure + green + approved -> merge
|
||||
PolicyRule::new(
|
||||
"merge-after-successful-recovery",
|
||||
PolicyCondition::And(vec![
|
||||
PolicyCondition::GreenAt { level: 3 },
|
||||
PolicyCondition::ReviewPassed,
|
||||
]),
|
||||
PolicyAction::MergeToDev,
|
||||
10,
|
||||
),
|
||||
]);
|
||||
|
||||
// Recovery success is a pre-condition; policy evaluates post-recovery context
|
||||
assert!(
|
||||
recovery_success,
|
||||
"recovery must succeed for lane to proceed"
|
||||
);
|
||||
let actions = policy_engine.evaluate(&post_recovery_context);
|
||||
assert_eq!(
|
||||
actions,
|
||||
vec![PolicyAction::MergeToDev],
|
||||
"post-recovery green+approved lane should be merge-ready"
|
||||
);
|
||||
}
|
||||
@ -15,12 +15,20 @@ commands = { path = "../commands" }
|
||||
compat-harness = { path = "../compat-harness" }
|
||||
crossterm = "0.28"
|
||||
pulldown-cmark = "0.13"
|
||||
plugins = { path = "../plugins" }
|
||||
rustyline = "15"
|
||||
runtime = { path = "../runtime" }
|
||||
serde_json = "1"
|
||||
plugins = { path = "../plugins" }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json.workspace = true
|
||||
syntect = "5"
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "time"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "signal", "time"] }
|
||||
tools = { path = "../tools" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
mock-anthropic-service = { path = "../mock-anthropic-service" }
|
||||
serde_json.workspace = true
|
||||
tokio = { version = "1", features = ["rt-multi-thread"] }
|
||||
|
||||
|
||||
38
crates/rusty-claude-cli/build.rs
Normal file
38
crates/rusty-claude-cli/build.rs
Normal file
@ -0,0 +1,38 @@
|
||||
use std::env;
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
// Get git SHA (short hash)
|
||||
let git_sha = Command::new("git")
|
||||
.args(["rev-parse", "--short", "HEAD"])
|
||||
.output()
|
||||
.ok()
|
||||
.and_then(|output| {
|
||||
if output.status.success() {
|
||||
String::from_utf8(output.stdout).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map_or_else(|| "unknown".to_string(), |s| s.trim().to_string());
|
||||
|
||||
println!("cargo:rustc-env=GIT_SHA={git_sha}");
|
||||
|
||||
// TARGET is always set by Cargo during build
|
||||
let target = env::var("TARGET").unwrap_or_else(|_| "unknown".to_string());
|
||||
println!("cargo:rustc-env=TARGET={target}");
|
||||
|
||||
// Build date from SOURCE_DATE_EPOCH (reproducible builds) or current date.
|
||||
let build_date = std::env::var("SOURCE_DATE_EPOCH")
|
||||
.ok()
|
||||
.and_then(|epoch| epoch.parse::<i64>().ok())
|
||||
.map_or_else(
|
||||
|| "unknown".to_string(),
|
||||
|_| std::env::var("BUILD_DATE").unwrap_or_else(|_| "unknown".to_string()),
|
||||
);
|
||||
println!("cargo:rustc-env=BUILD_DATE={build_date}");
|
||||
|
||||
// Rerun if git state changes
|
||||
println!("cargo:rerun-if-changed=.git/HEAD");
|
||||
println!("cargo:rerun-if-changed=.git/refs");
|
||||
}
|
||||
@ -1,398 +0,0 @@
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::args::{OutputFormat, PermissionMode};
|
||||
use crate::input::{LineEditor, ReadOutcome};
|
||||
use crate::render::{Spinner, TerminalRenderer};
|
||||
use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionConfig {
|
||||
pub model: String,
|
||||
pub permission_mode: PermissionMode,
|
||||
pub config: Option<PathBuf>,
|
||||
pub output_format: OutputFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SessionState {
|
||||
pub turns: usize,
|
||||
pub compacted_messages: usize,
|
||||
pub last_model: String,
|
||||
pub last_usage: UsageSummary,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
#[must_use]
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
turns: 0,
|
||||
compacted_messages: 0,
|
||||
last_model: model.into(),
|
||||
last_usage: UsageSummary::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CommandResult {
|
||||
Continue,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum SlashCommand {
|
||||
Help,
|
||||
Status,
|
||||
Compact,
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
impl SlashCommand {
|
||||
#[must_use]
|
||||
pub fn parse(input: &str) -> Option<Self> {
|
||||
let trimmed = input.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let command = trimmed
|
||||
.trim_start_matches('/')
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.unwrap_or_default();
|
||||
Some(match command {
|
||||
"help" => Self::Help,
|
||||
"status" => Self::Status,
|
||||
"compact" => Self::Compact,
|
||||
other => Self::Unknown(other.to_string()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct SlashCommandHandler {
|
||||
command: SlashCommand,
|
||||
summary: &'static str,
|
||||
}
|
||||
|
||||
const SLASH_COMMAND_HANDLERS: &[SlashCommandHandler] = &[
|
||||
SlashCommandHandler {
|
||||
command: SlashCommand::Help,
|
||||
summary: "Show command help",
|
||||
},
|
||||
SlashCommandHandler {
|
||||
command: SlashCommand::Status,
|
||||
summary: "Show current session status",
|
||||
},
|
||||
SlashCommandHandler {
|
||||
command: SlashCommand::Compact,
|
||||
summary: "Compact local session history",
|
||||
},
|
||||
];
|
||||
|
||||
pub struct CliApp {
|
||||
config: SessionConfig,
|
||||
renderer: TerminalRenderer,
|
||||
state: SessionState,
|
||||
conversation_client: ConversationClient,
|
||||
conversation_history: Vec<ConversationMessage>,
|
||||
}
|
||||
|
||||
impl CliApp {
|
||||
pub fn new(config: SessionConfig) -> Result<Self, RuntimeError> {
|
||||
let state = SessionState::new(config.model.clone());
|
||||
let conversation_client = ConversationClient::from_env(config.model.clone())?;
|
||||
Ok(Self {
|
||||
config,
|
||||
renderer: TerminalRenderer::new(),
|
||||
state,
|
||||
conversation_client,
|
||||
conversation_history: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run_repl(&mut self) -> io::Result<()> {
|
||||
let mut editor = LineEditor::new("› ", Vec::new());
|
||||
println!("Rusty Claude CLI interactive mode");
|
||||
println!("Type /help for commands. Shift+Enter or Ctrl+J inserts a newline.");
|
||||
|
||||
loop {
|
||||
match editor.read_line()? {
|
||||
ReadOutcome::Submit(input) => {
|
||||
if input.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
self.handle_submission(&input, &mut io::stdout())?;
|
||||
}
|
||||
ReadOutcome::Cancel => continue,
|
||||
ReadOutcome::Exit => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn run_prompt(&mut self, prompt: &str, out: &mut impl Write) -> io::Result<()> {
|
||||
self.render_response(prompt, out)
|
||||
}
|
||||
|
||||
pub fn handle_submission(
|
||||
&mut self,
|
||||
input: &str,
|
||||
out: &mut impl Write,
|
||||
) -> io::Result<CommandResult> {
|
||||
if let Some(command) = SlashCommand::parse(input) {
|
||||
return self.dispatch_slash_command(command, out);
|
||||
}
|
||||
|
||||
self.state.turns += 1;
|
||||
self.render_response(input, out)?;
|
||||
Ok(CommandResult::Continue)
|
||||
}
|
||||
|
||||
fn dispatch_slash_command(
|
||||
&mut self,
|
||||
command: SlashCommand,
|
||||
out: &mut impl Write,
|
||||
) -> io::Result<CommandResult> {
|
||||
match command {
|
||||
SlashCommand::Help => Self::handle_help(out),
|
||||
SlashCommand::Status => self.handle_status(out),
|
||||
SlashCommand::Compact => self.handle_compact(out),
|
||||
SlashCommand::Unknown(name) => {
|
||||
writeln!(out, "Unknown slash command: /{name}")?;
|
||||
Ok(CommandResult::Continue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_help(out: &mut impl Write) -> io::Result<CommandResult> {
|
||||
writeln!(out, "Available commands:")?;
|
||||
for handler in SLASH_COMMAND_HANDLERS {
|
||||
let name = match handler.command {
|
||||
SlashCommand::Help => "/help",
|
||||
SlashCommand::Status => "/status",
|
||||
SlashCommand::Compact => "/compact",
|
||||
SlashCommand::Unknown(_) => continue,
|
||||
};
|
||||
writeln!(out, " {name:<9} {}", handler.summary)?;
|
||||
}
|
||||
Ok(CommandResult::Continue)
|
||||
}
|
||||
|
||||
fn handle_status(&mut self, out: &mut impl Write) -> io::Result<CommandResult> {
|
||||
writeln!(
|
||||
out,
|
||||
"status: turns={} model={} permission-mode={:?} output-format={:?} last-usage={} in/{} out config={}",
|
||||
self.state.turns,
|
||||
self.state.last_model,
|
||||
self.config.permission_mode,
|
||||
self.config.output_format,
|
||||
self.state.last_usage.input_tokens,
|
||||
self.state.last_usage.output_tokens,
|
||||
self.config
|
||||
.config
|
||||
.as_ref()
|
||||
.map_or_else(|| String::from("<none>"), |path| path.display().to_string())
|
||||
)?;
|
||||
Ok(CommandResult::Continue)
|
||||
}
|
||||
|
||||
fn handle_compact(&mut self, out: &mut impl Write) -> io::Result<CommandResult> {
|
||||
self.state.compacted_messages += self.state.turns;
|
||||
self.state.turns = 0;
|
||||
self.conversation_history.clear();
|
||||
writeln!(
|
||||
out,
|
||||
"Compacted session history into a local summary ({} messages total compacted).",
|
||||
self.state.compacted_messages
|
||||
)?;
|
||||
Ok(CommandResult::Continue)
|
||||
}
|
||||
|
||||
fn handle_stream_event(
|
||||
renderer: &TerminalRenderer,
|
||||
event: StreamEvent,
|
||||
stream_spinner: &mut Spinner,
|
||||
tool_spinner: &mut Spinner,
|
||||
saw_text: &mut bool,
|
||||
turn_usage: &mut UsageSummary,
|
||||
out: &mut impl Write,
|
||||
) {
|
||||
match event {
|
||||
StreamEvent::TextDelta(delta) => {
|
||||
if !*saw_text {
|
||||
let _ =
|
||||
stream_spinner.finish("Streaming response", renderer.color_theme(), out);
|
||||
*saw_text = true;
|
||||
}
|
||||
let _ = write!(out, "{delta}");
|
||||
let _ = out.flush();
|
||||
}
|
||||
StreamEvent::ToolCallStart { name, input } => {
|
||||
if *saw_text {
|
||||
let _ = writeln!(out);
|
||||
}
|
||||
let _ = tool_spinner.tick(
|
||||
&format!("Running tool `{name}` with {input}"),
|
||||
renderer.color_theme(),
|
||||
out,
|
||||
);
|
||||
}
|
||||
StreamEvent::ToolCallResult {
|
||||
name,
|
||||
output,
|
||||
is_error,
|
||||
} => {
|
||||
let label = if is_error {
|
||||
format!("Tool `{name}` failed")
|
||||
} else {
|
||||
format!("Tool `{name}` completed")
|
||||
};
|
||||
let _ = tool_spinner.finish(&label, renderer.color_theme(), out);
|
||||
let rendered_output = format!("### Tool `{name}`\n\n```text\n{output}\n```\n");
|
||||
let _ = renderer.stream_markdown(&rendered_output, out);
|
||||
}
|
||||
StreamEvent::Usage(usage) => {
|
||||
*turn_usage = usage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_turn_output(
|
||||
&self,
|
||||
summary: &runtime::TurnSummary,
|
||||
out: &mut impl Write,
|
||||
) -> io::Result<()> {
|
||||
match self.config.output_format {
|
||||
OutputFormat::Text => {
|
||||
writeln!(
|
||||
out,
|
||||
"\nToken usage: {} input / {} output",
|
||||
self.state.last_usage.input_tokens, self.state.last_usage.output_tokens
|
||||
)?;
|
||||
}
|
||||
OutputFormat::Json => {
|
||||
writeln!(
|
||||
out,
|
||||
"{}",
|
||||
serde_json::json!({
|
||||
"message": summary.assistant_text,
|
||||
"usage": {
|
||||
"input_tokens": self.state.last_usage.input_tokens,
|
||||
"output_tokens": self.state.last_usage.output_tokens,
|
||||
}
|
||||
})
|
||||
)?;
|
||||
}
|
||||
OutputFormat::Ndjson => {
|
||||
writeln!(
|
||||
out,
|
||||
"{}",
|
||||
serde_json::json!({
|
||||
"type": "message",
|
||||
"text": summary.assistant_text,
|
||||
"usage": {
|
||||
"input_tokens": self.state.last_usage.input_tokens,
|
||||
"output_tokens": self.state.last_usage.output_tokens,
|
||||
}
|
||||
})
|
||||
)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn render_response(&mut self, input: &str, out: &mut impl Write) -> io::Result<()> {
|
||||
let mut stream_spinner = Spinner::new();
|
||||
stream_spinner.tick(
|
||||
"Opening conversation stream",
|
||||
self.renderer.color_theme(),
|
||||
out,
|
||||
)?;
|
||||
|
||||
let mut turn_usage = UsageSummary::default();
|
||||
let mut tool_spinner = Spinner::new();
|
||||
let mut saw_text = false;
|
||||
let renderer = &self.renderer;
|
||||
|
||||
let result =
|
||||
self.conversation_client
|
||||
.run_turn(&mut self.conversation_history, input, |event| {
|
||||
Self::handle_stream_event(
|
||||
renderer,
|
||||
event,
|
||||
&mut stream_spinner,
|
||||
&mut tool_spinner,
|
||||
&mut saw_text,
|
||||
&mut turn_usage,
|
||||
out,
|
||||
);
|
||||
});
|
||||
|
||||
let summary = match result {
|
||||
Ok(summary) => summary,
|
||||
Err(error) => {
|
||||
stream_spinner.fail(
|
||||
"Streaming response failed",
|
||||
self.renderer.color_theme(),
|
||||
out,
|
||||
)?;
|
||||
return Err(io::Error::other(error));
|
||||
}
|
||||
};
|
||||
self.state.last_usage = summary.usage.clone();
|
||||
if saw_text {
|
||||
writeln!(out)?;
|
||||
} else {
|
||||
stream_spinner.finish("Streaming response", self.renderer.color_theme(), out)?;
|
||||
}
|
||||
|
||||
self.write_turn_output(&summary, out)?;
|
||||
let _ = turn_usage;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::args::{OutputFormat, PermissionMode};
|
||||
|
||||
use super::{CommandResult, SessionConfig, SlashCommand};
|
||||
|
||||
#[test]
|
||||
fn parses_required_slash_commands() {
|
||||
assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help));
|
||||
assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status));
|
||||
assert_eq!(
|
||||
SlashCommand::parse("/compact now"),
|
||||
Some(SlashCommand::Compact)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn help_output_lists_commands() {
|
||||
let mut out = Vec::new();
|
||||
let result = super::CliApp::handle_help(&mut out).expect("help succeeds");
|
||||
assert_eq!(result, CommandResult::Continue);
|
||||
let output = String::from_utf8_lossy(&out);
|
||||
assert!(output.contains("/help"));
|
||||
assert!(output.contains("/status"));
|
||||
assert!(output.contains("/compact"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_state_tracks_config_values() {
|
||||
let config = SessionConfig {
|
||||
model: "claude".into(),
|
||||
permission_mode: PermissionMode::WorkspaceWrite,
|
||||
config: Some(PathBuf::from("settings.toml")),
|
||||
output_format: OutputFormat::Text,
|
||||
};
|
||||
|
||||
assert_eq!(config.model, "claude");
|
||||
assert_eq!(config.permission_mode, PermissionMode::WorkspaceWrite);
|
||||
assert_eq!(config.config, Some(PathBuf::from("settings.toml")));
|
||||
}
|
||||
}
|
||||
@ -1,102 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
|
||||
#[derive(Debug, Clone, Parser, PartialEq, Eq)]
|
||||
#[command(
|
||||
name = "rusty-claude-cli",
|
||||
version,
|
||||
about = "Rust Claude CLI prototype"
|
||||
)]
|
||||
pub struct Cli {
|
||||
#[arg(long, default_value = "claude-opus-4-6")]
|
||||
pub model: String,
|
||||
|
||||
#[arg(long, value_enum, default_value_t = PermissionMode::WorkspaceWrite)]
|
||||
pub permission_mode: PermissionMode,
|
||||
|
||||
#[arg(long)]
|
||||
pub config: Option<PathBuf>,
|
||||
|
||||
#[arg(long, value_enum, default_value_t = OutputFormat::Text)]
|
||||
pub output_format: OutputFormat,
|
||||
|
||||
#[command(subcommand)]
|
||||
pub command: Option<Command>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Subcommand, PartialEq, Eq)]
|
||||
pub enum Command {
|
||||
/// Read upstream TS sources and print extracted counts
|
||||
DumpManifests,
|
||||
/// Print the current bootstrap phase skeleton
|
||||
BootstrapPlan,
|
||||
/// Start the OAuth login flow
|
||||
Login,
|
||||
/// Clear saved OAuth credentials
|
||||
Logout,
|
||||
/// Run a non-interactive prompt and exit
|
||||
Prompt { prompt: Vec<String> },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
|
||||
pub enum PermissionMode {
|
||||
ReadOnly,
|
||||
WorkspaceWrite,
|
||||
DangerFullAccess,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
|
||||
pub enum OutputFormat {
|
||||
Text,
|
||||
Json,
|
||||
Ndjson,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use clap::Parser;
|
||||
|
||||
use super::{Cli, Command, OutputFormat, PermissionMode};
|
||||
|
||||
#[test]
|
||||
fn parses_requested_flags() {
|
||||
let cli = Cli::parse_from([
|
||||
"rusty-claude-cli",
|
||||
"--model",
|
||||
"claude-3-5-haiku",
|
||||
"--permission-mode",
|
||||
"read-only",
|
||||
"--config",
|
||||
"/tmp/config.toml",
|
||||
"--output-format",
|
||||
"ndjson",
|
||||
"prompt",
|
||||
"hello",
|
||||
"world",
|
||||
]);
|
||||
|
||||
assert_eq!(cli.model, "claude-3-5-haiku");
|
||||
assert_eq!(cli.permission_mode, PermissionMode::ReadOnly);
|
||||
assert_eq!(
|
||||
cli.config.as_deref(),
|
||||
Some(std::path::Path::new("/tmp/config.toml"))
|
||||
);
|
||||
assert_eq!(cli.output_format, OutputFormat::Ndjson);
|
||||
assert_eq!(
|
||||
cli.command,
|
||||
Some(Command::Prompt {
|
||||
prompt: vec!["hello".into(), "world".into()]
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_login_and_logout_commands() {
|
||||
let login = Cli::parse_from(["rusty-claude-cli", "login"]);
|
||||
assert_eq!(login.command, Some(Command::Login));
|
||||
|
||||
let logout = Cli::parse_from(["rusty-claude-cli", "logout"]);
|
||||
assert_eq!(logout.command, Some(Command::Logout));
|
||||
}
|
||||
}
|
||||
@ -1,15 +1,15 @@
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const STARTER_CLAUDE_JSON: &str = concat!(
|
||||
const STARTER_CLAW_JSON: &str = concat!(
|
||||
"{\n",
|
||||
" \"permissions\": {\n",
|
||||
" \"defaultMode\": \"acceptEdits\"\n",
|
||||
" \"defaultMode\": \"dontAsk\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
);
|
||||
const GITIGNORE_COMMENT: &str = "# Claude Code local artifacts";
|
||||
const GITIGNORE_ENTRIES: [&str; 2] = [".claude/settings.local.json", ".claude/sessions/"];
|
||||
const GITIGNORE_COMMENT: &str = "# Claw Code local artifacts";
|
||||
const GITIGNORE_ENTRIES: [&str; 2] = [".claw/settings.local.json", ".claw/sessions/"];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum InitStatus {
|
||||
@ -80,16 +80,16 @@ struct RepoDetection {
|
||||
pub(crate) fn initialize_repo(cwd: &Path) -> Result<InitReport, Box<dyn std::error::Error>> {
|
||||
let mut artifacts = Vec::new();
|
||||
|
||||
let claude_dir = cwd.join(".claude");
|
||||
let claw_dir = cwd.join(".claw");
|
||||
artifacts.push(InitArtifact {
|
||||
name: ".claude/",
|
||||
status: ensure_dir(&claude_dir)?,
|
||||
name: ".claw/",
|
||||
status: ensure_dir(&claw_dir)?,
|
||||
});
|
||||
|
||||
let claude_json = cwd.join(".claude.json");
|
||||
let claw_json = cwd.join(".claw.json");
|
||||
artifacts.push(InitArtifact {
|
||||
name: ".claude.json",
|
||||
status: write_file_if_missing(&claude_json, STARTER_CLAUDE_JSON)?,
|
||||
name: ".claw.json",
|
||||
status: write_file_if_missing(&claw_json, STARTER_CLAW_JSON)?,
|
||||
});
|
||||
|
||||
let gitignore = cwd.join(".gitignore");
|
||||
@ -164,7 +164,7 @@ pub(crate) fn render_init_claude_md(cwd: &Path) -> String {
|
||||
let mut lines = vec![
|
||||
"# CLAUDE.md".to_string(),
|
||||
String::new(),
|
||||
"This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.".to_string(),
|
||||
"This file provides guidance to Claw Code (clawcode.dev) when working with code in this repository.".to_string(),
|
||||
String::new(),
|
||||
];
|
||||
|
||||
@ -209,7 +209,7 @@ pub(crate) fn render_init_claude_md(cwd: &Path) -> String {
|
||||
|
||||
lines.push("## Working agreement".to_string());
|
||||
lines.push("- Prefer small, reviewable changes and keep generated bootstrap files aligned with actual repo workflows.".to_string());
|
||||
lines.push("- Keep shared defaults in `.claude.json`; reserve `.claude/settings.local.json` for machine-local overrides.".to_string());
|
||||
lines.push("- Keep shared defaults in `.claw.json`; reserve `.claw/settings.local.json` for machine-local overrides.".to_string());
|
||||
lines.push("- Do not overwrite existing `CLAUDE.md` content automatically; update it intentionally when repo workflows change.".to_string());
|
||||
lines.push(String::new());
|
||||
|
||||
@ -354,26 +354,27 @@ mod tests {
|
||||
|
||||
let report = initialize_repo(&root).expect("init should succeed");
|
||||
let rendered = report.render();
|
||||
assert!(rendered.contains(".claude/ created"));
|
||||
assert!(rendered.contains(".claude.json created"));
|
||||
assert!(rendered.contains(".claw/"));
|
||||
assert!(rendered.contains(".claw.json"));
|
||||
assert!(rendered.contains("created"));
|
||||
assert!(rendered.contains(".gitignore created"));
|
||||
assert!(rendered.contains("CLAUDE.md created"));
|
||||
assert!(root.join(".claude").is_dir());
|
||||
assert!(root.join(".claude.json").is_file());
|
||||
assert!(root.join(".claw").is_dir());
|
||||
assert!(root.join(".claw.json").is_file());
|
||||
assert!(root.join("CLAUDE.md").is_file());
|
||||
assert_eq!(
|
||||
fs::read_to_string(root.join(".claude.json")).expect("read claude json"),
|
||||
fs::read_to_string(root.join(".claw.json")).expect("read claw json"),
|
||||
concat!(
|
||||
"{\n",
|
||||
" \"permissions\": {\n",
|
||||
" \"defaultMode\": \"acceptEdits\"\n",
|
||||
" \"defaultMode\": \"dontAsk\"\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
)
|
||||
);
|
||||
let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore");
|
||||
assert!(gitignore.contains(".claude/settings.local.json"));
|
||||
assert!(gitignore.contains(".claude/sessions/"));
|
||||
assert!(gitignore.contains(".claw/settings.local.json"));
|
||||
assert!(gitignore.contains(".claw/sessions/"));
|
||||
let claude_md = fs::read_to_string(root.join("CLAUDE.md")).expect("read claude md");
|
||||
assert!(claude_md.contains("Languages: Rust."));
|
||||
assert!(claude_md.contains("cargo clippy --workspace --all-targets -- -D warnings"));
|
||||
@ -386,8 +387,7 @@ mod tests {
|
||||
let root = temp_dir();
|
||||
fs::create_dir_all(&root).expect("create root");
|
||||
fs::write(root.join("CLAUDE.md"), "custom guidance\n").expect("write existing claude md");
|
||||
fs::write(root.join(".gitignore"), ".claude/settings.local.json\n")
|
||||
.expect("write gitignore");
|
||||
fs::write(root.join(".gitignore"), ".claw/settings.local.json\n").expect("write gitignore");
|
||||
|
||||
let first = initialize_repo(&root).expect("first init should succeed");
|
||||
assert!(first
|
||||
@ -395,8 +395,9 @@ mod tests {
|
||||
.contains("CLAUDE.md skipped (already exists)"));
|
||||
let second = initialize_repo(&root).expect("second init should succeed");
|
||||
let second_rendered = second.render();
|
||||
assert!(second_rendered.contains(".claude/ skipped (already exists)"));
|
||||
assert!(second_rendered.contains(".claude.json skipped (already exists)"));
|
||||
assert!(second_rendered.contains(".claw/"));
|
||||
assert!(second_rendered.contains(".claw.json"));
|
||||
assert!(second_rendered.contains("skipped (already exists)"));
|
||||
assert!(second_rendered.contains(".gitignore skipped (already exists)"));
|
||||
assert!(second_rendered.contains("CLAUDE.md skipped (already exists)"));
|
||||
assert_eq!(
|
||||
@ -404,8 +405,8 @@ mod tests {
|
||||
"custom guidance\n"
|
||||
);
|
||||
let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore");
|
||||
assert_eq!(gitignore.matches(".claude/settings.local.json").count(), 1);
|
||||
assert_eq!(gitignore.matches(".claude/sessions/").count(), 1);
|
||||
assert_eq!(gitignore.matches(".claw/settings.local.json").count(), 1);
|
||||
assert_eq!(gitignore.matches(".claw/sessions/").count(), 1);
|
||||
|
||||
fs::remove_dir_all(root).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
@ -1,166 +1,17 @@
|
||||
use std::borrow::Cow;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::BTreeSet;
|
||||
use std::io::{self, IsTerminal, Write};
|
||||
|
||||
use crossterm::cursor::{MoveDown, MoveToColumn, MoveUp};
|
||||
use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyModifiers};
|
||||
use crossterm::queue;
|
||||
use crossterm::terminal::{disable_raw_mode, enable_raw_mode, Clear, ClearType};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct InputBuffer {
|
||||
buffer: String,
|
||||
cursor: usize,
|
||||
}
|
||||
|
||||
impl InputBuffer {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
cursor: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, ch: char) {
|
||||
self.buffer.insert(self.cursor, ch);
|
||||
self.cursor += ch.len_utf8();
|
||||
}
|
||||
|
||||
pub fn insert_newline(&mut self) {
|
||||
self.insert('\n');
|
||||
}
|
||||
|
||||
pub fn backspace(&mut self) {
|
||||
if self.cursor == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let previous = self.buffer[..self.cursor]
|
||||
.char_indices()
|
||||
.last()
|
||||
.map_or(0, |(idx, _)| idx);
|
||||
self.buffer.drain(previous..self.cursor);
|
||||
self.cursor = previous;
|
||||
}
|
||||
|
||||
pub fn move_left(&mut self) {
|
||||
if self.cursor == 0 {
|
||||
return;
|
||||
}
|
||||
self.cursor = self.buffer[..self.cursor]
|
||||
.char_indices()
|
||||
.last()
|
||||
.map_or(0, |(idx, _)| idx);
|
||||
}
|
||||
|
||||
pub fn move_right(&mut self) {
|
||||
if self.cursor >= self.buffer.len() {
|
||||
return;
|
||||
}
|
||||
if let Some(next) = self.buffer[self.cursor..].chars().next() {
|
||||
self.cursor += next.len_utf8();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn move_home(&mut self) {
|
||||
self.cursor = 0;
|
||||
}
|
||||
|
||||
pub fn move_end(&mut self) {
|
||||
self.cursor = self.buffer.len();
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[must_use]
|
||||
pub fn cursor(&self) -> usize {
|
||||
self.cursor
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
self.cursor = 0;
|
||||
}
|
||||
|
||||
pub fn replace(&mut self, value: impl Into<String>) {
|
||||
self.buffer = value.into();
|
||||
self.cursor = self.buffer.len();
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn current_command_prefix(&self) -> Option<&str> {
|
||||
if self.cursor != self.buffer.len() {
|
||||
return None;
|
||||
}
|
||||
let prefix = &self.buffer[..self.cursor];
|
||||
if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
Some(prefix)
|
||||
}
|
||||
|
||||
pub fn complete_slash_command(&mut self, candidates: &[String]) -> bool {
|
||||
let Some(prefix) = self.current_command_prefix() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let matches = candidates
|
||||
.iter()
|
||||
.filter(|candidate| candidate.starts_with(prefix))
|
||||
.map(String::as_str)
|
||||
.collect::<Vec<_>>();
|
||||
if matches.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let replacement = longest_common_prefix(&matches);
|
||||
if replacement == prefix {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.replace(replacement);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RenderedBuffer {
|
||||
lines: Vec<String>,
|
||||
cursor_row: u16,
|
||||
cursor_col: u16,
|
||||
}
|
||||
|
||||
impl RenderedBuffer {
|
||||
#[must_use]
|
||||
pub fn line_count(&self) -> usize {
|
||||
self.lines.len()
|
||||
}
|
||||
|
||||
fn write(&self, out: &mut impl Write) -> io::Result<()> {
|
||||
for (index, line) in self.lines.iter().enumerate() {
|
||||
if index > 0 {
|
||||
writeln!(out)?;
|
||||
}
|
||||
write!(out, "{line}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[must_use]
|
||||
pub fn lines(&self) -> &[String] {
|
||||
&self.lines
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[must_use]
|
||||
pub fn cursor_position(&self) -> (u16, u16) {
|
||||
(self.cursor_row, self.cursor_col)
|
||||
}
|
||||
}
|
||||
use rustyline::completion::{Completer, Pair};
|
||||
use rustyline::error::ReadlineError;
|
||||
use rustyline::highlight::{CmdKind, Highlighter};
|
||||
use rustyline::hint::Hinter;
|
||||
use rustyline::history::DefaultHistory;
|
||||
use rustyline::validate::Validator;
|
||||
use rustyline::{
|
||||
Cmd, CompletionType, Config, Context, EditMode, Editor, Helper, KeyCode, KeyEvent, Modifiers,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ReadOutcome {
|
||||
@ -169,25 +20,105 @@ pub enum ReadOutcome {
|
||||
Exit,
|
||||
}
|
||||
|
||||
struct SlashCommandHelper {
|
||||
completions: Vec<String>,
|
||||
current_line: RefCell<String>,
|
||||
}
|
||||
|
||||
impl SlashCommandHelper {
|
||||
fn new(completions: Vec<String>) -> Self {
|
||||
Self {
|
||||
completions: normalize_completions(completions),
|
||||
current_line: RefCell::new(String::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn reset_current_line(&self) {
|
||||
self.current_line.borrow_mut().clear();
|
||||
}
|
||||
|
||||
fn current_line(&self) -> String {
|
||||
self.current_line.borrow().clone()
|
||||
}
|
||||
|
||||
fn set_current_line(&self, line: &str) {
|
||||
let mut current = self.current_line.borrow_mut();
|
||||
current.clear();
|
||||
current.push_str(line);
|
||||
}
|
||||
|
||||
fn set_completions(&mut self, completions: Vec<String>) {
|
||||
self.completions = normalize_completions(completions);
|
||||
}
|
||||
}
|
||||
|
||||
impl Completer for SlashCommandHelper {
|
||||
type Candidate = Pair;
|
||||
|
||||
fn complete(
|
||||
&self,
|
||||
line: &str,
|
||||
pos: usize,
|
||||
_ctx: &Context<'_>,
|
||||
) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
|
||||
let Some(prefix) = slash_command_prefix(line, pos) else {
|
||||
return Ok((0, Vec::new()));
|
||||
};
|
||||
|
||||
let matches = self
|
||||
.completions
|
||||
.iter()
|
||||
.filter(|candidate| candidate.starts_with(prefix))
|
||||
.map(|candidate| Pair {
|
||||
display: candidate.clone(),
|
||||
replacement: candidate.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok((0, matches))
|
||||
}
|
||||
}
|
||||
|
||||
impl Hinter for SlashCommandHelper {
|
||||
type Hint = String;
|
||||
}
|
||||
|
||||
impl Highlighter for SlashCommandHelper {
|
||||
fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> {
|
||||
self.set_current_line(line);
|
||||
Cow::Borrowed(line)
|
||||
}
|
||||
|
||||
fn highlight_char(&self, line: &str, _pos: usize, _kind: CmdKind) -> bool {
|
||||
self.set_current_line(line);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Validator for SlashCommandHelper {}
|
||||
impl Helper for SlashCommandHelper {}
|
||||
|
||||
pub struct LineEditor {
|
||||
prompt: String,
|
||||
continuation_prompt: String,
|
||||
history: Vec<String>,
|
||||
history_index: Option<usize>,
|
||||
draft: Option<String>,
|
||||
completions: Vec<String>,
|
||||
editor: Editor<SlashCommandHelper, DefaultHistory>,
|
||||
}
|
||||
|
||||
impl LineEditor {
|
||||
#[must_use]
|
||||
pub fn new(prompt: impl Into<String>, completions: Vec<String>) -> Self {
|
||||
let config = Config::builder()
|
||||
.completion_type(CompletionType::List)
|
||||
.edit_mode(EditMode::Emacs)
|
||||
.build();
|
||||
let mut editor = Editor::<SlashCommandHelper, DefaultHistory>::with_config(config)
|
||||
.expect("rustyline editor should initialize");
|
||||
editor.set_helper(Some(SlashCommandHelper::new(completions)));
|
||||
editor.bind_sequence(KeyEvent(KeyCode::Char('J'), Modifiers::CTRL), Cmd::Newline);
|
||||
editor.bind_sequence(KeyEvent(KeyCode::Enter, Modifiers::SHIFT), Cmd::Newline);
|
||||
|
||||
Self {
|
||||
prompt: prompt.into(),
|
||||
continuation_prompt: String::from("> "),
|
||||
history: Vec::new(),
|
||||
history_index: None,
|
||||
draft: None,
|
||||
completions,
|
||||
editor,
|
||||
}
|
||||
}
|
||||
|
||||
@ -196,9 +127,14 @@ impl LineEditor {
|
||||
if entry.trim().is_empty() {
|
||||
return;
|
||||
}
|
||||
self.history.push(entry);
|
||||
self.history_index = None;
|
||||
self.draft = None;
|
||||
|
||||
let _ = self.editor.add_history_entry(entry);
|
||||
}
|
||||
|
||||
pub fn set_completions(&mut self, completions: Vec<String>) {
|
||||
if let Some(helper) = self.editor.helper_mut() {
|
||||
helper.set_completions(completions);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_line(&mut self) -> io::Result<ReadOutcome> {
|
||||
@ -206,43 +142,41 @@ impl LineEditor {
|
||||
return self.read_line_fallback();
|
||||
}
|
||||
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = io::stdout();
|
||||
let mut input = InputBuffer::new();
|
||||
let mut rendered_lines = 1usize;
|
||||
self.redraw(&mut stdout, &input, rendered_lines)?;
|
||||
if let Some(helper) = self.editor.helper_mut() {
|
||||
helper.reset_current_line();
|
||||
}
|
||||
|
||||
loop {
|
||||
let event = event::read()?;
|
||||
if let Event::Key(key) = event {
|
||||
match self.handle_key(key, &mut input) {
|
||||
EditorAction::Continue => {
|
||||
rendered_lines = self.redraw(&mut stdout, &input, rendered_lines)?;
|
||||
}
|
||||
EditorAction::Submit => {
|
||||
disable_raw_mode()?;
|
||||
writeln!(stdout)?;
|
||||
self.history_index = None;
|
||||
self.draft = None;
|
||||
return Ok(ReadOutcome::Submit(input.as_str().to_owned()));
|
||||
}
|
||||
EditorAction::Cancel => {
|
||||
disable_raw_mode()?;
|
||||
writeln!(stdout)?;
|
||||
self.history_index = None;
|
||||
self.draft = None;
|
||||
return Ok(ReadOutcome::Cancel);
|
||||
}
|
||||
EditorAction::Exit => {
|
||||
disable_raw_mode()?;
|
||||
writeln!(stdout)?;
|
||||
self.history_index = None;
|
||||
self.draft = None;
|
||||
return Ok(ReadOutcome::Exit);
|
||||
match self.editor.readline(&self.prompt) {
|
||||
Ok(line) => Ok(ReadOutcome::Submit(line)),
|
||||
Err(ReadlineError::Interrupted) => {
|
||||
let has_input = !self.current_line().is_empty();
|
||||
self.finish_interrupted_read()?;
|
||||
if has_input {
|
||||
Ok(ReadOutcome::Cancel)
|
||||
} else {
|
||||
Ok(ReadOutcome::Exit)
|
||||
}
|
||||
}
|
||||
Err(ReadlineError::Eof) => {
|
||||
self.finish_interrupted_read()?;
|
||||
Ok(ReadOutcome::Exit)
|
||||
}
|
||||
Err(error) => Err(io::Error::other(error)),
|
||||
}
|
||||
}
|
||||
|
||||
fn current_line(&self) -> String {
|
||||
self.editor
|
||||
.helper()
|
||||
.map_or_else(String::new, SlashCommandHelper::current_line)
|
||||
}
|
||||
|
||||
fn finish_interrupted_read(&mut self) -> io::Result<()> {
|
||||
if let Some(helper) = self.editor.helper_mut() {
|
||||
helper.reset_current_line();
|
||||
}
|
||||
let mut stdout = io::stdout();
|
||||
writeln!(stdout)
|
||||
}
|
||||
|
||||
fn read_line_fallback(&self) -> io::Result<ReadOutcome> {
|
||||
@ -261,388 +195,136 @@ impl LineEditor {
|
||||
}
|
||||
Ok(ReadOutcome::Submit(buffer))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn handle_key(&mut self, key: KeyEvent, input: &mut InputBuffer) -> EditorAction {
|
||||
match key {
|
||||
KeyEvent {
|
||||
code: KeyCode::Char('c'),
|
||||
modifiers,
|
||||
..
|
||||
} if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
if input.as_str().is_empty() {
|
||||
EditorAction::Exit
|
||||
} else {
|
||||
input.clear();
|
||||
self.history_index = None;
|
||||
self.draft = None;
|
||||
EditorAction::Cancel
|
||||
}
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Char('j'),
|
||||
modifiers,
|
||||
..
|
||||
} if modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
input.insert_newline();
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Enter,
|
||||
modifiers,
|
||||
..
|
||||
} if modifiers.contains(KeyModifiers::SHIFT) => {
|
||||
input.insert_newline();
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Enter,
|
||||
..
|
||||
} => EditorAction::Submit,
|
||||
KeyEvent {
|
||||
code: KeyCode::Backspace,
|
||||
..
|
||||
} => {
|
||||
input.backspace();
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Left,
|
||||
..
|
||||
} => {
|
||||
input.move_left();
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Right,
|
||||
..
|
||||
} => {
|
||||
input.move_right();
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Up, ..
|
||||
} => {
|
||||
self.navigate_history_up(input);
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Down,
|
||||
..
|
||||
} => {
|
||||
self.navigate_history_down(input);
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Tab, ..
|
||||
} => {
|
||||
input.complete_slash_command(&self.completions);
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Home,
|
||||
..
|
||||
} => {
|
||||
input.move_home();
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::End, ..
|
||||
} => {
|
||||
input.move_end();
|
||||
EditorAction::Continue
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Esc, ..
|
||||
} => {
|
||||
input.clear();
|
||||
self.history_index = None;
|
||||
self.draft = None;
|
||||
EditorAction::Cancel
|
||||
}
|
||||
KeyEvent {
|
||||
code: KeyCode::Char(ch),
|
||||
modifiers,
|
||||
..
|
||||
} if modifiers.is_empty() || modifiers == KeyModifiers::SHIFT => {
|
||||
input.insert(ch);
|
||||
self.history_index = None;
|
||||
self.draft = None;
|
||||
EditorAction::Continue
|
||||
}
|
||||
_ => EditorAction::Continue,
|
||||
}
|
||||
}
|
||||
|
||||
fn navigate_history_up(&mut self, input: &mut InputBuffer) {
|
||||
if self.history.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
match self.history_index {
|
||||
Some(0) => {}
|
||||
Some(index) => {
|
||||
let next_index = index - 1;
|
||||
input.replace(self.history[next_index].clone());
|
||||
self.history_index = Some(next_index);
|
||||
}
|
||||
None => {
|
||||
self.draft = Some(input.as_str().to_owned());
|
||||
let next_index = self.history.len() - 1;
|
||||
input.replace(self.history[next_index].clone());
|
||||
self.history_index = Some(next_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn navigate_history_down(&mut self, input: &mut InputBuffer) {
|
||||
let Some(index) = self.history_index else {
|
||||
return;
|
||||
};
|
||||
|
||||
if index + 1 < self.history.len() {
|
||||
let next_index = index + 1;
|
||||
input.replace(self.history[next_index].clone());
|
||||
self.history_index = Some(next_index);
|
||||
return;
|
||||
}
|
||||
|
||||
input.replace(self.draft.take().unwrap_or_default());
|
||||
self.history_index = None;
|
||||
}
|
||||
|
||||
fn redraw(
|
||||
&self,
|
||||
out: &mut impl Write,
|
||||
input: &InputBuffer,
|
||||
previous_line_count: usize,
|
||||
) -> io::Result<usize> {
|
||||
let rendered = render_buffer(&self.prompt, &self.continuation_prompt, input);
|
||||
if previous_line_count > 1 {
|
||||
queue!(out, MoveUp(saturating_u16(previous_line_count - 1)))?;
|
||||
}
|
||||
queue!(out, MoveToColumn(0), Clear(ClearType::FromCursorDown),)?;
|
||||
rendered.write(out)?;
|
||||
queue!(
|
||||
out,
|
||||
MoveUp(saturating_u16(rendered.line_count().saturating_sub(1))),
|
||||
MoveToColumn(0),
|
||||
)?;
|
||||
if rendered.cursor_row > 0 {
|
||||
queue!(out, MoveDown(rendered.cursor_row))?;
|
||||
}
|
||||
queue!(out, MoveToColumn(rendered.cursor_col))?;
|
||||
out.flush()?;
|
||||
Ok(rendered.line_count())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum EditorAction {
|
||||
Continue,
|
||||
Submit,
|
||||
Cancel,
|
||||
Exit,
|
||||
fn slash_command_prefix(line: &str, pos: usize) -> Option<&str> {
|
||||
if pos != line.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let prefix = &line[..pos];
|
||||
if !prefix.starts_with('/') {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(prefix)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn render_buffer(
|
||||
prompt: &str,
|
||||
continuation_prompt: &str,
|
||||
input: &InputBuffer,
|
||||
) -> RenderedBuffer {
|
||||
let before_cursor = &input.as_str()[..input.cursor];
|
||||
let cursor_row = saturating_u16(before_cursor.chars().filter(|ch| *ch == '\n').count());
|
||||
let cursor_line = before_cursor.rsplit('\n').next().unwrap_or_default();
|
||||
let cursor_prompt = if cursor_row == 0 {
|
||||
prompt
|
||||
} else {
|
||||
continuation_prompt
|
||||
};
|
||||
let cursor_col = saturating_u16(cursor_prompt.chars().count() + cursor_line.chars().count());
|
||||
|
||||
let mut lines = Vec::new();
|
||||
for (index, line) in input.as_str().split('\n').enumerate() {
|
||||
let prefix = if index == 0 {
|
||||
prompt
|
||||
} else {
|
||||
continuation_prompt
|
||||
};
|
||||
lines.push(format!("{prefix}{line}"));
|
||||
}
|
||||
if lines.is_empty() {
|
||||
lines.push(prompt.to_string());
|
||||
}
|
||||
|
||||
RenderedBuffer {
|
||||
lines,
|
||||
cursor_row,
|
||||
cursor_col,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn longest_common_prefix(values: &[&str]) -> String {
|
||||
let Some(first) = values.first() else {
|
||||
return String::new();
|
||||
};
|
||||
|
||||
let mut prefix = (*first).to_string();
|
||||
for value in values.iter().skip(1) {
|
||||
while !value.starts_with(&prefix) {
|
||||
prefix.pop();
|
||||
if prefix.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
prefix
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn saturating_u16(value: usize) -> u16 {
|
||||
u16::try_from(value).unwrap_or(u16::MAX)
|
||||
fn normalize_completions(completions: Vec<String>) -> Vec<String> {
|
||||
let mut seen = BTreeSet::new();
|
||||
completions
|
||||
.into_iter()
|
||||
.filter(|candidate| candidate.starts_with('/'))
|
||||
.filter(|candidate| seen.insert(candidate.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{render_buffer, InputBuffer, LineEditor};
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use super::{slash_command_prefix, LineEditor, SlashCommandHelper};
|
||||
use rustyline::completion::Completer;
|
||||
use rustyline::highlight::Highlighter;
|
||||
use rustyline::history::{DefaultHistory, History};
|
||||
use rustyline::Context;
|
||||
|
||||
fn key(code: KeyCode) -> KeyEvent {
|
||||
KeyEvent::new(code, KeyModifiers::NONE)
|
||||
#[test]
|
||||
fn extracts_terminal_slash_command_prefixes_with_arguments() {
|
||||
assert_eq!(slash_command_prefix("/he", 3), Some("/he"));
|
||||
assert_eq!(slash_command_prefix("/help me", 8), Some("/help me"));
|
||||
assert_eq!(
|
||||
slash_command_prefix("/session switch ses", 19),
|
||||
Some("/session switch ses")
|
||||
);
|
||||
assert_eq!(slash_command_prefix("hello", 5), None);
|
||||
assert_eq!(slash_command_prefix("/help", 2), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_basic_line_editing() {
|
||||
let mut input = InputBuffer::new();
|
||||
input.insert('h');
|
||||
input.insert('i');
|
||||
input.move_end();
|
||||
input.insert_newline();
|
||||
input.insert('x');
|
||||
|
||||
assert_eq!(input.as_str(), "hi\nx");
|
||||
assert_eq!(input.cursor(), 4);
|
||||
|
||||
input.move_left();
|
||||
input.backspace();
|
||||
assert_eq!(input.as_str(), "hix");
|
||||
assert_eq!(input.cursor(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn completes_unique_slash_command() {
|
||||
let mut input = InputBuffer::new();
|
||||
for ch in "/he".chars() {
|
||||
input.insert(ch);
|
||||
}
|
||||
|
||||
assert!(input.complete_slash_command(&[
|
||||
fn completes_matching_slash_commands() {
|
||||
let helper = SlashCommandHelper::new(vec![
|
||||
"/help".to_string(),
|
||||
"/hello".to_string(),
|
||||
"/status".to_string(),
|
||||
]));
|
||||
assert_eq!(input.as_str(), "/hel");
|
||||
]);
|
||||
let history = DefaultHistory::new();
|
||||
let ctx = Context::new(&history);
|
||||
let (start, matches) = helper
|
||||
.complete("/he", 3, &ctx)
|
||||
.expect("completion should work");
|
||||
|
||||
assert!(input.complete_slash_command(&["/help".to_string(), "/status".to_string()]));
|
||||
assert_eq!(input.as_str(), "/help");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_completion_when_prefix_is_not_a_slash_command() {
|
||||
let mut input = InputBuffer::new();
|
||||
for ch in "hello".chars() {
|
||||
input.insert(ch);
|
||||
}
|
||||
|
||||
assert!(!input.complete_slash_command(&["/help".to_string()]));
|
||||
assert_eq!(input.as_str(), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_navigation_restores_current_draft() {
|
||||
let mut editor = LineEditor::new("› ", vec![]);
|
||||
editor.push_history("/help");
|
||||
editor.push_history("status report");
|
||||
|
||||
let mut input = InputBuffer::new();
|
||||
for ch in "draft".chars() {
|
||||
input.insert(ch);
|
||||
}
|
||||
|
||||
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
|
||||
assert_eq!(input.as_str(), "status report");
|
||||
|
||||
let _ = editor.handle_key(key(KeyCode::Up), &mut input);
|
||||
assert_eq!(input.as_str(), "/help");
|
||||
|
||||
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
|
||||
assert_eq!(input.as_str(), "status report");
|
||||
|
||||
let _ = editor.handle_key(key(KeyCode::Down), &mut input);
|
||||
assert_eq!(input.as_str(), "draft");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tab_key_completes_from_editor_candidates() {
|
||||
let mut editor = LineEditor::new(
|
||||
"› ",
|
||||
vec![
|
||||
"/help".to_string(),
|
||||
"/status".to_string(),
|
||||
"/session".to_string(),
|
||||
],
|
||||
);
|
||||
let mut input = InputBuffer::new();
|
||||
for ch in "/st".chars() {
|
||||
input.insert(ch);
|
||||
}
|
||||
|
||||
let _ = editor.handle_key(key(KeyCode::Tab), &mut input);
|
||||
assert_eq!(input.as_str(), "/status");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_multiline_buffers_with_continuation_prompt() {
|
||||
let mut input = InputBuffer::new();
|
||||
for ch in "hello\nworld".chars() {
|
||||
if ch == '\n' {
|
||||
input.insert_newline();
|
||||
} else {
|
||||
input.insert(ch);
|
||||
}
|
||||
}
|
||||
|
||||
let rendered = render_buffer("› ", "> ", &input);
|
||||
assert_eq!(start, 0);
|
||||
assert_eq!(
|
||||
rendered.lines(),
|
||||
&["› hello".to_string(), "> world".to_string()]
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|candidate| candidate.replacement)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["/help".to_string(), "/hello".to_string()]
|
||||
);
|
||||
assert_eq!(rendered.cursor_position(), (1, 7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ctrl_c_exits_only_when_buffer_is_empty() {
|
||||
let mut editor = LineEditor::new("› ", vec![]);
|
||||
let mut empty = InputBuffer::new();
|
||||
assert!(matches!(
|
||||
editor.handle_key(
|
||||
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
|
||||
&mut empty,
|
||||
),
|
||||
super::EditorAction::Exit
|
||||
));
|
||||
fn completes_matching_slash_command_arguments() {
|
||||
let helper = SlashCommandHelper::new(vec![
|
||||
"/model".to_string(),
|
||||
"/model opus".to_string(),
|
||||
"/model sonnet".to_string(),
|
||||
"/session switch alpha".to_string(),
|
||||
]);
|
||||
let history = DefaultHistory::new();
|
||||
let ctx = Context::new(&history);
|
||||
let (start, matches) = helper
|
||||
.complete("/model o", 8, &ctx)
|
||||
.expect("completion should work");
|
||||
|
||||
let mut filled = InputBuffer::new();
|
||||
filled.insert('x');
|
||||
assert!(matches!(
|
||||
editor.handle_key(
|
||||
KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
|
||||
&mut filled,
|
||||
),
|
||||
super::EditorAction::Cancel
|
||||
));
|
||||
assert!(filled.as_str().is_empty());
|
||||
assert_eq!(start, 0);
|
||||
assert_eq!(
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|candidate| candidate.replacement)
|
||||
.collect::<Vec<_>>(),
|
||||
vec!["/model opus".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_non_slash_command_completion_requests() {
|
||||
let helper = SlashCommandHelper::new(vec!["/help".to_string()]);
|
||||
let history = DefaultHistory::new();
|
||||
let ctx = Context::new(&history);
|
||||
let (_, matches) = helper
|
||||
.complete("hello", 5, &ctx)
|
||||
.expect("completion should work");
|
||||
|
||||
assert!(matches.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tracks_current_buffer_through_highlighter() {
|
||||
let helper = SlashCommandHelper::new(Vec::new());
|
||||
let _ = helper.highlight("draft", 5);
|
||||
|
||||
assert_eq!(helper.current_line(), "draft");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_history_ignores_blank_entries() {
|
||||
let mut editor = LineEditor::new("> ", vec!["/help".to_string()]);
|
||||
editor.push_history(" ");
|
||||
editor.push_history("/help");
|
||||
|
||||
assert_eq!(editor.editor.history().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_completions_replaces_and_normalizes_candidates() {
|
||||
let mut editor = LineEditor::new("> ", vec!["/help".to_string()]);
|
||||
editor.set_completions(vec![
|
||||
"/model opus".to_string(),
|
||||
"/model opus".to_string(),
|
||||
"status".to_string(),
|
||||
]);
|
||||
|
||||
let helper = editor.editor.helper().expect("helper should exist");
|
||||
assert_eq!(helper.completions, vec!["/model opus".to_string()]);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,5 @@
|
||||
use std::fmt::Write as FmtWrite;
|
||||
use std::io::{self, Write};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
use crossterm::cursor::{MoveToColumn, RestorePosition, SavePosition};
|
||||
use crossterm::style::{Color, Print, ResetColor, SetForegroundColor, Stylize};
|
||||
@ -22,6 +20,7 @@ pub struct ColorTheme {
|
||||
link: Color,
|
||||
quote: Color,
|
||||
table_border: Color,
|
||||
code_block_border: Color,
|
||||
spinner_active: Color,
|
||||
spinner_done: Color,
|
||||
spinner_failed: Color,
|
||||
@ -37,6 +36,7 @@ impl Default for ColorTheme {
|
||||
link: Color::Blue,
|
||||
quote: Color::DarkGrey,
|
||||
table_border: Color::DarkCyan,
|
||||
code_block_border: Color::DarkGrey,
|
||||
spinner_active: Color::Blue,
|
||||
spinner_done: Color::Green,
|
||||
spinner_failed: Color::Red,
|
||||
@ -154,32 +154,63 @@ impl TableState {
|
||||
struct RenderState {
|
||||
emphasis: usize,
|
||||
strong: usize,
|
||||
heading_level: Option<u8>,
|
||||
quote: usize,
|
||||
list_stack: Vec<ListKind>,
|
||||
link_stack: Vec<LinkState>,
|
||||
table: Option<TableState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct LinkState {
|
||||
destination: String,
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl RenderState {
|
||||
fn style_text(&self, text: &str, theme: &ColorTheme) -> String {
|
||||
let mut styled = text.to_string();
|
||||
if self.strong > 0 {
|
||||
styled = format!("{}", styled.bold().with(theme.strong));
|
||||
let mut style = text.stylize();
|
||||
|
||||
if matches!(self.heading_level, Some(1 | 2)) || self.strong > 0 {
|
||||
style = style.bold();
|
||||
}
|
||||
if self.emphasis > 0 {
|
||||
styled = format!("{}", styled.italic().with(theme.emphasis));
|
||||
}
|
||||
if self.quote > 0 {
|
||||
styled = format!("{}", styled.with(theme.quote));
|
||||
}
|
||||
styled
|
||||
style = style.italic();
|
||||
}
|
||||
|
||||
fn capture_target_mut<'a>(&'a mut self, output: &'a mut String) -> &'a mut String {
|
||||
if let Some(table) = self.table.as_mut() {
|
||||
&mut table.current_cell
|
||||
} else {
|
||||
output
|
||||
if let Some(level) = self.heading_level {
|
||||
style = match level {
|
||||
1 => style.with(theme.heading),
|
||||
2 => style.white(),
|
||||
3 => style.with(Color::Blue),
|
||||
_ => style.with(Color::Grey),
|
||||
};
|
||||
} else if self.strong > 0 {
|
||||
style = style.with(theme.strong);
|
||||
} else if self.emphasis > 0 {
|
||||
style = style.with(theme.emphasis);
|
||||
}
|
||||
|
||||
if self.quote > 0 {
|
||||
style = style.with(theme.quote);
|
||||
}
|
||||
|
||||
format!("{style}")
|
||||
}
|
||||
|
||||
fn append_raw(&mut self, output: &mut String, text: &str) {
|
||||
if let Some(link) = self.link_stack.last_mut() {
|
||||
link.text.push_str(text);
|
||||
} else if let Some(table) = self.table.as_mut() {
|
||||
table.current_cell.push_str(text);
|
||||
} else {
|
||||
output.push_str(text);
|
||||
}
|
||||
}
|
||||
|
||||
fn append_styled(&mut self, output: &mut String, text: &str, theme: &ColorTheme) {
|
||||
let styled = self.style_text(text, theme);
|
||||
self.append_raw(output, &styled);
|
||||
}
|
||||
}
|
||||
|
||||
@ -218,13 +249,14 @@ impl TerminalRenderer {
|
||||
|
||||
#[must_use]
|
||||
pub fn render_markdown(&self, markdown: &str) -> String {
|
||||
let normalized = normalize_nested_fences(markdown);
|
||||
let mut output = String::new();
|
||||
let mut state = RenderState::default();
|
||||
let mut code_language = String::new();
|
||||
let mut code_buffer = String::new();
|
||||
let mut in_code_block = false;
|
||||
|
||||
for event in Parser::new_ext(markdown, Options::all()) {
|
||||
for event in Parser::new_ext(&normalized, Options::all()) {
|
||||
self.render_event(
|
||||
event,
|
||||
&mut state,
|
||||
@ -238,6 +270,11 @@ impl TerminalRenderer {
|
||||
output.trim_end().to_string()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn markdown_to_ansi(&self, markdown: &str) -> String {
|
||||
self.render_markdown(markdown)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn render_event(
|
||||
&self,
|
||||
@ -249,15 +286,21 @@ impl TerminalRenderer {
|
||||
in_code_block: &mut bool,
|
||||
) {
|
||||
match event {
|
||||
Event::Start(Tag::Heading { level, .. }) => self.start_heading(level as u8, output),
|
||||
Event::End(TagEnd::Heading(..) | TagEnd::Paragraph) => output.push_str("\n\n"),
|
||||
Event::Start(Tag::Heading { level, .. }) => {
|
||||
Self::start_heading(state, level as u8, output);
|
||||
}
|
||||
Event::End(TagEnd::Paragraph) => output.push_str("\n\n"),
|
||||
Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output),
|
||||
Event::End(TagEnd::BlockQuote(..)) => {
|
||||
state.quote = state.quote.saturating_sub(1);
|
||||
output.push('\n');
|
||||
}
|
||||
Event::End(TagEnd::Heading(..)) => {
|
||||
state.heading_level = None;
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
Event::End(TagEnd::Item) | Event::SoftBreak | Event::HardBreak => {
|
||||
state.capture_target_mut(output).push('\n');
|
||||
state.append_raw(output, "\n");
|
||||
}
|
||||
Event::Start(Tag::List(first_item)) => {
|
||||
let kind = match first_item {
|
||||
@ -293,41 +336,52 @@ impl TerminalRenderer {
|
||||
Event::Code(code) => {
|
||||
let rendered =
|
||||
format!("{}", format!("`{code}`").with(self.color_theme.inline_code));
|
||||
state.capture_target_mut(output).push_str(&rendered);
|
||||
state.append_raw(output, &rendered);
|
||||
}
|
||||
Event::Rule => output.push_str("---\n"),
|
||||
Event::Text(text) => {
|
||||
self.push_text(text.as_ref(), state, output, code_buffer, *in_code_block);
|
||||
}
|
||||
Event::Html(html) | Event::InlineHtml(html) => {
|
||||
state.capture_target_mut(output).push_str(&html);
|
||||
state.append_raw(output, &html);
|
||||
}
|
||||
Event::FootnoteReference(reference) => {
|
||||
let _ = write!(state.capture_target_mut(output), "[{reference}]");
|
||||
state.append_raw(output, &format!("[{reference}]"));
|
||||
}
|
||||
Event::TaskListMarker(done) => {
|
||||
state
|
||||
.capture_target_mut(output)
|
||||
.push_str(if done { "[x] " } else { "[ ] " });
|
||||
state.append_raw(output, if done { "[x] " } else { "[ ] " });
|
||||
}
|
||||
Event::InlineMath(math) | Event::DisplayMath(math) => {
|
||||
state.capture_target_mut(output).push_str(&math);
|
||||
state.append_raw(output, &math);
|
||||
}
|
||||
Event::Start(Tag::Link { dest_url, .. }) => {
|
||||
state.link_stack.push(LinkState {
|
||||
destination: dest_url.to_string(),
|
||||
text: String::new(),
|
||||
});
|
||||
}
|
||||
Event::End(TagEnd::Link) => {
|
||||
if let Some(link) = state.link_stack.pop() {
|
||||
let label = if link.text.is_empty() {
|
||||
link.destination.clone()
|
||||
} else {
|
||||
link.text
|
||||
};
|
||||
let rendered = format!(
|
||||
"{}",
|
||||
format!("[{dest_url}]")
|
||||
format!("[{label}]({})", link.destination)
|
||||
.underlined()
|
||||
.with(self.color_theme.link)
|
||||
);
|
||||
state.capture_target_mut(output).push_str(&rendered);
|
||||
state.append_raw(output, &rendered);
|
||||
}
|
||||
}
|
||||
Event::Start(Tag::Image { dest_url, .. }) => {
|
||||
let rendered = format!(
|
||||
"{}",
|
||||
format!("[image:{dest_url}]").with(self.color_theme.link)
|
||||
);
|
||||
state.capture_target_mut(output).push_str(&rendered);
|
||||
state.append_raw(output, &rendered);
|
||||
}
|
||||
Event::Start(Tag::Table(..)) => state.table = Some(TableState::default()),
|
||||
Event::End(TagEnd::Table) => {
|
||||
@ -369,19 +423,15 @@ impl TerminalRenderer {
|
||||
}
|
||||
}
|
||||
Event::Start(Tag::Paragraph | Tag::MetadataBlock(..) | _)
|
||||
| Event::End(TagEnd::Link | TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {}
|
||||
| Event::End(TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn start_heading(&self, level: u8, output: &mut String) {
|
||||
fn start_heading(state: &mut RenderState, level: u8, output: &mut String) {
|
||||
state.heading_level = Some(level);
|
||||
if !output.is_empty() {
|
||||
output.push('\n');
|
||||
let prefix = match level {
|
||||
1 => "# ",
|
||||
2 => "## ",
|
||||
3 => "### ",
|
||||
_ => "#### ",
|
||||
};
|
||||
let _ = write!(output, "{}", prefix.bold().with(self.color_theme.heading));
|
||||
}
|
||||
}
|
||||
|
||||
fn start_quote(&self, state: &mut RenderState, output: &mut String) {
|
||||
@ -405,20 +455,27 @@ impl TerminalRenderer {
|
||||
}
|
||||
|
||||
fn start_code_block(&self, code_language: &str, output: &mut String) {
|
||||
if !code_language.is_empty() {
|
||||
let label = if code_language.is_empty() {
|
||||
"code".to_string()
|
||||
} else {
|
||||
code_language.to_string()
|
||||
};
|
||||
let _ = writeln!(
|
||||
output,
|
||||
"{}",
|
||||
format!("╭─ {code_language}").with(self.color_theme.heading)
|
||||
format!("╭─ {label}")
|
||||
.bold()
|
||||
.with(self.color_theme.code_block_border)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn finish_code_block(&self, code_buffer: &str, code_language: &str, output: &mut String) {
|
||||
output.push_str(&self.highlight_code(code_buffer, code_language));
|
||||
if !code_language.is_empty() {
|
||||
let _ = write!(output, "{}", "╰─".with(self.color_theme.heading));
|
||||
}
|
||||
let _ = write!(
|
||||
output,
|
||||
"{}",
|
||||
"╰─".bold().with(self.color_theme.code_block_border)
|
||||
);
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
|
||||
@ -433,8 +490,7 @@ impl TerminalRenderer {
|
||||
if in_code_block {
|
||||
code_buffer.push_str(text);
|
||||
} else {
|
||||
let rendered = state.style_text(text, &self.color_theme);
|
||||
state.capture_target_mut(output).push_str(&rendered);
|
||||
state.append_styled(output, text, &self.color_theme);
|
||||
}
|
||||
}
|
||||
|
||||
@ -521,9 +577,10 @@ impl TerminalRenderer {
|
||||
for line in LinesWithEndings::from(code) {
|
||||
match syntax_highlighter.highlight_line(line, &self.syntax_set) {
|
||||
Ok(ranges) => {
|
||||
colored_output.push_str(&as_24_bit_terminal_escaped(&ranges[..], false));
|
||||
let escaped = as_24_bit_terminal_escaped(&ranges[..], false);
|
||||
colored_output.push_str(&apply_code_block_background(&escaped));
|
||||
}
|
||||
Err(_) => colored_output.push_str(line),
|
||||
Err(_) => colored_output.push_str(&apply_code_block_background(line)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -531,16 +588,296 @@ impl TerminalRenderer {
|
||||
}
|
||||
|
||||
pub fn stream_markdown(&self, markdown: &str, out: &mut impl Write) -> io::Result<()> {
|
||||
let rendered_markdown = self.render_markdown(markdown);
|
||||
for chunk in rendered_markdown.split_inclusive(char::is_whitespace) {
|
||||
write!(out, "{chunk}")?;
|
||||
out.flush()?;
|
||||
thread::sleep(Duration::from_millis(8));
|
||||
let rendered_markdown = self.markdown_to_ansi(markdown);
|
||||
write!(out, "{rendered_markdown}")?;
|
||||
if !rendered_markdown.ends_with('\n') {
|
||||
writeln!(out)?;
|
||||
}
|
||||
writeln!(out)
|
||||
out.flush()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, PartialEq, Eq)]
|
||||
pub struct MarkdownStreamState {
|
||||
pending: String,
|
||||
}
|
||||
|
||||
impl MarkdownStreamState {
|
||||
#[must_use]
|
||||
pub fn push(&mut self, renderer: &TerminalRenderer, delta: &str) -> Option<String> {
|
||||
self.pending.push_str(delta);
|
||||
let split = find_stream_safe_boundary(&self.pending)?;
|
||||
let ready = self.pending[..split].to_string();
|
||||
self.pending.drain(..split);
|
||||
Some(renderer.markdown_to_ansi(&ready))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn flush(&mut self, renderer: &TerminalRenderer) -> Option<String> {
|
||||
if self.pending.trim().is_empty() {
|
||||
self.pending.clear();
|
||||
None
|
||||
} else {
|
||||
let pending = std::mem::take(&mut self.pending);
|
||||
Some(renderer.markdown_to_ansi(&pending))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_code_block_background(line: &str) -> String {
|
||||
let trimmed = line.trim_end_matches('\n');
|
||||
let trailing_newline = if trimmed.len() == line.len() {
|
||||
""
|
||||
} else {
|
||||
"\n"
|
||||
};
|
||||
let with_background = trimmed.replace("\u{1b}[0m", "\u{1b}[0;48;5;236m");
|
||||
format!("\u{1b}[48;5;236m{with_background}\u{1b}[0m{trailing_newline}")
|
||||
}
|
||||
|
||||
/// Pre-process raw markdown so that fenced code blocks whose body contains
|
||||
/// fence markers of equal or greater length are wrapped with a longer fence.
|
||||
///
|
||||
/// LLMs frequently emit triple-backtick code blocks that contain triple-backtick
|
||||
/// examples. CommonMark (and pulldown-cmark) treats the inner marker as the
|
||||
/// closing fence, breaking the render. This function detects the situation and
|
||||
/// upgrades the outer fence to use enough backticks (or tildes) that the inner
|
||||
/// markers become ordinary content.
|
||||
fn normalize_nested_fences(markdown: &str) -> String {
|
||||
// A fence line is either "labeled" (has an info string ⇒ always an opener)
|
||||
// or "bare" (no info string ⇒ could be opener or closer).
|
||||
#[derive(Debug, Clone)]
|
||||
struct FenceLine {
|
||||
char: char,
|
||||
len: usize,
|
||||
has_info: bool,
|
||||
indent: usize,
|
||||
}
|
||||
|
||||
fn parse_fence_line(line: &str) -> Option<FenceLine> {
|
||||
let trimmed = line.trim_end_matches('\n').trim_end_matches('\r');
|
||||
let indent = trimmed.chars().take_while(|c| *c == ' ').count();
|
||||
if indent > 3 {
|
||||
return None;
|
||||
}
|
||||
let rest = &trimmed[indent..];
|
||||
let ch = rest.chars().next()?;
|
||||
if ch != '`' && ch != '~' {
|
||||
return None;
|
||||
}
|
||||
let len = rest.chars().take_while(|c| *c == ch).count();
|
||||
if len < 3 {
|
||||
return None;
|
||||
}
|
||||
let after = &rest[len..];
|
||||
if ch == '`' && after.contains('`') {
|
||||
return None;
|
||||
}
|
||||
let has_info = !after.trim().is_empty();
|
||||
Some(FenceLine {
|
||||
char: ch,
|
||||
len,
|
||||
has_info,
|
||||
indent,
|
||||
})
|
||||
}
|
||||
|
||||
let lines: Vec<&str> = markdown.split_inclusive('\n').collect();
|
||||
// Handle final line that may lack trailing newline.
|
||||
// split_inclusive already keeps the original chunks, including a
|
||||
// final chunk without '\n' if the input doesn't end with one.
|
||||
|
||||
// First pass: classify every line.
|
||||
let fence_info: Vec<Option<FenceLine>> = lines.iter().map(|l| parse_fence_line(l)).collect();
|
||||
|
||||
// Second pass: pair openers with closers using a stack, recording
|
||||
// (opener_idx, closer_idx) pairs plus the max fence length found between
|
||||
// them.
|
||||
struct StackEntry {
|
||||
line_idx: usize,
|
||||
fence: FenceLine,
|
||||
}
|
||||
|
||||
let mut stack: Vec<StackEntry> = Vec::new();
|
||||
// Paired blocks: (opener_line, closer_line, max_inner_fence_len)
|
||||
let mut pairs: Vec<(usize, usize, usize)> = Vec::new();
|
||||
|
||||
for (i, fi) in fence_info.iter().enumerate() {
|
||||
let Some(fl) = fi else { continue };
|
||||
|
||||
if fl.has_info {
|
||||
// Labeled fence ⇒ always an opener.
|
||||
stack.push(StackEntry {
|
||||
line_idx: i,
|
||||
fence: fl.clone(),
|
||||
});
|
||||
} else {
|
||||
// Bare fence ⇒ try to close the top of the stack if compatible.
|
||||
let closes_top = stack
|
||||
.last()
|
||||
.is_some_and(|top| top.fence.char == fl.char && fl.len >= top.fence.len);
|
||||
if closes_top {
|
||||
let opener = stack.pop().unwrap();
|
||||
// Find max fence length of any fence line strictly between
|
||||
// opener and closer (these are the nested fences).
|
||||
let inner_max = fence_info[opener.line_idx + 1..i]
|
||||
.iter()
|
||||
.filter_map(|fi| fi.as_ref().map(|f| f.len))
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
pairs.push((opener.line_idx, i, inner_max));
|
||||
} else {
|
||||
// Treat as opener.
|
||||
stack.push(StackEntry {
|
||||
line_idx: i,
|
||||
fence: fl.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine which lines need rewriting. A pair needs rewriting when
|
||||
// its opener length <= max inner fence length.
|
||||
struct Rewrite {
|
||||
char: char,
|
||||
new_len: usize,
|
||||
indent: usize,
|
||||
}
|
||||
let mut rewrites: std::collections::HashMap<usize, Rewrite> = std::collections::HashMap::new();
|
||||
|
||||
for (opener_idx, closer_idx, inner_max) in &pairs {
|
||||
let opener_fl = fence_info[*opener_idx].as_ref().unwrap();
|
||||
if opener_fl.len <= *inner_max {
|
||||
let new_len = inner_max + 1;
|
||||
let info_part = {
|
||||
let trimmed = lines[*opener_idx]
|
||||
.trim_end_matches('\n')
|
||||
.trim_end_matches('\r');
|
||||
let rest = &trimmed[opener_fl.indent..];
|
||||
rest[opener_fl.len..].to_string()
|
||||
};
|
||||
rewrites.insert(
|
||||
*opener_idx,
|
||||
Rewrite {
|
||||
char: opener_fl.char,
|
||||
new_len,
|
||||
indent: opener_fl.indent,
|
||||
},
|
||||
);
|
||||
let closer_fl = fence_info[*closer_idx].as_ref().unwrap();
|
||||
rewrites.insert(
|
||||
*closer_idx,
|
||||
Rewrite {
|
||||
char: closer_fl.char,
|
||||
new_len,
|
||||
indent: closer_fl.indent,
|
||||
},
|
||||
);
|
||||
// Store info string only in the opener; closer keeps the trailing
|
||||
// portion which is already handled through the original line.
|
||||
// Actually, we rebuild both lines from scratch below, including
|
||||
// the info string for the opener.
|
||||
let _ = info_part; // consumed in rebuild
|
||||
}
|
||||
}
|
||||
|
||||
if rewrites.is_empty() {
|
||||
return markdown.to_string();
|
||||
}
|
||||
|
||||
// Rebuild.
|
||||
let mut out = String::with_capacity(markdown.len() + rewrites.len() * 4);
|
||||
for (i, line) in lines.iter().enumerate() {
|
||||
if let Some(rw) = rewrites.get(&i) {
|
||||
let fence_str: String = std::iter::repeat(rw.char).take(rw.new_len).collect();
|
||||
let indent_str: String = std::iter::repeat(' ').take(rw.indent).collect();
|
||||
// Recover the original info string (if any) and trailing newline.
|
||||
let trimmed = line.trim_end_matches('\n').trim_end_matches('\r');
|
||||
let fi = fence_info[i].as_ref().unwrap();
|
||||
let info = &trimmed[fi.indent + fi.len..];
|
||||
let trailing = &line[trimmed.len()..];
|
||||
out.push_str(&indent_str);
|
||||
out.push_str(&fence_str);
|
||||
out.push_str(info);
|
||||
out.push_str(trailing);
|
||||
} else {
|
||||
out.push_str(line);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn find_stream_safe_boundary(markdown: &str) -> Option<usize> {
|
||||
let mut open_fence: Option<FenceMarker> = None;
|
||||
let mut last_boundary = None;
|
||||
|
||||
for (offset, line) in markdown.split_inclusive('\n').scan(0usize, |cursor, line| {
|
||||
let start = *cursor;
|
||||
*cursor += line.len();
|
||||
Some((start, line))
|
||||
}) {
|
||||
let line_without_newline = line.trim_end_matches('\n');
|
||||
if let Some(opener) = open_fence {
|
||||
if line_closes_fence(line_without_newline, opener) {
|
||||
open_fence = None;
|
||||
last_boundary = Some(offset + line.len());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(opener) = parse_fence_opener(line_without_newline) {
|
||||
open_fence = Some(opener);
|
||||
continue;
|
||||
}
|
||||
|
||||
if line_without_newline.trim().is_empty() {
|
||||
last_boundary = Some(offset + line.len());
|
||||
}
|
||||
}
|
||||
|
||||
last_boundary
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct FenceMarker {
|
||||
character: char,
|
||||
length: usize,
|
||||
}
|
||||
|
||||
fn parse_fence_opener(line: &str) -> Option<FenceMarker> {
|
||||
let indent = line.chars().take_while(|c| *c == ' ').count();
|
||||
if indent > 3 {
|
||||
return None;
|
||||
}
|
||||
let rest = &line[indent..];
|
||||
let character = rest.chars().next()?;
|
||||
if character != '`' && character != '~' {
|
||||
return None;
|
||||
}
|
||||
let length = rest.chars().take_while(|c| *c == character).count();
|
||||
if length < 3 {
|
||||
return None;
|
||||
}
|
||||
let info_string = &rest[length..];
|
||||
if character == '`' && info_string.contains('`') {
|
||||
return None;
|
||||
}
|
||||
Some(FenceMarker { character, length })
|
||||
}
|
||||
|
||||
fn line_closes_fence(line: &str, opener: FenceMarker) -> bool {
|
||||
let indent = line.chars().take_while(|c| *c == ' ').count();
|
||||
if indent > 3 {
|
||||
return false;
|
||||
}
|
||||
let rest = &line[indent..];
|
||||
let length = rest.chars().take_while(|c| *c == opener.character).count();
|
||||
if length < opener.length {
|
||||
return false;
|
||||
}
|
||||
rest[length..].chars().all(|c| c == ' ' || c == '\t')
|
||||
}
|
||||
|
||||
fn visible_width(input: &str) -> usize {
|
||||
strip_ansi(input).chars().count()
|
||||
}
|
||||
@ -569,7 +906,7 @@ fn strip_ansi(input: &str) -> String {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{strip_ansi, Spinner, TerminalRenderer};
|
||||
use super::{strip_ansi, MarkdownStreamState, Spinner, TerminalRenderer};
|
||||
|
||||
#[test]
|
||||
fn renders_markdown_with_styling_and_lists() {
|
||||
@ -583,16 +920,28 @@ mod tests {
|
||||
assert!(markdown_output.contains('\u{1b}'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_links_as_colored_markdown_labels() {
|
||||
let terminal_renderer = TerminalRenderer::new();
|
||||
let markdown_output =
|
||||
terminal_renderer.render_markdown("See [Claw](https://example.com/docs) now.");
|
||||
let plain_text = strip_ansi(&markdown_output);
|
||||
|
||||
assert!(plain_text.contains("[Claw](https://example.com/docs)"));
|
||||
assert!(markdown_output.contains('\u{1b}'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn highlights_fenced_code_blocks() {
|
||||
let terminal_renderer = TerminalRenderer::new();
|
||||
let markdown_output =
|
||||
terminal_renderer.render_markdown("```rust\nfn hi() { println!(\"hi\"); }\n```");
|
||||
terminal_renderer.markdown_to_ansi("```rust\nfn hi() { println!(\"hi\"); }\n```");
|
||||
let plain_text = strip_ansi(&markdown_output);
|
||||
|
||||
assert!(plain_text.contains("╭─ rust"));
|
||||
assert!(plain_text.contains("fn hi"));
|
||||
assert!(markdown_output.contains('\u{1b}'));
|
||||
assert!(markdown_output.contains("[48;5;236m"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -623,6 +972,80 @@ mod tests {
|
||||
assert!(markdown_output.contains('\u{1b}'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_state_waits_for_complete_blocks() {
|
||||
let renderer = TerminalRenderer::new();
|
||||
let mut state = MarkdownStreamState::default();
|
||||
|
||||
assert_eq!(state.push(&renderer, "# Heading"), None);
|
||||
let flushed = state
|
||||
.push(&renderer, "\n\nParagraph\n\n")
|
||||
.expect("completed block");
|
||||
let plain_text = strip_ansi(&flushed);
|
||||
assert!(plain_text.contains("Heading"));
|
||||
assert!(plain_text.contains("Paragraph"));
|
||||
|
||||
assert_eq!(state.push(&renderer, "```rust\nfn main() {}\n"), None);
|
||||
let code = state
|
||||
.push(&renderer, "```\n")
|
||||
.expect("closed code fence flushes");
|
||||
assert!(strip_ansi(&code).contains("fn main()"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_state_holds_outer_fence_with_nested_inner_fence() {
|
||||
let renderer = TerminalRenderer::new();
|
||||
let mut state = MarkdownStreamState::default();
|
||||
|
||||
assert_eq!(
|
||||
state.push(&renderer, "````markdown\n```rust\nfn inner() {}\n"),
|
||||
None,
|
||||
"inner triple backticks must not close the outer four-backtick fence"
|
||||
);
|
||||
assert_eq!(
|
||||
state.push(&renderer, "```\n"),
|
||||
None,
|
||||
"closing the inner fence must not flush the outer fence"
|
||||
);
|
||||
let flushed = state
|
||||
.push(&renderer, "````\n")
|
||||
.expect("closing the outer four-backtick fence flushes the buffered block");
|
||||
let plain_text = strip_ansi(&flushed);
|
||||
assert!(plain_text.contains("fn inner()"));
|
||||
assert!(plain_text.contains("```rust"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_state_distinguishes_backtick_and_tilde_fences() {
|
||||
let renderer = TerminalRenderer::new();
|
||||
let mut state = MarkdownStreamState::default();
|
||||
|
||||
assert_eq!(state.push(&renderer, "~~~text\n"), None);
|
||||
assert_eq!(
|
||||
state.push(&renderer, "```\nstill inside tilde fence\n"),
|
||||
None,
|
||||
"a backtick fence cannot close a tilde-opened fence"
|
||||
);
|
||||
assert_eq!(state.push(&renderer, "```\n"), None);
|
||||
let flushed = state
|
||||
.push(&renderer, "~~~\n")
|
||||
.expect("matching tilde marker closes the fence");
|
||||
let plain_text = strip_ansi(&flushed);
|
||||
assert!(plain_text.contains("still inside tilde fence"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renders_nested_fenced_code_block_preserves_inner_markers() {
|
||||
let terminal_renderer = TerminalRenderer::new();
|
||||
let markdown_output =
|
||||
terminal_renderer.markdown_to_ansi("````markdown\n```rust\nfn nested() {}\n```\n````");
|
||||
let plain_text = strip_ansi(&markdown_output);
|
||||
|
||||
assert!(plain_text.contains("╭─ markdown"));
|
||||
assert!(plain_text.contains("```rust"));
|
||||
assert!(plain_text.contains("fn nested()"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spinner_advances_frames() {
|
||||
let terminal_renderer = TerminalRenderer::new();
|
||||
|
||||
298
crates/rusty-claude-cli/tests/cli_flags_and_config_defaults.rs
Normal file
298
crates/rusty-claude-cli/tests/cli_flags_and_config_defaults.rs
Normal file
@ -0,0 +1,298 @@
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Output};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use runtime::Session;
|
||||
|
||||
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[test]
|
||||
fn status_command_applies_model_and_permission_mode_flags() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("status-flags");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
|
||||
// when
|
||||
let output = Command::new(env!("CARGO_BIN_EXE_claw"))
|
||||
.current_dir(&temp_dir)
|
||||
.args([
|
||||
"--model",
|
||||
"sonnet",
|
||||
"--permission-mode",
|
||||
"read-only",
|
||||
"status",
|
||||
])
|
||||
.output()
|
||||
.expect("claw should launch");
|
||||
|
||||
// then
|
||||
assert_success(&output);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Status"));
|
||||
assert!(stdout.contains("Model claude-sonnet-4-6"));
|
||||
assert!(stdout.contains("Permission mode read-only"));
|
||||
|
||||
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_flag_loads_a_saved_session_and_dispatches_status() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("resume-status");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = write_session(&temp_dir, "resume-status");
|
||||
|
||||
// when
|
||||
let output = Command::new(env!("CARGO_BIN_EXE_claw"))
|
||||
.current_dir(&temp_dir)
|
||||
.args([
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/status",
|
||||
])
|
||||
.output()
|
||||
.expect("claw should launch");
|
||||
|
||||
// then
|
||||
assert_success(&output);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Status"));
|
||||
assert!(stdout.contains("Messages 1"));
|
||||
assert!(stdout.contains("Session "));
|
||||
assert!(stdout.contains(session_path.to_str().expect("utf8 path")));
|
||||
|
||||
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slash_command_names_match_known_commands_and_suggest_nearby_unknown_ones() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("slash-dispatch");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
|
||||
// when
|
||||
let help_output = Command::new(env!("CARGO_BIN_EXE_claw"))
|
||||
.current_dir(&temp_dir)
|
||||
.arg("/help")
|
||||
.output()
|
||||
.expect("claw should launch");
|
||||
let unknown_output = Command::new(env!("CARGO_BIN_EXE_claw"))
|
||||
.current_dir(&temp_dir)
|
||||
.arg("/zstats")
|
||||
.output()
|
||||
.expect("claw should launch");
|
||||
|
||||
// then
|
||||
assert_success(&help_output);
|
||||
let help_stdout = String::from_utf8(help_output.stdout).expect("stdout should be utf8");
|
||||
assert!(help_stdout.contains("Interactive slash commands:"));
|
||||
assert!(help_stdout.contains("/status"));
|
||||
|
||||
assert!(
|
||||
!unknown_output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&unknown_output.stdout),
|
||||
String::from_utf8_lossy(&unknown_output.stderr)
|
||||
);
|
||||
let stderr = String::from_utf8(unknown_output.stderr).expect("stderr should be utf8");
|
||||
assert!(stderr.contains("unknown slash command outside the REPL: /zstats"));
|
||||
assert!(stderr.contains("Did you mean"));
|
||||
assert!(stderr.contains("/status"));
|
||||
|
||||
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn omc_namespaced_slash_commands_surface_a_targeted_compatibility_hint() {
|
||||
let temp_dir = unique_temp_dir("slash-dispatch-omc");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
|
||||
let output = Command::new(env!("CARGO_BIN_EXE_claw"))
|
||||
.current_dir(&temp_dir)
|
||||
.arg("/oh-my-claudecode:hud")
|
||||
.output()
|
||||
.expect("claw should launch");
|
||||
|
||||
assert!(
|
||||
!output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let stderr = String::from_utf8(output.stderr).expect("stderr should be utf8");
|
||||
assert!(stderr.contains("unknown slash command outside the REPL: /oh-my-claudecode:hud"));
|
||||
assert!(stderr.contains("Claude Code/OMC plugin command"));
|
||||
assert!(stderr.contains("does not yet load plugin slash commands"));
|
||||
|
||||
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_command_loads_defaults_from_standard_config_locations() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("config-defaults");
|
||||
let config_home = temp_dir.join("home").join(".claw");
|
||||
fs::create_dir_all(temp_dir.join(".claw")).expect("project config dir should exist");
|
||||
fs::create_dir_all(&config_home).expect("home config dir should exist");
|
||||
|
||||
fs::write(config_home.join("settings.json"), r#"{"model":"haiku"}"#)
|
||||
.expect("write user settings");
|
||||
fs::write(temp_dir.join(".claw.json"), r#"{"model":"sonnet"}"#)
|
||||
.expect("write project settings");
|
||||
fs::write(
|
||||
temp_dir.join(".claw").join("settings.local.json"),
|
||||
r#"{"model":"opus"}"#,
|
||||
)
|
||||
.expect("write local settings");
|
||||
let session_path = write_session(&temp_dir, "config-defaults");
|
||||
|
||||
// when
|
||||
let output = command_in(&temp_dir)
|
||||
.env("CLAW_CONFIG_HOME", &config_home)
|
||||
.args([
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/config",
|
||||
"model",
|
||||
])
|
||||
.output()
|
||||
.expect("claw should launch");
|
||||
|
||||
// then
|
||||
assert_success(&output);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Config"));
|
||||
assert!(stdout.contains("Loaded files 3"));
|
||||
assert!(stdout.contains("Merged section: model"));
|
||||
assert!(stdout.contains("opus"));
|
||||
assert!(stdout.contains(
|
||||
config_home
|
||||
.join("settings.json")
|
||||
.to_str()
|
||||
.expect("utf8 path")
|
||||
));
|
||||
assert!(stdout.contains(temp_dir.join(".claw.json").to_str().expect("utf8 path")));
|
||||
assert!(stdout.contains(
|
||||
temp_dir
|
||||
.join(".claw")
|
||||
.join("settings.local.json")
|
||||
.to_str()
|
||||
.expect("utf8 path")
|
||||
));
|
||||
|
||||
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn doctor_command_runs_as_a_local_shell_entrypoint() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("doctor-entrypoint");
|
||||
let config_home = temp_dir.join("home").join(".claw");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
|
||||
// when
|
||||
let output = command_in(&temp_dir)
|
||||
.env("CLAW_CONFIG_HOME", &config_home)
|
||||
.env_remove("ANTHROPIC_API_KEY")
|
||||
.env_remove("ANTHROPIC_AUTH_TOKEN")
|
||||
.env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9")
|
||||
.arg("doctor")
|
||||
.output()
|
||||
.expect("claw doctor should launch");
|
||||
|
||||
// then
|
||||
assert_success(&output);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Doctor"));
|
||||
assert!(stdout.contains("Auth"));
|
||||
assert!(stdout.contains("Config"));
|
||||
assert!(stdout.contains("Workspace"));
|
||||
assert!(stdout.contains("Sandbox"));
|
||||
assert!(!stdout.contains("Thinking"));
|
||||
|
||||
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_subcommand_help_does_not_fall_through_to_runtime_or_provider_calls() {
|
||||
let temp_dir = unique_temp_dir("subcommand-help");
|
||||
let config_home = temp_dir.join("home").join(".claw");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
|
||||
let doctor_help = command_in(&temp_dir)
|
||||
.env("CLAW_CONFIG_HOME", &config_home)
|
||||
.env_remove("ANTHROPIC_API_KEY")
|
||||
.env_remove("ANTHROPIC_AUTH_TOKEN")
|
||||
.env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9")
|
||||
.args(["doctor", "--help"])
|
||||
.output()
|
||||
.expect("doctor help should launch");
|
||||
let status_help = command_in(&temp_dir)
|
||||
.env("CLAW_CONFIG_HOME", &config_home)
|
||||
.env_remove("ANTHROPIC_API_KEY")
|
||||
.env_remove("ANTHROPIC_AUTH_TOKEN")
|
||||
.env("ANTHROPIC_BASE_URL", "http://127.0.0.1:9")
|
||||
.args(["status", "--help"])
|
||||
.output()
|
||||
.expect("status help should launch");
|
||||
|
||||
assert_success(&doctor_help);
|
||||
let doctor_stdout = String::from_utf8(doctor_help.stdout).expect("stdout should be utf8");
|
||||
assert!(doctor_stdout.contains("Usage claw doctor"));
|
||||
assert!(doctor_stdout.contains("local-only health report"));
|
||||
assert!(!doctor_stdout.contains("Thinking"));
|
||||
|
||||
assert_success(&status_help);
|
||||
let status_stdout = String::from_utf8(status_help.stdout).expect("stdout should be utf8");
|
||||
assert!(status_stdout.contains("Usage claw status"));
|
||||
assert!(status_stdout.contains("local workspace snapshot"));
|
||||
assert!(!status_stdout.contains("Thinking"));
|
||||
|
||||
let doctor_stderr = String::from_utf8(doctor_help.stderr).expect("stderr should be utf8");
|
||||
let status_stderr = String::from_utf8(status_help.stderr).expect("stderr should be utf8");
|
||||
assert!(!doctor_stderr.contains("auth_unavailable"));
|
||||
assert!(!status_stderr.contains("auth_unavailable"));
|
||||
|
||||
fs::remove_dir_all(temp_dir).expect("cleanup temp dir");
|
||||
}
|
||||
|
||||
fn command_in(cwd: &Path) -> Command {
|
||||
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
|
||||
command.current_dir(cwd);
|
||||
command
|
||||
}
|
||||
|
||||
fn write_session(root: &Path, label: &str) -> PathBuf {
|
||||
let session_path = root.join(format!("{label}.jsonl"));
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text(format!("session fixture for {label}"))
|
||||
.expect("session write should succeed");
|
||||
session
|
||||
.save_to_path(&session_path)
|
||||
.expect("session should persist");
|
||||
session_path
|
||||
}
|
||||
|
||||
fn assert_success(output: &Output) {
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
|
||||
fn unique_temp_dir(label: &str) -> PathBuf {
|
||||
let millis = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock should be after epoch")
|
||||
.as_millis();
|
||||
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!(
|
||||
"claw-{label}-{}-{millis}-{counter}",
|
||||
std::process::id()
|
||||
))
|
||||
}
|
||||
159
crates/rusty-claude-cli/tests/compact_output.rs
Normal file
159
crates/rusty-claude-cli/tests/compact_output.rs
Normal file
@ -0,0 +1,159 @@
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Output};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use mock_anthropic_service::{MockAnthropicService, SCENARIO_PREFIX};
|
||||
|
||||
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[test]
|
||||
fn compact_flag_prints_only_final_assistant_text_without_tool_call_details() {
|
||||
// given a workspace pointed at the mock Anthropic service and a fixture file
|
||||
// that the read_file_roundtrip scenario will fetch through a tool call
|
||||
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
|
||||
let server = runtime
|
||||
.block_on(MockAnthropicService::spawn())
|
||||
.expect("mock service should start");
|
||||
let base_url = server.base_url();
|
||||
|
||||
let workspace = unique_temp_dir("compact-read-file");
|
||||
let config_home = workspace.join("config-home");
|
||||
let home = workspace.join("home");
|
||||
fs::create_dir_all(&workspace).expect("workspace should exist");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
fs::create_dir_all(&home).expect("home should exist");
|
||||
fs::write(workspace.join("fixture.txt"), "alpha parity line\n").expect("fixture should write");
|
||||
|
||||
// when we run claw in compact text mode against a tool-using scenario
|
||||
let prompt = format!("{SCENARIO_PREFIX}read_file_roundtrip");
|
||||
let output = run_claw(
|
||||
&workspace,
|
||||
&config_home,
|
||||
&home,
|
||||
&base_url,
|
||||
&[
|
||||
"--model",
|
||||
"sonnet",
|
||||
"--permission-mode",
|
||||
"read-only",
|
||||
"--allowedTools",
|
||||
"read_file",
|
||||
"--compact",
|
||||
&prompt,
|
||||
],
|
||||
);
|
||||
|
||||
// then the command exits successfully and stdout contains exactly the final
|
||||
// assistant text with no tool call IDs, JSON envelopes, or spinner output
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"compact run should succeed\nstdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
let trimmed = stdout.trim_end_matches('\n');
|
||||
assert_eq!(
|
||||
trimmed, "read_file roundtrip complete: alpha parity line",
|
||||
"compact stdout should contain only the final assistant text"
|
||||
);
|
||||
assert!(
|
||||
!stdout.contains("toolu_"),
|
||||
"compact stdout must not leak tool_use_id ({stdout:?})"
|
||||
);
|
||||
assert!(
|
||||
!stdout.contains("\"tool_uses\""),
|
||||
"compact stdout must not leak json envelopes ({stdout:?})"
|
||||
);
|
||||
assert!(
|
||||
!stdout.contains("Thinking"),
|
||||
"compact stdout must not include the spinner banner ({stdout:?})"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compact_flag_streaming_text_only_emits_final_message_text() {
|
||||
// given a workspace pointed at the mock Anthropic service running the
|
||||
// streaming_text scenario which only emits a single assistant text block
|
||||
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
|
||||
let server = runtime
|
||||
.block_on(MockAnthropicService::spawn())
|
||||
.expect("mock service should start");
|
||||
let base_url = server.base_url();
|
||||
|
||||
let workspace = unique_temp_dir("compact-streaming-text");
|
||||
let config_home = workspace.join("config-home");
|
||||
let home = workspace.join("home");
|
||||
fs::create_dir_all(&workspace).expect("workspace should exist");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
fs::create_dir_all(&home).expect("home should exist");
|
||||
|
||||
// when we invoke claw with --compact for the streaming text scenario
|
||||
let prompt = format!("{SCENARIO_PREFIX}streaming_text");
|
||||
let output = run_claw(
|
||||
&workspace,
|
||||
&config_home,
|
||||
&home,
|
||||
&base_url,
|
||||
&[
|
||||
"--model",
|
||||
"sonnet",
|
||||
"--permission-mode",
|
||||
"read-only",
|
||||
"--compact",
|
||||
&prompt,
|
||||
],
|
||||
);
|
||||
|
||||
// then stdout should be exactly the assistant text followed by a newline
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"compact streaming run should succeed\nstdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert_eq!(
|
||||
stdout, "Mock streaming says hello from the parity harness.\n",
|
||||
"compact streaming stdout should contain only the final assistant text"
|
||||
);
|
||||
|
||||
fs::remove_dir_all(&workspace).expect("workspace cleanup should succeed");
|
||||
}
|
||||
|
||||
fn run_claw(
|
||||
cwd: &std::path::Path,
|
||||
config_home: &std::path::Path,
|
||||
home: &std::path::Path,
|
||||
base_url: &str,
|
||||
args: &[&str],
|
||||
) -> Output {
|
||||
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
|
||||
command
|
||||
.current_dir(cwd)
|
||||
.env_clear()
|
||||
.env("ANTHROPIC_API_KEY", "test-compact-key")
|
||||
.env("ANTHROPIC_BASE_URL", base_url)
|
||||
.env("CLAW_CONFIG_HOME", config_home)
|
||||
.env("HOME", home)
|
||||
.env("NO_COLOR", "1")
|
||||
.env("PATH", "/usr/bin:/bin")
|
||||
.args(args);
|
||||
command.output().expect("claw should launch")
|
||||
}
|
||||
|
||||
fn unique_temp_dir(label: &str) -> PathBuf {
|
||||
let millis = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock should be after epoch")
|
||||
.as_millis();
|
||||
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!(
|
||||
"claw-compact-{label}-{}-{millis}-{counter}",
|
||||
std::process::id()
|
||||
))
|
||||
}
|
||||
884
crates/rusty-claude-cli/tests/mock_parity_harness.rs
Normal file
884
crates/rusty-claude-cli/tests/mock_parity_harness.rs
Normal file
@ -0,0 +1,884 @@
|
||||
#![cfg(unix)]
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Output, Stdio};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use mock_anthropic_service::{MockAnthropicService, SCENARIO_PREFIX};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn clean_env_cli_reaches_mock_anthropic_service_across_scripted_parity_scenarios() {
|
||||
let manifest_entries = load_scenario_manifest();
|
||||
let manifest = manifest_entries
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|entry| (entry.name.clone(), entry))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
let runtime = tokio::runtime::Runtime::new().expect("tokio runtime should build");
|
||||
let server = runtime
|
||||
.block_on(MockAnthropicService::spawn())
|
||||
.expect("mock service should start");
|
||||
let base_url = server.base_url();
|
||||
|
||||
let cases = [
|
||||
ScenarioCase {
|
||||
name: "streaming_text",
|
||||
permission_mode: "read-only",
|
||||
allowed_tools: None,
|
||||
stdin: None,
|
||||
prepare: prepare_noop,
|
||||
assert: assert_streaming_text,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "read_file_roundtrip",
|
||||
permission_mode: "read-only",
|
||||
allowed_tools: Some("read_file"),
|
||||
stdin: None,
|
||||
prepare: prepare_read_fixture,
|
||||
assert: assert_read_file_roundtrip,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "grep_chunk_assembly",
|
||||
permission_mode: "read-only",
|
||||
allowed_tools: Some("grep_search"),
|
||||
stdin: None,
|
||||
prepare: prepare_grep_fixture,
|
||||
assert: assert_grep_chunk_assembly,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "write_file_allowed",
|
||||
permission_mode: "workspace-write",
|
||||
allowed_tools: Some("write_file"),
|
||||
stdin: None,
|
||||
prepare: prepare_noop,
|
||||
assert: assert_write_file_allowed,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "write_file_denied",
|
||||
permission_mode: "read-only",
|
||||
allowed_tools: Some("write_file"),
|
||||
stdin: None,
|
||||
prepare: prepare_noop,
|
||||
assert: assert_write_file_denied,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "multi_tool_turn_roundtrip",
|
||||
permission_mode: "read-only",
|
||||
allowed_tools: Some("read_file,grep_search"),
|
||||
stdin: None,
|
||||
prepare: prepare_multi_tool_fixture,
|
||||
assert: assert_multi_tool_turn_roundtrip,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "bash_stdout_roundtrip",
|
||||
permission_mode: "danger-full-access",
|
||||
allowed_tools: Some("bash"),
|
||||
stdin: None,
|
||||
prepare: prepare_noop,
|
||||
assert: assert_bash_stdout_roundtrip,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "bash_permission_prompt_approved",
|
||||
permission_mode: "workspace-write",
|
||||
allowed_tools: Some("bash"),
|
||||
stdin: Some("y\n"),
|
||||
prepare: prepare_noop,
|
||||
assert: assert_bash_permission_prompt_approved,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "bash_permission_prompt_denied",
|
||||
permission_mode: "workspace-write",
|
||||
allowed_tools: Some("bash"),
|
||||
stdin: Some("n\n"),
|
||||
prepare: prepare_noop,
|
||||
assert: assert_bash_permission_prompt_denied,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "plugin_tool_roundtrip",
|
||||
permission_mode: "workspace-write",
|
||||
allowed_tools: None,
|
||||
stdin: None,
|
||||
prepare: prepare_plugin_fixture,
|
||||
assert: assert_plugin_tool_roundtrip,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "auto_compact_triggered",
|
||||
permission_mode: "read-only",
|
||||
allowed_tools: None,
|
||||
stdin: None,
|
||||
prepare: prepare_noop,
|
||||
assert: assert_auto_compact_triggered,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
ScenarioCase {
|
||||
name: "token_cost_reporting",
|
||||
permission_mode: "read-only",
|
||||
allowed_tools: None,
|
||||
stdin: None,
|
||||
prepare: prepare_noop,
|
||||
assert: assert_token_cost_reporting,
|
||||
extra_env: None,
|
||||
resume_session: None,
|
||||
},
|
||||
];
|
||||
|
||||
let case_names = cases.iter().map(|case| case.name).collect::<Vec<_>>();
|
||||
let manifest_names = manifest_entries
|
||||
.iter()
|
||||
.map(|entry| entry.name.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
case_names, manifest_names,
|
||||
"manifest and harness cases must stay aligned"
|
||||
);
|
||||
|
||||
let mut scenario_reports = Vec::new();
|
||||
|
||||
for case in cases {
|
||||
let workspace = HarnessWorkspace::new(unique_temp_dir(case.name));
|
||||
workspace.create().expect("workspace should exist");
|
||||
(case.prepare)(&workspace);
|
||||
|
||||
let run = run_case(case, &workspace, &base_url);
|
||||
(case.assert)(&workspace, &run);
|
||||
|
||||
let manifest_entry = manifest
|
||||
.get(case.name)
|
||||
.unwrap_or_else(|| panic!("missing manifest entry for {}", case.name));
|
||||
scenario_reports.push(build_scenario_report(
|
||||
case.name,
|
||||
manifest_entry,
|
||||
&run.response,
|
||||
));
|
||||
|
||||
fs::remove_dir_all(&workspace.root).expect("workspace cleanup should succeed");
|
||||
}
|
||||
|
||||
let captured = runtime.block_on(server.captured_requests());
|
||||
// After `be561bf` added count_tokens preflight, each turn sends an
|
||||
// extra POST to `/v1/messages/count_tokens` before the messages POST.
|
||||
// The original count (21) assumed messages-only requests. We now
|
||||
// filter to `/v1/messages` and verify that subset matches the original
|
||||
// scenario expectation.
|
||||
let messages_only: Vec<_> = captured
|
||||
.iter()
|
||||
.filter(|r| r.path == "/v1/messages")
|
||||
.collect();
|
||||
assert_eq!(
|
||||
messages_only.len(),
|
||||
21,
|
||||
"twelve scenarios should produce twenty-one /v1/messages requests (total captured: {}, includes count_tokens)",
|
||||
captured.len()
|
||||
);
|
||||
assert!(messages_only.iter().all(|request| request.stream));
|
||||
|
||||
let scenarios = messages_only
|
||||
.iter()
|
||||
.map(|request| request.scenario.as_str())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
scenarios,
|
||||
vec![
|
||||
"streaming_text",
|
||||
"read_file_roundtrip",
|
||||
"read_file_roundtrip",
|
||||
"grep_chunk_assembly",
|
||||
"grep_chunk_assembly",
|
||||
"write_file_allowed",
|
||||
"write_file_allowed",
|
||||
"write_file_denied",
|
||||
"write_file_denied",
|
||||
"multi_tool_turn_roundtrip",
|
||||
"multi_tool_turn_roundtrip",
|
||||
"bash_stdout_roundtrip",
|
||||
"bash_stdout_roundtrip",
|
||||
"bash_permission_prompt_approved",
|
||||
"bash_permission_prompt_approved",
|
||||
"bash_permission_prompt_denied",
|
||||
"bash_permission_prompt_denied",
|
||||
"plugin_tool_roundtrip",
|
||||
"plugin_tool_roundtrip",
|
||||
"auto_compact_triggered",
|
||||
"token_cost_reporting",
|
||||
]
|
||||
);
|
||||
|
||||
let mut request_counts = BTreeMap::new();
|
||||
for request in &captured {
|
||||
*request_counts
|
||||
.entry(request.scenario.as_str())
|
||||
.or_insert(0_usize) += 1;
|
||||
}
|
||||
for report in &mut scenario_reports {
|
||||
report.request_count = *request_counts
|
||||
.get(report.name.as_str())
|
||||
.unwrap_or_else(|| panic!("missing request count for {}", report.name));
|
||||
}
|
||||
|
||||
maybe_write_report(&scenario_reports);
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct ScenarioCase {
|
||||
name: &'static str,
|
||||
permission_mode: &'static str,
|
||||
allowed_tools: Option<&'static str>,
|
||||
stdin: Option<&'static str>,
|
||||
prepare: fn(&HarnessWorkspace),
|
||||
assert: fn(&HarnessWorkspace, &ScenarioRun),
|
||||
extra_env: Option<(&'static str, &'static str)>,
|
||||
resume_session: Option<&'static str>,
|
||||
}
|
||||
|
||||
struct HarnessWorkspace {
|
||||
root: PathBuf,
|
||||
config_home: PathBuf,
|
||||
home: PathBuf,
|
||||
}
|
||||
|
||||
impl HarnessWorkspace {
|
||||
fn new(root: PathBuf) -> Self {
|
||||
Self {
|
||||
config_home: root.join("config-home"),
|
||||
home: root.join("home"),
|
||||
root,
|
||||
}
|
||||
}
|
||||
|
||||
fn create(&self) -> std::io::Result<()> {
|
||||
fs::create_dir_all(&self.root)?;
|
||||
fs::create_dir_all(&self.config_home)?;
|
||||
fs::create_dir_all(&self.home)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct ScenarioRun {
|
||||
response: Value,
|
||||
stdout: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ScenarioManifestEntry {
|
||||
name: String,
|
||||
category: String,
|
||||
description: String,
|
||||
parity_refs: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ScenarioReport {
|
||||
name: String,
|
||||
category: String,
|
||||
description: String,
|
||||
parity_refs: Vec<String>,
|
||||
iterations: u64,
|
||||
request_count: usize,
|
||||
tool_uses: Vec<String>,
|
||||
tool_error_count: usize,
|
||||
final_message: String,
|
||||
}
|
||||
|
||||
fn run_case(case: ScenarioCase, workspace: &HarnessWorkspace, base_url: &str) -> ScenarioRun {
|
||||
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
|
||||
command
|
||||
.current_dir(&workspace.root)
|
||||
.env_clear()
|
||||
.env("ANTHROPIC_API_KEY", "test-parity-key")
|
||||
.env("ANTHROPIC_BASE_URL", base_url)
|
||||
.env("CLAW_CONFIG_HOME", &workspace.config_home)
|
||||
.env("HOME", &workspace.home)
|
||||
.env("NO_COLOR", "1")
|
||||
.env("PATH", "/usr/bin:/bin")
|
||||
.args([
|
||||
"--model",
|
||||
"sonnet",
|
||||
"--permission-mode",
|
||||
case.permission_mode,
|
||||
"--output-format=json",
|
||||
]);
|
||||
|
||||
if let Some(allowed_tools) = case.allowed_tools {
|
||||
command.args(["--allowedTools", allowed_tools]);
|
||||
}
|
||||
if let Some((key, value)) = case.extra_env {
|
||||
command.env(key, value);
|
||||
}
|
||||
if let Some(session_id) = case.resume_session {
|
||||
command.args(["--resume", session_id]);
|
||||
}
|
||||
|
||||
let prompt = format!("{SCENARIO_PREFIX}{}", case.name);
|
||||
command.arg(prompt);
|
||||
|
||||
let output = if let Some(stdin) = case.stdin {
|
||||
let mut child = command
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.expect("claw should launch");
|
||||
child
|
||||
.stdin
|
||||
.as_mut()
|
||||
.expect("stdin should be piped")
|
||||
.write_all(stdin.as_bytes())
|
||||
.expect("stdin should write");
|
||||
child.wait_with_output().expect("claw should finish")
|
||||
} else {
|
||||
command.output().expect("claw should launch")
|
||||
};
|
||||
|
||||
assert_success(&output);
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).into_owned();
|
||||
ScenarioRun {
|
||||
response: parse_json_output(&stdout),
|
||||
stdout,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn prepare_auto_compact_fixture(workspace: &HarnessWorkspace) {
|
||||
let sessions_dir = workspace.root.join(".claw").join("sessions");
|
||||
fs::create_dir_all(&sessions_dir).expect("sessions dir should exist");
|
||||
|
||||
// Write a pre-seeded session with 6 messages so auto-compact can remove them
|
||||
let session_id = "parity-auto-compact-seed";
|
||||
let session_jsonl = r#"{"type":"session_meta","version":3,"session_id":"parity-auto-compact-seed","created_at_ms":1743724800000,"updated_at_ms":1743724800000}
|
||||
{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step one of the parity scenario"}]}}
|
||||
{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step one"}]}}
|
||||
{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step two of the parity scenario"}]}}
|
||||
{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step two"}]}}
|
||||
{"type":"message","message":{"role":"user","blocks":[{"type":"text","text":"step three of the parity scenario"}]}}
|
||||
{"type":"message","message":{"role":"assistant","blocks":[{"type":"text","text":"acknowledged step three"}]}}
|
||||
"#;
|
||||
fs::write(
|
||||
sessions_dir.join(format!("{session_id}.jsonl")),
|
||||
session_jsonl,
|
||||
)
|
||||
.expect("pre-seeded session should write");
|
||||
}
|
||||
|
||||
fn prepare_noop(_: &HarnessWorkspace) {}
|
||||
|
||||
fn prepare_read_fixture(workspace: &HarnessWorkspace) {
|
||||
fs::write(workspace.root.join("fixture.txt"), "alpha parity line\n")
|
||||
.expect("fixture should write");
|
||||
}
|
||||
|
||||
fn prepare_grep_fixture(workspace: &HarnessWorkspace) {
|
||||
fs::write(
|
||||
workspace.root.join("fixture.txt"),
|
||||
"alpha parity line\nbeta line\ngamma parity line\n",
|
||||
)
|
||||
.expect("grep fixture should write");
|
||||
}
|
||||
|
||||
fn prepare_multi_tool_fixture(workspace: &HarnessWorkspace) {
|
||||
fs::write(
|
||||
workspace.root.join("fixture.txt"),
|
||||
"alpha parity line\nbeta line\ngamma parity line\n",
|
||||
)
|
||||
.expect("multi tool fixture should write");
|
||||
}
|
||||
|
||||
fn prepare_plugin_fixture(workspace: &HarnessWorkspace) {
|
||||
let plugin_root = workspace
|
||||
.root
|
||||
.join("external-plugins")
|
||||
.join("parity-plugin");
|
||||
let tool_dir = plugin_root.join("tools");
|
||||
let manifest_dir = plugin_root.join(".claude-plugin");
|
||||
fs::create_dir_all(&tool_dir).expect("plugin tools dir");
|
||||
fs::create_dir_all(&manifest_dir).expect("plugin manifest dir");
|
||||
|
||||
let script_path = tool_dir.join("echo-json.sh");
|
||||
fs::write(
|
||||
&script_path,
|
||||
"#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAWD_PLUGIN_ID\" \"$CLAWD_TOOL_NAME\" \"$INPUT\"\n",
|
||||
)
|
||||
.expect("plugin script should write");
|
||||
let mut permissions = fs::metadata(&script_path)
|
||||
.expect("plugin script metadata")
|
||||
.permissions();
|
||||
permissions.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, permissions).expect("plugin script should be executable");
|
||||
|
||||
fs::write(
|
||||
manifest_dir.join("plugin.json"),
|
||||
r#"{
|
||||
"name": "parity-plugin",
|
||||
"version": "1.0.0",
|
||||
"description": "mock parity plugin",
|
||||
"tools": [
|
||||
{
|
||||
"name": "plugin_echo",
|
||||
"description": "Echo JSON input",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" }
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"command": "./tools/echo-json.sh",
|
||||
"requiredPermission": "workspace-write"
|
||||
}
|
||||
]
|
||||
}"#,
|
||||
)
|
||||
.expect("plugin manifest should write");
|
||||
|
||||
fs::write(
|
||||
workspace.config_home.join("settings.json"),
|
||||
json!({
|
||||
"enabledPlugins": {
|
||||
"parity-plugin@external": true
|
||||
},
|
||||
"plugins": {
|
||||
"externalDirectories": [plugin_root.parent().expect("plugin parent").display().to_string()]
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.expect("plugin settings should write");
|
||||
}
|
||||
|
||||
fn assert_streaming_text(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(
|
||||
run.response["message"],
|
||||
Value::String("Mock streaming says hello from the parity harness.".to_string())
|
||||
);
|
||||
assert_eq!(run.response["iterations"], Value::from(1));
|
||||
assert_eq!(run.response["tool_uses"], Value::Array(Vec::new()));
|
||||
assert_eq!(run.response["tool_results"], Value::Array(Vec::new()));
|
||||
}
|
||||
|
||||
fn assert_read_file_roundtrip(workspace: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["name"],
|
||||
Value::String("read_file".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["input"],
|
||||
Value::String(r#"{"path":"fixture.txt"}"#.to_string())
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("alpha parity line"));
|
||||
let output = run.response["tool_results"][0]["output"]
|
||||
.as_str()
|
||||
.expect("tool output");
|
||||
assert!(output.contains(&workspace.root.join("fixture.txt").display().to_string()));
|
||||
assert!(output.contains("alpha parity line"));
|
||||
}
|
||||
|
||||
fn assert_grep_chunk_assembly(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["name"],
|
||||
Value::String("grep_search".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["input"],
|
||||
Value::String(
|
||||
r#"{"pattern":"parity","path":"fixture.txt","output_mode":"count"}"#.to_string()
|
||||
)
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("2 occurrences"));
|
||||
assert_eq!(
|
||||
run.response["tool_results"][0]["is_error"],
|
||||
Value::Bool(false)
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_write_file_allowed(workspace: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["name"],
|
||||
Value::String("write_file".to_string())
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("generated/output.txt"));
|
||||
let generated = workspace.root.join("generated").join("output.txt");
|
||||
let contents = fs::read_to_string(&generated).expect("generated file should exist");
|
||||
assert_eq!(contents, "created by mock service\n");
|
||||
assert_eq!(
|
||||
run.response["tool_results"][0]["is_error"],
|
||||
Value::Bool(false)
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_write_file_denied(workspace: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["name"],
|
||||
Value::String("write_file".to_string())
|
||||
);
|
||||
let tool_output = run.response["tool_results"][0]["output"]
|
||||
.as_str()
|
||||
.expect("tool output");
|
||||
assert!(tool_output.contains("requires workspace-write permission"));
|
||||
assert_eq!(
|
||||
run.response["tool_results"][0]["is_error"],
|
||||
Value::Bool(true)
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("denied as expected"));
|
||||
assert!(!workspace.root.join("generated").join("denied.txt").exists());
|
||||
}
|
||||
|
||||
fn assert_multi_tool_turn_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
let tool_uses = run.response["tool_uses"]
|
||||
.as_array()
|
||||
.expect("tool uses array");
|
||||
assert_eq!(
|
||||
tool_uses.len(),
|
||||
2,
|
||||
"expected two tool uses in a single turn"
|
||||
);
|
||||
assert_eq!(tool_uses[0]["name"], Value::String("read_file".to_string()));
|
||||
assert_eq!(
|
||||
tool_uses[1]["name"],
|
||||
Value::String("grep_search".to_string())
|
||||
);
|
||||
let tool_results = run.response["tool_results"]
|
||||
.as_array()
|
||||
.expect("tool results array");
|
||||
assert_eq!(
|
||||
tool_results.len(),
|
||||
2,
|
||||
"expected two tool results in a single turn"
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("alpha parity line"));
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("2 occurrences"));
|
||||
}
|
||||
|
||||
fn assert_bash_stdout_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["name"],
|
||||
Value::String("bash".to_string())
|
||||
);
|
||||
let tool_output = run.response["tool_results"][0]["output"]
|
||||
.as_str()
|
||||
.expect("tool output");
|
||||
let parsed: Value = serde_json::from_str(tool_output).expect("bash output json");
|
||||
assert_eq!(
|
||||
parsed["stdout"],
|
||||
Value::String("alpha from bash".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
run.response["tool_results"][0]["is_error"],
|
||||
Value::Bool(false)
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("alpha from bash"));
|
||||
}
|
||||
|
||||
fn assert_bash_permission_prompt_approved(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert!(run.stdout.contains("Permission approval required"));
|
||||
assert!(run.stdout.contains("Approve this tool call? [y/N]:"));
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
assert_eq!(
|
||||
run.response["tool_results"][0]["is_error"],
|
||||
Value::Bool(false)
|
||||
);
|
||||
let tool_output = run.response["tool_results"][0]["output"]
|
||||
.as_str()
|
||||
.expect("tool output");
|
||||
let parsed: Value = serde_json::from_str(tool_output).expect("bash output json");
|
||||
assert_eq!(
|
||||
parsed["stdout"],
|
||||
Value::String("approved via prompt".to_string())
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("approved and executed"));
|
||||
}
|
||||
|
||||
fn assert_bash_permission_prompt_denied(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert!(run.stdout.contains("Permission approval required"));
|
||||
assert!(run.stdout.contains("Approve this tool call? [y/N]:"));
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
let tool_output = run.response["tool_results"][0]["output"]
|
||||
.as_str()
|
||||
.expect("tool output");
|
||||
assert!(tool_output.contains("denied by user approval prompt"));
|
||||
assert_eq!(
|
||||
run.response["tool_results"][0]["is_error"],
|
||||
Value::Bool(true)
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("denied as expected"));
|
||||
}
|
||||
|
||||
fn assert_plugin_tool_roundtrip(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(2));
|
||||
assert_eq!(
|
||||
run.response["tool_uses"][0]["name"],
|
||||
Value::String("plugin_echo".to_string())
|
||||
);
|
||||
let tool_output = run.response["tool_results"][0]["output"]
|
||||
.as_str()
|
||||
.expect("tool output");
|
||||
let parsed: Value = serde_json::from_str(tool_output).expect("plugin output json");
|
||||
assert_eq!(
|
||||
parsed["plugin"],
|
||||
Value::String("parity-plugin@external".to_string())
|
||||
);
|
||||
assert_eq!(parsed["tool"], Value::String("plugin_echo".to_string()));
|
||||
assert_eq!(
|
||||
parsed["input"]["message"],
|
||||
Value::String("hello from plugin parity".to_string())
|
||||
);
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("hello from plugin parity"));
|
||||
}
|
||||
|
||||
fn assert_auto_compact_triggered(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
// Validates that the auto_compaction field is present in JSON output (format parity).
|
||||
// Trigger behavior is covered by conversation::tests::auto_compacts_when_cumulative_input_threshold_is_crossed.
|
||||
assert_eq!(run.response["iterations"], Value::from(1));
|
||||
assert_eq!(run.response["tool_uses"], Value::Array(Vec::new()));
|
||||
assert!(
|
||||
run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("auto compact parity complete."),
|
||||
"expected auto compact message in response"
|
||||
);
|
||||
// auto_compaction key must be present in JSON (may be null for below-threshold sessions)
|
||||
assert!(
|
||||
run.response
|
||||
.as_object()
|
||||
.expect("response object")
|
||||
.contains_key("auto_compaction"),
|
||||
"auto_compaction key must be present in JSON output"
|
||||
);
|
||||
// Verify input_tokens field reflects the large mock token counts
|
||||
let input_tokens = run.response["usage"]["input_tokens"]
|
||||
.as_u64()
|
||||
.expect("input_tokens should be present");
|
||||
assert!(
|
||||
input_tokens >= 50_000,
|
||||
"input_tokens should reflect mock service value (got {input_tokens})"
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_token_cost_reporting(_: &HarnessWorkspace, run: &ScenarioRun) {
|
||||
assert_eq!(run.response["iterations"], Value::from(1));
|
||||
assert!(run.response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.contains("token cost reporting parity complete."),);
|
||||
let usage = &run.response["usage"];
|
||||
assert!(
|
||||
usage["input_tokens"].as_u64().unwrap_or(0) > 0,
|
||||
"input_tokens should be non-zero"
|
||||
);
|
||||
assert!(
|
||||
usage["output_tokens"].as_u64().unwrap_or(0) > 0,
|
||||
"output_tokens should be non-zero"
|
||||
);
|
||||
assert!(
|
||||
run.response["estimated_cost"]
|
||||
.as_str()
|
||||
.is_some_and(|cost| cost.starts_with('$')),
|
||||
"estimated_cost should be a dollar-prefixed string"
|
||||
);
|
||||
}
|
||||
|
||||
fn parse_json_output(stdout: &str) -> Value {
|
||||
if let Some(index) = stdout.rfind("{\"auto_compaction\"") {
|
||||
return serde_json::from_str(&stdout[index..]).unwrap_or_else(|error| {
|
||||
panic!("failed to parse JSON response from stdout: {error}\n{stdout}")
|
||||
});
|
||||
}
|
||||
|
||||
stdout
|
||||
.lines()
|
||||
.rev()
|
||||
.find_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.starts_with('{') && trimmed.ends_with('}') {
|
||||
serde_json::from_str(trimmed).ok()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|| panic!("no JSON response line found in stdout:\n{stdout}"))
|
||||
}
|
||||
|
||||
fn build_scenario_report(
|
||||
name: &str,
|
||||
manifest_entry: &ScenarioManifestEntry,
|
||||
response: &Value,
|
||||
) -> ScenarioReport {
|
||||
ScenarioReport {
|
||||
name: name.to_string(),
|
||||
category: manifest_entry.category.clone(),
|
||||
description: manifest_entry.description.clone(),
|
||||
parity_refs: manifest_entry.parity_refs.clone(),
|
||||
iterations: response["iterations"]
|
||||
.as_u64()
|
||||
.expect("iterations should exist"),
|
||||
request_count: 0,
|
||||
tool_uses: response["tool_uses"]
|
||||
.as_array()
|
||||
.expect("tool uses array")
|
||||
.iter()
|
||||
.filter_map(|value| value["name"].as_str().map(ToOwned::to_owned))
|
||||
.collect(),
|
||||
tool_error_count: response["tool_results"]
|
||||
.as_array()
|
||||
.expect("tool results array")
|
||||
.iter()
|
||||
.filter(|value| value["is_error"].as_bool().unwrap_or(false))
|
||||
.count(),
|
||||
final_message: response["message"]
|
||||
.as_str()
|
||||
.expect("message text")
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_write_report(reports: &[ScenarioReport]) {
|
||||
let Some(path) = std::env::var_os("MOCK_PARITY_REPORT_PATH") else {
|
||||
return;
|
||||
};
|
||||
|
||||
let payload = json!({
|
||||
"scenario_count": reports.len(),
|
||||
"request_count": reports.iter().map(|report| report.request_count).sum::<usize>(),
|
||||
"scenarios": reports.iter().map(scenario_report_json).collect::<Vec<_>>(),
|
||||
});
|
||||
fs::write(
|
||||
path,
|
||||
serde_json::to_vec_pretty(&payload).expect("report json should serialize"),
|
||||
)
|
||||
.expect("report should write");
|
||||
}
|
||||
|
||||
fn load_scenario_manifest() -> Vec<ScenarioManifestEntry> {
|
||||
let manifest_path =
|
||||
Path::new(env!("CARGO_MANIFEST_DIR")).join("../../mock_parity_scenarios.json");
|
||||
let manifest = fs::read_to_string(&manifest_path).expect("scenario manifest should exist");
|
||||
serde_json::from_str::<Vec<Value>>(&manifest)
|
||||
.expect("scenario manifest should parse")
|
||||
.into_iter()
|
||||
.map(|entry| ScenarioManifestEntry {
|
||||
name: entry["name"]
|
||||
.as_str()
|
||||
.expect("scenario name should be a string")
|
||||
.to_string(),
|
||||
category: entry["category"]
|
||||
.as_str()
|
||||
.expect("scenario category should be a string")
|
||||
.to_string(),
|
||||
description: entry["description"]
|
||||
.as_str()
|
||||
.expect("scenario description should be a string")
|
||||
.to_string(),
|
||||
parity_refs: entry["parity_refs"]
|
||||
.as_array()
|
||||
.expect("parity refs should be an array")
|
||||
.iter()
|
||||
.map(|value| {
|
||||
value
|
||||
.as_str()
|
||||
.expect("parity ref should be a string")
|
||||
.to_string()
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn scenario_report_json(report: &ScenarioReport) -> Value {
|
||||
json!({
|
||||
"name": report.name,
|
||||
"category": report.category,
|
||||
"description": report.description,
|
||||
"parity_refs": report.parity_refs,
|
||||
"iterations": report.iterations,
|
||||
"request_count": report.request_count,
|
||||
"tool_uses": report.tool_uses,
|
||||
"tool_error_count": report.tool_error_count,
|
||||
"final_message": report.final_message,
|
||||
})
|
||||
}
|
||||
|
||||
fn assert_success(output: &Output) {
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
}
|
||||
|
||||
fn unique_temp_dir(label: &str) -> PathBuf {
|
||||
let millis = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock should be after epoch")
|
||||
.as_millis();
|
||||
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!(
|
||||
"claw-mock-parity-{label}-{}-{millis}-{counter}",
|
||||
std::process::id()
|
||||
))
|
||||
}
|
||||
429
crates/rusty-claude-cli/tests/output_format_contract.rs
Normal file
429
crates/rusty-claude-cli/tests/output_format_contract.rs
Normal file
@ -0,0 +1,429 @@
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Command, Output};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[test]
|
||||
fn help_emits_json_when_requested() {
|
||||
let root = unique_temp_dir("help-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let parsed = assert_json_command(&root, &["--output-format", "json", "help"]);
|
||||
assert_eq!(parsed["kind"], "help");
|
||||
assert!(parsed["message"]
|
||||
.as_str()
|
||||
.expect("help text")
|
||||
.contains("Usage:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn version_emits_json_when_requested() {
|
||||
let root = unique_temp_dir("version-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let parsed = assert_json_command(&root, &["--output-format", "json", "version"]);
|
||||
assert_eq!(parsed["kind"], "version");
|
||||
assert_eq!(parsed["version"], env!("CARGO_PKG_VERSION"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_and_sandbox_emit_json_when_requested() {
|
||||
let root = unique_temp_dir("status-sandbox-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let status = assert_json_command(&root, &["--output-format", "json", "status"]);
|
||||
assert_eq!(status["kind"], "status");
|
||||
assert!(status["workspace"]["cwd"].as_str().is_some());
|
||||
|
||||
let sandbox = assert_json_command(&root, &["--output-format", "json", "sandbox"]);
|
||||
assert_eq!(sandbox["kind"], "sandbox");
|
||||
assert!(sandbox["filesystem_mode"].as_str().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inventory_commands_emit_structured_json_when_requested() {
|
||||
let root = unique_temp_dir("inventory-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let isolated_home = root.join("home");
|
||||
let isolated_config = root.join("config-home");
|
||||
let isolated_codex = root.join("codex-home");
|
||||
fs::create_dir_all(&isolated_home).expect("isolated home should exist");
|
||||
|
||||
let agents = assert_json_command_with_env(
|
||||
&root,
|
||||
&["--output-format", "json", "agents"],
|
||||
&[
|
||||
("HOME", isolated_home.to_str().expect("utf8 home")),
|
||||
(
|
||||
"CLAW_CONFIG_HOME",
|
||||
isolated_config.to_str().expect("utf8 config home"),
|
||||
),
|
||||
(
|
||||
"CODEX_HOME",
|
||||
isolated_codex.to_str().expect("utf8 codex home"),
|
||||
),
|
||||
],
|
||||
);
|
||||
assert_eq!(agents["kind"], "agents");
|
||||
assert_eq!(agents["action"], "list");
|
||||
assert_eq!(agents["count"], 0);
|
||||
assert_eq!(agents["summary"]["active"], 0);
|
||||
assert!(agents["agents"]
|
||||
.as_array()
|
||||
.expect("agents array")
|
||||
.is_empty());
|
||||
|
||||
let mcp = assert_json_command(&root, &["--output-format", "json", "mcp"]);
|
||||
assert_eq!(mcp["kind"], "mcp");
|
||||
assert_eq!(mcp["action"], "list");
|
||||
|
||||
let skills = assert_json_command(&root, &["--output-format", "json", "skills"]);
|
||||
assert_eq!(skills["kind"], "skills");
|
||||
assert_eq!(skills["action"], "list");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn agents_command_emits_structured_agent_entries_when_requested() {
|
||||
let root = unique_temp_dir("agents-json-populated");
|
||||
let workspace = root.join("workspace");
|
||||
let project_agents = workspace.join(".codex").join("agents");
|
||||
let home = root.join("home");
|
||||
let user_agents = home.join(".codex").join("agents");
|
||||
let isolated_config = root.join("config-home");
|
||||
let isolated_codex = root.join("codex-home");
|
||||
fs::create_dir_all(&workspace).expect("workspace should exist");
|
||||
write_agent(
|
||||
&project_agents,
|
||||
"planner",
|
||||
"Project planner",
|
||||
"gpt-5.4",
|
||||
"medium",
|
||||
);
|
||||
write_agent(
|
||||
&project_agents,
|
||||
"verifier",
|
||||
"Verification agent",
|
||||
"gpt-5.4-mini",
|
||||
"high",
|
||||
);
|
||||
write_agent(
|
||||
&user_agents,
|
||||
"planner",
|
||||
"User planner",
|
||||
"gpt-5.4-mini",
|
||||
"high",
|
||||
);
|
||||
|
||||
let parsed = assert_json_command_with_env(
|
||||
&workspace,
|
||||
&["--output-format", "json", "agents"],
|
||||
&[
|
||||
("HOME", home.to_str().expect("utf8 home")),
|
||||
(
|
||||
"CLAW_CONFIG_HOME",
|
||||
isolated_config.to_str().expect("utf8 config home"),
|
||||
),
|
||||
(
|
||||
"CODEX_HOME",
|
||||
isolated_codex.to_str().expect("utf8 codex home"),
|
||||
),
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(parsed["kind"], "agents");
|
||||
assert_eq!(parsed["action"], "list");
|
||||
assert_eq!(parsed["count"], 3);
|
||||
assert_eq!(parsed["summary"]["active"], 2);
|
||||
assert_eq!(parsed["summary"]["shadowed"], 1);
|
||||
assert_eq!(parsed["agents"][0]["name"], "planner");
|
||||
assert_eq!(parsed["agents"][0]["source"]["id"], "project_claw");
|
||||
assert_eq!(parsed["agents"][0]["active"], true);
|
||||
assert_eq!(parsed["agents"][1]["name"], "verifier");
|
||||
assert_eq!(parsed["agents"][2]["name"], "planner");
|
||||
assert_eq!(parsed["agents"][2]["active"], false);
|
||||
assert_eq!(parsed["agents"][2]["shadowed_by"]["id"], "project_claw");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bootstrap_and_system_prompt_emit_json_when_requested() {
|
||||
let root = unique_temp_dir("bootstrap-system-prompt-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let plan = assert_json_command(&root, &["--output-format", "json", "bootstrap-plan"]);
|
||||
assert_eq!(plan["kind"], "bootstrap-plan");
|
||||
assert!(plan["phases"].as_array().expect("phases").len() > 1);
|
||||
|
||||
let prompt = assert_json_command(&root, &["--output-format", "json", "system-prompt"]);
|
||||
assert_eq!(prompt["kind"], "system-prompt");
|
||||
assert!(prompt["message"]
|
||||
.as_str()
|
||||
.expect("prompt text")
|
||||
.contains("interactive agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dump_manifests_and_init_emit_json_when_requested() {
|
||||
let root = unique_temp_dir("manifest-init-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let upstream = write_upstream_fixture(&root);
|
||||
let manifests = assert_json_command_with_env(
|
||||
&root,
|
||||
&["--output-format", "json", "dump-manifests"],
|
||||
&[(
|
||||
"CLAUDE_CODE_UPSTREAM",
|
||||
upstream.to_str().expect("utf8 upstream"),
|
||||
)],
|
||||
);
|
||||
assert_eq!(manifests["kind"], "dump-manifests");
|
||||
assert_eq!(manifests["commands"], 1);
|
||||
assert_eq!(manifests["tools"], 1);
|
||||
|
||||
let workspace = root.join("workspace");
|
||||
fs::create_dir_all(&workspace).expect("workspace should exist");
|
||||
let init = assert_json_command(&workspace, &["--output-format", "json", "init"]);
|
||||
assert_eq!(init["kind"], "init");
|
||||
assert!(workspace.join("CLAUDE.md").exists());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn doctor_and_resume_status_emit_json_when_requested() {
|
||||
let root = unique_temp_dir("doctor-resume-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let doctor = assert_json_command(&root, &["--output-format", "json", "doctor"]);
|
||||
assert_eq!(doctor["kind"], "doctor");
|
||||
assert!(doctor["message"].is_string());
|
||||
let summary = doctor["summary"].as_object().expect("doctor summary");
|
||||
assert!(summary["ok"].as_u64().is_some());
|
||||
assert!(summary["warnings"].as_u64().is_some());
|
||||
assert!(summary["failures"].as_u64().is_some());
|
||||
|
||||
let checks = doctor["checks"].as_array().expect("doctor checks");
|
||||
assert_eq!(checks.len(), 5);
|
||||
let check_names = checks
|
||||
.iter()
|
||||
.map(|check| {
|
||||
assert!(check["status"].as_str().is_some());
|
||||
assert!(check["summary"].as_str().is_some());
|
||||
assert!(check["details"].is_array());
|
||||
check["name"].as_str().expect("doctor check name")
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
check_names,
|
||||
vec!["auth", "config", "workspace", "sandbox", "system"]
|
||||
);
|
||||
|
||||
let workspace = checks
|
||||
.iter()
|
||||
.find(|check| check["name"] == "workspace")
|
||||
.expect("workspace check");
|
||||
assert!(workspace["cwd"].as_str().is_some());
|
||||
assert!(workspace["in_git_repo"].is_boolean());
|
||||
|
||||
let sandbox = checks
|
||||
.iter()
|
||||
.find(|check| check["name"] == "sandbox")
|
||||
.expect("sandbox check");
|
||||
assert!(sandbox["filesystem_mode"].as_str().is_some());
|
||||
assert!(sandbox["enabled"].is_boolean());
|
||||
assert!(sandbox["fallback_reason"].is_null() || sandbox["fallback_reason"].is_string());
|
||||
|
||||
let session_path = root.join("session.jsonl");
|
||||
fs::write(
|
||||
&session_path,
|
||||
"{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n{\"type\":\"message\",\"message\":{\"role\":\"user\",\"blocks\":[{\"type\":\"text\",\"text\":\"hello\"}]}}\n",
|
||||
)
|
||||
.expect("session should write");
|
||||
let resumed = assert_json_command(
|
||||
&root,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 session path"),
|
||||
"/status",
|
||||
],
|
||||
);
|
||||
assert_eq!(resumed["kind"], "status");
|
||||
// model is null in resume mode (not known without --model flag)
|
||||
assert!(resumed["model"].is_null());
|
||||
assert_eq!(resumed["usage"]["messages"], 1);
|
||||
assert!(resumed["workspace"]["cwd"].as_str().is_some());
|
||||
assert!(resumed["sandbox"]["filesystem_mode"].as_str().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_inventory_commands_emit_structured_json_when_requested() {
|
||||
let root = unique_temp_dir("resume-inventory-json");
|
||||
let config_home = root.join("config-home");
|
||||
let home = root.join("home");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
fs::create_dir_all(&home).expect("home should exist");
|
||||
|
||||
let session_path = root.join("session.jsonl");
|
||||
fs::write(
|
||||
&session_path,
|
||||
"{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-inventory-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n{\"type\":\"message\",\"message\":{\"role\":\"user\",\"blocks\":[{\"type\":\"text\",\"text\":\"inventory\"}]}}\n",
|
||||
)
|
||||
.expect("session should write");
|
||||
|
||||
let mcp = assert_json_command_with_env(
|
||||
&root,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 session path"),
|
||||
"/mcp",
|
||||
],
|
||||
&[
|
||||
(
|
||||
"CLAW_CONFIG_HOME",
|
||||
config_home.to_str().expect("utf8 config home"),
|
||||
),
|
||||
("HOME", home.to_str().expect("utf8 home")),
|
||||
],
|
||||
);
|
||||
assert_eq!(mcp["kind"], "mcp");
|
||||
assert_eq!(mcp["action"], "list");
|
||||
assert!(mcp["servers"].is_array());
|
||||
|
||||
let skills = assert_json_command_with_env(
|
||||
&root,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 session path"),
|
||||
"/skills",
|
||||
],
|
||||
&[
|
||||
(
|
||||
"CLAW_CONFIG_HOME",
|
||||
config_home.to_str().expect("utf8 config home"),
|
||||
),
|
||||
("HOME", home.to_str().expect("utf8 home")),
|
||||
],
|
||||
);
|
||||
assert_eq!(skills["kind"], "skills");
|
||||
assert_eq!(skills["action"], "list");
|
||||
assert!(skills["summary"]["total"].is_number());
|
||||
assert!(skills["skills"].is_array());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_version_and_init_emit_structured_json_when_requested() {
|
||||
let root = unique_temp_dir("resume-version-init-json");
|
||||
fs::create_dir_all(&root).expect("temp dir should exist");
|
||||
|
||||
let session_path = root.join("session.jsonl");
|
||||
fs::write(
|
||||
&session_path,
|
||||
"{\"type\":\"session_meta\",\"version\":3,\"session_id\":\"resume-version-init-json\",\"created_at_ms\":0,\"updated_at_ms\":0}\n",
|
||||
)
|
||||
.expect("session should write");
|
||||
|
||||
let version = assert_json_command(
|
||||
&root,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 session path"),
|
||||
"/version",
|
||||
],
|
||||
);
|
||||
assert_eq!(version["kind"], "version");
|
||||
assert_eq!(version["version"], env!("CARGO_PKG_VERSION"));
|
||||
|
||||
let init = assert_json_command(
|
||||
&root,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 session path"),
|
||||
"/init",
|
||||
],
|
||||
);
|
||||
assert_eq!(init["kind"], "init");
|
||||
assert!(root.join("CLAUDE.md").exists());
|
||||
}
|
||||
|
||||
fn assert_json_command(current_dir: &Path, args: &[&str]) -> Value {
|
||||
assert_json_command_with_env(current_dir, args, &[])
|
||||
}
|
||||
|
||||
fn assert_json_command_with_env(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Value {
|
||||
let output = run_claw(current_dir, args, envs);
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
serde_json::from_slice(&output.stdout).expect("stdout should be valid json")
|
||||
}
|
||||
|
||||
fn run_claw(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output {
|
||||
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
|
||||
command.current_dir(current_dir).args(args);
|
||||
for (key, value) in envs {
|
||||
command.env(key, value);
|
||||
}
|
||||
command.output().expect("claw should launch")
|
||||
}
|
||||
|
||||
fn write_upstream_fixture(root: &Path) -> PathBuf {
|
||||
let upstream = root.join("claw-code");
|
||||
let src = upstream.join("src");
|
||||
let entrypoints = src.join("entrypoints");
|
||||
fs::create_dir_all(&entrypoints).expect("upstream entrypoints dir should exist");
|
||||
fs::write(
|
||||
src.join("commands.ts"),
|
||||
"import FooCommand from './commands/foo'\n",
|
||||
)
|
||||
.expect("commands fixture should write");
|
||||
fs::write(
|
||||
src.join("tools.ts"),
|
||||
"import ReadTool from './tools/read'\n",
|
||||
)
|
||||
.expect("tools fixture should write");
|
||||
fs::write(
|
||||
entrypoints.join("cli.tsx"),
|
||||
"if (args[0] === '--version') {}\nstartupProfiler()\n",
|
||||
)
|
||||
.expect("cli fixture should write");
|
||||
upstream
|
||||
}
|
||||
|
||||
fn write_agent(root: &Path, name: &str, description: &str, model: &str, reasoning: &str) {
|
||||
fs::create_dir_all(root).expect("agent root should exist");
|
||||
fs::write(
|
||||
root.join(format!("{name}.toml")),
|
||||
format!(
|
||||
"name = \"{name}\"\ndescription = \"{description}\"\nmodel = \"{model}\"\nmodel_reasoning_effort = \"{reasoning}\"\n"
|
||||
),
|
||||
)
|
||||
.expect("agent fixture should write");
|
||||
}
|
||||
|
||||
fn unique_temp_dir(label: &str) -> PathBuf {
|
||||
let millis = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock should be after epoch")
|
||||
.as_millis();
|
||||
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!(
|
||||
"claw-output-format-{label}-{}-{millis}-{counter}",
|
||||
std::process::id()
|
||||
))
|
||||
}
|
||||
555
crates/rusty-claude-cli/tests/resume_slash_commands.rs
Normal file
555
crates/rusty-claude-cli/tests/resume_slash_commands.rs
Normal file
@ -0,0 +1,555 @@
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Output};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use runtime::ContentBlock;
|
||||
use runtime::Session;
|
||||
use serde_json::Value;
|
||||
|
||||
static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[test]
|
||||
fn resumed_binary_accepts_slash_commands_with_arguments() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("resume-slash-commands");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
let export_path = temp_dir.join("notes.txt");
|
||||
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text("ship the slash command harness")
|
||||
.expect("session write should succeed");
|
||||
session
|
||||
.save_to_path(&session_path)
|
||||
.expect("session should persist");
|
||||
|
||||
// when
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/export",
|
||||
export_path.to_str().expect("utf8 path"),
|
||||
"/clear",
|
||||
"--confirm",
|
||||
],
|
||||
);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Export"));
|
||||
assert!(stdout.contains("wrote transcript"));
|
||||
assert!(stdout.contains(export_path.to_str().expect("utf8 path")));
|
||||
assert!(stdout.contains("Session cleared"));
|
||||
assert!(stdout.contains("Mode resumed session reset"));
|
||||
assert!(stdout.contains("Previous session"));
|
||||
assert!(stdout.contains("Resume previous claw --resume"));
|
||||
assert!(stdout.contains("Backup "));
|
||||
assert!(stdout.contains("Session file "));
|
||||
|
||||
let export = fs::read_to_string(&export_path).expect("export file should exist");
|
||||
assert!(export.contains("# Conversation Export"));
|
||||
assert!(export.contains("ship the slash command harness"));
|
||||
|
||||
let restored = Session::load_from_path(&session_path).expect("cleared session should load");
|
||||
assert!(restored.messages.is_empty());
|
||||
|
||||
let backup_path = stdout
|
||||
.lines()
|
||||
.find_map(|line| line.strip_prefix(" Backup "))
|
||||
.map(PathBuf::from)
|
||||
.expect("clear output should include backup path");
|
||||
let backup = Session::load_from_path(&backup_path).expect("backup session should load");
|
||||
assert_eq!(backup.messages.len(), 1);
|
||||
assert!(matches!(
|
||||
backup.messages[0].blocks.first(),
|
||||
Some(ContentBlock::Text { text }) if text == "ship the slash command harness"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn status_command_applies_cli_flags_end_to_end() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("status-command-flags");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
|
||||
// when
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--model",
|
||||
"sonnet",
|
||||
"--permission-mode",
|
||||
"read-only",
|
||||
"status",
|
||||
],
|
||||
);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Status"));
|
||||
assert!(stdout.contains("Model claude-sonnet-4-6"));
|
||||
assert!(stdout.contains("Permission mode read-only"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_config_command_loads_settings_files_end_to_end() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("resume-config");
|
||||
let project_dir = temp_dir.join("project");
|
||||
let config_home = temp_dir.join("home").join(".claw");
|
||||
fs::create_dir_all(project_dir.join(".claw")).expect("project config dir should exist");
|
||||
fs::create_dir_all(&config_home).expect("config home should exist");
|
||||
|
||||
let session_path = project_dir.join("session.jsonl");
|
||||
Session::new()
|
||||
.with_persistence_path(&session_path)
|
||||
.save_to_path(&session_path)
|
||||
.expect("session should persist");
|
||||
|
||||
fs::write(config_home.join("settings.json"), r#"{"model":"haiku"}"#)
|
||||
.expect("user config should write");
|
||||
fs::write(
|
||||
project_dir.join(".claw").join("settings.local.json"),
|
||||
r#"{"model":"opus"}"#,
|
||||
)
|
||||
.expect("local config should write");
|
||||
|
||||
// when
|
||||
let output = run_claw_with_env(
|
||||
&project_dir,
|
||||
&[
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/config",
|
||||
"model",
|
||||
],
|
||||
&[("CLAW_CONFIG_HOME", config_home.to_str().expect("utf8 path"))],
|
||||
);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Config"));
|
||||
assert!(stdout.contains("Loaded files 2"));
|
||||
assert!(stdout.contains(
|
||||
config_home
|
||||
.join("settings.json")
|
||||
.to_str()
|
||||
.expect("utf8 path")
|
||||
));
|
||||
assert!(stdout.contains(
|
||||
project_dir
|
||||
.join(".claw")
|
||||
.join("settings.local.json")
|
||||
.to_str()
|
||||
.expect("utf8 path")
|
||||
));
|
||||
assert!(stdout.contains("Merged section: model"));
|
||||
assert!(stdout.contains("opus"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resume_latest_restores_the_most_recent_managed_session() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("resume-latest");
|
||||
let project_dir = temp_dir.join("project");
|
||||
let sessions_dir = project_dir.join(".claw").join("sessions");
|
||||
fs::create_dir_all(&sessions_dir).expect("sessions dir should exist");
|
||||
|
||||
let older_path = sessions_dir.join("session-older.jsonl");
|
||||
let newer_path = sessions_dir.join("session-newer.jsonl");
|
||||
|
||||
let mut older = Session::new().with_persistence_path(&older_path);
|
||||
older
|
||||
.push_user_text("older session")
|
||||
.expect("older session write should succeed");
|
||||
older
|
||||
.save_to_path(&older_path)
|
||||
.expect("older session should persist");
|
||||
|
||||
let mut newer = Session::new().with_persistence_path(&newer_path);
|
||||
newer
|
||||
.push_user_text("newer session")
|
||||
.expect("newer session write should succeed");
|
||||
newer
|
||||
.push_user_text("resume me")
|
||||
.expect("newer session write should succeed");
|
||||
newer
|
||||
.save_to_path(&newer_path)
|
||||
.expect("newer session should persist");
|
||||
|
||||
// when
|
||||
let output = run_claw(&project_dir, &["--resume", "latest", "/status"]);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
assert!(stdout.contains("Status"));
|
||||
assert!(stdout.contains("Messages 2"));
|
||||
assert!(stdout.contains(newer_path.to_str().expect("utf8 path")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_status_command_emits_structured_json_when_requested() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("resume-status-json");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text("resume status json fixture")
|
||||
.expect("session write should succeed");
|
||||
session
|
||||
.save_to_path(&session_path)
|
||||
.expect("session should persist");
|
||||
|
||||
// when
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/status",
|
||||
],
|
||||
);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
let parsed: Value =
|
||||
serde_json::from_str(stdout.trim()).expect("resume status output should be json");
|
||||
assert_eq!(parsed["kind"], "status");
|
||||
// model is null in resume mode (not known without --model flag)
|
||||
assert!(parsed["model"].is_null());
|
||||
assert_eq!(parsed["permission_mode"], "danger-full-access");
|
||||
assert_eq!(parsed["usage"]["messages"], 1);
|
||||
assert!(parsed["usage"]["turns"].is_number());
|
||||
assert!(parsed["workspace"]["cwd"].as_str().is_some());
|
||||
assert_eq!(
|
||||
parsed["workspace"]["session"],
|
||||
session_path.to_str().expect("utf8 path")
|
||||
);
|
||||
assert!(parsed["workspace"]["changed_files"].is_number());
|
||||
assert_eq!(parsed["workspace"]["loaded_config_files"].as_u64(), Some(0));
|
||||
assert!(parsed["sandbox"]["filesystem_mode"].as_str().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_status_surfaces_persisted_model() {
|
||||
// given — create a session with model already set
|
||||
let temp_dir = unique_temp_dir("resume-status-model");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
|
||||
let mut session = Session::new();
|
||||
session.model = Some("claude-sonnet-4-6".to_string());
|
||||
session
|
||||
.push_user_text("model persistence fixture")
|
||||
.expect("write ok");
|
||||
session.save_to_path(&session_path).expect("persist ok");
|
||||
|
||||
// when
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/status",
|
||||
],
|
||||
);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("utf8");
|
||||
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
|
||||
assert_eq!(parsed["kind"], "status");
|
||||
assert_eq!(
|
||||
parsed["model"], "claude-sonnet-4-6",
|
||||
"model should round-trip through session metadata"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_sandbox_command_emits_structured_json_when_requested() {
|
||||
// given
|
||||
let temp_dir = unique_temp_dir("resume-sandbox-json");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
|
||||
Session::new()
|
||||
.save_to_path(&session_path)
|
||||
.expect("session should persist");
|
||||
|
||||
// when
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/sandbox",
|
||||
],
|
||||
);
|
||||
|
||||
// then
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stdout:\n{}\n\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).expect("stdout should be utf8");
|
||||
let parsed: Value =
|
||||
serde_json::from_str(stdout.trim()).expect("resume sandbox output should be json");
|
||||
assert_eq!(parsed["kind"], "sandbox");
|
||||
assert!(parsed["enabled"].is_boolean());
|
||||
assert!(parsed["active"].is_boolean());
|
||||
assert!(parsed["supported"].is_boolean());
|
||||
assert!(parsed["filesystem_mode"].as_str().is_some());
|
||||
assert!(parsed["allowed_mounts"].is_array());
|
||||
assert!(parsed["markers"].is_array());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_version_command_emits_structured_json() {
|
||||
let temp_dir = unique_temp_dir("resume-version-json");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
Session::new()
|
||||
.save_to_path(&session_path)
|
||||
.expect("session should persist");
|
||||
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/version",
|
||||
],
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("utf8");
|
||||
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
|
||||
assert_eq!(parsed["kind"], "version");
|
||||
assert!(parsed["version"].as_str().is_some());
|
||||
assert!(parsed["git_sha"].as_str().is_some());
|
||||
assert!(parsed["target"].as_str().is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_export_command_emits_structured_json() {
|
||||
let temp_dir = unique_temp_dir("resume-export-json");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text("export json fixture")
|
||||
.expect("write ok");
|
||||
session.save_to_path(&session_path).expect("persist ok");
|
||||
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/export",
|
||||
],
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("utf8");
|
||||
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
|
||||
assert_eq!(parsed["kind"], "export");
|
||||
assert!(parsed["file"].as_str().is_some());
|
||||
assert_eq!(parsed["message_count"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_help_command_emits_structured_json() {
|
||||
let temp_dir = unique_temp_dir("resume-help-json");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
Session::new()
|
||||
.save_to_path(&session_path)
|
||||
.expect("persist ok");
|
||||
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/help",
|
||||
],
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("utf8");
|
||||
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
|
||||
assert_eq!(parsed["kind"], "help");
|
||||
assert!(parsed["text"].as_str().is_some());
|
||||
let text = parsed["text"].as_str().unwrap();
|
||||
assert!(text.contains("/status"), "help text should list /status");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_no_command_emits_restored_json() {
|
||||
let temp_dir = unique_temp_dir("resume-no-cmd-json");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
let mut session = Session::new();
|
||||
session
|
||||
.push_user_text("restored json fixture")
|
||||
.expect("write ok");
|
||||
session.save_to_path(&session_path).expect("persist ok");
|
||||
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
],
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"stderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
let stdout = String::from_utf8(output.stdout).expect("utf8");
|
||||
let parsed: Value = serde_json::from_str(stdout.trim()).expect("should be json");
|
||||
assert_eq!(parsed["kind"], "restored");
|
||||
assert!(parsed["session_id"].as_str().is_some());
|
||||
assert!(parsed["path"].as_str().is_some());
|
||||
assert_eq!(parsed["message_count"], 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resumed_stub_command_emits_not_implemented_json() {
|
||||
let temp_dir = unique_temp_dir("resume-stub-json");
|
||||
fs::create_dir_all(&temp_dir).expect("temp dir should exist");
|
||||
let session_path = temp_dir.join("session.jsonl");
|
||||
Session::new()
|
||||
.save_to_path(&session_path)
|
||||
.expect("persist ok");
|
||||
|
||||
let output = run_claw(
|
||||
&temp_dir,
|
||||
&[
|
||||
"--output-format",
|
||||
"json",
|
||||
"--resume",
|
||||
session_path.to_str().expect("utf8 path"),
|
||||
"/allowed-tools",
|
||||
],
|
||||
);
|
||||
|
||||
// Stub commands exit with code 2
|
||||
assert!(!output.status.success());
|
||||
let stderr = String::from_utf8(output.stderr).expect("utf8");
|
||||
let parsed: Value = serde_json::from_str(stderr.trim()).expect("should be json");
|
||||
assert_eq!(parsed["type"], "error");
|
||||
assert!(
|
||||
parsed["error"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("not yet implemented"),
|
||||
"error should say not yet implemented: {:?}",
|
||||
parsed["error"]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_claw(current_dir: &Path, args: &[&str]) -> Output {
|
||||
run_claw_with_env(current_dir, args, &[])
|
||||
}
|
||||
|
||||
fn run_claw_with_env(current_dir: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output {
|
||||
let mut command = Command::new(env!("CARGO_BIN_EXE_claw"));
|
||||
command.current_dir(current_dir).args(args);
|
||||
for (key, value) in envs {
|
||||
command.env(key, value);
|
||||
}
|
||||
command.output().expect("claw should launch")
|
||||
}
|
||||
|
||||
fn unique_temp_dir(label: &str) -> PathBuf {
|
||||
let millis = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("clock should be after epoch")
|
||||
.as_millis();
|
||||
let counter = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
std::env::temp_dir().join(format!(
|
||||
"claw-{label}-{}-{millis}-{counter}",
|
||||
std::process::id()
|
||||
))
|
||||
}
|
||||
@ -20,7 +20,7 @@ use runtime::{
|
||||
ToolError, ToolExecutor,
|
||||
};
|
||||
use api::{
|
||||
max_tokens_for_model, resolve_startup_auth_source, AuthSource, ClawApiClient,
|
||||
max_tokens_for_model, resolve_startup_auth_source, AnthropicClient, AuthSource,
|
||||
ContentBlockDelta, InputContentBlock, InputMessage, MessageRequest,
|
||||
MessageResponse, OutputContentBlock,
|
||||
StreamEvent as ApiStreamEvent, ToolChoice, ToolResultContentBlock,
|
||||
@ -104,7 +104,7 @@ impl SessionEvent {
|
||||
// ── ServerApiClient:实现 runtime::ApiClient trait ────────────────────
|
||||
|
||||
struct ServerApiClient {
|
||||
client: ClawApiClient,
|
||||
client: AnthropicClient,
|
||||
model: String,
|
||||
tool_registry: GlobalToolRegistry,
|
||||
allowed_tools: Option<BTreeSet<String>>,
|
||||
@ -127,6 +127,12 @@ impl ApiClient for ServerApiClient {
|
||||
.then(|| self.tool_registry.definitions(self.allowed_tools.as_ref())),
|
||||
tool_choice: self.enable_tools.then_some(ToolChoice::Auto),
|
||||
stream: true,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
stop: None,
|
||||
reasoning_effort: None,
|
||||
};
|
||||
|
||||
let rt = tokio::runtime::Runtime::new()
|
||||
@ -192,7 +198,6 @@ impl ApiClient for ServerApiClient {
|
||||
session_id: self.session_id.clone(),
|
||||
thinking: thinking.clone(),
|
||||
});
|
||||
events.push(AssistantEvent::ThinkingDelta(thinking));
|
||||
}
|
||||
}
|
||||
ContentBlockDelta::SignatureDelta { .. } => {}
|
||||
@ -303,7 +308,6 @@ fn push_output_block(
|
||||
session_id: session_id.to_string(),
|
||||
thinking: thinking.clone(),
|
||||
});
|
||||
events.push(AssistantEvent::ThinkingDelta(thinking));
|
||||
}
|
||||
}
|
||||
OutputContentBlock::RedactedThinking { .. } => {}
|
||||
@ -412,17 +416,6 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
|
||||
ContentBlock::Text { text } => {
|
||||
InputContentBlock::Text { text: text.clone() }
|
||||
}
|
||||
ContentBlock::Thinking {
|
||||
thinking,
|
||||
signature,
|
||||
} => InputContentBlock::Thinking {
|
||||
thinking: thinking.clone(),
|
||||
signature: signature.clone(),
|
||||
},
|
||||
ContentBlock::RedactedThinking { data } => InputContentBlock::RedactedThinking {
|
||||
data: serde_json::from_str(&data.render())
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
},
|
||||
ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
@ -454,6 +447,7 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> {
|
||||
fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy {
|
||||
tool_registry
|
||||
.permission_specs(None)
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.fold(PermissionPolicy::new(mode), |policy, (name, req)| {
|
||||
policy.with_tool_requirement(name, req)
|
||||
@ -632,7 +626,7 @@ impl Session {
|
||||
let (events_tx, _) = broadcast::channel(BROADCAST_CAPACITY);
|
||||
|
||||
let auth = resolve_server_auth_source(cwd)?;
|
||||
let client = ClawApiClient::from_auth(auth).with_base_url(api::read_base_url());
|
||||
let client = AnthropicClient::from_auth(auth).with_base_url(api::read_base_url());
|
||||
|
||||
let api_client = ServerApiClient {
|
||||
client,
|
||||
@ -658,7 +652,7 @@ impl Session {
|
||||
tool_executor,
|
||||
policy,
|
||||
system_prompt,
|
||||
feature_config.clone(),
|
||||
&feature_config,
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
|
||||
13
crates/telemetry/Cargo.toml
Normal file
13
crates/telemetry/Cargo.toml
Normal file
@ -0,0 +1,13 @@
|
||||
[package]
|
||||
name = "telemetry"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
526
crates/telemetry/src/lib.rs
Normal file
526
crates/telemetry/src/lib.rs
Normal file
@ -0,0 +1,526 @@
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
pub const DEFAULT_ANTHROPIC_VERSION: &str = "2023-06-01";
|
||||
pub const DEFAULT_APP_NAME: &str = "claude-code";
|
||||
pub const DEFAULT_RUNTIME: &str = "rust";
|
||||
pub const DEFAULT_AGENTIC_BETA: &str = "claude-code-20250219";
|
||||
pub const DEFAULT_PROMPT_CACHING_SCOPE_BETA: &str = "prompt-caching-scope-2026-01-05";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ClientIdentity {
|
||||
pub app_name: String,
|
||||
pub app_version: String,
|
||||
pub runtime: String,
|
||||
}
|
||||
|
||||
impl ClientIdentity {
|
||||
#[must_use]
|
||||
pub fn new(app_name: impl Into<String>, app_version: impl Into<String>) -> Self {
|
||||
Self {
|
||||
app_name: app_name.into(),
|
||||
app_version: app_version.into(),
|
||||
runtime: DEFAULT_RUNTIME.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_runtime(mut self, runtime: impl Into<String>) -> Self {
|
||||
self.runtime = runtime.into();
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn user_agent(&self) -> String {
|
||||
format!("{}/{}", self.app_name, self.app_version)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClientIdentity {
|
||||
fn default() -> Self {
|
||||
Self::new(DEFAULT_APP_NAME, env!("CARGO_PKG_VERSION"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AnthropicRequestProfile {
|
||||
pub anthropic_version: String,
|
||||
pub client_identity: ClientIdentity,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub betas: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Map::is_empty")]
|
||||
pub extra_body: Map<String, Value>,
|
||||
}
|
||||
|
||||
impl AnthropicRequestProfile {
|
||||
#[must_use]
|
||||
pub fn new(client_identity: ClientIdentity) -> Self {
|
||||
Self {
|
||||
anthropic_version: DEFAULT_ANTHROPIC_VERSION.to_string(),
|
||||
client_identity,
|
||||
betas: vec![
|
||||
DEFAULT_AGENTIC_BETA.to_string(),
|
||||
DEFAULT_PROMPT_CACHING_SCOPE_BETA.to_string(),
|
||||
],
|
||||
extra_body: Map::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_beta(mut self, beta: impl Into<String>) -> Self {
|
||||
let beta = beta.into();
|
||||
if !self.betas.contains(&beta) {
|
||||
self.betas.push(beta);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_extra_body(mut self, key: impl Into<String>, value: Value) -> Self {
|
||||
self.extra_body.insert(key.into(), value);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn header_pairs(&self) -> Vec<(String, String)> {
|
||||
let mut headers = vec![
|
||||
(
|
||||
"anthropic-version".to_string(),
|
||||
self.anthropic_version.clone(),
|
||||
),
|
||||
("user-agent".to_string(), self.client_identity.user_agent()),
|
||||
];
|
||||
if !self.betas.is_empty() {
|
||||
headers.push(("anthropic-beta".to_string(), self.betas.join(",")));
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
pub fn render_json_body<T: Serialize>(&self, request: &T) -> Result<Value, serde_json::Error> {
|
||||
let mut body = serde_json::to_value(request)?;
|
||||
let object = body.as_object_mut().ok_or_else(|| {
|
||||
serde_json::Error::io(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"request body must serialize to a JSON object",
|
||||
))
|
||||
})?;
|
||||
for (key, value) in &self.extra_body {
|
||||
object.insert(key.clone(), value.clone());
|
||||
}
|
||||
if !self.betas.is_empty() {
|
||||
object.insert(
|
||||
"betas".to_string(),
|
||||
Value::Array(self.betas.iter().cloned().map(Value::String).collect()),
|
||||
);
|
||||
}
|
||||
Ok(body)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AnthropicRequestProfile {
|
||||
fn default() -> Self {
|
||||
Self::new(ClientIdentity::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct AnalyticsEvent {
|
||||
pub namespace: String,
|
||||
pub action: String,
|
||||
#[serde(default, skip_serializing_if = "Map::is_empty")]
|
||||
pub properties: Map<String, Value>,
|
||||
}
|
||||
|
||||
impl AnalyticsEvent {
|
||||
#[must_use]
|
||||
pub fn new(namespace: impl Into<String>, action: impl Into<String>) -> Self {
|
||||
Self {
|
||||
namespace: namespace.into(),
|
||||
action: action.into(),
|
||||
properties: Map::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_property(mut self, key: impl Into<String>, value: Value) -> Self {
|
||||
self.properties.insert(key.into(), value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SessionTraceRecord {
|
||||
pub session_id: String,
|
||||
pub sequence: u64,
|
||||
pub name: String,
|
||||
pub timestamp_ms: u64,
|
||||
#[serde(default, skip_serializing_if = "Map::is_empty")]
|
||||
pub attributes: Map<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum TelemetryEvent {
|
||||
HttpRequestStarted {
|
||||
session_id: String,
|
||||
attempt: u32,
|
||||
method: String,
|
||||
path: String,
|
||||
#[serde(default, skip_serializing_if = "Map::is_empty")]
|
||||
attributes: Map<String, Value>,
|
||||
},
|
||||
HttpRequestSucceeded {
|
||||
session_id: String,
|
||||
attempt: u32,
|
||||
method: String,
|
||||
path: String,
|
||||
status: u16,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
request_id: Option<String>,
|
||||
#[serde(default, skip_serializing_if = "Map::is_empty")]
|
||||
attributes: Map<String, Value>,
|
||||
},
|
||||
HttpRequestFailed {
|
||||
session_id: String,
|
||||
attempt: u32,
|
||||
method: String,
|
||||
path: String,
|
||||
error: String,
|
||||
retryable: bool,
|
||||
#[serde(default, skip_serializing_if = "Map::is_empty")]
|
||||
attributes: Map<String, Value>,
|
||||
},
|
||||
Analytics(AnalyticsEvent),
|
||||
SessionTrace(SessionTraceRecord),
|
||||
}
|
||||
|
||||
pub trait TelemetrySink: Send + Sync {
|
||||
fn record(&self, event: TelemetryEvent);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct MemoryTelemetrySink {
|
||||
events: Mutex<Vec<TelemetryEvent>>,
|
||||
}
|
||||
|
||||
impl MemoryTelemetrySink {
|
||||
#[must_use]
|
||||
pub fn events(&self) -> Vec<TelemetryEvent> {
|
||||
self.events
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl TelemetrySink for MemoryTelemetrySink {
|
||||
fn record(&self, event: TelemetryEvent) {
|
||||
self.events
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JsonlTelemetrySink {
|
||||
path: PathBuf,
|
||||
file: Mutex<File>,
|
||||
}
|
||||
|
||||
impl Debug for JsonlTelemetrySink {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("JsonlTelemetrySink")
|
||||
.field("path", &self.path)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonlTelemetrySink {
|
||||
pub fn new(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
|
||||
let path = path.as_ref().to_path_buf();
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let file = OpenOptions::new().create(true).append(true).open(&path)?;
|
||||
Ok(Self {
|
||||
path,
|
||||
file: Mutex::new(file),
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
}
|
||||
|
||||
impl TelemetrySink for JsonlTelemetrySink {
|
||||
fn record(&self, event: TelemetryEvent) {
|
||||
let Ok(line) = serde_json::to_string(&event) else {
|
||||
return;
|
||||
};
|
||||
let mut file = self
|
||||
.file
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let _ = writeln!(file, "{line}");
|
||||
let _ = file.flush();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionTracer {
|
||||
session_id: String,
|
||||
sequence: Arc<AtomicU64>,
|
||||
sink: Arc<dyn TelemetrySink>,
|
||||
}
|
||||
|
||||
impl Debug for SessionTracer {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SessionTracer")
|
||||
.field("session_id", &self.session_id)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionTracer {
|
||||
#[must_use]
|
||||
pub fn new(session_id: impl Into<String>, sink: Arc<dyn TelemetrySink>) -> Self {
|
||||
Self {
|
||||
session_id: session_id.into(),
|
||||
sequence: Arc::new(AtomicU64::new(0)),
|
||||
sink,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn session_id(&self) -> &str {
|
||||
&self.session_id
|
||||
}
|
||||
|
||||
pub fn record(&self, name: impl Into<String>, attributes: Map<String, Value>) {
|
||||
let record = SessionTraceRecord {
|
||||
session_id: self.session_id.clone(),
|
||||
sequence: self.sequence.fetch_add(1, Ordering::Relaxed),
|
||||
name: name.into(),
|
||||
timestamp_ms: current_timestamp_ms(),
|
||||
attributes,
|
||||
};
|
||||
self.sink.record(TelemetryEvent::SessionTrace(record));
|
||||
}
|
||||
|
||||
pub fn record_http_request_started(
|
||||
&self,
|
||||
attempt: u32,
|
||||
method: impl Into<String>,
|
||||
path: impl Into<String>,
|
||||
attributes: Map<String, Value>,
|
||||
) {
|
||||
let method = method.into();
|
||||
let path = path.into();
|
||||
self.sink.record(TelemetryEvent::HttpRequestStarted {
|
||||
session_id: self.session_id.clone(),
|
||||
attempt,
|
||||
method: method.clone(),
|
||||
path: path.clone(),
|
||||
attributes: attributes.clone(),
|
||||
});
|
||||
self.record(
|
||||
"http_request_started",
|
||||
merge_trace_fields(method, path, attempt, attributes),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn record_http_request_succeeded(
|
||||
&self,
|
||||
attempt: u32,
|
||||
method: impl Into<String>,
|
||||
path: impl Into<String>,
|
||||
status: u16,
|
||||
request_id: Option<String>,
|
||||
attributes: Map<String, Value>,
|
||||
) {
|
||||
let method = method.into();
|
||||
let path = path.into();
|
||||
self.sink.record(TelemetryEvent::HttpRequestSucceeded {
|
||||
session_id: self.session_id.clone(),
|
||||
attempt,
|
||||
method: method.clone(),
|
||||
path: path.clone(),
|
||||
status,
|
||||
request_id: request_id.clone(),
|
||||
attributes: attributes.clone(),
|
||||
});
|
||||
let mut trace_attributes = merge_trace_fields(method, path, attempt, attributes);
|
||||
trace_attributes.insert("status".to_string(), Value::from(status));
|
||||
if let Some(request_id) = request_id {
|
||||
trace_attributes.insert("request_id".to_string(), Value::String(request_id));
|
||||
}
|
||||
self.record("http_request_succeeded", trace_attributes);
|
||||
}
|
||||
|
||||
pub fn record_http_request_failed(
|
||||
&self,
|
||||
attempt: u32,
|
||||
method: impl Into<String>,
|
||||
path: impl Into<String>,
|
||||
error: impl Into<String>,
|
||||
retryable: bool,
|
||||
attributes: Map<String, Value>,
|
||||
) {
|
||||
let method = method.into();
|
||||
let path = path.into();
|
||||
let error = error.into();
|
||||
self.sink.record(TelemetryEvent::HttpRequestFailed {
|
||||
session_id: self.session_id.clone(),
|
||||
attempt,
|
||||
method: method.clone(),
|
||||
path: path.clone(),
|
||||
error: error.clone(),
|
||||
retryable,
|
||||
attributes: attributes.clone(),
|
||||
});
|
||||
let mut trace_attributes = merge_trace_fields(method, path, attempt, attributes);
|
||||
trace_attributes.insert("error".to_string(), Value::String(error));
|
||||
trace_attributes.insert("retryable".to_string(), Value::Bool(retryable));
|
||||
self.record("http_request_failed", trace_attributes);
|
||||
}
|
||||
|
||||
pub fn record_analytics(&self, event: AnalyticsEvent) {
|
||||
let mut attributes = event.properties.clone();
|
||||
attributes.insert(
|
||||
"namespace".to_string(),
|
||||
Value::String(event.namespace.clone()),
|
||||
);
|
||||
attributes.insert("action".to_string(), Value::String(event.action.clone()));
|
||||
self.sink.record(TelemetryEvent::Analytics(event));
|
||||
self.record("analytics", attributes);
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_trace_fields(
|
||||
method: String,
|
||||
path: String,
|
||||
attempt: u32,
|
||||
mut attributes: Map<String, Value>,
|
||||
) -> Map<String, Value> {
|
||||
attributes.insert("method".to_string(), Value::String(method));
|
||||
attributes.insert("path".to_string(), Value::String(path));
|
||||
attributes.insert("attempt".to_string(), Value::from(attempt));
|
||||
attributes
|
||||
}
|
||||
|
||||
fn current_timestamp_ms() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis()
|
||||
.try_into()
|
||||
.unwrap_or(u64::MAX)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn request_profile_emits_headers_and_merges_body() {
|
||||
let profile = AnthropicRequestProfile::new(
|
||||
ClientIdentity::new("claude-code", "1.2.3").with_runtime("rust-cli"),
|
||||
)
|
||||
.with_beta("tools-2026-04-01")
|
||||
.with_extra_body("metadata", serde_json::json!({"source": "test"}));
|
||||
|
||||
assert_eq!(
|
||||
profile.header_pairs(),
|
||||
vec![
|
||||
(
|
||||
"anthropic-version".to_string(),
|
||||
DEFAULT_ANTHROPIC_VERSION.to_string()
|
||||
),
|
||||
("user-agent".to_string(), "claude-code/1.2.3".to_string()),
|
||||
(
|
||||
"anthropic-beta".to_string(),
|
||||
"claude-code-20250219,prompt-caching-scope-2026-01-05,tools-2026-04-01"
|
||||
.to_string(),
|
||||
),
|
||||
]
|
||||
);
|
||||
|
||||
let body = profile
|
||||
.render_json_body(&serde_json::json!({"model": "claude-sonnet"}))
|
||||
.expect("body should serialize");
|
||||
assert_eq!(
|
||||
body["metadata"]["source"],
|
||||
Value::String("test".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
body["betas"],
|
||||
serde_json::json!([
|
||||
"claude-code-20250219",
|
||||
"prompt-caching-scope-2026-01-05",
|
||||
"tools-2026-04-01"
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_tracer_records_structured_events_and_trace_sequence() {
|
||||
let sink = Arc::new(MemoryTelemetrySink::default());
|
||||
let tracer = SessionTracer::new("session-123", sink.clone());
|
||||
|
||||
tracer.record_http_request_started(1, "POST", "/v1/messages", Map::new());
|
||||
tracer.record_analytics(
|
||||
AnalyticsEvent::new("cli", "prompt_sent")
|
||||
.with_property("model", Value::String("claude-opus".to_string())),
|
||||
);
|
||||
|
||||
let events = sink.events();
|
||||
assert!(matches!(
|
||||
&events[0],
|
||||
TelemetryEvent::HttpRequestStarted {
|
||||
session_id,
|
||||
attempt: 1,
|
||||
method,
|
||||
path,
|
||||
..
|
||||
} if session_id == "session-123" && method == "POST" && path == "/v1/messages"
|
||||
));
|
||||
assert!(matches!(
|
||||
&events[1],
|
||||
TelemetryEvent::SessionTrace(SessionTraceRecord { sequence: 0, name, .. })
|
||||
if name == "http_request_started"
|
||||
));
|
||||
assert!(matches!(&events[2], TelemetryEvent::Analytics(_)));
|
||||
assert!(matches!(
|
||||
&events[3],
|
||||
TelemetryEvent::SessionTrace(SessionTraceRecord { sequence: 1, name, .. })
|
||||
if name == "analytics"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonl_sink_persists_events() {
|
||||
let path =
|
||||
std::env::temp_dir().join(format!("telemetry-jsonl-{}.log", current_timestamp_ms()));
|
||||
let sink = JsonlTelemetrySink::new(&path).expect("sink should create file");
|
||||
|
||||
sink.record(TelemetryEvent::Analytics(
|
||||
AnalyticsEvent::new("cli", "turn_completed").with_property("ok", Value::Bool(true)),
|
||||
));
|
||||
|
||||
let contents = std::fs::read_to_string(&path).expect("telemetry log should be readable");
|
||||
assert!(contents.contains("\"type\":\"analytics\""));
|
||||
assert!(contents.contains("\"action\":\"turn_completed\""));
|
||||
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
@ -7,6 +7,8 @@ publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
api = { path = "../api" }
|
||||
commands = { path = "../commands" }
|
||||
flate2 = "1"
|
||||
plugins = { path = "../plugins" }
|
||||
runtime = { path = "../runtime" }
|
||||
reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] }
|
||||
|
||||
181
crates/tools/src/lane_completion.rs
Normal file
181
crates/tools/src/lane_completion.rs
Normal file
@ -0,0 +1,181 @@
|
||||
//! Lane completion detector — automatically marks lanes as completed when
|
||||
//! session finishes successfully with green tests and pushed code.
|
||||
//!
|
||||
//! This bridges the gap where `LaneContext::completed` was a passive bool
|
||||
//! that nothing automatically set. Now completion is detected from:
|
||||
//! - Agent output shows Finished status
|
||||
//! - No errors/blockers present
|
||||
//! - Tests passed (green status)
|
||||
//! - Code pushed (has output file)
|
||||
|
||||
use runtime::{
|
||||
evaluate, LaneBlocker, LaneContext, PolicyAction, PolicyCondition, PolicyEngine, PolicyRule,
|
||||
ReviewStatus,
|
||||
};
|
||||
|
||||
use crate::AgentOutput;
|
||||
|
||||
/// Detects if a lane should be automatically marked as completed.
|
||||
///
|
||||
/// Returns `Some(LaneContext)` with `completed = true` if all conditions met,
|
||||
/// `None` if lane should remain active.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn detect_lane_completion(
|
||||
output: &AgentOutput,
|
||||
test_green: bool,
|
||||
has_pushed: bool,
|
||||
) -> Option<LaneContext> {
|
||||
// Must be finished without errors
|
||||
if output.error.is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Must have finished status
|
||||
if !output.status.eq_ignore_ascii_case("completed")
|
||||
&& !output.status.eq_ignore_ascii_case("finished")
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Must have no current blocker
|
||||
if output.current_blocker.is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Must have green tests
|
||||
if !test_green {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Must have pushed code
|
||||
if !has_pushed {
|
||||
return None;
|
||||
}
|
||||
|
||||
// All conditions met — create completed context
|
||||
Some(LaneContext {
|
||||
lane_id: output.agent_id.clone(),
|
||||
green_level: 3, // Workspace green
|
||||
branch_freshness: std::time::Duration::from_secs(0),
|
||||
blocker: LaneBlocker::None,
|
||||
review_status: ReviewStatus::Approved,
|
||||
diff_scope: runtime::DiffScope::Scoped,
|
||||
completed: true,
|
||||
reconciled: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Evaluates policy actions for a completed lane.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn evaluate_completed_lane(context: &LaneContext) -> Vec<PolicyAction> {
|
||||
let engine = PolicyEngine::new(vec![
|
||||
PolicyRule::new(
|
||||
"closeout-completed-lane",
|
||||
PolicyCondition::And(vec![
|
||||
PolicyCondition::LaneCompleted,
|
||||
PolicyCondition::GreenAt { level: 3 },
|
||||
]),
|
||||
PolicyAction::CloseoutLane,
|
||||
10,
|
||||
),
|
||||
PolicyRule::new(
|
||||
"cleanup-completed-session",
|
||||
PolicyCondition::LaneCompleted,
|
||||
PolicyAction::CleanupSession,
|
||||
5,
|
||||
),
|
||||
]);
|
||||
|
||||
evaluate(&engine, context)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use runtime::{DiffScope, LaneBlocker};
|
||||
|
||||
fn test_output() -> AgentOutput {
|
||||
AgentOutput {
|
||||
agent_id: "test-lane-1".to_string(),
|
||||
name: "Test Agent".to_string(),
|
||||
description: "Test".to_string(),
|
||||
subagent_type: None,
|
||||
model: None,
|
||||
status: "Finished".to_string(),
|
||||
output_file: "/tmp/test.output".to_string(),
|
||||
manifest_file: "/tmp/test.manifest".to_string(),
|
||||
created_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
started_at: Some("2024-01-01T00:00:00Z".to_string()),
|
||||
completed_at: Some("2024-01-01T00:00:00Z".to_string()),
|
||||
lane_events: vec![],
|
||||
derived_state: "working".to_string(),
|
||||
current_blocker: None,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_completion_when_all_conditions_met() {
|
||||
let output = test_output();
|
||||
let result = detect_lane_completion(&output, true, true);
|
||||
|
||||
assert!(result.is_some());
|
||||
let context = result.unwrap();
|
||||
assert!(context.completed);
|
||||
assert_eq!(context.green_level, 3);
|
||||
assert_eq!(context.blocker, LaneBlocker::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_completion_when_error_present() {
|
||||
let mut output = test_output();
|
||||
output.error = Some("Build failed".to_string());
|
||||
|
||||
let result = detect_lane_completion(&output, true, true);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_completion_when_not_finished() {
|
||||
let mut output = test_output();
|
||||
output.status = "Running".to_string();
|
||||
|
||||
let result = detect_lane_completion(&output, true, true);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_completion_when_tests_not_green() {
|
||||
let output = test_output();
|
||||
|
||||
let result = detect_lane_completion(&output, false, true);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_completion_when_not_pushed() {
|
||||
let output = test_output();
|
||||
|
||||
let result = detect_lane_completion(&output, true, false);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_triggers_closeout_for_completed_lane() {
|
||||
let context = LaneContext {
|
||||
lane_id: "completed-lane".to_string(),
|
||||
green_level: 3,
|
||||
branch_freshness: std::time::Duration::from_secs(0),
|
||||
blocker: LaneBlocker::None,
|
||||
review_status: ReviewStatus::Approved,
|
||||
diff_scope: DiffScope::Scoped,
|
||||
completed: true,
|
||||
reconciled: false,
|
||||
};
|
||||
|
||||
let actions = evaluate_completed_lane(&context);
|
||||
|
||||
assert!(actions.contains(&PolicyAction::CloseoutLane));
|
||||
assert!(actions.contains(&PolicyAction::CleanupSession));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
548
crates/tools/src/pdf_extract.rs
Normal file
548
crates/tools/src/pdf_extract.rs
Normal file
@ -0,0 +1,548 @@
|
||||
//! Minimal PDF text extraction.
|
||||
//!
|
||||
//! Reads a PDF file, locates `/Contents` stream objects, decompresses with
|
||||
//! flate2 when the stream uses `/FlateDecode`, and extracts text operators
|
||||
//! found between `BT` / `ET` markers.
|
||||
|
||||
use std::io::Read as _;
|
||||
use std::path::Path;
|
||||
|
||||
/// Extract all readable text from a PDF file.
|
||||
///
|
||||
/// Returns the concatenated text found inside BT/ET operators across all
|
||||
/// content streams. Non-text pages or encrypted PDFs yield an empty string
|
||||
/// rather than an error.
|
||||
pub fn extract_text(path: &Path) -> Result<String, String> {
|
||||
let data = std::fs::read(path).map_err(|e| format!("failed to read PDF: {e}"))?;
|
||||
Ok(extract_text_from_bytes(&data))
|
||||
}
|
||||
|
||||
/// Core extraction from raw PDF bytes — useful for testing without touching the
|
||||
/// filesystem.
|
||||
pub(crate) fn extract_text_from_bytes(data: &[u8]) -> String {
|
||||
let mut all_text = String::new();
|
||||
let mut offset = 0;
|
||||
|
||||
while offset < data.len() {
|
||||
let Some(stream_start) = find_subsequence(&data[offset..], b"stream") else {
|
||||
break;
|
||||
};
|
||||
let abs_start = offset + stream_start;
|
||||
|
||||
// Determine the byte offset right after "stream\r\n" or "stream\n".
|
||||
let content_start = skip_stream_eol(data, abs_start + b"stream".len());
|
||||
|
||||
let Some(end_rel) = find_subsequence(&data[content_start..], b"endstream") else {
|
||||
break;
|
||||
};
|
||||
let content_end = content_start + end_rel;
|
||||
|
||||
// Look backwards from "stream" for a FlateDecode hint in the object
|
||||
// dictionary. We scan at most 512 bytes before the stream keyword.
|
||||
let dict_window_start = abs_start.saturating_sub(512);
|
||||
let dict_window = &data[dict_window_start..abs_start];
|
||||
let is_flate = find_subsequence(dict_window, b"FlateDecode").is_some();
|
||||
|
||||
// Only process streams whose parent dictionary references /Contents or
|
||||
// looks like a page content stream (contains /Length). We intentionally
|
||||
// keep this loose to cover both inline and referenced content streams.
|
||||
let raw = &data[content_start..content_end];
|
||||
let decompressed;
|
||||
let stream_bytes: &[u8] = if is_flate {
|
||||
if let Ok(buf) = inflate(raw) {
|
||||
decompressed = buf;
|
||||
&decompressed
|
||||
} else {
|
||||
offset = content_end;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
raw
|
||||
};
|
||||
|
||||
let text = extract_bt_et_text(stream_bytes);
|
||||
if !text.is_empty() {
|
||||
if !all_text.is_empty() {
|
||||
all_text.push('\n');
|
||||
}
|
||||
all_text.push_str(&text);
|
||||
}
|
||||
|
||||
offset = content_end;
|
||||
}
|
||||
|
||||
all_text
|
||||
}
|
||||
|
||||
/// Inflate (zlib / deflate) compressed data via `flate2`.
|
||||
fn inflate(data: &[u8]) -> Result<Vec<u8>, String> {
|
||||
let mut decoder = flate2::read::ZlibDecoder::new(data);
|
||||
let mut buf = Vec::new();
|
||||
decoder
|
||||
.read_to_end(&mut buf)
|
||||
.map_err(|e| format!("flate2 inflate error: {e}"))?;
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Extract text from PDF content-stream operators between BT and ET markers.
|
||||
///
|
||||
/// Handles the common text-showing operators:
|
||||
/// - `Tj` — show a string
|
||||
/// - `TJ` — show an array of strings/numbers
|
||||
/// - `'` — move to next line and show string
|
||||
/// - `"` — set spacing, move to next line and show string
|
||||
fn extract_bt_et_text(stream: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stream);
|
||||
let mut result = String::new();
|
||||
let mut in_bt = false;
|
||||
|
||||
for line in text.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed == "BT" {
|
||||
in_bt = true;
|
||||
continue;
|
||||
}
|
||||
if trimmed == "ET" {
|
||||
in_bt = false;
|
||||
continue;
|
||||
}
|
||||
if !in_bt {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Tj operator: (text) Tj
|
||||
if trimmed.ends_with("Tj") {
|
||||
if let Some(s) = extract_parenthesized_string(trimmed) {
|
||||
if !result.is_empty() && !result.ends_with('\n') {
|
||||
result.push(' ');
|
||||
}
|
||||
result.push_str(&s);
|
||||
}
|
||||
}
|
||||
// TJ operator: [ (text) 123 (text) ] TJ
|
||||
else if trimmed.ends_with("TJ") {
|
||||
let extracted = extract_tj_array(trimmed);
|
||||
if !extracted.is_empty() {
|
||||
if !result.is_empty() && !result.ends_with('\n') {
|
||||
result.push(' ');
|
||||
}
|
||||
result.push_str(&extracted);
|
||||
}
|
||||
}
|
||||
// ' operator: (text) ' and " operator: aw ac (text) "
|
||||
else if is_newline_show_operator(trimmed) {
|
||||
if let Some(s) = extract_parenthesized_string(trimmed) {
|
||||
if !result.is_empty() {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str(&s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns `true` when `trimmed` looks like a `'` or `"` text-show operator.
|
||||
fn is_newline_show_operator(trimmed: &str) -> bool {
|
||||
(trimmed.ends_with('\'') && trimmed.len() > 1)
|
||||
|| (trimmed.ends_with('"') && trimmed.contains('('))
|
||||
}
|
||||
|
||||
/// Pull the text from the first `(…)` group, handling escaped parens and
|
||||
/// common PDF escape sequences.
|
||||
fn extract_parenthesized_string(input: &str) -> Option<String> {
|
||||
let open = input.find('(')?;
|
||||
let bytes = input.as_bytes();
|
||||
let mut depth = 0;
|
||||
let mut result = String::new();
|
||||
let mut i = open;
|
||||
|
||||
while i < bytes.len() {
|
||||
match bytes[i] {
|
||||
b'(' => {
|
||||
if depth > 0 {
|
||||
result.push('(');
|
||||
}
|
||||
depth += 1;
|
||||
}
|
||||
b')' => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(result);
|
||||
}
|
||||
result.push(')');
|
||||
}
|
||||
b'\\' if i + 1 < bytes.len() => {
|
||||
i += 1;
|
||||
match bytes[i] {
|
||||
b'n' => result.push('\n'),
|
||||
b'r' => result.push('\r'),
|
||||
b't' => result.push('\t'),
|
||||
b'\\' => result.push('\\'),
|
||||
b'(' => result.push('('),
|
||||
b')' => result.push(')'),
|
||||
// Octal sequences — up to 3 digits.
|
||||
d @ b'0'..=b'7' => {
|
||||
let mut octal = u32::from(d - b'0');
|
||||
for _ in 0..2 {
|
||||
if i + 1 < bytes.len()
|
||||
&& bytes[i + 1].is_ascii_digit()
|
||||
&& bytes[i + 1] <= b'7'
|
||||
{
|
||||
i += 1;
|
||||
octal = octal * 8 + u32::from(bytes[i] - b'0');
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Some(ch) = char::from_u32(octal) {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
other => result.push(char::from(other)),
|
||||
}
|
||||
}
|
||||
ch => result.push(char::from(ch)),
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
None // unbalanced
|
||||
}
|
||||
|
||||
/// Extract concatenated strings from a TJ array like `[ (Hello) -120 (World) ] TJ`.
|
||||
fn extract_tj_array(input: &str) -> String {
|
||||
let mut result = String::new();
|
||||
let Some(bracket_start) = input.find('[') else {
|
||||
return result;
|
||||
};
|
||||
let Some(bracket_end) = input.rfind(']') else {
|
||||
return result;
|
||||
};
|
||||
let inner = &input[bracket_start + 1..bracket_end];
|
||||
|
||||
let mut i = 0;
|
||||
let bytes = inner.as_bytes();
|
||||
while i < bytes.len() {
|
||||
if bytes[i] == b'(' {
|
||||
// Reconstruct the parenthesized string and extract it.
|
||||
if let Some(s) = extract_parenthesized_string(&inner[i..]) {
|
||||
result.push_str(&s);
|
||||
// Skip past the closing paren.
|
||||
let mut depth = 0u32;
|
||||
for &b in &bytes[i..] {
|
||||
i += 1;
|
||||
if b == b'(' {
|
||||
depth += 1;
|
||||
} else if b == b')' {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Skip past the end-of-line marker that immediately follows the `stream`
|
||||
/// keyword. Per the PDF spec this is either `\r\n` or `\n`.
|
||||
fn skip_stream_eol(data: &[u8], pos: usize) -> usize {
|
||||
if pos < data.len() && data[pos] == b'\r' {
|
||||
if pos + 1 < data.len() && data[pos + 1] == b'\n' {
|
||||
return pos + 2;
|
||||
}
|
||||
return pos + 1;
|
||||
}
|
||||
if pos < data.len() && data[pos] == b'\n' {
|
||||
return pos + 1;
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
/// Simple byte-subsequence search.
|
||||
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
haystack
|
||||
.windows(needle.len())
|
||||
.position(|window| window == needle)
|
||||
}
|
||||
|
||||
/// Check if a user-supplied path looks like a PDF file reference.
|
||||
#[must_use]
|
||||
pub fn looks_like_pdf_path(text: &str) -> Option<&str> {
|
||||
for token in text.split_whitespace() {
|
||||
let cleaned = token.trim_matches(|c: char| c == '\'' || c == '"' || c == '`');
|
||||
if let Some(dot_pos) = cleaned.rfind('.') {
|
||||
if cleaned[dot_pos + 1..].eq_ignore_ascii_case("pdf") && dot_pos > 0 {
|
||||
return Some(cleaned);
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Auto-extract text from a PDF path mentioned in a user prompt.
|
||||
///
|
||||
/// Returns `Some((path, extracted_text))` when a `.pdf` path is detected and
|
||||
/// the file exists, otherwise `None`.
|
||||
#[must_use]
|
||||
pub fn maybe_extract_pdf_from_prompt(prompt: &str) -> Option<(String, String)> {
|
||||
let pdf_path = looks_like_pdf_path(prompt)?;
|
||||
let path = Path::new(pdf_path);
|
||||
if !path.exists() {
|
||||
return None;
|
||||
}
|
||||
let text = extract_text(path).ok()?;
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((pdf_path.to_string(), text))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Build a minimal valid PDF with a single page containing uncompressed
|
||||
/// text. This is the smallest PDF structure that exercises the BT/ET
|
||||
/// extraction path.
|
||||
fn build_simple_pdf(text: &str) -> Vec<u8> {
|
||||
let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET");
|
||||
let stream_bytes = content_stream.as_bytes();
|
||||
let mut pdf = Vec::new();
|
||||
|
||||
// Header
|
||||
pdf.extend_from_slice(b"%PDF-1.4\n");
|
||||
|
||||
// Object 1 — Catalog
|
||||
let obj1_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n");
|
||||
|
||||
// Object 2 — Pages
|
||||
let obj2_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n");
|
||||
|
||||
// Object 3 — Page
|
||||
let obj3_offset = pdf.len();
|
||||
pdf.extend_from_slice(
|
||||
b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n",
|
||||
);
|
||||
|
||||
// Object 4 — Content stream (uncompressed)
|
||||
let obj4_offset = pdf.len();
|
||||
let length = stream_bytes.len();
|
||||
let header = format!("4 0 obj\n<< /Length {length} >>\nstream\n");
|
||||
pdf.extend_from_slice(header.as_bytes());
|
||||
pdf.extend_from_slice(stream_bytes);
|
||||
pdf.extend_from_slice(b"\nendstream\nendobj\n");
|
||||
|
||||
// Cross-reference table
|
||||
let xref_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"xref\n0 5\n");
|
||||
pdf.extend_from_slice(b"0000000000 65535 f \n");
|
||||
pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes());
|
||||
|
||||
// Trailer
|
||||
pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n");
|
||||
pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes());
|
||||
|
||||
pdf
|
||||
}
|
||||
|
||||
/// Build a minimal PDF with flate-compressed content stream.
|
||||
fn build_flate_pdf(text: &str) -> Vec<u8> {
|
||||
use flate2::write::ZlibEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write as _;
|
||||
|
||||
let content_stream = format!("BT\n/F1 12 Tf\n({text}) Tj\nET");
|
||||
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
|
||||
encoder
|
||||
.write_all(content_stream.as_bytes())
|
||||
.expect("compress");
|
||||
let compressed = encoder.finish().expect("finish");
|
||||
|
||||
let mut pdf = Vec::new();
|
||||
pdf.extend_from_slice(b"%PDF-1.4\n");
|
||||
|
||||
let obj1_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"1 0 obj\n<< /Type /Catalog /Pages 2 0 R >>\nendobj\n");
|
||||
|
||||
let obj2_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"2 0 obj\n<< /Type /Pages /Kids [3 0 R] /Count 1 >>\nendobj\n");
|
||||
|
||||
let obj3_offset = pdf.len();
|
||||
pdf.extend_from_slice(
|
||||
b"3 0 obj\n<< /Type /Page /Parent 2 0 R /Contents 4 0 R >>\nendobj\n",
|
||||
);
|
||||
|
||||
let obj4_offset = pdf.len();
|
||||
let length = compressed.len();
|
||||
let header = format!("4 0 obj\n<< /Length {length} /Filter /FlateDecode >>\nstream\n");
|
||||
pdf.extend_from_slice(header.as_bytes());
|
||||
pdf.extend_from_slice(&compressed);
|
||||
pdf.extend_from_slice(b"\nendstream\nendobj\n");
|
||||
|
||||
let xref_offset = pdf.len();
|
||||
pdf.extend_from_slice(b"xref\n0 5\n");
|
||||
pdf.extend_from_slice(b"0000000000 65535 f \n");
|
||||
pdf.extend_from_slice(format!("{obj1_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj2_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj3_offset:010} 00000 n \n").as_bytes());
|
||||
pdf.extend_from_slice(format!("{obj4_offset:010} 00000 n \n").as_bytes());
|
||||
|
||||
pdf.extend_from_slice(b"trailer\n<< /Size 5 /Root 1 0 R >>\n");
|
||||
pdf.extend_from_slice(format!("startxref\n{xref_offset}\n%%EOF\n").as_bytes());
|
||||
|
||||
pdf
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_uncompressed_text_from_minimal_pdf() {
|
||||
// given
|
||||
let pdf_bytes = build_simple_pdf("Hello World");
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_text_from_flate_compressed_stream() {
|
||||
// given
|
||||
let pdf_bytes = build_flate_pdf("Compressed PDF Text");
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Compressed PDF Text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_tj_array_operator() {
|
||||
// given
|
||||
let stream = b"BT\n/F1 12 Tf\n[ (Hello) -120 ( World) ] TJ\nET";
|
||||
// Build a raw PDF with TJ array operator instead of simple Tj.
|
||||
let content_stream = std::str::from_utf8(stream).unwrap();
|
||||
let raw = format!(
|
||||
"%PDF-1.4\n1 0 obj\n<< /Type /Catalog >>\nendobj\n\
|
||||
2 0 obj\n<< /Length {} >>\nstream\n{}\nendstream\nendobj\n%%EOF\n",
|
||||
content_stream.len(),
|
||||
content_stream
|
||||
);
|
||||
let pdf_bytes = raw.into_bytes();
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Hello World");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_escaped_parentheses() {
|
||||
// given
|
||||
let content = b"BT\n(Hello \\(World\\)) Tj\nET";
|
||||
let raw = format!(
|
||||
"%PDF-1.4\n1 0 obj\n<< /Length {} >>\nstream\n",
|
||||
content.len()
|
||||
);
|
||||
let mut pdf_bytes = raw.into_bytes();
|
||||
pdf_bytes.extend_from_slice(content);
|
||||
pdf_bytes.extend_from_slice(b"\nendstream\nendobj\n%%EOF\n");
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(&pdf_bytes);
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Hello (World)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_empty_for_non_pdf_data() {
|
||||
// given
|
||||
let data = b"This is not a PDF file at all";
|
||||
|
||||
// when
|
||||
let text = extract_text_from_bytes(data);
|
||||
|
||||
// then
|
||||
assert!(text.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_text_from_file_on_disk() {
|
||||
// given
|
||||
let pdf_bytes = build_simple_pdf("Disk Test");
|
||||
let dir = std::env::temp_dir().join("clawd-pdf-extract-test");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let pdf_path = dir.join("test.pdf");
|
||||
std::fs::write(&pdf_path, &pdf_bytes).unwrap();
|
||||
|
||||
// when
|
||||
let text = extract_text(&pdf_path).unwrap();
|
||||
|
||||
// then
|
||||
assert_eq!(text, "Disk Test");
|
||||
|
||||
// cleanup
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn looks_like_pdf_path_detects_pdf_references() {
|
||||
// given / when / then
|
||||
assert_eq!(
|
||||
looks_like_pdf_path("Please read /tmp/report.pdf"),
|
||||
Some("/tmp/report.pdf")
|
||||
);
|
||||
assert_eq!(looks_like_pdf_path("Check file.PDF now"), Some("file.PDF"));
|
||||
assert_eq!(looks_like_pdf_path("no pdf here"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_extract_pdf_from_prompt_returns_none_for_missing_file() {
|
||||
// given
|
||||
let prompt = "Read /tmp/nonexistent-abc123.pdf please";
|
||||
|
||||
// when
|
||||
let result = maybe_extract_pdf_from_prompt(prompt);
|
||||
|
||||
// then
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn maybe_extract_pdf_from_prompt_extracts_existing_file() {
|
||||
// given
|
||||
let pdf_bytes = build_simple_pdf("Auto Extracted");
|
||||
let dir = std::env::temp_dir().join("clawd-pdf-auto-extract-test");
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
let pdf_path = dir.join("auto.pdf");
|
||||
std::fs::write(&pdf_path, &pdf_bytes).unwrap();
|
||||
let prompt = format!("Summarize {}", pdf_path.display());
|
||||
|
||||
// when
|
||||
let result = maybe_extract_pdf_from_prompt(&prompt);
|
||||
|
||||
// then
|
||||
let (path, text) = result.expect("should extract");
|
||||
assert_eq!(path, pdf_path.display().to_string());
|
||||
assert_eq!(text, "Auto Extracted");
|
||||
|
||||
// cleanup
|
||||
let _ = std::fs::remove_dir_all(&dir);
|
||||
}
|
||||
}
|
||||
@ -5,9 +5,9 @@ import { theme, Skeleton, Spin, Popover } from 'antd';
|
||||
import { XMarkdown } from '@ant-design/x-markdown';
|
||||
|
||||
// 助手气泡 body 撑满可用宽度,避免 Mermaid 等内容宽度受文本行长度影响
|
||||
const bubbleStyle = document.createElement('style');
|
||||
bubbleStyle.textContent = '.ant-bubble-start > .ant-bubble-body { width: 80%; }';
|
||||
document.head.appendChild(bubbleStyle);
|
||||
// const bubbleStyle = document.createElement('style');
|
||||
// bubbleStyle.textContent = '.ant-bubble-start > .ant-bubble-body { width: 100%; }';
|
||||
// document.head.appendChild(bubbleStyle);
|
||||
import type { ComponentProps, Token } from '@ant-design/x-markdown';
|
||||
import Latex from '@ant-design/x-markdown/plugins/latex';
|
||||
import '@ant-design/x-markdown/themes/light.css';
|
||||
@ -18,8 +18,6 @@ import WelcomeScreen from './WelcomeScreen';
|
||||
|
||||
// ── XMarkdown 插件配置 ────────────────────────────────────────────────
|
||||
|
||||
// LaTeX 数学公式插件:解析 $...$ / $$...$$ / \(...\) / \[...\]
|
||||
// 自定义脚注插件:解析 [^1] 语法 → <footnote> 标签
|
||||
const footnoteExtension = {
|
||||
name: 'footnote',
|
||||
level: 'inline' as const,
|
||||
@ -281,14 +279,12 @@ const ChatView: React.FC<ChatViewProps> = ({
|
||||
variant: 'filled' as const,
|
||||
shape: 'round' as const,
|
||||
avatar: <UserOutlined />,
|
||||
// styles: { content: { width: '80%' } },
|
||||
},
|
||||
assistant: {
|
||||
placement: 'start' as const,
|
||||
variant: 'borderless' as const,
|
||||
avatar: <RobotOutlined />,
|
||||
streaming: true,
|
||||
styles: { content: { width: '80%' } },
|
||||
header: (_content: unknown, { status }: { status?: string }) => {
|
||||
if (status === 'loading' || status === 'updating') {
|
||||
return (
|
||||
|
||||
Loading…
Reference in New Issue
Block a user